Theory FNTT_Rings
subsection "Fast Number Theoretic Transforms in Rings"
theory FNTT_Rings
imports NTT_Rings "Number_Theoretic_Transform.Butterfly"
begin
context cring begin
text "The following lemma is the essence of Fast Number Theoretic Transforms (FNTTs)."
lemma NTT_recursion:
assumes "even n"
assumes "primitive_root n μ"
assumes[simp]: "length a = n"
assumes[simp]: "j < n"
assumes[simp]: "set a ⊆ carrier R"
defines "j' ≡ (if j < n div 2 then j else j - n div 2)"
shows "j' < n div 2" "j = (if j < n div 2 then j' else j' + n div 2)"
and "(NTT μ a) ! j = (NTT (μ [^] (2::nat)) [a ! i. i ← filter even [0..<n]]) ! j'
⊕ μ [^] j ⊗ (NTT (μ [^] (2::nat)) [a ! i. i ← filter odd [0..<n]]) ! j'"
proof -
from assms have "n > 0" by linarith
have[simp]: "μ ∈ carrier R" using ‹primitive_root n μ› unfolding primitive_root_def root_of_unity_def by blast
then have μ_pow_carrier[simp]: "μ [^] i ∈ carrier R" for i :: nat by simp
show "j' < n div 2" unfolding j'_def using ‹j < n› ‹even n› by fastforce
show j'_alt: "j = (if j < n div 2 then j' else j' + n div 2)"
unfolding j'_def by simp
have a_even_carrier[simp]: "a ! (2 * i) ∈ carrier R" if "i < n div 2" for i
using set_subseteqD[OF ‹set a ⊆ carrier R›] assms that by simp
have a_odd_carrier[simp]: "a ! (2 * i + 1) ∈ carrier R" if "i < n div 2" for i
using set_subseteqD[OF ‹set a ⊆ carrier R›] assms that by simp
have μ_pow: "μ [^] (j * (2 * i)) = (μ [^] (2::nat)) [^] (j' * i)" for i
proof -
have "μ [^] (j * (2 * i)) = (μ [^] (j * 2)) [^] i"
using mult.assoc nat_pow_pow[symmetric] by simp
also have "μ [^] (j * 2) = μ [^] (j' * 2)"
proof (cases "j < n div 2")
case True
then show ?thesis unfolding j'_def by simp
next
case False
then have "μ [^] (j * 2) = μ [^] (j' * 2 + n)"
using j'_alt by (simp add: ‹even n›)
also have "... = μ [^] (j' * 2)"
using ‹n > 0› ‹primitive_root n μ›
by (intro root_of_unity_powers_nat[of n]) auto
finally show ?thesis .
qed
finally show ?thesis unfolding nat_pow_pow[OF ‹μ ∈ carrier R›]
by (simp add: mult.assoc mult.commute)
qed
have "(NTT μ a) ! j = (⨁i ← [0..<n]. a ! i ⊗ (μ [^] (j * i)))"
using NTT_nth_2[of a n j μ] by simp
also have "... = (⨁i ← [0..<n div 2]. a ! (2 * i) ⊗ (μ [^] (j * (2 * i))))
⊕ (⨁i ← [0..<n div 2]. a ! (2 * i + 1) ⊗ (μ [^] (j * (2 * i + 1))))"
using ‹even n›
by (intro monoid_sum_list_even_odd_split m_closed nat_pow_closed set_subseteqD) simp_all
also have "(⨁i ← [0..<n div 2]. a ! (2 * i + 1) ⊗ (μ [^] (j * (2 * i + 1))))
= (⨁i ← [0..<n div 2]. μ [^] j ⊗ (a ! (2 * i + 1) ⊗ (μ [^] (j * (2 * i)))))"
proof (intro monoid_sum_list_cong)
fix i
assume "i ∈ set [0..<n div 2]"
then have[simp]: "i < n div 2" by simp
have "a ! (2 * i + 1) ⊗ (μ [^] (j * (2 * i + 1))) =
a ! (2 * i + 1) ⊗ (μ [^] (j * (2 * i)) ⊗ μ [^] j)"
unfolding distrib_left mult_1_right
unfolding nat_pow_mult[symmetric, OF ‹μ ∈ carrier R›]
by (rule refl)
also have "... = (a ! (2 * i + 1) ⊗ μ [^] (j * (2 * i))) ⊗ μ [^] j"
using a_odd_carrier[OF ‹i < n div 2›]
by (intro m_assoc[symmetric]; simp)
also have "... = μ [^] j ⊗ (a ! (2 * i + 1) ⊗ μ [^] (j * (2 * i)))"
using a_odd_carrier[OF ‹i < n div 2›]
by (intro m_comm; simp)
finally show "a ! (2 * i + 1) ⊗ μ [^] (j * (2 * i + 1)) = ..." .
qed
also have "... = μ [^] j ⊗ (⨁i ← [0..<n div 2]. a ! (2 * i + 1) ⊗ (μ [^] (j * (2 * i))))"
using a_odd_carrier by (intro monoid_sum_list_in_left; simp)
finally have "(NTT μ a) ! j = (⨁i ← [0..<n div 2]. a ! (2 * i) ⊗ (μ [^] (2::nat)) [^] (j' * i))
⊕ μ [^] j ⊗ (⨁i ← [0..<n div 2]. a ! (2 * i + 1) ⊗ (μ [^] (2::nat)) [^] (j' * i))"
unfolding μ_pow .
also have "... = (⨁i ← [0..<n div 2]. [a ! i. i ← filter even [0..<n]] ! i ⊗ (μ [^] (2::nat)) [^] (j' * i))
⊕ μ [^] j ⊗ (⨁i ← [0..<n div 2]. [a ! i. i ← filter odd [0..<n]] ! i ⊗ (μ [^] (2::nat)) [^] (j' * i))"
by (intro_cong "[cong_tag_2 (⊕), cong_tag_2 (⊗)]" more: monoid_sum_list_cong)
(simp_all add: filter_even_nth length_filter_even length_filter_odd filter_odd_nth)
also have "... = (NTT (μ [^] (2::nat)) [a ! i. i ← filter even [0..<n]]) ! j'
⊕ μ [^] j ⊗ (NTT (μ [^] (2::nat)) [a ! i. i ← filter odd [0..<n]]) ! j'"
by (intro_cong "[cong_tag_2 (⊕), cong_tag_2 (⊗)]" more: NTT_nth_2[symmetric])
(simp_all add: length_filter_even length_filter_odd ‹even n› ‹j' < n div 2›)
finally show "(NTT μ a) ! j = ..." .
qed
lemma NTT_recursion_1:
assumes "even n"
assumes "primitive_root n μ"
assumes[simp]: "length a = n"
assumes[simp]: "j < n div 2"
assumes[simp]: "set a ⊆ carrier R"
shows "(NTT μ a) ! j =
(NTT (μ [^] (2::nat)) [a ! i. i ← filter even [0..<n]]) ! j
⊕ μ [^] j ⊗ (NTT (μ [^] (2::nat)) [a ! i. i ← filter odd [0..<n]]) ! j"
proof -
have "j < n" using ‹j < n div 2› by linarith
show ?thesis
using NTT_recursion[OF ‹even n› ‹primitive_root n μ› ‹length a = n› ‹j < n› ‹set a ⊆ carrier R›]
using ‹j < n div 2› by presburger
qed
lemma NTT_recursion_2:
assumes "even n"
assumes "primitive_root n μ"
assumes[simp]: "length a = n"
assumes[simp]: "j < n div 2"
assumes[simp]: "set a ⊆ carrier R"
assumes halfway_property: "μ [^] (n div 2) = ⊖ 𝟭"
shows "(NTT μ a) ! (n div 2 + j) =
(NTT (μ [^] (2::nat)) [a ! i. i ← filter even [0..<n]]) ! j
⊖ μ [^] j ⊗ (NTT (μ [^] (2::nat)) [a ! i. i ← filter odd [0..<n]]) ! j"
proof -
from assms have "μ ∈ carrier R" unfolding primitive_root_def root_of_unity_def by simp
then have carrier_1: "μ [^] j ∈ carrier R"
by simp
have carrier_2: "NTT (μ [^] (2::nat)) (map ((!) a) (filter odd [0..<n])) ! j ∈ carrier R"
apply (intro NTT_nth_closed[where n = "n div 2"])
subgoal using ‹set a ⊆ carrier R› ‹length a = n› by fastforce
subgoal using ‹μ ∈ carrier R› by simp
subgoal by (simp add: length_filter_odd)
subgoal using ‹j < n div 2› .
done
have "n div 2 + j < n" using ‹j < n div 2› ‹even n› by linarith
then have "(NTT μ a) ! (n div 2 + j) =
(NTT (μ [^] (2::nat)) [a ! i. i ← filter even [0..<n]]) ! j
⊕ μ [^] (n div 2 + j) ⊗ (NTT (μ [^] (2::nat)) [a ! i. i ← filter odd [0..<n]]) ! j"
using NTT_recursion[OF ‹even n› ‹primitive_root n μ› ‹length a = n› ‹n div 2 + j < n› ‹set a ⊆ carrier R›]
by simp
also have "μ [^] (n div 2 + j) = ⊖ (μ [^] j)"
unfolding nat_pow_mult[symmetric, OF ‹μ ∈ carrier R›] halfway_property
by (intro minus_eq_mult_one[symmetric]; simp add: ‹μ ∈ carrier R›)
finally show ?thesis unfolding minus_eq l_minus[OF carrier_1 carrier_2] .
qed
lemma NTT_diffs:
assumes "even n"
assumes "primitive_root n μ"
assumes "length a = n"
assumes "j < n div 2"
assumes "set a ⊆ carrier R"
assumes "μ [^] (n div 2) = ⊖ 𝟭"
shows "NTT μ a ! j ⊖ NTT μ a ! (n div 2 + j) = nat_embedding 2 ⊗ (μ [^] j ⊗ NTT (μ [^] (2::nat)) (map ((!) a) (filter odd [0..<n])) ! j)"
proof -
have[simp]: "μ ∈ carrier R" using ‹primitive_root n μ› unfolding primitive_root_def root_of_unity_def by blast
define ntt1 where "ntt1 = NTT (μ [^] (2::nat)) (map ((!) a) (filter even [0..<n])) ! j"
have "ntt1 ∈ carrier R" unfolding ntt1_def
apply (intro set_subseteqD[OF NTT_closed] set_subseteqI)
subgoal for i
using set_subseteqD[OF ‹set a ⊆ carrier R›]
by (simp add: filter_even_nth ‹length a = n› ‹even n› length_filter_even)
subgoal by simp
subgoal using assms by (simp add: length_filter_even ‹even n›)
done
define ntt2 where "ntt2 = NTT (μ [^] (2::nat)) (map ((!) a) (filter odd [0..<n])) ! j"
have "ntt2 ∈ carrier R" unfolding ntt2_def
apply (intro set_subseteqD[OF NTT_closed] set_subseteqI)
subgoal for i
using set_subseteqD[OF ‹set a ⊆ carrier R›]
by (simp add: filter_odd_nth ‹length a = n› ‹even n› length_filter_odd)
subgoal by simp
subgoal using assms by (simp add: length_filter_odd ‹even n›)
done
have "NTT μ a ! j ⊖ NTT μ a ! (n div 2 + j) =
(ntt1 ⊕ μ [^] j ⊗ ntt2) ⊖ (ntt1 ⊖ μ [^] j ⊗ ntt2)"
apply (intro arg_cong2[where f = "λi j. i ⊖ j"])
unfolding ntt1_def ntt2_def
subgoal by (intro NTT_recursion_1 assms)
subgoal by (intro NTT_recursion_2 assms)
done
also have "... = μ [^] j ⊗ (ntt2 ⊕ ntt2)"
using ‹ntt1 ∈ carrier R› ‹ntt2 ∈ carrier R› nat_pow_closed[OF ‹μ ∈ carrier R›]
by algebra
also have "... = μ [^] j ⊗ ((𝟭 ⊕ 𝟭) ⊗ ntt2)"
using ‹ntt2 ∈ carrier R› one_closed by algebra
also have "... = μ [^] j ⊗ (nat_embedding 2 ⊗ ntt2)"
by (simp add: numeral_2_eq_2)
also have "... = nat_embedding 2 ⊗ (μ [^] j ⊗ ntt2)"
using nat_pow_closed[OF ‹μ ∈ carrier R›] ‹ntt2 ∈ carrier R› nat_embedding_closed
by algebra
finally show ?thesis unfolding ntt2_def .
qed
text "The following algorithm is adapted from @{theory Number_Theoretic_Transform.Butterfly}"
lemma FNTT_term_aux[simp]: "length (filter P [0..<l]) < Suc l"
by (metis diff_zero le_imp_less_Suc length_filter_le length_upt)
fun FNTT :: "'a ⇒ 'a list ⇒ 'a list" where
"FNTT μ [] = []"
| "FNTT μ [x] = [x]"
| "FNTT μ [x, y] = [x ⊕ y, x ⊖ y]"
| "FNTT μ a = (let n = length a;
nums1 = [a!i. i ← filter even [0..<n]];
nums2 = [a!i. i ← filter odd [0..<n]];
b = FNTT (μ [^] (2::nat)) nums1;
c = FNTT (μ [^] (2::nat)) nums2;
g = [b!i ⊕ (μ [^] i) ⊗ c!i. i ← [0..<(n div 2)]];
h = [b!i ⊖ (μ [^] i) ⊗ c!i. i ← [0..<(n div 2)]]
in g@h)"
lemmas [simp del] = FNTT_term_aux
declare FNTT.simps[simp del]
lemma length_FNTT[simp]:
assumes "length a = 2 ^ k"
shows "length (FNTT μ a) = length a"
using assms
proof (induction rule: FNTT.induct)
case (1 μ)
then show ?case by simp
next
case (2 μ x)
then show ?case by (simp add: FNTT.simps)
next
case (3 μ x y)
then show ?case by (simp add: FNTT.simps)
next
case (4 μ a1 a2 a3 as)
define a where "a = a1 # a2 # a3 # as"
define n where "n = length a"
with a_def have "even n" using 4(3)
by (cases "k = 0") simp_all
define nums1 where "nums1 = [a!i. i ← filter even [0..<n]]"
define nums2 where "nums2 = [a!i. i ← filter odd [0..<n]]"
define b where "b = FNTT (μ [^] (2::nat)) nums1"
define c where "c = FNTT (μ [^] (2::nat)) nums2"
define g where "g = [b!i ⊕ (μ [^] i) ⊗ c!i. i ← [0..<(n div 2)]]"
define h where "h = [b!i ⊖ (μ [^] i) ⊗ c!i. i ← [0..<(n div 2)]]"
note defs = a_def n_def nums1_def nums2_def b_def c_def g_def h_def
have "length (FNTT μ a) = length g + length h"
using defs by (simp add: Let_def FNTT.simps)
also have "... = (n div 2) + (n div 2)" unfolding g_def h_def by simp
also have "... = n" using ‹even n› by fastforce
finally show ?case by (simp only: a_def n_def)
qed
theorem FNTT_NTT:
assumes[simp]: "μ ∈ carrier R"
assumes "n = 2 ^ k"
assumes "primitive_root n μ"
assumes halfway_property: "μ [^] (n div 2) = ⊖ 𝟭"
assumes[simp]: "length a = n"
assumes "set a ⊆ carrier R"
shows "FNTT μ a = NTT μ a"
using assms
proof (induction μ a arbitrary: n k rule: FNTT.induct)
case (1 μ)
then show ?case unfolding NTT_def by simp
next
case (2 μ x)
then have "n = 1" by simp
then have "k = 0" using ‹n = 2 ^ k› by simp
moreover have "x ∈ carrier R" using 2 by simp
ultimately show ?case unfolding NTT_def by (simp add: Let_def FNTT.simps)
next
case (3 μ x y)
then have[simp]: "x ∈ carrier R" "y ∈ carrier R" by simp_all
from 3 have "n = 2" by simp
with ‹μ [^] (n div 2) = ⊖ 𝟭› have "μ [^] (1 :: nat) = ⊖ 𝟭" by simp
then have "μ = ⊖ 𝟭" by (simp add: ‹μ ∈ carrier R›)
have "NTT μ [x, y] = [x ⊕ y, x ⊖ y]"
unfolding NTT_def
apply (simp add: Let_def 3 ‹μ = ⊖ 𝟭›)
using ‹x ∈ carrier R› ‹y ∈ carrier R› by algebra
then show ?case by (simp add: FNTT.simps)
next
case (4 μ a1 a2 a3 as)
define a where "a = a1 # a2 # a3 # as"
then have[simp]: "length a = n" using 4(7) by simp
define nums1 where "nums1 = [a!i. i ← filter even [0..<n]]"
define nums2 where "nums2 = [a!i. i ← filter odd [0..<n]]"
define b where "b = FNTT (μ [^] (2::nat)) nums1"
define c where "c = FNTT (μ [^] (2::nat)) nums2"
define g where "g = [b!i ⊕ (μ [^] i) ⊗ c!i. i ← [0..<(n div 2)]]"
then have "length g = n div 2" by simp
define h where "h = [b!i ⊖ (μ [^] i) ⊗ c!i. i ← [0..<(n div 2)]]"
then have "length h = n div 2" by simp
note defs = a_def nums1_def nums2_def b_def c_def g_def h_def
have "k > 0"
using ‹length (a1 # a2 # a3 # as) = n› ‹n = 2 ^ k›
by (cases "k = 0") simp_all
then have "even n" "n div 2 = 2 ^ (k - 1)"
using ‹n = 2 ^ k› by (simp_all add: power_diff)
have "FNTT μ (a1 # a2 # a3 # as) = g @ h"
unfolding FNTT.simps
using ‹length (a1 # a2 # a3 # as) = n› by (simp only: Let_def defs)
then have "FNTT μ a = g @ h" using a_def by simp
have recursive_halfway: "(μ [^] (2 :: nat)) [^] (n div 2 div 2) = ⊖ 𝟭"
proof -
have "n ≥ 3"
using ‹length (a1 # a2 # a3 # as) = n› by simp
then have "k ≥ 2" using ‹n = 2 ^ k› by (cases "k ∈ {0, 1}") auto
then have "even (n div 2)" using ‹n div 2 = 2 ^ (k - 1)› by fastforce
then show ?thesis
by (simp add: nat_pow_pow ‹μ ∈ carrier R› ‹μ [^] (n div 2) = ⊖ 𝟭›)
qed
have "b = NTT (μ [^] (2::nat)) nums1"
unfolding b_def
apply (intro 4(1)[of n nums1 nums2 "n div 2" "k - 1"])
subgoal using ‹length (a1 # a2 # a3 # as) = n› by simp
subgoal using nums1_def a_def by simp
subgoal using nums2_def a_def by simp
subgoal using ‹μ ∈ carrier R› by simp
subgoal using ‹n div 2 = 2 ^ (k - 1)› .
subgoal using primitive_root_recursion ‹even n› ‹primitive_root n μ› by blast
subgoal using recursive_halfway .
subgoal using nums1_def length_filter_even ‹even n› by simp
subgoal
unfolding nums1_def
apply (intro set_subseteqI)
using set_subseteqD[OF ‹set (a1 # a2 # a3 # as) ⊆ carrier R›]
by (simp add: a_def[symmetric] filter_even_nth length_filter_even ‹even n›)
done
have "c = NTT (μ [^] (2::nat)) nums2"
unfolding c_def
apply (intro 4(2)[of n nums1 nums2 b "n div 2" "k - 1"])
subgoal using ‹length (a1 # a2 # a3 # as) = n› by simp
subgoal unfolding nums1_def a_def by simp
subgoal unfolding nums2_def a_def by simp
subgoal using b_def .
subgoal using ‹μ ∈ carrier R› by simp
subgoal using ‹n div 2 = 2 ^ (k - 1)› .
subgoal using primitive_root_recursion ‹even n› ‹primitive_root n μ› by blast
subgoal using recursive_halfway .
subgoal unfolding nums2_def using length_filter_odd by simp
subgoal
unfolding nums2_def
apply (intro set_subseteqI)
using set_subseteqD[OF ‹set (a1 # a2 # a3 # as) ⊆ carrier R›]
by (simp add: a_def[symmetric] filter_odd_nth length_filter_odd)
done
show ?case
proof (intro nth_equalityI)
have[simp]: "length (FNTT μ (a1 # a2 # a3 # as)) = n"
using ‹length (a1 # a2 # a3 # as) = n› ‹n = 2 ^ k› length_FNTT[of "a1 # a2 # a3 # as"]
by blast
then show "length (FNTT μ (a1 # a2 # a3 # as)) = length (NTT μ (a1 # a2 # a3 # as))"
using NTT_length[of μ "a1 # a2 # a3 # as"] ‹length (a1 # a2 # a3 # as) = n› by argo
fix i
assume "i < length (FNTT μ (a1 # a2 # a3 # as))"
then have "i < n" by simp
have "FNTT μ a ! i = NTT μ a ! i"
proof (cases "i < n div 2")
case True
then have "NTT μ a ! i =
(NTT (μ [^] (2::nat)) [a ! i. i ← filter even [0..<n]]) ! i
⊕ μ [^] i ⊗ (NTT (μ [^] (2::nat)) [a ! i. i ← filter odd [0..<n]]) ! i"
apply (intro NTT_recursion_1)
using True ‹even n› ‹primitive_root n μ› ‹set (a1 # a2 # a3 # as) ⊆ carrier R› a_def
using ‹μ ∈ carrier R› ‹length (a1 # a2 # a3 # as) = n›
by simp_all
also have "... = (NTT (μ [^] (2::nat)) nums1) ! i
⊕ μ [^] i ⊗ (NTT (μ [^] (2::nat)) nums2) ! i"
unfolding nums1_def nums2_def by blast
also have "... = b ! i ⊕ μ [^] i ⊗ c ! i"
using ‹b = NTT (μ [^] 2) nums1› ‹c = NTT (μ [^] 2) nums2› by blast
also have "... = g ! i"
unfolding g_def using True by simp
also have "... = FNTT μ a ! i"
using ‹FNTT μ a = g @ h› ‹length g = n div 2› True
by (simp add: nth_append)
finally show ?thesis by simp
next
case False
then obtain j where j_def: "i = n div 2 + j" "j < n div 2"
using ‹i < n› ‹even n›
by (metis add_diff_inverse_nat add_self_div_2 div_plus_div_distrib_dvd_right nat_add_left_cancel_less)
have "NTT μ a ! (n div 2 + j) =
(NTT (μ [^] (2::nat)) [a ! i. i ← filter even [0..<n]]) ! j
⊖ μ [^] j ⊗ (NTT (μ [^] (2::nat)) [a ! i. i ← filter odd [0..<n]]) ! j"
apply (intro NTT_recursion_2)
subgoal using ‹even n› .
subgoal using ‹primitive_root n μ› .
subgoal using ‹length (a1 # a2 # a3 # as) = n› a_def by simp
subgoal using j_def by simp
subgoal using ‹set (a1 # a2 # a3 # as) ⊆ carrier R› a_def by simp
subgoal using ‹μ [^] (n div 2) = ⊖ 𝟭› .
done
also have "... = (NTT (μ [^] (2::nat)) nums1) ! j
⊖ μ [^] j ⊗ (NTT (μ [^] (2::nat)) nums2) ! j"
unfolding nums1_def nums2_def by blast
also have "... = b ! j ⊖ μ [^] j ⊗ c ! j"
using ‹b = NTT (μ [^] 2) nums1› ‹c = NTT (μ [^] 2) nums2› by blast
also have "... = h ! j"
unfolding g_def h_def using j_def by simp
also have "... = FNTT μ a ! i"
using ‹FNTT μ a = g @ h› ‹length g = n div 2› j_def
by (simp add: nth_append)
finally show ?thesis using j_def by simp
qed
then show "FNTT μ (a1 # a2 # a3 # as) ! i = NTT μ (a1 # a2 # a3 # as) ! i"
using a_def by simp
qed
qed
end
text "The following is copied from @{theory Number_Theoretic_Transform.Butterfly} and moved outside
of the @{locale butterfly} locale."
fun evens_odds where
"evens_odds _ [] = []"
| "evens_odds True (x#xs)= (x # evens_odds False xs)"
| "evens_odds False (x#xs) = evens_odds True xs"
lemma map_filter_shift: " map f (filter even [0..<Suc g]) =
f 0 # map (λ x. f (x+1)) (filter odd [0..<g])"
by (induction g) auto
lemma map_filter_shift': " map f (filter odd [0..<Suc g]) =
map (λ x. f (x+1)) (filter even [0..<g])"
by (induction g) auto
lemma filter_comprehension_evens_odds:
"[xs ! i. i ← filter even [0..<length xs]] = evens_odds True xs ∧
[xs ! i. i ← filter odd [0..<length xs]] = evens_odds False xs "
apply(induction xs)
apply simp
subgoal for x xs
apply rule
subgoal
apply(subst evens_odds.simps)
apply(rule trans[of _ "map ((!) (x # xs)) (filter even [0..<Suc (length xs)])"])
subgoal by simp
apply(rule trans[OF map_filter_shift[of "(!) (x # xs)" "length xs"]])
apply simp
done
apply(subst evens_odds.simps)
apply(rule trans[of _ "map ((!) (x # xs)) (filter odd [0..<Suc (length xs)])"])
subgoal by simp
apply(rule trans[OF map_filter_shift'[of "(!) (x # xs)" "length xs"]])
apply simp
done
done
lemma FNTT'_termination_aux[simp]: "length (evens_odds True xs) < Suc (length xs)"
"length (evens_odds False xs) < Suc (length xs)"
by (metis filter_comprehension_evens_odds le_imp_less_Suc length_filter_le length_map map_nth)+
text "(End of copy)"
lemma map_evens_odds: "map f (evens_odds x a) = evens_odds x (map f a)"
by (induction x a rule: evens_odds.induct) simp_all
lemma length_evens_odds:
"length (evens_odds True a) = (if even (length a) then length a div 2 else length a div 2 + 1)"
"length (evens_odds False a) = length a div 2"
using filter_comprehension_evens_odds[of a] length_filter_even[of "length a"] length_filter_odd[of "length a"]
using length_map by (metis, metis)
lemma set_evens_odds:
"set (evens_odds x a) ⊆ set a"
by (induction x a rule: evens_odds.induct) fastforce+
context cring begin
text "Similar to @{theory Number_Theoretic_Transform.Butterfly}, we give an abstract algorithm that can be
refined more easily to a verifiably efficient FNTT algorithm."
fun FNTT' :: "'a ⇒ 'a list ⇒ 'a list" where
"FNTT' μ [] = []"
| "FNTT' μ [x] = [x]"
| "FNTT' μ [x, y] = [x ⊕ y, x ⊖ y]"
| "FNTT' μ a = (let n = length a;
nums1 = evens_odds True a;
nums2 = evens_odds False a;
b = FNTT' (μ [^] (2::nat)) nums1;
c = FNTT' (μ [^] (2::nat)) nums2;
g = [b!i ⊕ (μ [^] i) ⊗ c!i. i ← [0..<(n div 2)]];
h = [b!i ⊖ (μ [^] i) ⊗ c!i. i ← [0..<(n div 2)]]
in g@h)"
lemma FNTT'_FNTT: "FNTT' μ xs = FNTT μ xs"
apply (induction μ xs rule: FNTT'.induct)
subgoal by (simp add: FNTT.simps)
subgoal by (simp add: FNTT.simps)
subgoal by (simp add: FNTT.simps)
subgoal for μ a1 a2 a3 as
unfolding FNTT.simps FNTT'.simps Let_def
using filter_comprehension_evens_odds[of "a1 # a2 # a3 # as"] by presburger
done
fun FNTT'' :: "'a ⇒ 'a list ⇒ 'a list" where
"FNTT'' μ [] = []"
| "FNTT'' μ [x] = [x]"
| "FNTT'' μ [x, y] = [x ⊕ y, x ⊖ y]"
| "FNTT'' μ a = (let n = length a;
nums1 = evens_odds True a;
nums2 = evens_odds False a;
b = FNTT'' (μ [^] (2::nat)) nums1;
c = FNTT'' (μ [^] (2::nat)) nums2;
g = map2 (⊕) b (map2 (⊗) [μ [^] i. i ← [0..<(n div 2)]] c);
h = map2 (λx y. x ⊖ y) b (map2 (⊗) [μ [^] i. i ← [0..<(n div 2)]] c)
in g@h)"
lemma FNTT''_FNTT':
assumes "length a = 2 ^ k"
shows "FNTT'' μ a = FNTT' μ a"
using assms
proof (induction μ a arbitrary: k rule: FNTT''.induct)
case (4 μ a1 a2 a3 as)
define a where "a = a1 # a2 # a3 # as"
define n where "n = length a"
then have "n = 2 ^ k" using 4 a_def by simp
then have "k ≥ 2" using n_def a_def by (cases "k = 0"; cases "k = 1") simp_all
then have "even n" using ‹n = 2 ^ k› by simp
have "n div 2 = 2 ^ (k - 1)" using ‹n = 2 ^ k› ‹k ≥ 2› by (simp add: power_diff)
then have "even (n div 2)" using ‹k ≥ 2› by simp
define nums1 where "nums1 = evens_odds True a"
then have "length nums1 = n div 2"
using length_filter_even[of n] filter_comprehension_evens_odds[of a] n_def ‹even n›
by (metis length_map)
define nums2 where "nums2 = evens_odds False a"
then have "length nums2 = n div 2"
using length_filter_odd[of n] filter_comprehension_evens_odds[of a] n_def
by (metis length_map)
define b where "b = FNTT' (μ [^] (2::nat)) nums1"
then have "length b = n div 2"
by (simp add: FNTT'_FNTT ‹length nums1 = n div 2› ‹n div 2 = 2 ^ (k - 1)›)
define c where "c = FNTT' (μ [^] (2::nat)) nums2"
then have "length c = n div 2"
by (simp add: FNTT'_FNTT ‹length nums2 = n div 2› ‹n div 2 = 2 ^ (k - 1)›)
define g1 where "g1 = [b!i ⊕ (μ [^] i) ⊗ c!i. i ← [0..<(n div 2)]]"
then have "length g1 = n div 2" by simp
define h1 where "h1 = [b!i ⊖ (μ [^] i) ⊗ c!i. i ← [0..<(n div 2)]]"
then have "length h1 = n div 2" by simp
define g2 where "g2 = map2 (⊕) b (map2 (⊗) [μ [^] i. i ← [0..<(n div 2)]] c)"
then have "length g2 = n div 2"
by (simp add: ‹length b = n div 2› ‹length c = n div 2›)
have "g1 = g2"
apply (intro nth_equalityI)
subgoal by (simp only: ‹length g1 = n div 2› ‹length g2 = n div 2›)
subgoal for i
by (simp add: g1_def g2_def ‹length b = n div 2› ‹length c = n div 2›)
done
define h2 where "h2 = map2 (λx y. x ⊖ y) b (map2 (⊗) [μ [^] i. i ← [0..<(n div 2)]] c)"
then have "length h2 = n div 2"
by (simp add: ‹length b = n div 2› ‹length c = n div 2›)
have "h1 = h2"
apply (intro nth_equalityI)
subgoal by (simp only: ‹length h1 = n div 2› ‹length h2 = n div 2›)
subgoal for i
by (simp add: h1_def h2_def ‹length b = n div 2› ‹length c = n div 2›)
done
have 1: "FNTT'' (μ [^] (2::nat)) nums1 = FNTT' (μ [^] (2::nat)) nums1"
apply (intro 4(1))
using a_def n_def ‹length (a1 # a2 # a3 # as) = 2 ^ k› ‹length nums1 = n div 2› ‹n div 2 = 2 ^ (k - 1)›
by (simp_all add: nums1_def)
have 2: "FNTT'' (μ [^] (2::nat)) nums2 = FNTT' (μ [^] (2::nat)) nums2"
apply (intro 4(2))
using a_def n_def ‹length (a1 # a2 # a3 # as) = 2 ^ k› ‹length nums2 = n div 2› ‹n div 2 = 2 ^ (k - 1)›
by (simp_all add: nums2_def)
show ?case
apply (simp only: FNTT'.simps FNTT''.simps)
apply (simp only: Let_def 1 2 a_def[symmetric] nums1_def[symmetric] nums2_def[symmetric]
b_def[symmetric] c_def[symmetric])
using ‹h1 = h2› ‹g1 = g2› n_def g1_def h1_def g2_def h2_def
by argo
qed simp_all
end
end