Theory NTT_Scheme

theory NTT_Scheme

imports Crypto_Scheme
  Mod_Ring_Numeral
  "Number_Theoretic_Transform.NTT"

begin
section ‹Number Theoretic Transform for Kyber›

lemma Poly_strip_while:
"Poly (strip_while ((=) 0) x) = Poly x"
by (metis Poly_coeffs coeffs_Poly)



locale kyber_ntt = kyber_spec _ _ _ _"TYPE('a :: qr_spec)" "TYPE('k::finite)" +
fixes type_a :: "('a :: qr_spec) itself" 
  and type_k :: "('k ::finite) itself" 
  and ω :: "('a::qr_spec) mod_ring"
  and μ :: "'a mod_ring"
  and ψ :: "'a mod_ring"
  and ψinv :: "'a mod_ring"
  and ninv :: "'a mod_ring"
  and mult_factor :: int
assumes
      omega_properties: "ω^n = 1" "ω  1" "( m. ω^m = 1  m0  m  n)"
  and mu_properties: "μ * ω = 1" "μ  1"
  and psi_properties: "ψ^2 = ω" "ψ^n = -1"
  and psi_psiinv: "ψ * ψinv = 1"
  and n_ninv: "(of_int_mod_ring n) * ninv = 1"
  and q_split: "q = mult_factor * n + 1"
begin
text ‹Some properties of the roots $\omega$ and $\psi$ and their inverses $\mu$ and $\psi_inv$.›
lemma mu_prop:
  "( m. μ^m = 1  m0  m  n)"
by (metis mu_properties(1) mult.commute mult.right_neutral 
  omega_properties(3) power_mult_distrib power_one)

lemma mu_prop':
assumes "μ^m' = 1" "m'0" shows "m'  n"
using mu_prop  assms by blast

lemma omega_prop':
assumes "ω^m' = 1" "m'0" shows "m'  n"
using omega_properties(3)  assms by blast

lemma psi_props:
shows "ψ^(2*n) = 1"
      "ψ^(n*(2*a+1)) = -1"
      "ψ1"
proof -
  show "ψ^(2*n) = 1" 
  by (simp add: omega_properties(1) power_mult psi_properties)
  show "ψ^(n*(2*a+1)) = -1" 
  by (metis (no_types, lifting) mult.commute mult_1 power_add 
    power_minus1_even power_mult psi_properties(2))
  show "ψ1"
  using omega_properties(2) one_power2 psi_properties(1) by blast
qed

lemma psi_inv_exp:
"ψ^i * ψinv ^i = 1"
using left_right_inverse_power psi_psiinv by blast

lemma inv_psi_exp:
"ψinv^i * ψ ^i = 1"
by (simp add: mult.commute psi_inv_exp)


lemma negative_psi:
assumes "i<j"
shows "ψ^j * ψinv ^i = ψ^(j-i)"
proof -
  have j: "ψ^j = ψ^(j-i) * ψ^i" using assms 
  by (metis add.commute le_add_diff_inverse nat_less_le power_add)
  show ?thesis unfolding j
  by (simp add: left_right_inverse_power psi_psiinv)
qed

lemma negative_psi':
assumes "ij"
shows "ψinv^i * ψ ^j = ψ^(j-i)"
proof -
  have j: "ψ^j = ψ^i * ψ^(j-i)" using assms 
  by (metis le_add_diff_inverse power_add)
  show ?thesis unfolding j mult.assoc[symmetric] using inv_psi_exp[of i] by simp
qed

lemma psiinv_prop:
shows "ψinv^2 = μ"
proof -
  show "ψinv^2 = μ"
  by (metis (mono_tags, lifting) mu_properties(1) mult.commute 
    mult_cancel_right mult_cancel_right2 power_mult_distrib psi_properties(1) psi_psiinv)
qed

lemma n_ninv':
"ninv * (of_int_mod_ring n) = 1"
using n_ninv 
by (simp add: mult.commute)


text ‹The map2› function for polynomials.›
definition map2_poly :: "('a mod_ring  'a mod_ring  'a mod_ring)  
    'a mod_ring poly  'a mod_ring poly  'a mod_ring poly" where 
"map2_poly f p1 p2 = 
  Poly (map2 f (map (poly.coeff p1) [0..<nat n]) (map (poly.coeff p2) [0..<nat n]))"

text ‹Additional lemmas on polynomials.›
lemma Poly_map_coeff:
assumes "degree f < num"
shows "Poly (map (poly.coeff (f)) [0..<num]) = f"
proof (subst poly_eq_iff, safe)
  fix j
  show "poly.coeff (Poly (map (poly.coeff f) [0..<num])) j = poly.coeff f j"
  proof (cases "j<num")
    case True
    then show ?thesis
    unfolding coeff_Poly by (subst nth_default_nth, auto)
  next
    case False
    then have "j>degree f" using assms by auto
    then show ?thesis unfolding coeff_Poly using False
    by (simp add: coeff_eq_0 nth_default_beyond)
  qed
qed

lemma map_upto_n_mod: 
"(Poly (map f [0..<n]) mod qr_poly) = (Poly (map f [0..<n]) :: 'a mod_ring poly)"
proof -
  have "degree (Poly (map f [0::nat..<n])) < n" 
  by (metis Suc_pred' deg_Poly' deg_qr_n deg_qr_pos degree_0 diff_zero le_imp_less_Suc 
    length_map length_upt nat_int)
  then show ?thesis
  by (subst deg_mod_qr_poly, use deg_qr_n in auto)
qed


lemma coeff_of_qr_zero:
assumes "in"
shows "poly.coeff (of_qr (f :: 'a qr)) i = 0"
proof -
  have "degree (of_qr f) < i"
    using deg_of_qr deg_qr_n assms order_less_le_trans by auto
  then show ?thesis by (subst coeff_eq_0, auto)
qed

text ‹Definition of NTT on polynomials. 
  In contrast to the ordinary NTT, we use a different exponent on the root of unity $\psi$.›
definition ntt_coeff_poly :: "'a qr  nat  'a mod_ring" where
  "ntt_coeff_poly g i = (j{0..<n}. (poly.coeff (of_qr g) j) * ψ^(j * (2*i+1)))"

definition ntt_coeffs :: "'a qr  'a mod_ring list" where
  "ntt_coeffs g = map (ntt_coeff_poly g) [0..<n]"

definition ntt_poly :: "'a qr  'a qr" where
"ntt_poly g = to_qr (Poly (ntt_coeffs g))"

text ‹Definition of inverse NTT on polynomials.
  The inverse transformed is already scaled such that it is the true inverse of the NTT.›
definition inv_ntt_coeff_poly :: "'a qr  nat  'a mod_ring" where
  "inv_ntt_coeff_poly g' i = ninv * 
    (j{0..<n}. (poly.coeff (of_qr g') j) * ψinv^(i*(2*j+1)))"

definition inv_ntt_coeffs :: "'a qr  'a mod_ring list" where
  "inv_ntt_coeffs g' = map (inv_ntt_coeff_poly g') [0..<n]"

definition inv_ntt_poly :: "'a qr  'a qr" where
  "inv_ntt_poly g = to_qr (Poly (inv_ntt_coeffs g))"


text ‹Kyber is indeed in the NTT-domain with root of unity $\omega$.
Note, that our ntt on polynomials uses a slightly different exponent.
The root of unity $\omega$ defines an alternative NTT in Kyber.›
text ‹Have $7681 = 30*256 + 1$ and $3329 = 13 * 256 + 1$.›
interpretation kyber_ntt: ntt "nat q" "nat n" "nat mult_factor" ω μ
proof (unfold_locales, goal_cases)
  case 2
  then show ?case  using q_gt_two by linarith
next
  case 3
  then show ?case 
    by (smt (verit, del_insts) int_nat_eq mult.commute nat_int_add 
    nat_mult_distrib of_nat_1 q_gt_two q_split zadd_int_left)
next
  case 4
  then show ?case using n_gt_1 by linarith
qed (use CARD_a nat_int in auto simp add: omega_properties mu_properties)


text ‹Multiplication in of polynomials in $R_q$ is a negacyclic convolution 
(because we factored by $x^n + 1$, thus $x^n\equiv -1 \mod x^n+1$).
This is the reason why we needed to adapt the exponent in the NTT.›
definition qr_mult_coeffs :: "'a qr  'a qr  'a qr" (infixl "" 70) where
  "qr_mult_coeffs f g = to_qr (map2_poly (*) (of_qr f) (of_qr g))"
 

text ‹The definition of the exponentiation ^› only allows for natural exponents, 
thus we need to cheat a bit by introducing conv_sign x›$\equiv (-1)^x$.›
definition conv_sign :: "int  'a mod_ring" where
"conv_sign x = (if x mod 2 = 0 then 1 else -1)"

text ‹The definition of the negacyclic convolution.›
definition negacycl_conv :: "'a qr  'a qr  'a qr" where
"negacycl_conv f g = 
  to_qr (Poly (map 
  (λi. j<n. conv_sign ((int i - int j) div n) *  
    poly.coeff (of_qr f) j * poly.coeff (of_qr g) (nat ((int i - int j) mod n)))
  [0..<n]))"

lemma negacycl_conv_mod_qr_poly:
"of_qr (negacycl_conv f g) mod qr_poly = of_qr (negacycl_conv f g)"
unfolding negacycl_conv_def of_qr_to_qr by auto


text ‹Representation of f modulo qr_poly›.›
lemma mod_div_qr_poly:
"(f :: 'a mod_ring poly) = (f mod qr_poly) + qr_poly * (f div qr_poly)"
by simp

text take_deg› returns the first $n$ coefficients of a polynomial.›
definition take_deg :: "nat  ('b::zero) poly  'b poly"  where
"take_deg = (λn. λf. Poly (take n (coeffs f)))"

text drop_deg› returns the coefficients of a polynomial strarting from the $n$-th coefficient.›
definition drop_deg :: "nat  ('b::zero) poly  'b poly"  where
"drop_deg = (λn. λf. Poly (drop n (coeffs f)))"

text take_deg› and drop_deg› return the modulo and divisor representants.›
lemma take_deg_monom_drop_deg:
assumes "degree f  n"
shows "(f :: 'a mod_ring poly) = take_deg n f + (Polynomial.monom 1 n) * drop_deg n f"
proof -
  have "min (length (coeffs f)) n = n" using assms 
  by (metis bot_nat_0.not_eq_extremum degree_0 le_imp_less_Suc 
    length_coeffs_degree min.absorb1 min.absorb4)
  then show ?thesis
    unfolding take_deg_def drop_deg_def 
    apply (subst Poly_coeffs[of f,symmetric]) 
    apply (subst append_take_drop_id[of n "coeffs f", symmetric])
    apply (subst Poly_append)
    by (auto)
qed

lemma split_mod_qr_poly:
assumes "degree f  n"
shows "(f :: 'a mod_ring poly) = take_deg n f - drop_deg n f + qr_poly * drop_deg n f"
proof -
  have "(Polynomial.monom 1 n + 1) * drop_deg n f = 
    Polynomial.monom 1 n *  drop_deg n f + drop_deg n f"
    by (simp add: mult_poly_add_left)
  then show ?thesis 
    apply (subst take_deg_monom_drop_deg[OF assms])
    apply (unfold qr_poly_def qr_poly'_eq of_int_hom.map_poly_hom_add) 
    by auto
qed

text ‹Lemmas on the degrees of take_deg› and drop_deg›.›
lemma degree_drop_n:
"degree (drop_deg n f) = degree f - n"
unfolding drop_deg_def
by (simp add: degree_eq_length_coeffs)

lemma degree_drop_2n:
assumes "degree f < 2*n"
shows "degree (drop_deg n f) < n"
using assms unfolding degree_drop_n by auto

lemma degree_take_n:
"degree (take_deg n f) < n"
unfolding take_deg_def 
by (metis coeff_Poly_eq deg_qr_n deg_qr_pos degree_0 leading_coeff_0_iff 
  nth_default_take of_nat_eq_iff)

lemma deg_mult_of_qr:
"degree (of_qr (f ::'a qr) * of_qr g) < 2 * n"
by (metis add_less_mono deg_of_qr deg_qr_n degree_0 degree_mult_eq 
  mult_2 mult_eq_0_iff nat_int_comparison(1))

text ‹Representation of a polynomial modulo qr_poly› using take_deg› and drop_deg›.›
lemma mod_qr_poly:
assumes "degree f  n" "degree f < 2*n"
shows "(f :: 'a mod_ring poly) mod qr_poly = take_deg n f - drop_deg n f "
proof -
  have "degree (take_deg n f - drop_deg n f) < deg_qr TYPE('a)" 
    using degree_diff_le_max[of "take_deg n f" "drop_deg n f"]
     degree_drop_2n[OF assms(2)]  degree_take_n
     by (metis deg_qr_n degree_diff_less nat_int)
  then have "(take_deg n f - drop_deg n f) mod qr_poly =
    take_deg n f - drop_deg n f" by (subst deg_mod_qr_poly, auto)
  then show ?thesis
    by (subst split_mod_qr_poly[OF assms(1)], auto)
qed

text ‹Coefficients of take_deg›, drop_deg› and the modulo representant.›
lemma coeff_take_deg:
assumes "i<n"
shows "poly.coeff (take_deg n f) i = poly.coeff (f::'a mod_ring poly) i"
using assms unfolding take_deg_def 
by (simp add: nth_default_coeffs_eq nth_default_take)

lemma coeff_drop_deg:
assumes "i<n"
shows "poly.coeff (drop_deg n f) i = poly.coeff (f::'a mod_ring poly) (i+n)"
using assms unfolding drop_deg_def 
by (simp add: nth_default_coeffs_eq nth_default_drop)

lemma coeff_mod_qr_poly:
assumes "degree (f::'a mod_ring poly)  n" "degree f < 2*n" "i<n"
shows "poly.coeff (f mod qr_poly) i = poly.coeff f i - poly.coeff f (i+n)"
apply (subst mod_qr_poly[OF assms(1) assms(2)]) 
apply (subst coeff_diff)
apply (unfold coeff_take_deg[OF assms(3)] coeff_drop_deg[OF assms(3)])
by auto

text ‹More lemmas on the splitting of sums.›

lemma sum_leq_split:
"(iai+n. f ia) = (ia<n. f ia) + (ia{n..i+n}. f ia)"
proof -
  have *: "{..i + n} - {..<n} = {n..i + n}" 
  by (metis atLeastLessThanSuc_atLeastAtMost lessThan_Suc_atMost lessThan_minus_lessThan) 
  show ?thesis 
  by (subst sum.subset_diff[of "{..<n}" "{..i+n}" f]) (auto simp add: * add.commute)
qed

lemma less_diff:
assumes "l1<l2"
shows "{..<l2} - {..l1} = {l1<..<l2::nat}"
by (metis atLeastSucLessThan_greaterThanLessThan lessThan_Suc_atMost lessThan_minus_lessThan)

lemma sum_less_split:
assumes "l1<(l2::nat)"
shows "sum f {..<l2} = sum f {..l1} + sum f {l1<..<l2}"
by (subst sum.subset_diff[of "{..l1}" "{..<l2}" f]) 
   (auto simp add: assms add.commute order_le_less_trans less_diff[OF assms]) 

lemma div_minus_1:
assumes "(x::int)  {-b..<0}"
shows "x div b = -1" 
using assms 
by (smt (verit, ccfv_SIG) atLeastLessThan_iff div_minus_minus div_pos_neg_trivial)

text ‹A coefficient of polynomial multiplication is a coefficient of the negacyclic convolution.›
lemma coeff_conv:
fixes f :: "'a qr"
assumes "i<n" 
shows "poly.coeff ((of_qr f) * (of_qr g) mod qr_poly) i = 
    (j<n. conv_sign ((int i - int j) div n) * 
      poly.coeff (of_qr f) j * poly.coeff (of_qr g) (nat ((int i - int j) mod n)))"
proof (cases "degree (of_qr f) + degree (of_qr g)<n")
  case True
  then have True':"degree ((of_qr f) * (of_qr g)) <n" using degree_mult_le 
  using order_le_less_trans by blast
  have "poly.coeff ((of_qr f) * (of_qr g) mod qr_poly) i = 
    poly.coeff ((of_qr f) * (of_qr g)) i" using True'
  by (metis deg_qr_n degree_qr_poly mod_poly_less nat_int)
  also have " = (iai. poly.coeff (of_qr f) ia * poly.coeff (of_qr g) (i - ia))" 
    unfolding coeff_mult by auto
  also have " = (iai. conv_sign ((int i - int ia) div int n) *
    poly.coeff (of_qr f) ia * 
    poly.coeff (of_qr g) (nat ((int i - int ia) mod n)))" 
  proof -
    have "i-ia = nat ((int i - int ia) mod n)" if "ia  i" for ia
    using assms that by force
    moreover have "conv_sign ((int i - int ia) div int n) = 1" 
      if "ia  i" for ia unfolding conv_sign_def 
      using assms that by force
    ultimately show ?thesis by auto
  qed
  also have " = (ia<n. conv_sign ((int i - int ia) div int n) *
    poly.coeff (of_qr f) ia * 
    poly.coeff (of_qr g) (nat ((int i - int ia) mod n)))"
  proof -
    have "poly.coeff (of_qr f) ia *
       poly.coeff (of_qr g) (nat ((int i - int ia) mod int n)) = 0" 
      if "ia  {i<..<n}" for ia 
    proof (subst mult_eq_0_iff, safe, goal_cases)
      case 1
      have deg_g: "nat ((int i - int ia) mod int n)  degree (of_qr g)" 
        using le_degree[OF 1] by auto
      have "ia > degree (of_qr f)" 
      proof (rule ccontr)
        assume "¬ degree (of_qr f) < ia"
        then have as: "ia  degree (of_qr f)" by auto
        then have ni: "nat ((int i - int ia) mod int n) + ia = n + i" 
          using that  by (smt (verit, ccfv_threshold) True deg_g 
          greaterThanLessThan_iff int_nat_eq less_imp_of_nat_less 
          mod_add_self1 mod_pos_pos_trivial of_nat_0_le_iff of_nat_add 
          of_nat_mono)
        have "n + i  degree (of_qr f) + degree (of_qr g)"
          unfolding ni[symmetric] using as deg_g by auto
        then show False using True by auto
      qed
      then show ?case 
      using coeff_eq_0 by blast
    qed
    then show ?thesis
      by (subst sum_less_split[OF i<n]) (simp add: sum.neutral)
  qed
  finally show ?thesis by blast
next
  case False
  then have *: "degree (of_qr f * of_qr g)  n"
  by (metis add.right_neutral add_0 deg_of_qr deg_qr_n degree_0 degree_mult_eq 
      linorder_not_le nat_int)
  have "poly.coeff (of_qr f * of_qr g) (i + n) = (ia<n.
        poly.coeff (of_qr f) ia * poly.coeff (of_qr g) (i + n - ia))" 
    unfolding coeff_mult using coeff_of_qr_zero 
    by (subst sum_leq_split[of _ i]) (auto)
  also have " = (ia{i<..<n}.
        poly.coeff (of_qr f) ia * poly.coeff (of_qr g) (i + n - ia))" 
    using coeff_of_qr_zero by (subst sum_less_split[OF i<n]) auto
  also have " = (ia{i<..<n}.
        poly.coeff (of_qr f) ia * poly.coeff (of_qr g) (nat ((int i - ia) mod n)))"
  proof -
    have "int i - int ia + int n {0..<n}" if "ia  {i<..<n}" for ia using assms that by auto
    then have "int i + n-ia = (int i-ia) mod n" if "ia{i<..<n}" for ia
     using i<n that by (smt (verit, best) mod_add_self1 mod_rangeE)
    then have "i+n-ia = nat ((int i - ia) mod n)" if "ia{i<..<n}" for ia
    by (metis int_minus nat_int of_nat_add that)
    then show ?thesis by fastforce
  qed
  also have " = - (ia{i<..<n}. conv_sign ((int i - int ia) div n) *
        poly.coeff (of_qr f) ia * poly.coeff (of_qr g) (nat ((int i - ia) mod n)))"
  proof -
    have negative:"(int i - int ia)  {-n..<0}" if "ia {i<..<n}" for ia 
      using that by auto
    have "(int i - int ia) div n = -1" if "ia {i<..<n}" for ia
      using div_minus_1[OF negative[OF that]] .
    then have "conv_sign ((int i - int ia) div n) = -1" if "ia {i<..<n}" for ia 
      unfolding conv_sign_def using that by auto
    then have *: "(ia{i<..<n}. foo ia) =
      (x{i<..<n}. - (conv_sign ((int i - int x) div int n) * foo x))" 
    for foo by auto
    show ?thesis 
    by (subst sum_negf[symmetric], subst *) (simp add: mult.assoc) 
  qed
  finally have i_n: "poly.coeff (of_qr f * of_qr g) (i + n) = 
    - (ia{i<..<n}. conv_sign ((int i - int ia) div n) *
        poly.coeff (of_qr f) ia * poly.coeff (of_qr g) (nat ((int i - ia) mod n)))"
    by blast
  have i_n': "poly.coeff (of_qr f * of_qr g) i = 
      (iai. conv_sign ((int i - int ia) div n) *
        poly.coeff (of_qr f) ia * poly.coeff (of_qr g) (nat ((int i - ia) mod n)))"
  proof -
    have "conv_sign ((int i - int ia) div n) = 1" if "ia i" for ia 
      using that assms conv_sign_def by force
    moreover have "i-ia {0..<n}" if "ia i" for ia using that assms by auto
    then have "i-ia = (nat ((int i - ia) mod n))" if "iai" for ia
      using assms that by force
    ultimately show ?thesis unfolding coeff_mult 
    using assms less_imp_diff_less mod_less by auto
  qed  
  have calc: "poly.coeff (of_qr f * of_qr g) i - poly.coeff (of_qr f * of_qr g) (i + n) = 
    (ia<n. conv_sign ((int i - int ia) div n) *
        poly.coeff (of_qr f) ia * poly.coeff (of_qr g) (nat ((int i - ia) mod n)))" 
    by (subst i_n, subst i_n') 
       (metis (no_types, lifting) assms diff_minus_eq_add sum_less_split)
  show ?thesis unfolding coeff_mod_qr_poly[OF * deg_mult_of_qr assms] calc
    by auto
qed

text ‹Polynomial multiplication in $R_q$ is the negacyclic convolution.›

lemma mult_negacycl:
"f * g = negacycl_conv f g"
proof -
  have f_times_g: "f * g = to_qr ((of_qr f) * (of_qr g) mod qr_poly)"
    by (metis of_qr_mult to_qr_of_qr)
  have conv: "poly.coeff ((of_qr f) * (of_qr g) mod qr_poly) i = 
    (j<n. conv_sign ((int i - int j) div n) * 
    poly.coeff (of_qr f) j * poly.coeff (of_qr g) (nat ((int i - j) mod n)))" 
  if "i<n" for i using coeff_conv[OF that] by auto
  have "poly.coeff (of_qr (f*g)) i = 
    poly.coeff (of_qr (negacycl_conv f g)) i" for i 
  proof (cases "i<n")
    case True
    then show ?thesis unfolding negacycl_conv_def f_times_g of_qr_to_qr 
      map_upto_n_mod mod_mod_trivial coeff_Poly_eq 
      using conv[OF True] by (subst nth_default_nth[of i], auto)
  next
    case False
    then show ?thesis using coeff_of_qr_zero[of i "f*g"] 
      coeff_of_qr_zero[of i "negacycl_conv f g"] by auto 
  qed 
  then show ?thesis 
    using poly_eq_iff [of "of_qr (f * g)" "of_qr (negacycl_conv f g)"]
    by (metis to_qr_of_qr)
qed

text ‹Additional lemmas on ntt_coeffs›.›

lemma length_ntt_coeffs:
"length (ntt_coeffs f)  n"
unfolding ntt_coeffs_def by auto

lemma degree_Poly_ntt_coeffs:
"degree (Poly (ntt_coeffs f)) < n"
using length_ntt_coeffs 
by (smt (verit) deg_Poly' degree_0 degree_take_n diff_diff_cancel 
  diff_is_0_eq le_neq_implies_less less_nat_zero_code nat_le_linear 
  order.strict_trans1 power_0_left power_eq_0_iff)

lemma Poly_ntt_coeffs_mod_qr_poly:
  "Poly (ntt_coeffs f) mod qr_poly = Poly (ntt_coeffs f)"
using map_upto_n_mod ntt_coeffs_def by presburger


lemma nth_default_map:
assumes "i<na"
shows "nth_default x (map f [0..<na]) i = f i"
using assms 
by (simp add: nth_default_nth)


lemma nth_coeffs_negacycl:
assumes "j<n"
shows "poly.coeff (of_qr (negacycl_conv f g)) j =
  (i<n. conv_sign ((int j - int i) div int n) * poly.coeff (of_qr f) i *
   poly.coeff (of_qr g) (nat ((int j - int i) mod int n)))"
unfolding negacycl_conv_def of_qr_to_qr map_upto_n_mod coeff_Poly_eq 
 nth_default_map[OF assms] by auto


text ‹Writing the convolution sign as a conditional if statement.›
lemma conv_sign_if:
assumes "x<n" "y<n"
shows "conv_sign ((int x - int y) div int n) = (if int x - int y < 0 then -1 else 1)"
unfolding conv_sign_def 
proof (split if_splits, safe, goal_cases)
  case 1
  then have "int x - int y  {-n..<0}" using assms by simp
  then have "(int x - int y) div int n mod 2 = 1"
    using div_minus_1 by presburger
  then show ?case by auto
next
  case 2
  then have "(int x - int y) div int n mod 2 = 0"
  using assms(1) by force
  then show ?case by auto
qed

text ‹The convolution theorem on coefficients.›

lemma ntt_coeff_poly_mult:
assumes "l<n"
shows "ntt_coeff_poly (f*g) l = ntt_coeff_poly f l * ntt_coeff_poly g l"
proof -
  define f1 where "f1 = (λx. λ y.
        conv_sign ((int x - int y) div int n) *
        poly.coeff (of_qr f) y *
        poly.coeff (of_qr g) (nat ((int x - int y) mod int n)))"
  have "ntt_coeff_poly (f*g) l = (j = 0..<n. poly.coeff (of_qr (negacycl_conv f g)) j *
        ψ^(j*(2*l+1)))" unfolding ntt_coeff_poly_def mult_negacycl by auto
  also have " = (j=0..<n. (i<n. f1 j i * ψ^(j*(2*l+1))))"
  proof (subst sum.cong[of "{0..<n}" "{0..<n}" 
      "(λj. poly.coeff (of_qr (negacycl_conv f g)) j * ψ^(j*(2*l+1)))"
      "(λj. (i<n. f1 j i * ψ^(j*(2*l+1))))"], 
      goal_cases)
    case (2 j)
    then have "j<n" by auto
    have "poly.coeff (of_qr (negacycl_conv f g)) j * ψ ^ (j * (2 * l + 1)) = 
      (na<n. (conv_sign ((int j - int na) div int n) *
         poly.coeff (of_qr f) na * poly.coeff (of_qr g) (nat ((int j - int na) mod int n))) *
        ψ ^ (j * (2 * l + 1)))"
      apply (subst nth_coeffs_negacycl[OF j<n])
      apply (subst sum_distrib_right)
      by auto
    also have " = (na<n. f1 j na * ψ ^ (j * (2 * l + 1)))"
      unfolding f1_def by auto
    finally show ?case by blast
  qed auto
  also have " = (i<n. j<n. f1 j i * ψ ^ (j * (2 * l + 1))) "
    by (subst atLeast0LessThan, subst sum.swap, auto) 
  also have " = (i<n. poly.coeff (of_qr f) i * ψ ^ (i * (2 * l + 1)) * 
    (j<n. poly.coeff (of_qr g) (nat ((int j - int i) mod int n)) *
        (if int j - int i < 0 then -1 else 1) *
         ψinv ^ (i * (2 * l + 1)) * ψ ^ (j * (2 * l + 1))))"
  proof (subst sum.cong[of "{..<n}" "{..<n}" "(λi. (j<n. f1 j i * ψ ^ (j * (2 * l + 1))))"
    "(λi. poly.coeff (of_qr f) i * ψ ^ (i * (2 * l + 1)) * 
        (j<n. poly.coeff (of_qr g) (nat ((int j - int i) mod int n)) *
        (if int j - int i < 0 then -1 else 1) *
         ψinv ^ (i * (2 * l + 1)) * ψ ^ (j * (2 * l + 1))))"], goal_cases)
    case (2 i)
    then show ?case 
    proof (subst sum_distrib_left, subst sum.cong[of "{..<n}" "{..<n}" 
      "(λj. f1 j i * ψ ^ (j * (2 * l + 1)))"
      "(λj. poly.coeff (of_qr f) i * ψ ^ (i * (2 * l + 1)) *
        (poly.coeff (of_qr g) (nat ((int j - int i) mod int n)) *
        (if int j - int i < 0 then - 1 else 1) *
        ψinv ^ (i * (2 * l + 1)) * ψ ^ (j * (2 * l + 1))))"], goal_cases)
      case (2 j)
      then have *: "conv_sign ((int j - int i) div int n) = 
        (if int j - int i < 0 then - 1 else 1)" using conv_sign_if by auto
      have "f1 j i * ψ ^ (j * (2 * l + 1)) = 
        ψ ^ (i * (2 * l + 1)) * f1 j i * ψinv ^ (i * (2 * l + 1)) * ψ ^ (j * (2 * l + 1))"
        using psi_psiinv
        by (simp add: left_right_inverse_power)
      also have " = poly.coeff (of_qr f) i * ψ ^ (i * (2 * l + 1)) *
      (poly.coeff (of_qr g) (nat ((int j - int i) mod int n)) *
      (if int j - int i < 0 then - 1 else 1) * ψinv ^ (i * (2 * l + 1)) * ψ ^ (j * (2 * l + 1)))"
        unfolding f1_def mult.assoc  
        by (simp add: "*" mult.left_commute)
      finally show ?case  by blast
    qed auto
  qed auto
  also have " = (i<n. poly.coeff (of_qr f) i * ψ ^ (i * (2 * l + 1)) * 
    (x<n. poly.coeff (of_qr g) x * ψ ^ (x * (2 * l + 1))))"
  proof -
    define x' where "x' = (λj i. nat ((int j - int i) mod int n))"
    let ?if_inv = "(λi j. (if int j - int i < 0 then - 1 else 1) * 
      ψinv ^ (i * (2 * l + 1)) * ψ ^ (j * (2 * l + 1)))"
    have rewrite: "(if int j - int i < 0 then - 1 else 1) * 
      ψinv ^ (i * (2 * l + 1)) * ψ ^ (j * (2 * l + 1)) = 
      ψ ^ ((x' j i) * (2 * l + 1))"  if "i<n" "j<n" for i j
    proof (cases "int j - int i <0")
      case True
      have lt: "i * (2 * l + 1) < n * (2 * l + 1)" using i<n 
      by (metis One_nat_def add_gr_0 lessI mult_less_mono1)
      have "?if_inv i j = (-1) * ψinv ^ (i * (2 * l + 1)) * ψ ^ (j * (2 * l + 1))"
        using True by (auto split: if_splits) 
      also have " = ψ^((n-i+j)* (2 * l + 1))" unfolding psi_props(2)[of l, symmetric] 
         negative_psi[OF lt]
         by (metis comm_semiring_class.distrib diff_mult_distrib power_add)
      also have " = ψ ^ ((x' j i) * (2 * l + 1))" unfolding x'_def
      by (smt (verit, best) True mod_add_self2 mod_pos_pos_trivial nat_int_add 
        nat_less_le of_nat_0_le_iff of_nat_diff that(1))
      finally show ?thesis by blast
    next
      case False
      then have "ij" by auto 
      have lt: "i * (2 * l + 1)  j * (2 * l + 1)" using ij 
      using add_gr_0 less_one mult_less_mono1 
      using mult_le_cancel2 by presburger
      have "?if_inv i j = ψinv ^ (i * (2 * l + 1)) * ψ ^ (j * (2 * l + 1))"
        using False by (auto split: if_splits) 
      also have " = ψ^((j-i)* (2 * l + 1))" 
        using negative_psi'[OF lt]  diff_mult_distrib by presburger
      also have " = ψ ^ ((x' j i) * (2 * l + 1))" unfolding x'_def
        by (metis i  j less_imp_diff_less mod_pos_pos_trivial nat_int 
          of_nat_0_le_iff of_nat_diff of_nat_less_iff that(2))
      finally show ?thesis by blast
    qed
    then have "(j<n. poly.coeff (of_qr g) (nat ((int j - int i) mod int n)) *
      (if int j - int i < 0 then - 1 else 1) * 
      ψinv ^ (i * (2 * l + 1)) * ψ ^ (j * (2 * l + 1))) = 
      (x<n. poly.coeff (of_qr g) x * ψ ^ (x * (2 * l + 1)))"
      (is "(j<n. ?left j i) = _") if "i<n" for i
    proof -
      have *: "(j<n. ?left j i) = 
        (j<n. poly.coeff (of_qr g) (x' j i) * ψ ^ ((x' j i) * (2 * l + 1)))"
        using rewrite[OF that] x'_def
        by (smt (verit, ccfv_SIG) lessThan_iff mult.assoc sum.cong)
      have eq: "(λj. x' j i) ` {..<n} = {..<n}" unfolding x'_def 
      proof (safe, goal_cases)
        case (1 _ j)
        with n_gt_zero show ?case
          by (simp add: nat_less_iff)
      next
        case (2 x)
        define j where "j = (x+i) mod n"
        have "j{..<n}" 
        by (metis j_def lessThan_iff mod_less_divisor n_gt_zero of_nat_0_less_iff)
        moreover have "x = nat ((int j - int i) mod int n)" unfolding j_def
        by (simp add: "2" mod_diff_cong zmod_int) 
        ultimately show ?case by auto
      qed
      have inj: "inj_on (λj. x' j i) {..<n}" unfolding x'_def inj_on_def 
      proof (safe, goal_cases)
        case (1 x y)
        then have "((int x - int i) mod int n) = ((int y - int i) mod int n)"
          by (meson eq_nat_nat_iff mod_int_pos_iff n_gt_zero)
        then have "int x mod int n = int y mod int n"
          by (smt (z3) mod_diff_cong)
        then show ?case using 1 by auto
      qed
      show ?thesis unfolding * by (subst sum.reindex_cong[OF inj eq[symmetric], 
        of "(λx. poly.coeff (of_qr g) x * ψ ^ (x * (2 * l + 1)))" 
        "(λj. poly.coeff (of_qr g) (x' j i) * ψ ^ (x' j i * (2 * l + 1)))"], auto)
    qed
    then show ?thesis by force
  qed
  also have " = (i<n. poly.coeff (of_qr f) i * ψ ^ (i * (2 * l + 1))) * 
    (x'<n. poly.coeff (of_qr g) x' * ψ ^ (x' * (2 * l + 1)))"
    unfolding sum_distrib_right by auto
  also have " = ntt_coeff_poly f l * ntt_coeff_poly g l"
    unfolding ntt_coeff_poly_def atLeast0LessThan by auto
  finally show ?thesis by blast
qed


lemma ntt_coeffs_mult: 
assumes "i<n"
shows "ntt_coeffs (f*g) !i = ntt_coeffs f! i * ntt_coeffs g ! i"
unfolding ntt_coeffs_def using ntt_coeff_poly_mult[OF assms]
by (simp add: assms)

text ‹Steps towards the convolution theorem.›

lemma nth_default_ntt_coeff_mult:
"nth_default 0 (ntt_coeffs (f * g)) i =
 nth_default 0 (map2 (*) 
  (map (poly.coeff (Poly (ntt_coeffs f))) [0..<nat (int n)])
  (map (poly.coeff (Poly (ntt_coeffs g))) [0..<nat (int n)])) i"
(is "?left i = ?right i")
proof (cases "i{0..<n}")
  case True
  then have l: "?left i = ntt_coeffs (f * g) ! i"
    by (simp add: nth_default_nth ntt_coeffs_def)
  have *: "?right i = (poly.coeff (Poly (ntt_coeffs f)) i) * (poly.coeff (Poly (ntt_coeffs g)) i)"
    using True 
    by (metis (no_types, lifting) coeff_Poly_eq diff_zero length_map length_upt 
      map_nth_default mult_hom.hom_zero nat_int nth_default_map2 ntt_coeffs_def)
  then have r: "?right i = (ntt_coeffs f) ! i * (ntt_coeffs g) ! i"
    unfolding * unfolding coeff_Poly using nth_default_nth
    by (metis True atLeastLessThan_iff diff_zero length_map length_upt ntt_coeffs_def)
  show ?thesis unfolding l r using ntt_coeffs_mult True by auto
next
  case False
  then have "?left i = 0" unfolding ntt_coeffs_def
    by (simp add: nth_default_beyond)
  moreover have "?right i = 0" using False
  by (simp add: nth_default_def)
  ultimately show ?thesis by presburger
qed

lemma Poly_ntt_coeffs_mult:
"Poly (ntt_coeffs (f * g)) = Poly (map2 (*) 
  (map (poly.coeff (Poly (ntt_coeffs f))) [0..<nat (int n)])
  (map (poly.coeff (Poly (ntt_coeffs g))) [0..<nat (int n)]))"
apply (intro poly_eqI) apply (unfold coeff_Poly)
using nth_default_ntt_coeff_mult[of f g] by auto


text ‹Convolution theorem for NTT›
lemma ntt_mult:
"ntt_poly (f * g) = qr_mult_coeffs (ntt_poly f) (ntt_poly g)"
proof -
  have "Poly (ntt_coeffs (f*g)) mod qr_poly = 
    Poly (ntt_coeffs (f*g))"
    using Poly_ntt_coeffs_mod_qr_poly by force
  also have " = Poly (coeffs (map2_poly (*) (Poly (ntt_coeffs f)) (Poly (ntt_coeffs g))))"
    unfolding map2_poly_def coeffs_Poly Poly_strip_while 
    using Poly_ntt_coeffs_mult by auto
  also have " = (map2_poly (*) (of_qr (to_qr (Poly (ntt_coeffs f))))
         (of_qr (to_qr (Poly (ntt_coeffs g))))) mod qr_poly"
  unfolding of_qr_to_qr map_poly_def Poly_ntt_coeffs_mod_qr_poly
  by (metis Poly_coeffs Poly_ntt_coeffs_mod_qr_poly calculation)
  finally have "[Poly (ntt_coeffs (f * g)) = 
       (map2_poly (*) (of_qr (to_qr (Poly (ntt_coeffs f))))
         (of_qr (to_qr (Poly (ntt_coeffs g)))))] (mod qr_poly)"
  using cong_def by blast
  then have "to_qr (Poly (ntt_coeffs (f * g))) =
    to_qr (map2_poly (*) (of_qr (to_qr (Poly (ntt_coeffs f))))
         (of_qr (to_qr (Poly (ntt_coeffs g)))))"
  using of_qr_to_qr by auto
  then show ?thesis
    unfolding ntt_poly_def qr_mult_coeffs_def  
    by auto
qed

text ‹Correctness of NTT on polynomials.›

lemma inv_ntt_poly_correct:
"inv_ntt_poly (ntt_poly f) = f"
proof -
  have rew_sum: "(j = 0..<n. nth_default 0 
    (map (λi. j = 0..<n. poly.coeff (of_qr f) j * ψ ^ (j * (2 * i + 1))) [0..<n])
     j * ψinv ^ (i * (2 * j + 1))) = 
    (j = 0..<n. (j' = 0..<n. poly.coeff (of_qr f) j' * ψ ^ (j' * (2 * j + 1)))
     * ψinv ^ (i * (2 * j + 1)))" 
    (is "(j = 0..<n. ?left j) = (j = 0..<n. ?right j)") for i
  proof (subst sum.cong[of "{0..<n}" "{0..<n}" ?left ?right], goal_cases)
    case (2 x)
    then show ?case by (subst nth_default_map[of x n], auto)
  qed auto
  have  "(j = 0..<n. j' = 0..<n. poly.coeff (of_qr f) j' * 
    ψ ^ (j' * (2 * j + 1)) * ψinv ^ (i * (2 * j + 1))) = 
    (of_int_mod_ring n) * poly.coeff (of_qr f) i" if "i<n" for i
  proof -
    have rew_psi: "ψ ^ (j * (2 * j' + 1)) * ψinv ^ (i * (2 * j' + 1)) =
      ψ ^ j * ψinv ^ i * (ψ ^ (j * 2) * ψinv ^ (i * 2)) ^ j'"
      if "j'<n" "j<n" for j' j
    by (smt (verit, ccfv_threshold) kyber_ntt.exp_rule mult.commute 
      power_add power_mult power_one_right)
    have "(j = 0..<n. j' = 0..<n. poly.coeff (of_qr f) j' * 
        ψ ^ (j' * (2 * j + 1)) * ψinv ^ (i * (2 * j + 1))) = 
      (j' = 0..<n. poly.coeff (of_qr f) j' * ψ^j' * ψinv^i *
        (j = 0..<n. (ψ ^ (j' * 2) * ψinv ^ (i * 2))^j))"
    apply (subst sum_distrib_left, subst sum.swap)
     proof (subst sum.cong[of "{0..<n}" "{0..<n}"
      "(λj. ia = 0..<n. poly.coeff (of_qr f) j * ψ ^ (j * (2 * ia + 1)) *
           ψinv ^ (i * (2 * ia + 1)))"
      "(λj. ia = 0..<n. poly.coeff (of_qr f) j * ψ ^ j * ψinv ^ i *
           (ψ ^ (j * 2) * ψinv ^ (i * 2)) ^ ia)"], goal_cases)
      case (2 j)
      then show ?case proof (subst sum.cong[of "{0..<n}" "{0..<n}"
      "(λia. poly.coeff (of_qr f) j * ψ ^ (j * (2 * ia + 1)) *
        ψinv ^ (i * (2 * ia + 1)))"
      "(λia. poly.coeff (of_qr f) j * ψ ^ j * ψinv ^ i *
        (ψ ^ (j * 2) * ψinv ^ (i * 2)) ^ ia)"], goal_cases)
        case (2 j')
        then show ?case using rew_psi[of j' j] by simp
      qed auto
    qed auto
    also have " = (j' = 0..<n. 
      (if j' = i then poly.coeff (of_qr f) j' * ψ^j' * ψinv^i * (of_int_mod_ring n) else 0))" 
    proof (subst sum.cong[of "{0..<n}" "{0..<n}" 
      "(λj'. poly.coeff (of_qr f) j' * ψ ^ j' * ψinv ^ i *
        (j=0..<n. (ψ ^ (j' * 2) * ψinv ^ (i * 2))^j))"
      "(λj'. (if j' = i then poly.coeff (of_qr f) j' * ψ^j' * ψinv^i * 
        (of_int_mod_ring n) else 0))"], goal_cases)
      case (2 j')
      then show ?case proof (cases "j' = i")
        case True
        then have "(j=0..<n. (ψ ^ (j' * 2) * ψinv ^ (i * 2))^j) = of_int_mod_ring n"
          unfolding True psi_inv_exp 
          by (metis kyber_ntt.sum_rules(5) mult.right_neutral power_one sum.cong)
        then show ?thesis using True by auto
      next
        case False
        have not1: "ψ ^ (j' * 2) * ψinv ^ (i * 2)  1"
        proof -
          have "ω^j' * μ ^i  1" 
          proof (cases "j'<i")
            case True
            have *: "ω^j' * μ ^i = μ^(i-j')" using True 
            by (metis (no_types, lifting) le_add_diff_inverse less_or_eq_imp_le 
              mult.assoc mult_cancel_right2 power_add power_mult psi_inv_exp 
              psi_properties(1) psiinv_prop)
            show ?thesis proof (unfold *, rule ccontr)
              assume "¬ μ ^ (i - j')  1" 
              then have 1: "μ ^ (i - j') = 1" by auto
              show False using mu_prop'[OF 1] j'i
              using True less_imp_diff_less that diff_is_0_eq leD by blast
            qed
          next
            case False
            have 2: "ω^j' * μ ^i = ω^(j'-i)" using False
            by (smt (verit) Nat.add_diff_assoc ab_semigroup_mult_class.mult_ac(1) 
              add_diff_cancel_left' left_right_inverse_power linorder_not_less 
              mu_properties(1) mult.commute mult_numeral_1_right numeral_One power_add) 
            show ?thesis proof (unfold 2, rule ccontr)
              assume "¬ ω ^ (j' - i)  1" 
              then have 1: "ω ^ (j' - i) = 1" by auto
              have "n > j' - i" using j'  {0..<n} by auto
              then show False using omega_prop'[OF 1] j'i
              using False 
              by (meson diff_is_0_eq leD order_le_imp_less_or_eq)
            qed
          qed
          then show ?thesis
          by (metis mult.commute power_mult psi_properties(1) psiinv_prop) 
        qed
        have "(1 - ψ ^ (j' * 2) * ψinv ^ (i * 2)) * 
          (j=0..<n. (ψ ^ (j' * 2) * ψinv ^ (i * 2))^j) = 0"
        proof (subst kyber_ntt.geo_sum, goal_cases)
          case 1
          then show ?case using not1 by auto
        next
          case 2
          then show ?case 
          by (metis (no_types, opaque_lifting) cancel_comm_monoid_add_class.diff_cancel 
            mu_properties(1) mult.commute omega_properties(1) power_mult power_mult_distrib 
            power_one psi_properties(1) psiinv_prop)
        qed
        then have "(j=0..<n. (ψ ^ (j' * 2) * ψinv ^ (i * 2))^j) = 0" 
          using not1 by auto
        then show ?thesis using False by auto
      qed
    qed auto
    also have " = poly.coeff (of_qr f) i * ψ^i * ψinv^i * (of_int_mod_ring n)"
      by (subst sum.delta[of "{0..<n}" i], use i<n in auto)
    also have " = (of_int_mod_ring n) * poly.coeff (of_qr f) i"
      by (simp add: psi_inv_exp)
    finally show ?thesis by blast
  qed
  then have rew_coeff: "(map (λi. ninv * (j = 0..<n. n = 0..<n.
    poly.coeff (of_qr f) n * ψ ^ (n * (2 * j + 1)) * ψinv ^ (i * (2 * j + 1)))) [0..<n]) = 
    map (λi. ninv * (of_int_mod_ring (int n) * poly.coeff (of_qr f) i)) [0..<n]"
  unfolding map_eq_conv by auto
  show ?thesis unfolding inv_ntt_poly_def ntt_poly_def inv_ntt_coeffs_def ntt_coeffs_def
    inv_ntt_coeff_poly_def ntt_coeff_poly_def of_qr_to_qr map_upto_n_mod coeff_Poly
    apply (subst rew_sum) 
    apply (subst sum_distrib_right) 
    apply (subst rew_coeff)
    apply (subst mult.assoc[symmetric]) 
    apply (subst n_ninv')
    apply (subst mult_1)
    apply (subst Poly_map_coeff)
    subgoal using deg_of_qr deg_qr_n by fastforce
    subgoal unfolding to_qr_of_qr by auto
  done
qed


lemma ntt_inv_poly_correct:
"ntt_poly (inv_ntt_poly f) = f"
proof -
  have rew_sum: "(j = 0..<n. nth_default 0 (map (λi. ninv *
    (j' = 0..<n. poly.coeff (of_qr f) j' * ψinv ^ (i * (2 * j' + 1)))) [0..<n]) j *
      ψ ^ (j * (2 * i + 1))) = 
    (j = 0..<n. ninv * (j' = 0..<n. poly.coeff (of_qr f) j' * ψinv ^ (j * (2 * j' + 1)))
     * ψ ^ (j * (2 * i + 1)))" 
    (is "(j = 0..<n. ?left j) = (j = 0..<n. ?right j)") for i
  proof (subst sum.cong[of "{0..<n}" "{0..<n}" ?left ?right], goal_cases)
    case (2 x)
    then show ?case by (subst nth_default_map[of x n], auto)
  qed auto
  have  "(j = 0..<n. n = 0..<n. ninv * (poly.coeff (of_qr f) n *
    ψinv ^ (j * (2 * n + 1))) * ψ ^ (j * (2 * i + 1))) = 
    ninv * (of_int_mod_ring (int n) * poly.coeff (of_qr f) i)" if "i<n" for i
  proof -
    have rew_psi: "ψinv ^ (j' * (2 * j + 1)) * ψ ^ (j' * (2 * i + 1)) =
      (ψinv ^ (j * 2) * ψ ^ (i * 2)) ^ j'"
      if "j'<n" "j<n" for j' j
    proof -
      have "ψinv ^ (j' * (2 * j + 1)) * ψ ^ (j' * (2 * i + 1)) =
        ψinv ^ (j' * (2 * j)) * ψ ^ (j' * (2 * i)) * ψinv ^ j'  * ψ ^ j' "
      by (simp add: power_add)
      also have " = (ψinv ^ (2 * j) * ψ ^ (2 * i))^ j'"
      by (smt (verit, best) inv_psi_exp kyber_ntt.exp_rule mult.assoc 
        mult.commute mult.right_neutral power_mult) 
      also have " = (ψinv ^ (j * 2) * ψ ^ (i * 2)) ^ j'"
      by (simp add: mult.commute) 
      finally show ?thesis by blast
    qed
    have "(j = 0..<n. j' = 0..<n. ninv * (poly.coeff (of_qr f) j' *
    ψinv ^ (j * (2 * j' + 1))) * ψ ^ (j * (2 * i + 1))) = 
      (j' = 0..<n. ninv * poly.coeff (of_qr f) j' * 
        (j = 0..<n. (ψinv ^ (j' * 2) * ψ ^ (i * 2))^j))"
    apply (subst sum_distrib_left, subst sum.swap, unfold mult.assoc[symmetric])
    proof (subst sum.cong[of "{0..<n}" "{0..<n}"
      "(λj. ia = 0..<n. ninv * poly.coeff (of_qr f) j * ψinv ^ (ia * (2 * j + 1)) *
           ψ ^ (ia * (2 * i + 1)))"
      "(λj. n = 0..<n. ninv * poly.coeff (of_qr f) j *
           (ψinv ^ (j * 2) * ψ ^ (i * 2)) ^ n)"], goal_cases)
      case (2 j)
      then show ?case 
      proof (subst sum.cong[of "{0..<n}" "{0..<n}"
      "(λia. ninv * poly.coeff (of_qr f) j * ψinv ^ (ia * (2 * j + 1)) *
        ψ ^ (ia * (2 * i + 1)))"
      "(λia. ninv * poly.coeff (of_qr f) j *
        (ψinv ^ (j * 2) * ψ ^ (i * 2)) ^ ia)"], goal_cases)
        case (2 j')
        then show ?case using rew_psi[of j' j] by simp
      qed auto
    qed auto
    also have " = (j' = 0..<n. 
      (if j' = i then ninv * poly.coeff (of_qr f) j' * 
        ψinv^j' * ψ^i * (of_int_mod_ring n) else 0))" 
    (is "(j' = 0..<n. ?right j') = (j' = 0..<n. ?left j')")
    proof (subst sum.cong[of "{0..<n}" "{0..<n}" "?right" "?left"], goal_cases)
      case (2 j')
      then show ?case proof (cases "j' = i")
        case True
        then have "(j=0..<n. (ψinv ^ (j' * 2) * ψ ^ (i * 2))^j) = of_int_mod_ring n"
          unfolding True psi_inv_exp 
          by (metis kyber_ntt.sum_const mult.commute mult.right_neutral 
          power_one psi_inv_exp sum.cong) 
        then show ?thesis using True
        by (simp add: inv_psi_exp)
      next
        case False
        have not1: "ψinv ^ (j' * 2) * ψ ^ (i * 2)  1"
        proof -
          have "μ^j' * ω ^i  1" 
          proof (cases "j'<i")
            case True
            have *: "μ^j' * ω ^i = ω^(i-j')" using True
            by (smt (verit, best) add.commute kyber_ntt.omega_properties(1) 
              le_add_diff_inverse left_right_inverse_power less_or_eq_imp_le 
              mu_properties(1) mult.left_commute mult_cancel_right1 power_add)
            show ?thesis proof (unfold *, rule ccontr)
              assume "¬ ω ^ (i - j')  1" 
              then have 1: "ω ^ (i - j') = 1" by auto
              show False using omega_prop'[OF 1] j'i
              using True less_imp_diff_less that diff_is_0_eq leD by blast
            qed
          next
            case False
            have 2: "μ^j' * ω ^i = μ^(j'-i)" using False
            by (smt (verit) Nat.add_diff_assoc ab_semigroup_mult_class.mult_ac(1) 
              add_diff_cancel_left' left_right_inverse_power linorder_not_less 
              mu_properties(1) mult.commute mult_numeral_1_right numeral_One power_add) 
            show ?thesis proof (unfold 2, rule ccontr)
              assume "¬ μ ^ (j' - i)  1" 
              then have 1: "μ ^ (j' - i) = 1" by auto
              have "n > j' - i" using j'  {0..<n} by auto
              then show False using mu_prop'[OF 1] j'i
              using False 
              by (meson diff_is_0_eq leD order_le_imp_less_or_eq)
            qed
          qed
          then show ?thesis
          by (metis mult.commute power_mult psi_properties(1) psiinv_prop) 
        qed
        have "(1 - ψinv ^ (j' * 2) * ψ ^ (i * 2)) * 
          (j=0..<n. (ψinv ^ (j' * 2) * ψ ^ (i * 2))^j) = 0"
        proof (subst kyber_ntt.geo_sum, goal_cases)
          case 1
          then show ?case using not1 by auto
        next
          case 2
          then show ?case 
          by (metis (no_types, opaque_lifting) cancel_comm_monoid_add_class.diff_cancel 
            mu_properties(1) mult.commute omega_properties(1) power_mult power_mult_distrib 
            power_one psi_properties(1) psiinv_prop)
        qed
        then have "(j=0..<n. (ψinv ^ (j' * 2) * ψ ^ (i * 2))^j) = 0" 
          using not1 by auto
        then show ?thesis using False by auto
      qed
    qed auto
    also have " = ninv * poly.coeff (of_qr f) i * ψinv^i * ψ^i * (of_int_mod_ring n)"
      by (subst sum.delta[of "{0..<n}" i], use i<n in auto)
    also have " = ninv * ((of_int_mod_ring n) * poly.coeff (of_qr f) i)"
      by (simp add: psi_inv_exp mult.commute)
    finally show ?thesis by blast
  qed
  then have rew_coeff: "(map (λi. j = 0..<n. n = 0..<n. ninv * (poly.coeff (of_qr f) n *
    ψinv ^ (j * (2 * n + 1))) * ψ ^ (j * (2 * i + 1))) [0..<n]) = 
    map (λi. ninv * (of_int_mod_ring (int n) * poly.coeff (of_qr f) i)) [0..<n]"
  unfolding map_eq_conv by auto
  show ?thesis unfolding inv_ntt_poly_def ntt_poly_def inv_ntt_coeffs_def ntt_coeffs_def
    inv_ntt_coeff_poly_def ntt_coeff_poly_def of_qr_to_qr map_upto_n_mod coeff_Poly
    apply (subst rew_sum)
    apply (subst sum_distrib_left)
    apply (subst sum_distrib_right) 
    apply (subst rew_coeff)
    apply (subst mult.assoc[symmetric]) 
    apply (subst n_ninv')
    apply (subst mult_1)
    apply (subst Poly_map_coeff)
    subgoal using deg_of_qr deg_qr_n by fastforce
    subgoal unfolding to_qr_of_qr by auto
  done
qed

text ‹The multiplication of two polynomials can be computed by the NTT.›

lemma convolution_thm_ntt_poly:
  "f*g = inv_ntt_poly (qr_mult_coeffs (ntt_poly f) (ntt_poly g))"
unfolding ntt_mult[symmetric] inv_ntt_poly_correct by auto



end
end