Theory NTT_Rings
section "Number Theoretic Transforms in Rings"
theory NTT_Rings
imports
"Number_Theoretic_Transform.NTT"
Karatsuba.Monoid_Sums
Karatsuba.Karatsuba_Preliminaries
"../Preliminaries/Schoenhage_Strassen_Preliminaries"
"../Preliminaries/Schoenhage_Strassen_Ring_Lemmas"
begin
lemma max_dividing_power_factorization:
fixes a :: nat
assumes "a ≠ 0"
assumes "k = Max {s. p ^ s dvd a}"
assumes "r = a div (p ^ k)"
assumes "prime p"
shows "a = r * p ^ k" "coprime p r"
subgoal
proof -
have "p ^ 0 dvd a" by simp
then have "{s. p ^ s dvd a} ≠ {}" by blast
with assms have "p ^ k dvd a"
by (metis Max_in finite_divisor_powers mem_Collect_eq not_prime_unit)
with assms show ?thesis by simp
qed
subgoal
proof (rule ccontr)
assume "¬ coprime p r"
then have "p dvd r" using prime_imp_coprime_nat ‹prime p› by blast
then have "p ^ (k + 1) dvd a" using ‹a = r * p ^ k› by simp
then have "k ≥ k + 1"
using assms Max_ge[of "{s. p ^ s dvd a}" k] Max_in[of "{s. p ^ s dvd a}"]
by (metis Max.coboundedI finite_divisor_powers mem_Collect_eq not_prime_unit)
then show "False" by simp
qed
done
context cring
begin
interpretation units_group: group "units_of R"
by (rule units_group)
interpretation units_subgroup: multiplicative_subgroup R "Units R" "units_of R"
by (rule units_subgroup)
subsection "Roots of Unity"
definition root_of_unity :: "nat ⇒ 'a ⇒ bool" where
"root_of_unity n μ ≡ μ ∈ carrier R ∧ μ [^] n = 𝟭"
lemma root_of_unityI[intro]: "μ ∈ carrier R ⟹ μ [^] n = 𝟭 ⟹ root_of_unity n μ"
unfolding root_of_unity_def by simp
lemma root_of_unityD[simp]: "root_of_unity n μ ⟹ μ [^] n = 𝟭"
unfolding root_of_unity_def by simp
lemma root_of_unity_closed[simp]: "root_of_unity n μ ⟹ μ ∈ carrier R"
unfolding root_of_unity_def by simp
context
fixes n :: nat
assumes "n > 0"
begin
lemma roots_Units[simp]:
assumes "root_of_unity n μ"
shows "μ ∈ Units R"
proof -
from ‹n > 0› obtain n' where "n = Suc n'"
using gr0_implies_Suc by auto
then have "𝟭 = μ ⊗ (μ [^] n')"
using assms nat_pow_Suc2 unfolding root_of_unity_def by auto
then show "μ ∈ Units R" using assms m_comm[of μ "μ [^] n'"] nat_pow_closed[of μ n']
unfolding Units_def root_of_unity_def by auto
qed
definition roots_of_unity_group where
"roots_of_unity_group ≡ ⦇ carrier = {μ. root_of_unity n μ}, monoid.mult = (⊗), one = 𝟭 ⦈"
lemma roots_of_unity_group_is_group:
shows "group roots_of_unity_group"
apply (intro groupI)
unfolding roots_of_unity_group_def root_of_unity_def
apply (simp_all add: nat_pow_distrib m_assoc)
subgoal for x
using ‹n > 0›
by (metis Group.nat_pow_Suc Nat.lessE mult.commute nat_pow_closed nat_pow_one nat_pow_pow)
done
interpretation root_group : group "roots_of_unity_group"
by (rule roots_of_unity_group_is_group)
interpretation root_subgroup : multiplicative_subgroup R "{μ. root_of_unity n μ}" roots_of_unity_group
apply unfold_locales
subgoal using roots_Units ‹n > 0› by blast
subgoal unfolding roots_of_unity_group_def by simp
done
lemma root_of_unity_inv:
assumes "root_of_unity n μ"
shows "root_of_unity n (inv μ)"
using assms root_group.inv_closed[of μ] root_subgroup.carrier_M root_subgroup.inv_eq[of μ] by simp
lemma inv_root_of_unity:
assumes "root_of_unity n μ"
shows "inv μ = μ [^] (n - 1)"
proof -
have "μ ∈ Units R" using assms
using roots_Units by blast
then have "inv μ = μ [^] (-1 :: int)"
using units_group.int_pow_neg units_subgroup.inv_eq units_subgroup.int_pow_eq
using units_group.int_pow_1 by force
also have "... = 𝟭 ⊗ μ [^] (-1 :: int)"
apply (intro l_one[symmetric])
using ‹μ ∈ Units R› by (metis Units_inv_closed calculation)
also have "... = μ [^] n ⊗ μ [^] (-1 :: int)"
using assms by simp
also have "... = μ [^] (int n) ⊗ μ [^] (-1 :: int)"
using Units_closed[OF ‹μ ∈ Units R›]
by (simp add: int_pow_int)
also have "... = μ [^] (int n - 1)"
using units_group.int_pow_mult[of μ] ‹μ ∈ Units R› units_subgroup.int_pow_eq[of μ]
using units_of_mult units_subgroup.carrier_M
by (metis add.commute uminus_add_conv_diff)
also have "... = μ [^] (n - 1)"
using ‹n > 0› Units_closed[OF ‹μ ∈ Units R›]
by (metis Suc_diff_1 add_diff_cancel_left' int_pow_int mult_Suc_right nat_mult_1 of_nat_1 of_nat_add)
finally show ?thesis .
qed
lemma inv_pow_root_of_unity:
assumes "root_of_unity n μ"
assumes "i ∈ {1..<n}"
shows "(inv μ) [^] i = μ [^] (n - i)" "n - i ∈ {1..<n}"
proof -
have "(inv μ) [^] i = (μ [^] (n - (1::nat))) [^] i"
using assms inv_root_of_unity by algebra
also have "... = μ [^] ((n - 1) * i)"
apply (intro nat_pow_pow) using assms roots_Units Units_closed by blast
also have "... = μ [^] n ⊗ μ [^] ((n - 1) * i)"
using assms root_of_unity_def[of n μ] by fastforce
also have "... = μ [^] (n + (n - 1) * i)"
apply (intro nat_pow_mult) using assms roots_Units Units_closed by blast
also have "... = μ [^] (n * i + (n - i))"
proof (intro arg_cong[where f = "([^]) μ"])
have "int (n + (n - 1) * i) = int (n * i + (n - i))"
proof -
have "int (n + (n - 1) * i) = int n + int (n - 1) * int i"
by simp
also have "... = int n + (int n - int 1) * int i"
using ‹n > 0› by fastforce
also have "... = int n + int n * int i - int i"
by (simp add: left_diff_distrib')
also have "... = int n * int i + (int n - int i)"
by simp
also have "... = int (n * i) + int (n - i)"
using assms(2) by fastforce
finally show ?thesis by presburger
qed
then show "n + (n - 1) * i = n * i + (n - i)" by presburger
qed
also have "... = (μ [^] n) [^] i ⊗ μ [^] (n - i)"
using nat_pow_mult nat_pow_pow
using assms roots_Units Units_closed by algebra
also have "... = μ [^] (n - i)"
using assms unfolding root_of_unity_def by simp
finally show "(inv μ) [^] i = μ [^] (n - i)" by blast
show "n - i ∈ {1..<n}" using assms by auto
qed
lemma root_of_unity_nat_pow_closed:
assumes "root_of_unity n μ"
shows "root_of_unity n (μ [^] (m :: nat))"
using assms root_group.nat_pow_closed root_subgroup.nat_pow_eq by simp
lemma root_of_unity_powers:
assumes "root_of_unity n μ"
shows "μ [^] i = μ [^] (i mod n)"
proof -
have[simp]: "μ ∈ carrier R" using assms by simp
define s t where "s = i div n" "t = i mod n"
then have "i = s * n + t" "t < n" using ‹n > 0› by simp_all
then have "μ [^] i = μ [^] (s * n) ⊗ μ [^] t" by (simp add: nat_pow_mult)
also have "μ [^] (s * n) = (μ [^] n) [^] s" by (simp add: nat_pow_pow mult.commute)
also have "... = 𝟭" using assms by simp
finally show ?thesis using ‹t = i mod n› by simp
qed
lemma root_of_unity_powers_modint:
assumes "root_of_unity n μ"
shows "μ [^] (i :: int) = μ [^] (i mod int n)"
proof -
have "μ ∈ Units R" "μ [^] n = 𝟭" using assms by simp_all
define s t where "s = i div int n" "t = i mod int n"
then have "i = s * int n + t" "t ≥ 0" "t < int n" using ‹n > 0› by simp_all
then have "μ [^] i = μ [^] (s * int n) ⊗ μ [^] t"
using int_pow_mult[OF ‹μ ∈ Units R›] by simp
also have "... = (μ [^] int n) [^] s ⊗ μ [^] t"
by (intro_cong "[cong_tag_2 (⊗)]" more: refl) (simp add: int_pow_pow ‹μ ∈ Units R› mult.commute)
also have "... = (μ [^] n) [^] s ⊗ μ [^] t"
apply (intro_cong "[cong_tag_2 (⊗), cong_tag_1 (λi. i [^] s)]" more: refl)
using ‹n > 0› by (simp add: int_pow_int)
also have "... = μ [^] t"
using int_pow_closed[OF ‹μ ∈ Units R›] Units_closed l_one
by (simp add: ‹μ [^] n = 𝟭› int_pow_one int_pow_closed)
finally show ?thesis unfolding s_t_def .
qed
lemma root_of_unity_powers_nat:
assumes "root_of_unity n μ"
assumes "i mod n = j mod n"
shows "μ [^] i = μ [^] j"
using assms root_of_unity_powers by metis
lemma root_of_unity_powers_int:
assumes "root_of_unity n μ"
assumes "i mod int n = j mod int n"
shows "μ [^] i = μ [^] j"
using assms root_of_unity_powers_modint by metis
end
subsection "Primitive Roots"
definition primitive_root :: "nat ⇒ 'a ⇒ bool" where
"primitive_root n μ ≡ root_of_unity n μ ∧ (∀i ∈ {1..<n}. μ [^] i ≠ 𝟭)"
lemma primitive_rootI[intro]:
assumes "μ ∈ carrier R"
assumes "μ [^] n = 𝟭"
assumes "⋀i. i > 0 ⟹ i < n ⟹ μ [^] i ≠ 𝟭"
shows "primitive_root n μ"
unfolding primitive_root_def root_of_unity_def using assms by simp
lemma primitive_root_is_root_of_unity[simp]: "primitive_root n μ ⟹ root_of_unity n μ"
unfolding primitive_root_def by simp
lemma primitive_root_recursion:
assumes "even n"
assumes "primitive_root n μ"
shows "primitive_root (n div 2) (μ [^] (2 :: nat))"
unfolding primitive_root_def root_of_unity_def
apply (intro conjI)
subgoal
using assms(2) unfolding primitive_root_def root_of_unity_def by blast
subgoal
using nat_pow_pow[of μ "2::nat" "n div 2"] assms apply simp
unfolding primitive_root_def root_of_unity_def apply simp
done
subgoal
proof
fix i
assume "i ∈ {1..<n div 2}"
then have "2 * i ∈ {1..<n}" using ‹even n› by auto
have "(μ [^] (2::nat)) [^] i = μ [^] (2 * i)"
using assms unfolding primitive_root_def root_of_unity_def by (simp add: nat_pow_pow)
also have "... ≠ 𝟭"
using assms unfolding primitive_root_def using ‹2 * i ∈ {1..<n}› by blast
finally show "(μ [^] (2::nat)) [^] i ≠ 𝟭" .
qed
done
lemma primitive_root_inv:
assumes "n > 0"
assumes "primitive_root n μ"
shows "primitive_root n (inv μ)"
unfolding primitive_root_def
proof (intro conjI)
show "root_of_unity n (inv μ)" using assms unfolding primitive_root_def
by (simp add: root_of_unity_inv)
show "∀i∈{1..<n}. inv μ [^] i ≠ 𝟭" using assms unfolding primitive_root_def
by (metis Group.nat_pow_0 Units_inv_inv bot_nat_0.extremum_strict nat_neq_iff root_of_unity_def root_of_unity_inv roots_Units)
qed
subsection "Number Theoretic Transforms"
definition NTT :: "'a ⇒ 'a list ⇒ 'a list" where
"NTT μ a ≡ let n = length a in [⨁j ← [0..<n]. (a ! j) ⊗ (μ [^] i) [^] j. i ← [0..<n]]"
lemma NTT_length[simp]: "length (NTT μ a) = length a"
unfolding NTT_def by (metis length_map map_nth)
lemma NTT_nth:
assumes "length a = n"
assumes "i < n"
shows "NTT μ a ! i = (⨁j ← [0..<n]. (a ! j) ⊗ (μ [^] i) [^] j)"
unfolding NTT_def using assms by auto
lemma NTT_nth_2:
assumes "length a = n"
assumes "i < n"
assumes "μ ∈ carrier R"
shows "NTT μ a ! i = (⨁j ← [0..<n]. (a ! j) ⊗ (μ [^] (i * j)))"
unfolding NTT_nth[OF assms(1) assms(2)]
by (intro monoid_sum_list_cong arg_cong[where f = "(⊗) _"] nat_pow_pow assms(3))
lemma NTT_nth_closed:
assumes "set a ⊆ carrier R"
assumes "μ ∈ carrier R"
assumes "length a = n"
assumes "i < n"
shows "NTT μ a ! i ∈ carrier R"
proof -
have "NTT μ a ! i = (⨁j ← [0..<length a]. (a ! j) ⊗ (μ [^] i) [^] j)"
using NTT_nth assms by blast
also have "... ∈ carrier R"
by (intro monoid_sum_list_closed m_closed nat_pow_closed assms(2) set_subseteqD[OF assms(1)]) simp
finally show ?thesis .
qed
lemma NTT_closed:
assumes "set a ⊆ carrier R"
assumes "μ ∈ carrier R"
shows "set (NTT μ a) ⊆ carrier R"
using assms NTT_nth_closed[of a μ]
by (intro subsetI)(metis NTT_length in_set_conv_nth)
lemma "primitive_root 1 𝟭"
unfolding primitive_root_def root_of_unity_def
by simp
lemma "(⊖ 𝟭) [^] (2::nat) = 𝟭"
by (simp add: numeral_2_eq_2) algebra
lemma "𝟭 ⊕ 𝟭 ≠ 𝟬 ⟹ primitive_root 2 (⊖ 𝟭)"
unfolding primitive_root_def root_of_unity_def
apply (intro conjI)
subgoal by simp
subgoal by (simp add: numeral_2_eq_2, algebra)
subgoal
proof (standard, rule ccontr)
fix i
assume "𝟭 ⊕ 𝟭 ≠ 𝟬" "i ∈ {1::nat..<2}"
then have "i = 1" by simp
assume "¬ ⊖ 𝟭 [^] i ≠ 𝟭"
then have "⊖ 𝟭 = 𝟭" using ‹i = 1› by simp
then have "𝟭 ⊕ 𝟭 = 𝟬" using l_neg by fastforce
thus False using ‹𝟭 ⊕ 𝟭 ≠ 𝟬› by simp
qed
done
subsubsection "Inversion Rule"
theorem inversion_rule:
fixes μ :: 'a
fixes n :: nat
assumes "n > 0"
assumes "primitive_root n μ"
assumes good: "⋀i. i ∈ {1..<n} ⟹ (⨁j ← [0..<n]. (μ [^] i) [^] j) = 𝟬"
assumes[simp]: "length a = n"
assumes[simp]: "set a ⊆ carrier R"
shows "NTT (inv μ) (NTT μ a) = map (λx. nat_embedding n ⊗ x) a"
proof (intro nth_equalityI)
have "μ ∈ Units R" using assms unfolding primitive_root_def using roots_Units by blast
then have[simp]: "μ ∈ carrier R" by blast
show "length (NTT (inv μ) (NTT μ a)) = length (map ((⊗) (nat_embedding n)) a)" using NTT_length
by simp
fix i
assume "i < length (NTT (inv μ) (NTT μ a))"
then have "i < n" by simp
have[simp]: "inv μ ∈ carrier R" using assms roots_Units unfolding primitive_root_def by blast
then have[simp]: "⋀i :: nat. (inv μ) [^] i ∈ carrier R" by simp
have 0: "NTT (inv μ) (NTT μ a) ! i = (⨁j ← [0..<n]. (NTT μ a ! j) ⊗ ((inv μ) [^] i) [^] j)"
using NTT_nth
using assms NTT_length ‹i < n› by blast
also have "... = (⨁j ← [0..<n]. (⨁k ← [0..<n]. a ! k ⊗ μ [^] ((int k - int i) * int j)))"
proof (intro monoid_sum_list_cong)
fix j
assume "j ∈ set [0..<n]"
then have[simp]: "j < n" by simp
have nj: "(NTT μ a ! j) = (⨁k ← [0..<n]. a ! k ⊗ (μ [^] j) [^] k)"
using NTT_nth by simp
have "... ⊗ ((inv μ) [^] i) [^] j = (⨁k ← [0..<n]. a ! k ⊗ ((μ [^] j) [^] k) ⊗ ((inv μ) [^] i) [^] j)"
apply (intro monoid_sum_list_in_right[symmetric] nat_pow_closed m_closed)
using set_subseteqD[OF assms(5)] by simp_all
also have "... = (⨁k ← [0..<n]. a ! k ⊗ μ [^] ((int k - int i) * int j))"
proof (intro monoid_sum_list_cong)
fix k
assume "k ∈ set [0..<n]"
have "a ! k ⊗ (μ [^] j) [^] k ⊗ (inv μ [^] i) [^] j = a ! k ⊗ ((μ [^] j) [^] k ⊗ (inv μ [^] i) [^] j)"
apply (intro m_assoc nat_pow_closed)
using set_subseteqD[OF assms(5)] ‹k ∈ set [0..<n]› by simp_all
also have "inv μ [^] i = μ [^] (- int i)"
by (metis ‹μ ∈ Units R› cring.units_int_pow_neg int_pow_int is_cring)
also have "((μ [^] j) [^] k ⊗ (μ [^] (- int i)) [^] j) = μ [^] (int j * int k - int i * int j)"
using ‹μ ∈ Units R›
by (simp add: int_pow_int[symmetric] int_pow_pow int_pow_mult)
also have "... = μ [^] ((int k - int i) * int j)"
apply (intro arg_cong[where f = "([^]) _"])
by (simp add: mult.commute right_diff_distrib')
finally show "a ! k ⊗ (μ [^] j) [^] k ⊗ (inv μ [^] i) [^] j = a ! k ⊗ μ [^] ((int k - int i) * int j)"
using ‹inv μ [^] i = μ [^] (- int i)› by argo
qed
finally show "NTT μ a ! j ⊗ (inv μ [^] i) [^] j = monoid_sum_list (λk. a ! k ⊗ μ [^] ((int k - int i) * int j)) [0..<n]"
by (simp add: nj)
qed
also have "... = (⨁k ← [0..<n]. (⨁j ← [0..<n]. a ! k ⊗ μ [^] ((int k - int i) * int j)))"
apply (intro monoid_sum_list_swap m_closed)
subgoal for j k
using assms by (metis atLeastLessThan_iff atLeastLessThan_upt nth_mem subset_eq)
subgoal for j k
using ‹μ ∈ Units R›
using units_of_int_pow[OF ‹μ ∈ Units R›]
using group.int_pow_closed[OF units_group, of μ]
by (metis Units_closed units_of_carrier)
done
also have "... = (⨁k ← [0..<n]. a ! k ⊗ (⨁j ← [0..<n]. μ [^] ((int k - int i) * int j)))"
apply (intro monoid_sum_list_cong monoid_sum_list_in_left)
subgoal using set_subseteqD[OF assms(5)] by simp
subgoal for j
by (simp add: Units_closed int_pow_closed ‹μ ∈ Units R›)
done
also have "... = (⨁k ← [0..<n]. a ! k ⊗ (if i = k then nat_embedding n else 𝟬))"
proof (intro monoid_sum_list_cong arg_cong[where f = "(⊗) _"])
fix k
assume "k ∈ set [0..<n]"
then have[simp]: "k < n" by simp
consider "i < k" | "i = k" | "i > k" by fastforce
then show "(⨁j ← [0..<n]. μ [^] ((int k - int i) * int j)) = (if i = k then nat_embedding n else 𝟬)"
proof (cases)
case 1
then have "⋀j. j < n ⟹ μ [^] ((int k - int i) * int j) = (μ [^] (k - i)) [^] j"
proof -
fix j
assume "j < n"
have "(int k - int i) * int j = int ((k - i) * j)" using 1 by auto
then have "μ [^] ((int k - int i) * int j) = μ [^] int ((k - i) * j)"
by argo
also have "... = μ [^] ((k - i) * j)"
by (intro int_pow_int)
also have "... = (μ [^] (k - i)) [^] j"
by (intro nat_pow_pow[symmetric] ‹μ ∈ carrier R›)
finally show "μ [^] ((int k - int i) * int j) = (μ [^] (k - i)) [^] j" .
qed
then have "(⨁j ← [0..<n]. μ [^] ((int k - int i) * int j)) = (⨁j ← [0..<n]. (μ [^] (k - i)) [^] j)"
by (intro monoid_sum_list_cong, simp)
also have "... = 𝟬"
using good[of "k - i"]
proof
show "k - i ∈ {1..<n}" using 1 ‹k < n› by (simp add: less_imp_diff_less)
qed simp
finally show ?thesis using 1 by simp
next
case 2
then have "⋀j. j < n ⟹ μ [^] ((int k - int i) * int j) = 𝟭" by simp
then have "(⨁j ← [0..<n]. μ [^] ((int k - int i) * int j)) = nat_embedding n"
using monoid_sum_list_const[of 𝟭 "[0..<n]"]
using monoid_sum_list_cong[of "[0..<n]" "λj. μ [^] ((int k - int i) * int j)" "λj. 𝟭"]
by simp
then show ?thesis using 2 by simp
next
case 3
then have "⋀j. j < n ⟹ μ [^] ((int k - int i) * int j) = (μ [^] (n + k - i)) [^] j"
proof -
fix j
assume "j < n"
have "μ [^] ((int k - int i) * int j) = (μ [^] (int k - int i)) [^] j"
using int_pow_pow by (metis ‹μ ∈ Units R› int_pow_int)
also have "... = (μ [^] n ⊗ μ [^] (int k - int i)) [^] j"
proof -
have "μ [^] (int k - int i) ∈ carrier R"
using ‹μ ∈ Units R› int_pow_closed Units_closed by simp
then have "μ [^] (int k - int i) = μ [^] n ⊗ μ [^] (int k - int i)"
using l_one assms(2) unfolding primitive_root_def root_of_unity_def
by presburger
then show ?thesis by simp
qed
also have "... = (μ [^] (int n) ⊗ μ [^] (int k - int i)) [^] j"
by (simp add: int_pow_int)
also have "... = (μ [^] (int n + int k - int i)) [^] j"
using ‹μ ∈ Units R› by (simp add: int_pow_mult add_diff_eq)
finally show "μ [^] ((int k - int i) * int j) = (μ [^] (n + k - i)) [^] j" using 3
by (metis (no_types, opaque_lifting) ‹i < n› diff_cancel2 diff_diff_cancel diff_le_self int_plus int_pow_int less_or_eq_imp_le of_nat_diff)
qed
then have "(⨁j ← [0..<n]. μ [^] ((int k - int i) * int j)) = (⨁j ← [0..<n]. (μ [^] (n + k - i)) [^] j)"
by (intro monoid_sum_list_cong, simp)
also have "... = 𝟬"
using good[of "n + k - i"]
proof
show "n + k - i ∈ {1..<n}" using 3 ‹k < n› ‹i < n› by fastforce
qed simp
finally show ?thesis using 3 by simp
qed
qed
also have "... = (⨁k ← [0..<n]. a ! k ⊗ (nat_embedding n ⊗ delta k i))"
apply (intro monoid_sum_list_cong)
unfolding delta_def
by simp
also have "... = (⨁k ← [0..<n]. nat_embedding n ⊗ (delta k i ⊗ a ! k))"
apply (intro monoid_sum_list_cong)
using m_assoc m_comm delta_closed set_subseteqD[OF assms(5)] nat_embedding_closed by simp
also have "... = nat_embedding n ⊗ (⨁k ← [0..<n]. delta k i ⊗ a ! k)"
using set_subseteqD[OF assms(5)]
by (intro monoid_sum_list_in_left) auto
also have "... = nat_embedding n ⊗ a ! i"
using monoid_sum_list_delta[of n "λk. a ! k" i] ‹i < n› assms
by (metis (no_types, lifting) nth_mem subsetD)
finally show "NTT (inv μ) (NTT μ a) ! i = map ((⊗) (nat_embedding n)) a ! i"
using nth_map ‹i < n› ‹length a = n› NTT_length 0
by simp
qed
lemma inv_good:
assumes "n > 0"
assumes "primitive_root n μ"
assumes good: "⋀i. i ∈ {1..<n} ⟹ (⨁j ← [0..<n]. (μ [^] i) [^] j) = 𝟬"
shows "primitive_root n (inv μ)"
"⋀i. i ∈ {1..<n} ⟹ (⨁j ← [0..<n]. ((inv μ) [^] i) [^] j) = 𝟬"
subgoal using assms by (simp add: primitive_root_inv)
subgoal for i
proof -
assume "i ∈ {1..<n}"
then have "n - i ∈ {1..<n}" by auto
then have "(⨁j ← [0..<n]. (μ [^] (n - i)) [^] j) = 𝟬"
using assms by blast
moreover have "μ [^] (n - i) = inv μ [^] i"
using assms inv_pow_root_of_unity[of n μ i] ‹i ∈ {1..<n}›
by auto
ultimately show "(⨁j ← [0..<n]. ((inv μ) [^] i) [^] j) = 𝟬" by simp
qed
done
lemma inv_halfway_property:
assumes "μ ∈ Units R"
assumes "μ [^] (i::nat) = ⊖ 𝟭"
shows "(inv μ) [^] i = ⊖ 𝟭"
proof -
have "(inv μ) [^] i = (inv⇘units_of R⇙ μ) [^] i"
by (intro arg_cong[where f = "λj. j [^] i"] units_of_inv[symmetric] assms(1))
also have "... = (inv⇘units_of R⇙ μ) [^]⇘units_of R⇙ i"
apply (intro units_of_pow[symmetric])
using units_group.Units_inv_Units assms(1) by simp
also have "... = inv⇘units_of R⇙ (μ [^]⇘units_of R⇙ i)"
apply (intro units_group.nat_pow_inv)
using assms(1) by (simp add: units_of_def)
also have "... = inv (μ [^]⇘units_of R⇙ i)"
apply (intro units_of_inv)
using assms(1) units_group.nat_pow_closed by (simp add: units_of_def)
also have "... = inv (μ [^] i)"
using units_of_pow assms(1) by simp
finally have "(inv μ) [^] i = inv (μ [^] i)" .
also have "... = inv (⊖ 𝟭)" using assms(2) by simp
also have "... = ⊖ 𝟭" by simp
finally show ?thesis .
qed
lemma sufficiently_good_aux:
assumes "primitive_root m η"
assumes "m = 2 ^ j"
assumes "η [^] (m div 2) = ⊖ 𝟭"
assumes "odd r"
assumes "r * 2 ^ k < m"
shows "(⨁l ← [0..<m]. (η [^] (r * 2 ^ k)) [^] l) = 𝟬"
using assms
proof (induction k arbitrary: η m j)
case 0
then have "root_of_unity m η" by simp
then have "η ∈ carrier R" by simp
have "j > 0"
proof (rule ccontr)
assume "¬ j > 0"
then have "j = 0" by simp
then have "m = 1" using 0 by simp
then have "r * 2 ^ k = 0" using 0 by simp
then have "r = 0" by simp
then show "False" using ‹odd r› by simp
qed
then have "even m" using 0 by simp
then have "m = m div 2 + m div 2" by auto
then have "(⨁l ← [0..<m]. (η [^] (r * 2 ^ 0)) [^] l) = (⨁l ← [0..<m div 2 + m div 2]. (η [^] r) [^] l)"
by simp
also have "... = (⨁l ← [0..<m div 2]. (η [^] r) [^] l) ⊕ (⨁l ← [m div 2..<m div 2 + m div 2]. (η [^] r) [^] l)"
by (intro monoid_sum_list_split[symmetric] nat_pow_closed, rule ‹η ∈ carrier R›)
also have "... = (⨁l ← [0..<m div 2]. (η [^] r) [^] l) ⊕ (⨁l ← [0..<m div 2]. (η [^] r) [^] (m div 2 + l))"
by (intro arg_cong[where f = "(⊕) _"] monoid_sum_list_index_shift_0)
also have "... = (⨁l ← [0..<m div 2]. (η [^] r) [^] l ⊕ (η [^] r) [^] (m div 2 + l))"
by (intro monoid_sum_list_add_in nat_pow_closed; rule ‹η ∈ carrier R›)
also have "... = (⨁l ← [0..<m div 2]. (η [^] r) [^] l ⊖ (η [^] r) [^] l)"
proof (intro monoid_sum_list_cong)
fix l
have "(η [^] r) [^] (m div 2 + l) = (η [^] r) [^] (m div 2) ⊗ (η [^] r) [^] l"
by (intro nat_pow_mult[symmetric] nat_pow_closed, rule ‹η ∈ carrier R›)
also have "(η [^] r) [^] (m div 2) = (⊖ 𝟭) [^] r"
unfolding nat_pow_pow[OF ‹η ∈ carrier R›] mult.commute[of r _]
by (simp only: nat_pow_pow[symmetric] ‹η ∈ carrier R› ‹η [^] (m div 2) = ⊖ 𝟭›)
also have "... = ⊖ 𝟭" using ‹odd r›
by (simp add: powers_of_negative)
finally have "(η [^] r) [^] (m div 2 + l) = ⊖ ((η [^] r) [^] l)"
using ‹η ∈ carrier R› nat_pow_closed by algebra
then show "(η [^] r) [^] l ⊕ (η [^] r) [^] (m div 2 + l) = (η [^] r) [^] l ⊖ (η [^] r) [^] l"
unfolding minus_eq
by (intro arg_cong[where f = "(⊕) _"])
qed
also have "... = (⨁l ← [0..<m div 2]. 𝟬)"
by (intro monoid_sum_list_cong) (simp add: ‹η ∈ carrier R›)
also have "... = 𝟬" by simp
finally show ?case .
next
case (Suc k)
have "j > 0"
proof (rule ccontr)
assume "¬ j > 0"
then have "j = 0" by simp
then have "m = 1" using Suc by simp
then have "r * 2 ^ k = 0" using Suc by simp
then have "r = 0" by simp
then show "False" using ‹odd r› by simp
qed
then have "even m" using Suc by simp
then have "m = m div 2 + m div 2" by auto
have "root_of_unity m η" using ‹primitive_root m η› by simp
then have "η ∈ carrier R" by simp
from ‹j > 0› obtain j' where "j = Suc j'"
using gr0_implies_Suc by blast
then have "m div 2 = 2 ^ j'" using ‹m = 2 ^ j› by simp
have "j' > 0"
proof (rule ccontr)
assume "¬ j' > 0"
then have "j' = 0" by simp
then have "m = 2" using ‹m = 2 ^ j› ‹j = Suc j'› by simp
then have "r * 2 ^ Suc k < 2" using Suc by simp
then show "False" using ‹odd r› by simp
qed
then have "even (m div 2)" using ‹m div 2 = 2 ^ j'› by simp
have IH': "(⨁l ← [0..<m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l) = 𝟬"
apply (intro Suc.IH[of "m div 2" "η [^] (2::nat)" j'])
subgoal using primitive_root_recursion[OF ‹even m›, OF ‹primitive_root m η›] .
subgoal using ‹m = 2 ^ j› ‹j = Suc j'› by simp
subgoal
by (simp add: ‹η ∈ carrier R› nat_pow_pow ‹even (m div 2)› ‹η [^] (m div 2) = ⊖ 𝟭›)
subgoal using ‹odd r› .
subgoal using ‹r * 2 ^ (Suc k) < m› ‹even m› by auto
done
have "(⨁l ← [0..<m]. (η [^] (r * 2 ^ (Suc k))) [^] l) = (⨁l ← [0..<m]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l)"
unfolding nat_pow_pow[OF ‹η ∈ carrier R›]
apply (intro monoid_sum_list_cong arg_cong[where f = "λi. i [^] _"])
apply (intro arg_cong[where f = "([^]) _"])
by simp
also have "... = (⨁l ← [0..<m div 2 + m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l)"
using ‹m = m div 2 + m div 2› by argo
also have "... = (⨁l ← [0..<m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l) ⊕ (⨁l ← [m div 2..<m div 2 + m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l)"
by (intro monoid_sum_list_split[symmetric] nat_pow_closed, rule ‹η ∈ carrier R›)
also have "... = 𝟬 ⊕ (⨁l ← [m div 2..<m div 2 + m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l)"
using IH' by argo
also have "... = (⨁l ← [m div 2..<m div 2 + m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l)"
by (intro l_zero monoid_sum_list_closed nat_pow_closed, rule ‹η ∈ carrier R›)
also have "... = (⨁l ← [0..<m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] (m div 2 + l))"
by (intro monoid_sum_list_index_shift_0)
also have "... = (⨁l ← [0..<m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] (m div 2) ⊗ ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l)"
by (intro monoid_sum_list_cong nat_pow_mult[symmetric] nat_pow_closed, rule ‹η ∈ carrier R›)
also have "... = ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] (m div 2) ⊗ (⨁l ← [0..<m div 2]. ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] l)"
by (intro monoid_sum_list_in_left nat_pow_closed; rule ‹η ∈ carrier R›)
also have "... = ((η [^] (2::nat)) [^] (r * 2 ^ k)) [^] (m div 2) ⊗ 𝟬"
using IH' by argo
also have "... = 𝟬"
by (intro r_null nat_pow_closed, rule ‹η ∈ carrier R›)
finally show ?case .
qed
lemma sufficiently_good:
assumes "primitive_root n μ"
assumes "domain R ∨ (n = 2 ^ k ∧ μ [^] (n div 2) = ⊖ 𝟭)"
shows good: "⋀i. i ∈ {1..<n} ⟹ (⨁j ← [0..<n]. (μ [^] i) [^] j) = 𝟬"
proof (cases "domain R")
case True
fix i
assume "i ∈ {1..<n}"
have "root_of_unity n μ" using assms(1) by simp
then have "μ ∈ carrier R" "μ [^] n = 𝟭" by simp_all
have "μ [^] i ≠ 𝟭" using assms(1) ‹i ∈ {1..<n}› unfolding primitive_root_def
by simp
then have "𝟭 ⊖ μ [^] i ≠ 𝟬" using ‹μ ∈ carrier R› by simp
have "(μ [^] i) [^] n = 𝟭"
unfolding nat_pow_pow[OF ‹μ ∈ carrier R›]
using root_of_unity_powers[OF _ ‹root_of_unity n μ›, of "i * n"]
by (cases "n > 0"; simp)
then have "𝟬 = 𝟭 ⊖ (μ [^] i) [^] n" by algebra
also have "... = (𝟭 ⊖ μ [^] i) ⊗ (⨁j ← [0..<n]. (μ [^] i) [^] j)"
by (intro geo_monoid_list_sum[symmetric] nat_pow_closed ‹μ ∈ carrier R›)
finally show "(⨁j ← [0..<n]. (μ [^] i) [^] j) = 𝟬"
using ‹𝟭 ⊖ μ [^] i ≠ 𝟬› True ‹μ ∈ carrier R›
by (metis domain.integral minus_closed monoid_sum_list_closed nat_pow_closed one_closed)
next
case False
then have "n = 2 ^ k" "μ [^] (n div 2) = ⊖ 𝟭" using assms(2) by auto
have "root_of_unity n μ" using ‹primitive_root n μ› by simp
then have "μ ∈ carrier R" "μ [^] n = 𝟭" by simp_all
fix i
assume "i ∈ {1..<n}"
define l where "l = Max {s. 2 ^ s dvd i}"
define r where "r = i div 2 ^ l"
from ‹i ∈ {1..<n}› have "i ≠ 0" by simp
then have "i = r * 2 ^ l" "odd r" using max_dividing_power_factorization[of i l 2 r]
using l_def r_def coprime_left_2_iff_odd[of r] by simp_all
show "(⨁j ← [0..<n]. (μ [^] i) [^] j) = 𝟬"
apply (simp only: ‹i = r * 2 ^ l›)
apply (intro sufficiently_good_aux[of n μ k r l, OF ‹primitive_root n μ› ‹n = 2 ^ k› ‹μ [^] (n div 2) = ⊖ 𝟭› ‹odd r›])
using ‹i = r * 2 ^ l› ‹i ∈ {1..<n}› by simp
qed
corollary inversion_rule_inv:
fixes μ :: 'a
fixes n :: nat
assumes "n > 0"
assumes "primitive_root n μ"
assumes good: "⋀i. i ∈ {1..<n} ⟹ (⨁j ← [0..<n]. (μ [^] i) [^] j) = 𝟬"
assumes[simp]: "length a = n"
assumes[simp]: "set a ⊆ carrier R"
shows "NTT μ (NTT (inv μ) a) = map (λx. nat_embedding n ⊗ x) a"
using assms inv_good[of n μ] inversion_rule[of n "inv μ" a]
using Units_inv_inv[of μ]
using roots_Units[of n μ]
unfolding primitive_root_def
by algebra
subsubsection "Convolution Theorem"
lemma root_of_unity_power_sum_product:
assumes "root_of_unity n x"
assumes[simp]: "⋀i. i < n ⟹ f i ∈ carrier R"
assumes[simp]: "⋀i. i < n ⟹ g i ∈ carrier R"
shows "(⨁i ← [0..<n]. f i ⊗ x [^] i) ⊗ (⨁i ← [0..<n]. g i ⊗ x [^] i) =
(⨁k ← [0..<n]. (⨁i ← [0..<n]. f i ⊗ g ((n + k - i) mod n)) ⊗ x [^] k)"
proof (cases "n > 0")
case False
then have "n = 0" by simp
then show ?thesis by simp
next
case True
have[simp]: "x ∈ carrier R" using ‹root_of_unity n x› by simp
have "(⨁k ← [0..<n]. (⨁i ← [0..<n]. f i ⊗ g ((n + k - i) mod n)) ⊗ x [^] k) =
(⨁k ← [0..<n]. (⨁i ← [0..<n]. f i ⊗ g ((n + k - i) mod n) ⊗ x [^] k))"
by (intro monoid_sum_list_cong monoid_sum_list_in_right[symmetric] nat_pow_closed m_closed)
simp_all
also have "... = (⨁k ← [0..<n]. (⨁i ← [0..<n]. f i ⊗ g ((n + k - i) mod n) ⊗ x [^] ((n + k - i) mod n + i)))"
apply (intro monoid_sum_list_cong arg_cong[where f = "(⊗) _"])
apply (intro root_of_unity_powers_nat[OF ‹n > 0› ‹root_of_unity n x›])
by (simp add: add.commute mod_add_right_eq)
also have "... = (⨁k ← [0..<n]. (⨁i ← [0..<n]. f i ⊗ g ((n + k - i) mod n) ⊗ (x [^] ((n + k - i) mod n) ⊗ x [^] i)))"
by (intro monoid_sum_list_cong arg_cong[where f = "(⊗) _"] nat_pow_mult[symmetric]) simp
also have "... = (⨁k ← [0..<n]. (⨁i ← [0..<n]. f i ⊗ x [^] i ⊗ (g ((n + k - i) mod n) ⊗ x [^] ((n + k - i) mod n))))"
proof -
have reorder: "⋀a b c d. ⟦ a ∈ carrier R; b ∈ carrier R; c ∈ carrier R; d ∈ carrier R ⟧ ⟹
a ⊗ b ⊗ (c ⊗ d) = a ⊗ d ⊗ (b ⊗ c)"
using m_comm m_assoc by algebra
show ?thesis
by (intro monoid_sum_list_cong reorder nat_pow_closed) simp_all
qed
also have "... = (⨁i ← [0..<n]. (⨁k ← [0..<n]. f i ⊗ x [^] i ⊗ (g ((n + k - i) mod n) ⊗ x [^] ((n + k - i) mod n))))"
by (intro monoid_sum_list_swap m_closed nat_pow_closed) simp_all
also have "... = (⨁i ← [0..<n]. f i ⊗ x [^] i ⊗ (⨁k ← [0..<n]. (g ((n + k - i) mod n) ⊗ x [^] ((n + k - i) mod n))))"
by (intro monoid_sum_list_cong monoid_sum_list_in_left m_closed nat_pow_closed) simp_all
also have "... = (⨁i ← [0..<n]. f i ⊗ x [^] i ⊗ (⨁j ← [0..<n]. (g j ⊗ x [^] j)))"
(is "(⨁i ← _. _ ⊗ ?lhs i) = (⨁i ← _. _ ⊗ ?rhs i)")
proof -
have "⋀i. i ∈ set [0..<n] ⟹ ?lhs i = ?rhs i"
proof (intro monoid_sum_list_index_permutation[symmetric] m_closed nat_pow_closed)
fix i
assume "i ∈ set [0..<n]"
have "bij_betw (λia. (n - i + ia) mod n) {0..<n} {0..<n}"
by (intro const_add_mod_bij)
also have "bij_betw (λia. (n - i + ia) mod n) {0..<n} {0..<n} =
bij_betw (λia. (n + ia - i) mod n) {0..<n} {0..<n}"
apply (intro bij_betw_cong)
using ‹i ∈ set [0..<n]› by simp
finally show "bij_betw (λia. (n + ia - i) mod n) (set [0..<n]) (set [0..<n])" by simp
qed simp_all
then show ?thesis
by (intro monoid_sum_list_cong) (intro arg_cong[where f = "(⊗) _"])
qed
also have "... = (⨁i ← [0..<n]. f i ⊗ x [^] i) ⊗ (⨁j ← [0..<n]. (g j ⊗ x [^] j))"
by (intro monoid_sum_list_in_right monoid_sum_list_closed) simp_all
finally show ?thesis by argo
qed
context
fixes n :: nat
begin
definition cyclic_convolution :: "'a list ⇒ 'a list ⇒ 'a list" (infixl ‹⋆› 70) where
"cyclic_convolution a b ≡ [(⨁σ ← [0..<n]. (a ! σ ⊗ b ! ((n + i - σ) mod n))). i ← [0..<n]]"
lemma cyclic_convolution_length[simp]:
"length (a ⋆ b) = n" unfolding cyclic_convolution_def by simp
lemma cyclic_convolution_nth:
"i < n ⟹ (a ⋆ b) ! i = (⨁σ ← [0..<n]. (a ! σ ⊗ b ! ((n + i - σ) mod n)))"
unfolding cyclic_convolution_def by simp
lemma cyclic_convolution_closed:
assumes "length a = n" "length b = n"
assumes "set a ⊆ carrier R" "set b ⊆ carrier R"
shows "set (a ⋆ b) ⊆ carrier R"
proof (intro set_subseteqI)
fix i
assume "i < length (a ⋆ b)"
then have "i < n" using assms(1) assms(2) by simp
then have "(a ⋆ b) ! i = (⨁σ ← [0..<n]. (a ! σ ⊗ b ! ((n + i - σ) mod n)))"
using cyclic_convolution_nth by presburger
also have "... ∈ carrier R"
apply (intro monoid_sum_list_closed m_closed)
subgoal for σ using set_subseteqD[OF assms(3)] ‹length a = n› by simp
subgoal for σ using set_subseteqD[OF assms(4)] ‹length b = n› by simp
done
finally show "(a ⋆ b) ! i ∈ carrier R" .
qed
theorem convolution_rule:
assumes "length a = n"
assumes "length b = n"
assumes "set a ⊆ carrier R"
assumes "set b ⊆ carrier R"
assumes "root_of_unity n μ"
assumes "i < n"
shows "NTT μ a ! i ⊗ NTT μ b ! i = NTT μ (a ⋆ b) ! i"
proof (cases "n > 0")
case False
then show ?thesis using ‹i < n› by simp
next
case True
then interpret root_group : group "roots_of_unity_group n"
by (rule roots_of_unity_group_is_group)
interpret root_subgroup : multiplicative_subgroup R "{μ. root_of_unity n μ}" "roots_of_unity_group n"
apply unfold_locales
subgoal using roots_Units ‹n > 0› by blast
subgoal unfolding roots_of_unity_group_def[OF ‹n > 0›] by simp
done
have "μ ∈ carrier R" using assms(5) by simp
have "NTT μ a ! i ⊗ NTT μ b ! i =
(⨁j ← [0..<n]. a ! j ⊗ (μ [^] i) [^] j) ⊗ (⨁j ← [0..<n]. b ! j ⊗ (μ [^] i) [^] j)"
unfolding NTT_nth[OF assms(1) ‹i < n›] NTT_nth[OF assms(2) ‹i < n›] by argo
also have "... = (⨁j ← [0..<n]. (⨁k ← [0..<n]. (a ! k) ⊗ (b ! ((n + j - k) mod n))) ⊗ (μ [^] i) [^] j)"
apply (intro root_of_unity_power_sum_product root_of_unity_nat_pow_closed)
using True ‹root_of_unity n μ› set_subseteqD[OF assms(3)] set_subseteqD[OF assms(4)] assms(1) assms(2)
by simp_all
also have "... = (⨁j ← [0..<n]. (a ⋆ b) ! j ⊗ (μ [^] i) [^] j)"
apply (intro monoid_sum_list_cong arg_cong[where f = "λj. j ⊗ _"] cyclic_convolution_nth[symmetric])
by simp
also have "... = NTT μ (a ⋆ b) ! i"
apply (intro NTT_nth[symmetric]) using ‹i < n› by simp_all
finally show ?thesis .
qed
end
end
end