Theory Number_Theoretic_Transform.NTT
theory NTT
imports Preliminary_Lemmas
begin
section ‹Number Theoretic Transform and Inverse Transform›
text ‹\label{NTT}›
locale ntt = preliminary "TYPE ('a ::prime_card)" +
fixes ω :: "('a::prime_card mod_ring)"
fixes μ :: "('a mod_ring)"
assumes omega_properties: "ω^n = 1" "ω ≠ 1" "(∀ m. ω^m = 1 ∧ m≠0 ⟶ m ≥ n)"
assumes mu_properties: "μ * ω = 1"
begin
lemma mu_properties': "μ ≠ 1"
using omega_properties mu_properties by auto
subsection ‹Definition of $NTT$ and $INTT$›
text ‹\label{NTTdef}›
text ‹
Now we can state an analogue to the $DFT$ on finite fields,
namely the \textit{Number Theoretic Transform}.
First, let us look at an informal definition of $\mathsf{NTT}$~\parencite{ntt_intro}:
\begin{equation*}
\mathsf{NTT}(\vec{x}) =
\begin{pmatrix}
1 & 1 & 1 & 1 & \cdots& 1 \\
1 & \omega & \omega^2 & \omega^3 & \cdots & \omega^{n-1} \\
1 & \omega^2 & \omega^4 & \omega^6 & \cdots & \omega^{2\cdot(n-1)} \\
1 & \omega^3 & \omega^6 & \omega^9 & \cdots & \omega^{3\cdot(n-1)} \\
\vdots & \vdots & \vdots & \vdots & & \vdots \\
1 & \omega^{n-1} & \omega^{2\cdot(n-1)} & \omega^{3\cdot(n-1)} & \cdots & \omega^{(n-1)\cdot(n-1)}
\end{pmatrix} \cdot \vec{x}
\end{equation*}
Or for single vector entries:
\begin{equation*}
\mathsf{NTT}(\vec{x})_i = \sum _{j = 0} ^{n-1} x_j \cdot \omega ^{i\cdot j}
\end{equation*}
›
text ‹Formally:›
definition ntt::"(('a ::prime_card) mod_ring) list ⇒ nat ⇒ 'a mod_ring" where
"ntt numbers i = (∑j=0..<n. (numbers ! j) * ω^(i*j)) "
definition "NTT numbers = map (ntt numbers) [0..<n]"
text ‹\label{INTTdef}
We define the inverse transform $\mathsf{INTT}$ by matrices:
\begin{equation*}
\mathsf{INTT}(\vec{y}) =
\begin{pmatrix}
1 & 1 & 1 & 1 & \cdots& 1 \\
1 & \mu & \mu^2 & \mu^3 & \cdots & \mu^{n-1} \\
1 & \mu^2 & \mu^4 & \mu^6 & \cdots & \mu^{2\cdot(n-1)} \\
1 & \mu^3 & \mu^6 & \mu^9 & \cdots & \mu^{3\cdot(n-1)} \\
\vdots & \vdots & \vdots & \vdots & & \vdots \\
1 & \mu^{n-1} & \mu^{2\cdot(n-1)} & \mu^{3\cdot(n-1)} & \cdots & \mu^{(n-1)\cdot(n-1)}
\end{pmatrix} \cdot \vec{y}
\end{equation*}
Per component:
\begin{equation*}
%
\mathsf{INTT}(\vec{y})_i = \sum _{j = 0} ^{n-1} y_j \cdot \mu ^{i\cdot j}
%
\end{equation*}
›
definition "intt xs i = (∑j=0..<n. (xs ! j) * μ^(i*j)) "
definition "INTT xs = map (intt xs) [0..<n]"
text ‹Vector length is preserved.›
lemma length_NTT:
assumes n_def: "length numbers = n"
shows "length (NTT numbers) = n"
unfolding NTT_def ntt_def using n_def length_map[of _ "[0..<n]"]
by simp
lemma length_INTT:
assumes n_def: "length numbers = n"
shows "length (INTT numbers) = n"
unfolding INTT_def intt_def using n_def length_map[of _ "[0..<n]"]
by simp
subsection ‹Correctness Proof of $NTT$ and $INTT$›
text ‹\label{NTTcorr}›
text ‹
We prove $\mathsf{NTT}$ and $\mathsf{INTT}$ correct:
By taking $\mathsf{INTT}(\mathsf{NTT} (x))$ we obtain $x$ scaled by $n$.
Analogue to $DFT$, one can get rid of the factor $n$ by a simple rescaling.
First, consider an informal proof sketch using the matrix form:
\begin{equation*}
\begin{split}
\mathsf{INTT}(\mathsf{NTT}(\vec{x})) = \hspace{11cm}\\
%
\begin{pmatrix}
1 & 1 & 1 & \cdots& 1 \\
1 & \mu & \mu^2 & \cdots & \mu^{n-1} \\
1 & \mu^2 & \mu^4 & \cdots & \mu^{2\cdot(n-1)} \\
\vdots & \vdots & \vdots & & \vdots \\
1 & \mu^{n-1} & \mu^{2\cdot(n-1)}& \cdots & \mu^{(n-1)\cdot(n-1)}
\end{pmatrix}
%
\cdot
%
\begin{pmatrix}
1 & 1 & 1 & \cdots& 1 \\
1 & \omega & \omega^2 & \cdots & \omega^{n-1} \\
1 & \omega^2 & \omega^4 & \cdots & \omega^{2\cdot(n-1)} \\
\vdots & \vdots & \vdots & & \vdots \\
1 & \omega^{n-1} & \omega^{2\cdot(n-1)} & \cdots & \omega^{(n-1)\cdot(n-1)}
\end{pmatrix}
\cdot
%
\vec{x}
%
\end{split}
\end{equation*}
A resulting entry is of the following form:
\begin{equation*}
%
\mathsf{INTT}(\mathsf{NTT}(x))_i = \sum _ {j = 0} ^{n-1} (\sum _{k = 0} ^{n-1} \mu^{i\cdot k} \cdot \omega^{j\cdot k}) \cdot x_j
%
\end{equation*}
Now, we analyze the interior sum by cases on $i = j$.
\paragraph \noindent Case $i = j$.
\begin{align*}
\sum _{k = 0} ^{n-1} \mu^{i\cdot k} \cdot \omega^{j\cdot k}
&= \sum _{k = 0} ^{n-1} (\mu \cdot \omega)^{i \cdot k} \\
&= n \cdot (\mu \cdot \omega)^{i \cdot k} \\
&= n \cdot 1 ^{i \cdot k} \\ &= n
\end{align*}
Note that $\omega$ and $\mu$ are mutually inverse.
\paragraph \noindent Case $i \neq j$. Wlog assume $i > j$, otherwise replace $\omega$ by $\mu$ and $i -j$ by $j - i$ respectively.
\begin{align*}
\sum _{k = 0} ^{n-1} \mu^{i\cdot k} \cdot \omega^{j\cdot k}
&= \sum _{k = 0} ^{n-1} (\mu \cdot \omega)^{j \cdot k} \cdot \omega^{(i-j) \cdot k} \\
&= \sum _{k = 0} ^{n-1} \omega^{(i-j) \cdot k} \\
&= (1 - \omega ^{(i-j)\cdot n}) \cdot (1 - \omega^{i-j})^{-1} && \text{by lemma on geometric sum}\\
&= (1 - 1^n) \cdot (1 - \omega^{i-j})^{-1} \\
&= 0
\end{align*}
We conclude that $\sum \limits _ {j = 0} ^{n-1} (\sum \limits _{k = 0} ^{n-1} \mu^{i\cdot k} \cdot \omega^{j\cdot k}) \cdot x_j = n \cdot x_i$.
›
theorem ntt_correct:
assumes n_def: "length numbers = n"
shows "INTT (NTT numbers) = map (λ x. (of_int_mod_ring n) * x ) numbers"
proof-
have 0:"⋀ i. i < n ⟹ (INTT (NTT numbers)) ! i = intt (NTT numbers) i " using n_def length_NTT
unfolding INTT_def NTT_def intt_def by simp
text ‹Major sublemma.›
have 1:"⋀ i. i < n ⟹intt (NTT numbers) i = (of_int_mod_ring n)*numbers ! i"
proof-
fix i
assume i_assms:"i < n"
text ‹First, simplify by some chains of equations.›
hence 1:"intt (NTT numbers) i =
(∑l = 0..<n.
(∑j = 0..<n. numbers ! j * ω ^ (l * j)) * μ ^ (i * l))"
unfolding NTT_def intt_def ntt_def using n_def length_map nth_map by simp
also have 2:"… =
(∑l = 0..<n.
(∑j = 0..<n. (numbers ! j * ω ^ (l * j)) * μ ^ (i * l)))"
using sum_in by (simp add: sum_distrib_right)
also have 3:"… =
(∑j = 0..<n.
(∑l = 0..<n. (numbers ! j * ω ^ (l * j) * μ ^ (i * l))))" using sum_swap by fast
text ‹As in the informal proof, we consider three cases. First $j = i$.›
have iisj:"⋀ j. j = i ⟹ (∑l = 0..<n. (numbers ! j * ω ^ (l * j) * μ ^ (i * l))) = (numbers ! j)* (of_int_mod_ring n)"
proof-
fix j
assume "j=i"
hence "⋀ l. l < n ⟹ (numbers ! j * ω ^ (l * j) * μ ^ (i * l))= (numbers ! j)"
by (simp add: left_right_inverse_power mult.commute mu_properties(1))
moreover have "⋀ l. l < n ⟹ numbers ! j * ω ^ (l * j) * μ ^ (i * l) = numbers ! j"
using calculation by blast
text ‹$\omega^{il}\cdot \omega^{jl} = 1$. Thus, we sum over $1$ $n$ times, which gives the goal.›
ultimately show "(∑l = 0..<n. (numbers ! j * ω ^ (l * j) * μ ^ (i * l))) =
(numbers ! j)* (of_int_mod_ring n)"
using n_def sum_const[of "numbers ! j" n] exp_rule[of ω μ] mu_properties(1)
by (metis (no_types, lifting) atLeastLessThan_iff mult.commute sum.cong)
qed
text ‹Case $j < i$.›
have jlsi:"⋀ j. j < i ⟹ (∑l = 0..<n. (numbers ! j * ω ^ (l * j) * μ ^ (i * l))) = 0"
proof-
fix j
assume j_assms:"j < i"
hence 00:"⋀ (c::('a::prime_card) mod_ring) a b. c * a^j*b^i = (a*b)^j*(c * b^(i-j))"
using algebra_simps
by (smt (z3) le_less ordered_cancel_comm_monoid_diff_class.add_diff_inverse power_add)
text ‹A geometric sum over $\mu^l$ remains.›
have 01:" (∑l = 0..<n. (numbers ! j * ω ^ (l * j) * μ ^ (i * l))) =
(∑l = 0..<n. (numbers ! j * (μ^l)^(i-j)))"
apply(rule sum_eq)
using mu_properties(1) 00 algebra_simps(23)
by (smt (z3) mult.commute mult.left_neutral power_mult power_one)
have 02:"… = numbers ! j *(∑l = 0..<n. ((μ^l)^(i-j))) "
using sum_in[of "λ l. numbers ! j * (μ ^ l) ^ (i - j)" " numbers ! j" n]
by (simp add: mult_hom.hom_sum)
moreover have 03:"(∑l = 0..<n. ((μ^l)^(i-j))) =
(∑l = 0..<n. ((μ^(i-j))^l)) "
by(rule sum_eq) (metis mult.commute power_mult)
have "μ^(i-j) ≠ 1"
proof
assume "μ ^ (i - j) = 1"
hence "ord p (to_int_mod_ring μ) ≤ i-j"
by (simp add: j_assms not_le ord_max)
moreover hence "ord p (to_int_mod_ring ω) ≤ i-j"
by (metis ‹μ ^ (i - j) = 1› diff_is_0_eq exp_rule j_assms leD mult.comm_neutral mult.commute mu_properties(1) ord_max)
moreover hence "i-j < n"
using j_assms i_assms p_fact k_bound n_lst2 by linarith
moreover have "ord p (to_int_mod_ring ω) = n" using omega_properties n_lst2 unfolding ord_def
by (metis (no_types) ‹μ ^ (i - j) = 1› calculation(3) diff_is_0_eq j_assms leD left_right_inverse_power mult.comm_neutral mult_cancel_left mu_properties(1) omega_properties(3) zero_neq_one)
ultimately show False by simp
qed
text ‹Application of the lemma for geometric sums.›
ultimately have "(1-μ^(i-j))*(∑l = 0..<n. ((μ^(i-j))^l)) = (1-(μ^(i-j))^n)"
using geo_sum[of "μ ^ (i - j)" n] by simp
moreover have "(μ^(i-j))^n = 1"
by (metis (no_types) left_right_inverse_power mult.commute mult.right_neutral mu_properties(1) omega_properties(1) power_mult power_one)
text ‹The sum for the current index is 0.›
ultimately have "(∑l = 0..<n. ((μ^(i-j))^l)) = 0"
by (metis ‹μ ^ (i - j) ≠ 1› divisors_zero eq_iff_diff_eq_0)
thus "(∑l = 0..<n. numbers ! j * ω ^ (l * j) * μ ^ (i * l)) = 0" using 01 02 03 by simp
qed
text ‹Case $i < j$.
We also rewrite the whole summation until the lemma for geometric sums is applicable.
From this, we conclude that the term is 0.›
have ilsj:"⋀ j. i < j ∧ j < n ⟹ (∑l = 0..<n. (numbers ! j * ω ^ (l * j) * μ ^ (i * l))) = 0"
proof-
fix j
assume ij_Assm: "i < j ∧ j < n"
hence 00:"⋀ (c::('a::prime_card) mod_ring) a b. (a*b)^i*(c * b^(j-i)) = c * a^i*b^j "
by (auto simp: field_simps simp flip: power_add)
have 01:" (∑l = 0..<n. (numbers ! j * ω ^ (l * j) * μ ^ (i * l))) =
(∑l = 0..<n. (numbers ! j * (ω^l)^(j-i))) "
apply(rule sum_eq) subgoal for l
using mu_properties(1) 00[of "ω^l" "μ^l" "numbers ! j "] algebra_simps(23)
by (smt (z3) "00" left_right_inverse_power mult.assoc mult.commute mult.right_neutral power_mult)
done
moreover have 02:"(∑l = 0..<n. (numbers ! j * (ω^l)^(j-i))) =
numbers ! j *(∑l = 0..<n. ( (ω^l)^(j-i))) "
by (simp add: mult_hom.hom_sum)
moreover have 03:"(∑l = 0..<n. ( (ω^l)^(j-i))) =
(∑l = 0..<n. (( (ω^(j-i))^l))) "
by(rule sum_eq) (metis mult.commute power_mult)
have "ω^(j-i) ≠ 1"
proof
assume " ω ^ (j - i) = 1"
hence "ord p (to_int_mod_ring ω) ≤ j-i" using ord_max[of "j-i" ω] ij_Assm by simp
moreover have "ord p (to_int_mod_ring ω) =p-1"
by (meson ‹ω ^ (j - i) = 1› diff_is_0_eq diff_le_self ij_Assm leD le_trans omega_properties(3))
ultimately show False
by (meson ‹ω ^ (j - i) = 1› diff_is_0_eq diff_le_self ij_Assm leD le_trans omega_properties(3))
qed
text ‹Geometric sum.›
ultimately have "(1-ω^(j-i))* (∑l = 0..<n. ((ω^(j-i))^l)) = (1-(ω^(j-i))^n)"
using geo_sum[of "ω ^ (j-i)" n] by simp
moreover have "(ω^(j-i))^n = 1"
by (metis (no_types) mult.commute omega_properties(1) power_mult power_one)
ultimately have "(∑l = 0..<n. ((ω^(j-i))^l)) = 0"
by (metis ‹ω ^ (j - i) ≠ 1› eq_iff_diff_eq_0 no_zero_divisors)
thus "(∑l = 0..<n. numbers ! j * ω ^ (l * j) * μ ^ (i * l)) = 0" using 01 02 03 by simp
qed
text ‹We compose the cases $j <i$, $j = i$ and $j > i$ to a complete summation over index $j$.›
have " (∑j = 0..<i. ∑l = 0..<n. numbers ! j * ω ^ (l * j) * μ ^ (i * l)) = 0" using jlsi by simp
moreover have " (∑j = i..<i+1. ∑l = 0..<n. numbers ! j * ω ^ (l * j) * μ ^ (i * l)) = numbers ! i * (of_int_mod_ring n)" using iisj by simp
moreover have " (∑j = (i+1)..<n. ∑l = 0..<n. numbers ! j * ω ^ (l * j) * μ ^ (i * l)) = 0" using ilsj by simp
ultimately have " (∑j = 0..<n. ∑l = 0..<n. numbers ! j * ω ^ (l * j) * μ ^ (i * l)) =
numbers ! i * (of_int_mod_ring n)" using i_assms sum_split
by (smt (z3) add.commute add.left_neutral int_ops(2) less_imp_of_nat_less of_nat_add of_nat_eq_iff of_nat_less_imp_less)
text ‹Index-wise equality can be shown.›
thus "intt (NTT numbers) i = of_int_mod_ring (int n) * numbers ! i" using 1 2 3
by (metis mult.commute)
qed
have 2: "⋀ i. i < n ⟹ (map ((*) (of_int_mod_ring (int n))) numbers ) ! i = (of_int_mod_ring (int n)) * (numbers ! i)"
by (simp add: n_def)
text ‹We relate index-wise equality to the function definition.›
show ?thesis
apply(rule nth_equalityI)
subgoal my_subgoal
unfolding INTT_def NTT_def
apply (simp add: n_def)
done
subgoal for i
using 0 1 2 n_def algebra_simps my_subgoal length_map
apply auto
done
done
qed
text ‹Now we prove the converse to be true:
$\mathsf{NTT}(\mathsf{INTT}(\vec{x})) = n \cdot \vec{x}$.
The proof proceeds analogously with exchanged roles of $\omega$ and $\mu$.
›
theorem inv_ntt_correct:
assumes n_def: "length numbers = n"
shows "NTT (INTT numbers) = map (λ x. (of_int_mod_ring n) * x ) numbers"
proof-
have 0:"⋀ i. i < n ⟹ (NTT (INTT numbers)) ! i = ntt (INTT numbers) i " using n_def length_NTT
unfolding INTT_def NTT_def intt_def by simp
have 1:"⋀ i. i < n ⟹ntt (INTT numbers) i = (of_int_mod_ring n)*numbers ! i"
proof-
fix i
assume i_assms:"i < n"
hence 1:"ntt (INTT numbers) i =
(∑l = 0..<n.
(∑j = 0..<n. numbers ! j * μ ^ (l * j)) * ω ^ (i * l))"
unfolding INTT_def ntt_def intt_def using n_def length_map nth_map by simp
hence 2:"… = (∑l = 0..<n.
(∑j = 0..<n. (numbers ! j * μ ^ (l * j)) * ω ^ (i * l)))" using sum_in by simp
have 3:" … =(∑j = 0..<n.
(∑l = 0..<n. (numbers ! j * μ ^ (l * j) * ω ^ (i * l))))" using sum_swap by fast
have iisj:"⋀ j. j = i ⟹ (∑l = 0..<n. (numbers ! j * μ ^ (l * j) * ω ^ (i * l))) = (numbers ! j)* (of_int_mod_ring n)"
proof-
fix j
assume "j=i"
hence "⋀ l. l < n ⟹ (numbers ! j * μ ^ (l * j) * ω ^ (i * l))= (numbers ! j)"
by (simp add: left_right_inverse_power mult.commute mu_properties(1))
moreover have "⋀ l. l < n ⟹ numbers ! j * μ ^ (l * j) * ω ^ (i * l) = numbers ! j"
using calculation by blast
ultimately show "(∑l = 0..<n. (numbers ! j * μ ^ (l * j) * ω ^ (i * l))) = (numbers ! j)* (of_int_mod_ring n)"
using n_def sum_const[of "numbers ! j" n] exp_rule[of ω μ] mu_properties(1)
by (metis (no_types, lifting) atLeastLessThan_iff mult.commute sum.cong)
qed
have jlsi:"⋀ j. j < i ⟹ (∑l = 0..<n. (numbers ! j * μ ^ (l * j) * ω ^ (i * l))) = 0"
proof-
fix j
assume j_assms:"j < i"
hence 00:"⋀ (c::('a::prime_card) mod_ring) a b. c * a^j*b^i = (a*b)^j*(c * b^(i-j))"
using algebra_simps
by (smt (z3) le_less ordered_cancel_comm_monoid_diff_class.add_diff_inverse power_add)
have 01:" (∑l = 0..<n. (numbers ! j * μ ^ (l * j) * ω ^ (i * l))) =
(∑l = 0..<n. (numbers ! j * (ω^l)^(i-j))) "
apply(rule sum_eq)
using mu_properties(1) 00 algebra_simps(23)
by (smt (z3) mult.commute mult.left_neutral power_mult power_one)
moreover have 02: "…= numbers ! j *(∑l = 0..<n. ((ω^l)^(i-j))) "
using sum_in[of "λ l. numbers ! j * (μ ^ l) ^ (i - j)" " numbers ! j" n]
by (simp add: mult_hom.hom_sum)
moreover have 03:"(∑l = 0..<n. ((ω^l)^(i-j))) =
(∑l = 0..<n. ((ω^(i-j))^l)) "
by(rule sum_eq) (metis mult.commute power_mult)
have "ω^(i-j) ≠ 1"
proof
assume "ω ^ (i - j) = 1"
hence "ord p (to_int_mod_ring ω) ≤ i-j"
by (simp add: j_assms not_le ord_max)
moreover have "ord p (to_int_mod_ring ω) = n" using omega_properties n_lst2 unfolding ord_def
by (meson ‹ω ^ (i - j) = 1› diff_is_0_eq diff_le_self i_assms j_assms leD le_trans)
ultimately show False
by (metis i_assms leD less_imp_diff_less)
qed
ultimately have "(1-ω^(i-j))*(∑l = 0..<n. ((ω^(i-j))^l)) = (1-(ω^(i-j))^n)"
using geo_sum[of "ω ^ (i - j)" n] by simp
moreover have "(ω^(i-j))^n = 1"
by (metis (no_types) mult.commute omega_properties(1) power_mult power_one)
ultimately have "(∑l = 0..<n. ((ω^(i-j))^l)) = 0"
by (metis ‹ω ^ (i - j) ≠ 1› divisors_zero eq_iff_diff_eq_0)
thus "(∑l = 0..<n. numbers ! j * μ ^ (l * j) * ω ^ (i * l)) = 0" using 01 02 03 by simp
qed
have ilsj:"⋀ j. i < j ∧ j < n ⟹ (∑l = 0..<n. (numbers ! j * μ ^ (l * j) * ω ^ (i * l))) = 0"
proof-
fix j
assume ij_Assm: "i < j ∧ j < n"
hence 00:"⋀ (c::('a::prime_card) mod_ring) a b. (a*b)^i*(c * b^(j-i)) = c * a^i*b^j "
by (simp add: field_simps flip: power_add)
have 01:" (∑l = 0..<n. (numbers ! j * μ ^ (l * j) * ω ^ (i * l))) =
(∑l = 0..<n. (numbers ! j * (μ^l)^(j-i))) "
apply(rule sum_eq) subgoal for l
using mu_properties(1) 00[of "ω^l" "μ^l" "numbers ! j "] algebra_simps(23)
by (smt (z3) "00" left_right_inverse_power mult.assoc mult.commute mult.right_neutral power_mult)
done
moreover have 02:"(∑l = 0..<n. (numbers ! j * (μ^l)^(j-i))) =
numbers ! j *(∑l = 0..<n. ( (μ^l)^(j-i))) "
by (simp add: mult_hom.hom_sum)
moreover have 03:"(∑l = 0..<n. ( (μ^l)^(j-i))) =
(∑l = 0..<n. (( (μ^(j-i))^l))) "
by(rule sum_eq) (metis mult.commute power_mult)
have "μ^(j-i) ≠ 1"
proof
assume "μ ^ (j - i) = 1"
hence "ord p (to_int_mod_ring μ) ≤ j -i "
by (simp add: ij_Assm not_le ord_max)
moreover hence "ord p (to_int_mod_ring ω) ≤ j-i"
by (metis ‹μ ^ (j - i) = 1› diff_is_0_eq exp_rule ij_Assm leD mult.comm_neutral mult.commute mu_properties(1) ord_max)
moreover hence "j-i < n" using ij_Assm i_assms p_fact k_bound n_lst2 by linarith
moreover have "ord p (to_int_mod_ring ω) = n" using omega_properties n_lst2 unfolding ord_def
by (metis (no_types) ‹μ ^ (j-i) = 1› calculation(3) diff_is_0_eq ij_Assm leD left_right_inverse_power mult.comm_neutral mult_cancel_left mu_properties(1) omega_properties(3) zero_neq_one)
ultimately show False by simp
qed
ultimately have "(1-μ^(j-i))* (∑l = 0..<n. ((μ^(j-i))^l)) = (1-(μ^(j-i))^n)"
using geo_sum[of "μ ^ (j-i)" n] by simp
moreover have "(μ^(j-i))^n = 1"
by (metis (no_types) left_right_inverse_power mult.commute mult.right_neutral mu_properties(1) omega_properties(1) power_mult power_one)
ultimately have "(∑l = 0..<n. ((μ^(j-i))^l)) = 0"
by (metis ‹μ ^ (j - i) ≠ 1› eq_iff_diff_eq_0 no_zero_divisors)
thus "(∑l = 0..<n. numbers ! j * μ ^ (l * j) * ω ^ (i * l)) = 0" using 01 02 03 by simp
qed
have " (∑j = 0..<i. ∑l = 0..<n. numbers ! j * μ ^ (l * j) * ω ^ (i * l)) = 0" using jlsi by simp
moreover have " (∑j = i..<i+1. ∑l = 0..<n. numbers ! j * μ ^ (l * j) * ω ^ (i * l)) = numbers ! i * (of_int_mod_ring n)" using iisj by simp
moreover have " (∑j = (i+1)..<n. ∑l = 0..<n. numbers ! j * μ ^ (l * j) * ω ^ (i * l)) = 0" using ilsj by simp
ultimately have " (∑j = 0..<n. ∑l = 0..<n. numbers ! j * μ ^ (l * j) * ω ^ (i * l)) =
numbers ! i * (of_int_mod_ring n)" using i_assms sum_split
by (smt (z3) add.commute add.left_neutral int_ops(2) less_imp_of_nat_less of_nat_add of_nat_eq_iff of_nat_less_imp_less)
thus "ntt (INTT numbers) i = of_int_mod_ring (int n) * numbers ! i" using 1 2 3
by (metis mult.commute)
qed
have 2: "⋀ i. i < n ⟹ (map ((*) (of_int_mod_ring (int n))) numbers ) ! i = (of_int_mod_ring (int n)) * (numbers ! i)"
by (simp add: n_def)
show ?thesis
apply(rule nth_equalityI)
subgoal my_little_subgoal
unfolding INTT_def NTT_def
apply (simp add: n_def)
done
subgoal for i
using 0 1 2 n_def algebra_simps my_little_subgoal length_map
apply auto
done
done
qed
end
end