Theory Crypto_Scheme_NTT
theory Crypto_Scheme_NTT
imports Crypto_Scheme
NTT_Scheme
begin
section ‹Kyber Algorithm using NTT for Fast Multiplication›
hide_type Matrix.vec
context kyber_ntt
begin
definition mult_ntt:: "'a qr ⇒ 'a qr ⇒ 'a qr" (infixl ‹*⇘ntt⇙› 70) where
"mult_ntt f g = inv_ntt_poly (ntt_poly f ⋆ ntt_poly g)"
lemma mult_ntt:
"f*g = f *⇘ntt⇙ g"
unfolding mult_ntt_def using convolution_thm_ntt_poly by auto
definition scalar_prod_ntt::
"('a qr, 'k) vec ⇒ ('a qr, 'k) vec ⇒ 'a qr" (infixl ‹∙⇘ntt⇙› 70) where
"scalar_prod_ntt v w =
(∑i∈(UNIV::'k set). (vec_nth v i) *⇘ntt⇙ (vec_nth w i))"
lemma scalar_prod_ntt:
"scalar_product v w = scalar_prod_ntt v w"
unfolding scalar_product_def scalar_prod_ntt_def using mult_ntt by auto
definition mat_vec_mult_ntt::
"(('a qr, 'k) vec, 'k) vec ⇒ ('a qr, 'k) vec ⇒ ('a qr, 'k) vec" (infixl ‹⋅⇘ntt⇙› 70) where
"mat_vec_mult_ntt A v = vec_lambda (λi.
(∑j∈UNIV. (vec_nth (vec_nth A i) j) *⇘ntt⇙ (vec_nth v j)))"
lemma mat_vec_mult_ntt:
"A *v v = mat_vec_mult_ntt A v"
unfolding matrix_vector_mult_def mat_vec_mult_ntt_def using mult_ntt by auto
text ‹Refined algorithm using NTT for multiplications›
definition key_gen_ntt ::
"nat ⇒ (('a qr, 'k) vec, 'k) vec ⇒ ('a qr, 'k) vec ⇒
('a qr, 'k) vec ⇒ ('a qr, 'k) vec" where
"key_gen_ntt dt A s e = compress_vec dt (A ⋅⇘ntt⇙ s + e)"
lemma key_gen_ntt:
"key_gen_ntt dt A s e = key_gen dt A s e"
unfolding key_gen_ntt_def key_gen_def mat_vec_mult_ntt by auto
definition encrypt_ntt ::
"('a qr, 'k) vec ⇒ (('a qr, 'k) vec, 'k) vec ⇒
('a qr, 'k) vec ⇒ ('a qr, 'k) vec ⇒ ('a qr) ⇒
nat ⇒ nat ⇒ nat ⇒ 'a qr ⇒
(('a qr, 'k) vec) * ('a qr)" where
"encrypt_ntt t A r e1 e2 dt du dv m =
(compress_vec du ((transpose A) ⋅⇘ntt⇙ r + e1),
compress_poly dv ((decompress_vec dt t) ∙⇘ntt⇙ r +
e2 + to_module (round((real_of_int q)/2)) *⇘ntt⇙ m)) "
lemma encrypt_ntt:
"encrypt_ntt t A r e1 e2 dt du dv m = encrypt t A r e1 e2 dt du dv m"
unfolding encrypt_ntt_def encrypt_def mat_vec_mult_ntt scalar_prod_ntt mult_ntt by auto
definition decrypt_ntt ::
"('a qr, 'k) vec ⇒ ('a qr) ⇒ ('a qr, 'k) vec ⇒
nat ⇒ nat ⇒ 'a qr" where
"decrypt_ntt u v s du dv = compress_poly 1 ((decompress_poly dv v) -
s ∙⇘ntt⇙ (decompress_vec du u))"
lemma decrypt_ntt:
"decrypt_ntt u v s du dv = decrypt u v s du dv"
unfolding decrypt_ntt_def decrypt_def scalar_prod_ntt by auto
text ‹$(1-\delta)$-correctness for the refined algorithm›
lemma kyber_correct_ntt:
fixes A s r e e1 e2 dt du dv ct cu cv t u v
assumes
t_def: "t = key_gen_ntt dt A s e"
and u_v_def: "(u,v) = encrypt_ntt t A r e1 e2 dt du dv m"
and ct_def: "ct = compress_error_vec dt (A ⋅⇘ntt⇙ s + e)"
and cu_def: "cu = compress_error_vec du
((transpose A) ⋅⇘ntt⇙ r + e1)"
and cv_def: "cv = compress_error_poly dv
((decompress_vec dt t) ∙⇘ntt⇙ r + e2 +
to_module (round((real_of_int q)/2)) *⇘ntt⇙ m)"
and delta: "abs_infty_poly (e ∙⇘ntt⇙ r + e2 + cv -
s ∙⇘ntt⇙ e1 + ct ∙⇘ntt⇙ r -
s ∙⇘ntt⇙ cu) < round (real_of_int q / 4)"
and m01: "set ((coeffs ∘ of_qr) m) ⊆ {0,1}"
shows "decrypt_ntt u v s du dv = m"
using assms unfolding key_gen_ntt encrypt_ntt decrypt_ntt mat_vec_mult_ntt[symmetric]
scalar_prod_ntt[symmetric] mult_ntt[symmetric] using kyber_correct by auto
end
end