Theory Berlekamp_Zassenhaus.Karatsuba_Multiplication
subsection ‹Karatsuba's Multiplication Algorithm for Polynomials›
theory Karatsuba_Multiplication
imports
Polynomial_Interpolation.Missing_Polynomial
begin
lemma karatsuba_main_step: fixes f :: "'a :: comm_ring_1 poly"
assumes f: "f = monom_mult n f1 + f0" and g: "g = monom_mult n g1 + g0"
shows
"monom_mult (n + n) (f1 * g1) + (monom_mult n (f1 * g1 - (f1 - f0) * (g1 - g0) + f0 * g0) + f0 * g0) = f * g"
unfolding assms
by (auto simp: field_simps mult_monom monom_mult_def)
lemma karatsuba_single_sided: fixes f :: "'a :: comm_ring_1 poly"
assumes "f = monom_mult n f1 + f0"
shows "monom_mult n (f1 * g) + f0 * g = f * g"
unfolding assms by (auto simp: field_simps mult_monom monom_mult_def)
definition split_at :: "nat ⇒ 'a list ⇒ 'a list × 'a list" where
[code del]: "split_at n xs = (take n xs, drop n xs)"
lemma split_at_code[code]:
"split_at n [] = ([],[])"
"split_at n (x # xs) = (if n = 0 then ([], x # xs) else case split_at (n-1) xs of (bef,aft)
⇒ (x # bef, aft))"
unfolding split_at_def by (force, cases n, auto)
fun coeffs_minus :: "'a :: ab_group_add list ⇒ 'a list ⇒ 'a list" where
"coeffs_minus (x # xs) (y # ys) = ((x - y) # coeffs_minus xs ys)"
| "coeffs_minus xs [] = xs"
| "coeffs_minus [] ys = map uminus ys"
text ‹The following constant determines at which size we will switch to the standard
multiplication algorithm.›
definition karatsuba_lower_bound where [termination_simp]: "karatsuba_lower_bound = (7 :: nat)"
fun karatsuba_main :: "'a :: comm_ring_1 list ⇒ nat ⇒ 'a list ⇒ nat ⇒ 'a poly" where
"karatsuba_main f n g m = (if n ≤ karatsuba_lower_bound ∨ m ≤ karatsuba_lower_bound then
let ff = poly_of_list f in foldr (λa p. smult a ff + pCons 0 p) g 0
else let n2 = n div 2 in
if m > n2 then (case split_at n2 f of
(f0,f1) ⇒ case split_at n2 g of
(g0,g1) ⇒ let
p1 = karatsuba_main f1 (n - n2) g1 (m - n2);
p2 = karatsuba_main (coeffs_minus f1 f0) n2 (coeffs_minus g1 g0) n2;
p3 = karatsuba_main f0 n2 g0 n2
in monom_mult (n2 + n2) p1 + (monom_mult n2 (p1 - p2 + p3) + p3))
else case split_at n2 f of
(f0,f1) ⇒ let
p1 = karatsuba_main f1 (n - n2) g m;
p2 = karatsuba_main f0 n2 g m
in monom_mult n2 p1 + p2)"
declare karatsuba_main.simps[simp del]
lemma poly_of_list_split_at: assumes "split_at n f = (f0,f1)"
shows "poly_of_list f = monom_mult n (poly_of_list f1) + poly_of_list f0"
proof -
from assms have id: "f1 = drop n f" "f0 = take n f" unfolding split_at_def by auto
show ?thesis unfolding id
proof (rule poly_eqI)
fix i
show "coeff (poly_of_list f) i =
coeff (monom_mult n (poly_of_list (drop n f)) + poly_of_list (take n f)) i"
unfolding monom_mult_def coeff_monom_mult coeff_add poly_of_list_def coeff_Poly
by (cases "n ≤ i"; cases "i ≥ length f", auto simp: nth_default_nth nth_default_beyond)
qed
qed
lemma coeffs_minus: "poly_of_list (coeffs_minus f1 f0) = poly_of_list f1 - poly_of_list f0"
proof (rule poly_eqI, unfold poly_of_list_def coeff_diff coeff_Poly)
fix i
show "nth_default 0 (coeffs_minus f1 f0) i = nth_default 0 f1 i - nth_default 0 f0 i"
proof (induct f1 f0 arbitrary: i rule: coeffs_minus.induct)
case (1 x xs y ys)
thus ?case by (cases i, auto)
next
case (3 x xs)
thus ?case unfolding coeffs_minus.simps
by (subst nth_default_map_eq[of uminus 0 0], auto)
qed auto
qed
lemma karatsuba_main: "karatsuba_main f n g m = poly_of_list f * poly_of_list g"
proof (induct n arbitrary: f g m rule: less_induct)
case (less n f g m)
note simp[simp] = karatsuba_main.simps[of f n g m]
show ?case (is "?lhs = ?rhs")
proof (cases "(n ≤ karatsuba_lower_bound ∨ m ≤ karatsuba_lower_bound) = False")
case False
hence lhs: "?lhs = foldr (λa p. smult a (poly_of_list f) + pCons 0 p) g 0" by simp
have rhs: "?rhs = poly_of_list g * poly_of_list f" by simp
also have "… = foldr (λa p. smult a (poly_of_list f) + pCons 0 p) (strip_while ((=) 0) g) 0"
unfolding times_poly_def fold_coeffs_def poly_of_list_impl ..
also have "… = ?lhs" unfolding lhs
proof (induct g)
case (Cons x xs)
have "∀x∈set xs. x = 0 ⟹ foldr (λa p. smult a (Poly f) + pCons 0 p) xs 0 = 0"
by (induct xs, auto)
thus ?case using Cons by (auto simp: cCons_def Cons)
qed auto
finally show ?thesis by simp
next
case True
let ?n2 = "n div 2"
have "?n2 < n" "n - ?n2 < n" using True unfolding karatsuba_lower_bound_def by auto
note IH = less[OF this(1)] less[OF this(2)]
obtain f1 f0 where f: "split_at ?n2 f = (f0,f1)" by force
obtain g1 g0 where g: "split_at ?n2 g = (g0,g1)" by force
note fsplit = poly_of_list_split_at[OF f]
note gsplit = poly_of_list_split_at[OF g]
show "?lhs = ?rhs" unfolding simp Let_def f g split IH True if_False coeffs_minus
karatsuba_single_sided[OF fsplit] karatsuba_main_step[OF fsplit gsplit] by auto
qed
qed
definition karatsuba_mult_poly :: "'a :: comm_ring_1 poly ⇒ 'a poly ⇒ 'a poly" where
"karatsuba_mult_poly f g = (let ff = coeffs f; gg = coeffs g; n = length ff; m = length gg
in (if n ≤ karatsuba_lower_bound ∨ m ≤ karatsuba_lower_bound then if n ≤ m
then foldr (λa p. smult a g + pCons 0 p) ff 0
else foldr (λa p. smult a f + pCons 0 p) gg 0
else if n ≤ m
then karatsuba_main gg m ff n
else karatsuba_main ff n gg m))"
lemma karatsuba_mult_poly: "karatsuba_mult_poly f g = f * g"
proof -
note d = karatsuba_mult_poly_def Let_def
let ?len = "length (coeffs f) ≤ length (coeffs g)"
show ?thesis (is "?lhs = ?rhs")
proof (cases "length (coeffs f) ≤ karatsuba_lower_bound ∨ length (coeffs g) ≤ karatsuba_lower_bound")
case True note outer = this
show ?thesis
proof (cases ?len)
case True
with outer have "?lhs = foldr (λa p. smult a g + pCons 0 p) (coeffs f) 0" unfolding d by auto
also have "… = ?rhs" unfolding times_poly_def fold_coeffs_def by auto
finally show ?thesis .
next
case False
with outer have "?lhs = foldr (λa p. smult a f + pCons 0 p) (coeffs g) 0" unfolding d by auto
also have "… = g * f" unfolding times_poly_def fold_coeffs_def by auto
also have "… = ?rhs" by simp
finally show ?thesis .
qed
next
case False note outer = this
show ?thesis
proof (cases ?len)
case True
with outer have "?lhs = karatsuba_main (coeffs g) (length (coeffs g)) (coeffs f) (length (coeffs f))"
unfolding d by auto
also have "… = g * f" unfolding karatsuba_main by auto
also have "… = ?rhs" by auto
finally show ?thesis .
next
case False
with outer have "?lhs = karatsuba_main (coeffs f) (length (coeffs f)) (coeffs g) (length (coeffs g))"
unfolding d by auto
also have "… = ?rhs" unfolding karatsuba_main by auto
finally show ?thesis .
qed
qed
qed
lemma karatsuba_mult_poly_code_unfold[code_unfold]: "(*) = karatsuba_mult_poly"
by (intro ext, unfold karatsuba_mult_poly, auto)
text ‹The following declaration will resolve a race-conflict between @{thm karatsuba_mult_poly_code_unfold}
and @{thm monom_mult_unfold}.›
lemmas karatsuba_monom_mult_code_unfold[code_unfold] =
monom_mult_unfold[where f = "f :: 'a :: comm_ring_1 poly" for f, unfolded karatsuba_mult_poly_code_unfold]
end