# Theory Crypto_Scheme_NTT

theory Crypto_Scheme_NTT

imports Crypto_Scheme
NTT_Scheme

begin

section ‹Kyber Algorithm using NTT for Fast Multiplication›

hide_type Matrix.vec

context kyber_ntt
begin

definition mult_ntt:: "'a qr ⇒ 'a qr ⇒ 'a qr" (infixl "*⇘ntt⇙" 70) where
"mult_ntt f g = inv_ntt_poly (ntt_poly f ⋆ ntt_poly g)"

lemma mult_ntt:
"f*g = f *⇘ntt⇙ g"
unfolding mult_ntt_def using convolution_thm_ntt_poly by auto

definition scalar_prod_ntt::
"('a qr, 'k) vec ⇒ ('a qr, 'k) vec ⇒ 'a qr" (infixl "∙⇘ntt⇙" 70) where
"scalar_prod_ntt v w =
(∑i∈(UNIV::'k set). (vec_nth v i) *⇘ntt⇙ (vec_nth w i))"

lemma scalar_prod_ntt:
"scalar_product v w = scalar_prod_ntt v w"
unfolding scalar_product_def scalar_prod_ntt_def using mult_ntt by auto

definition mat_vec_mult_ntt::
"(('a qr, 'k) vec, 'k) vec ⇒ ('a qr, 'k) vec ⇒ ('a qr, 'k) vec" (infixl "⋅⇘ntt⇙" 70) where
"mat_vec_mult_ntt A v = vec_lambda (λi.
(∑j∈UNIV. (vec_nth (vec_nth A i) j) *⇘ntt⇙ (vec_nth v j)))"

lemma mat_vec_mult_ntt:
"A *v v = mat_vec_mult_ntt A v"
unfolding matrix_vector_mult_def mat_vec_mult_ntt_def using mult_ntt by auto

text ‹Refined algorithm using NTT for multiplications›

definition key_gen_ntt ::
"nat ⇒ (('a qr, 'k) vec, 'k) vec ⇒ ('a qr, 'k) vec ⇒
('a qr, 'k) vec ⇒ ('a qr, 'k) vec" where
"key_gen_ntt dt A s e = compress_vec dt (A ⋅⇘ntt⇙ s + e)"

lemma key_gen_ntt:
"key_gen_ntt dt A s e = key_gen dt A s e"
unfolding key_gen_ntt_def key_gen_def mat_vec_mult_ntt by auto

definition encrypt_ntt ::
"('a qr, 'k) vec ⇒ (('a qr, 'k) vec, 'k) vec ⇒
('a qr, 'k) vec ⇒ ('a qr, 'k) vec ⇒ ('a qr) ⇒
nat ⇒ nat ⇒ nat ⇒ 'a qr ⇒
(('a qr, 'k) vec) * ('a qr)" where
"encrypt_ntt t A r e1 e2 dt du dv m =
(compress_vec du ((transpose A) ⋅⇘ntt⇙ r + e1),
compress_poly dv ((decompress_vec dt t) ∙⇘ntt⇙ r +
e2 + to_module (round((real_of_int q)/2)) *⇘ntt⇙ m)) "

lemma encrypt_ntt:
"encrypt_ntt t A r e1 e2 dt du dv m = encrypt t A r e1 e2 dt du dv m"
unfolding encrypt_ntt_def encrypt_def mat_vec_mult_ntt scalar_prod_ntt mult_ntt by auto

definition decrypt_ntt ::
"('a qr, 'k) vec ⇒ ('a qr) ⇒ ('a qr, 'k) vec ⇒
nat ⇒ nat ⇒ 'a qr" where
"decrypt_ntt u v s du dv = compress_poly 1 ((decompress_poly dv v) -
s ∙⇘ntt⇙ (decompress_vec du u))"

lemma decrypt_ntt:
"decrypt_ntt u v s du dv = decrypt u v s du dv"
unfolding decrypt_ntt_def decrypt_def scalar_prod_ntt by auto

text ‹$(1-\delta)$-correctness for the refined algorithm›

lemma kyber_correct_ntt:
fixes A s r e e1 e2 dt du dv ct cu cv t u v
assumes
t_def:   "t = key_gen_ntt dt A s e"
and u_v_def: "(u,v) = encrypt_ntt t A r e1 e2 dt du dv m"
and ct_def:  "ct = compress_error_vec dt (A ⋅⇘ntt⇙ s + e)"
and cu_def:  "cu = compress_error_vec du
((transpose A) ⋅⇘ntt⇙ r + e1)"
and cv_def:  "cv = compress_error_poly dv
((decompress_vec dt t) ∙⇘ntt⇙ r + e2 +
to_module (round((real_of_int q)/2)) *⇘ntt⇙ m)"
and delta:   "abs_infty_poly (e ∙⇘ntt⇙ r + e2 + cv -
s ∙⇘ntt⇙ e1 + ct ∙⇘ntt⇙ r -
s ∙⇘ntt⇙ cu) < round (real_of_int q / 4)"
and m01:     "set ((coeffs ∘ of_qr) m) ⊆ {0,1}"
shows "decrypt_ntt u v s du dv = m"
using assms unfolding key_gen_ntt encrypt_ntt decrypt_ntt mat_vec_mult_ntt[symmetric]
scalar_prod_ntt[symmetric] mult_ntt[symmetric] using kyber_correct by auto

end
end