Theory CRYSTALS-Kyber.Crypto_Scheme
theory Crypto_Scheme
imports Kyber_spec
Compress
Abs_Qr
begin
section ‹$(1-\delta)$-Correctness Proof of the Kyber Crypto Scheme›
context kyber_spec
begin
text ‹In the following the key generation, encryption and decryption algorithms
of Kyber are stated. Here, the variables have the meaning:
\begin{itemize}
\item $A$: matrix, part of Alices public key
\item $s$: vector, Alices secret key
\item $t$: is the key generated by Alice qrom $A$ and $s$ in ‹key_gen›
\item $r$: Bobs "secret" key, randomly picked vector
\item $m$: message bits, $m\in \{0,1\}^{256}$
\item $(u,v)$: encrypted message
\item $dt$, $du$, $dv$: the compression parameters for $t$, $u$ and $v$ respectively.
Notice that ‹0 < d < ⌈log_2 q⌉›. The $d$ values are public knowledge.
\item $e$, $e1$ and $e2$: error parameters to obscure the message.
We need to make certain that an eavesdropper cannot distinguish
the encrypted message qrom uniformly random input.
Notice that $e$ and $e1$ are vectors while $e2$ is a mere element in ‹ℤ_q[X]/(X^n+1).›
\end{itemize}
›
definition key_gen ::
"nat ⇒ (('a qr, 'k) vec, 'k) vec ⇒ ('a qr, 'k) vec ⇒
('a qr, 'k) vec ⇒ ('a qr, 'k) vec" where
"key_gen dt A s e = compress_vec dt (A *v s + e)"
definition encrypt ::
"('a qr, 'k) vec ⇒ (('a qr, 'k) vec, 'k) vec ⇒
('a qr, 'k) vec ⇒ ('a qr, 'k) vec ⇒ ('a qr) ⇒
nat ⇒ nat ⇒ nat ⇒ 'a qr ⇒
(('a qr, 'k) vec) * ('a qr)" where
"encrypt t A r e1 e2 dt du dv m =
(compress_vec du ((transpose A) *v r + e1),
compress_poly dv (scalar_product (decompress_vec dt t) r +
e2 + to_module (round((real_of_int q)/2)) * m)) "
definition decrypt ::
"('a qr, 'k) vec ⇒ ('a qr) ⇒ ('a qr, 'k) vec ⇒
nat ⇒ nat ⇒ 'a qr" where
"decrypt u v s du dv = compress_poly 1 ((decompress_poly dv v) -
scalar_product s (decompress_vec du u))"
text ‹Lifting a function to the quotient ring›
fun f_int_to_poly :: "(int ⇒ int) ⇒ ('a qr) ⇒ ('a qr)" where
"f_int_to_poly f =
to_qr ∘
Poly ∘
(map of_int_mod_ring) ∘
(map f) ∘
(map to_int_mod_ring) ∘
coeffs ∘
of_qr"
text ‹Error of compression and decompression.›
definition compress_error_poly ::
"nat ⇒ 'a qr ⇒ 'a qr" where
"compress_error_poly d y =
decompress_poly d (compress_poly d y) - y"
definition compress_error_vec ::
"nat ⇒ ('a qr, 'k) vec ⇒ ('a qr, 'k) vec" where
"compress_error_vec d y =
decompress_vec d (compress_vec d y) - y"
text ‹Lemmas for scalar product›
lemma scalar_product_linear_left:
"scalar_product (a+b) c =
scalar_product a c + scalar_product b (c :: ('a qr, 'k) vec)"
unfolding scalar_product_def
by auto (metis (no_types, lifting) distrib_right sum.cong sum.distrib)
lemma scalar_product_linear_right:
"scalar_product a (b+c) =
scalar_product a b + scalar_product a (c :: ('a qr, 'k) vec)"
unfolding scalar_product_def
by auto (metis (no_types, lifting) distrib_left sum.cong sum.distrib)
lemma scalar_product_assoc:
"scalar_product (A *v s) (r :: ('a qr, 'k) vec ) =
scalar_product s (r v* A)"
unfolding scalar_product_def matrix_vector_mult_def
vector_matrix_mult_def
proof auto
have "(∑i∈UNIV. (∑j∈UNIV. (vec_nth (vec_nth A i) j) *
(vec_nth s j)) * (vec_nth r i)) =
(∑i∈UNIV. (∑j∈UNIV. (vec_nth (vec_nth A i) j) *
(vec_nth s j) * (vec_nth r i)))"
by (simp add: sum_distrib_right)
also have "… = (∑j∈UNIV. (∑i∈UNIV. (vec_nth (vec_nth A i) j) *
(vec_nth s j) * (vec_nth r i)))"
using sum.swap .
also have "… = (∑j∈UNIV. (∑i∈UNIV. (vec_nth s j) *
(vec_nth (vec_nth A i) j) * (vec_nth r i)))"
by (metis (no_types, lifting) mult_commute_abs sum.cong)
also have "… = (∑j∈UNIV. (vec_nth s j) *
(∑i∈UNIV. (vec_nth (vec_nth A i) j) * (vec_nth r i)))"
by (metis (no_types, lifting) mult.assoc sum.cong sum_distrib_left)
finally show "(∑i∈UNIV. (∑j∈UNIV. (vec_nth (vec_nth A i) j) *
(vec_nth s j)) * (vec_nth r i)) = (∑j∈UNIV. (vec_nth s j) *
(∑i∈UNIV. (vec_nth r i) * (vec_nth (vec_nth A i) j)))"
by (simp add: mult.commute)
qed
text ‹Lemma about coeff Poly›
lemma coeffs_in_coeff:
assumes "∀i. poly.coeff x i ∈ A"
shows "set (coeffs x) ⊆ A"
by (simp add: assms coeffs_def image_subsetI)
lemma set_coeff_Poly: "set ((coeffs ∘ Poly) xs) ⊆ set xs"
proof -
have "x ∈ set (strip_while ((=) 0) xs) ⟹ x ∈ set xs"
for x
by (metis append.assoc append_Cons in_set_conv_decomp
split_strip_while_append)
then show ?thesis by auto
qed
text ‹We now want to show the deterministic correctness of the algorithm.
That means, after choosing the variables correctly, generating the public key, encrypting
and decrypting, we get back the original message.›
lemma kyber_correct:
fixes A s r e e1 e2 dt du dv ct cu cv t u v
assumes
t_def: "t = key_gen dt A s e"
and u_v_def: "(u,v) = encrypt t A r e1 e2 dt du dv m"
and ct_def: "ct = compress_error_vec dt (A *v s + e)"
and cu_def: "cu = compress_error_vec du
((transpose A) *v r + e1)"
and cv_def: "cv = compress_error_poly dv
(scalar_product (decompress_vec dt t) r + e2 +
to_module (round((real_of_int q)/2)) * m)"
and delta: "abs_infty_poly (scalar_product e r + e2 + cv -
scalar_product s e1 + scalar_product ct r -
scalar_product s cu) < round (real_of_int q / 4)"
and m01: "set ((coeffs ∘ of_qr) m) ⊆ {0,1}"
shows "decrypt u v s du dv = m"
proof -
text ‹First, show that the calculations are performed correctly.›
have t_correct: "decompress_vec dt t = A *v s + e + ct "
using t_def ct_def unfolding compress_error_vec_def
key_gen_def by simp
have u_correct: "decompress_vec du u =
(transpose A) *v r + e1 + cu"
using u_v_def cu_def unfolding encrypt_def
compress_error_vec_def by simp
have v_correct: "decompress_poly dv v =
scalar_product (decompress_vec dt t) r + e2 +
to_module (round((real_of_int q)/2)) * m + cv"
using u_v_def cv_def unfolding encrypt_def
compress_error_poly_def by simp
have v_correct': "decompress_poly dv v =
scalar_product (A *v s + e) r + e2 +
to_module (round((real_of_int q)/2)) * m + cv +
scalar_product ct r"
using t_correct v_correct
by (auto simp add: scalar_product_linear_left)
let ?t = "decompress_vec dt t"
let ?u = "decompress_vec du u"
let ?v = "decompress_poly dv v"
text ‹Define w as the error term of the message encoding.
Have $\|w\|_{\infty ,q} < \lceil q/4 \rfloor$›
define w where "w = scalar_product e r + e2 + cv -
scalar_product s e1 + scalar_product ct r -
scalar_product s cu"
have w_length: "abs_infty_poly w < round (real_of_int q / 4)"
unfolding w_def using delta by auto
moreover have "abs_infty_poly w = abs_infty_poly (-w)"
unfolding abs_infty_poly_def
using neg_mod_plus_minus[OF q_odd q_gt_zero]
using abs_infty_q_def abs_infty_q_minus by auto
ultimately have minus_w_length:
"abs_infty_poly (-w) < round (real_of_int q / 4)"
by auto
have vsu: "?v - scalar_product s ?u =
w + to_module (round((real_of_int q)/2)) * m"
unfolding w_def by (auto simp add: u_correct v_correct'
scalar_product_linear_left scalar_product_linear_right
scalar_product_assoc)
text ‹Set m' as the actual result of the decryption.
It remains to show that $m' = m$.›
define m' where "m' = decrypt u v s du dv"
have coeffs_m': "∀i. poly.coeff (of_qr m') i ∈ {0,1}"
unfolding m'_def decrypt_def using compress_poly_1 by auto
text ‹Show $\| v - s^Tu - \lceil q/2 \rfloor m' \|_{\infty, q}
\leq \lceil q/4 \rfloor$›
have "abs_infty_poly (?v - scalar_product s ?u -
to_module (round((real_of_int q)/2)) * m')
= abs_infty_poly (?v - scalar_product s ?u -
decompress_poly 1 (compress_poly 1 (?v - scalar_product s ?u)))"
by (auto simp flip: decompress_poly_1[of m', OF coeffs_m'])
(simp add:m'_def decrypt_def)
also have "… ≤ round (real_of_int q / 4)"
using decompress_compress_poly[of 1 "?v - scalar_product s ?u"]
q_gt_two by fastforce
finally have "abs_infty_poly (?v - scalar_product s ?u -
to_module (round((real_of_int q)/2)) * m') ≤
round (real_of_int q / 4)"
by auto
text ‹Show $\| \lceil q/2 \rfloor (m-m')) \|_{\infty, q} <
2 \lceil q/4 \rfloor $›
then have "abs_infty_poly (w + to_module
(round((real_of_int q)/2)) * m - to_module
(round((real_of_int q)/2)) * m') ≤ round (real_of_int q / 4)"
using vsu by auto
then have w_mm': "abs_infty_poly (w +
to_module (round((real_of_int q)/2)) * (m - m'))
≤ round (real_of_int q / 4)"
by (smt (verit) add_uminus_conv_diff is_num_normalize(1)
right_diff_distrib')
have "abs_infty_poly (to_module
(round((real_of_int q)/2)) * (m - m')) =
abs_infty_poly (w + to_module
(round((real_of_int q)/2)) * (m - m') - w)"
by auto
also have "… ≤ abs_infty_poly
(w + to_module (round((real_of_int q)/2)) * (m - m'))
+ abs_infty_poly (- w)"
using abs_infty_poly_triangle_ineq[of
"w+to_module (round((real_of_int q)/2)) * (m - m')" "-w"]
by auto
also have "… < 2 * round (real_of_int q / 4)"
using w_mm' minus_w_length by auto
finally have error_lt: "abs_infty_poly (to_module (round((real_of_int q)/2)) * (m - m')) <
2 * round (real_of_int q / 4)"
by auto
text ‹Finally show that $m-m'$ is small enough, ie that it is
an integer smaller than one.
Here, we need that $q \cong 1\mod 4$.›
have coeffs_m':"set ((coeffs ∘ of_qr) m') ⊆ {0,1}"
proof -
have "compress 1 a ∈ {0,1}" for a
unfolding compress_def by auto
then have "poly.coeff (of_qr (compress_poly 1 a)) i ∈ {0,1}"
for a i
using compress_poly_1 by presburger
then have "set (coeffs (of_qr (compress_poly 1 a))) ⊆ {0,1}"
for a
using coeffs_in_coeff[of "of_qr (compress_poly 1 a)" "{0,1}"]
by simp
then show ?thesis unfolding m'_def decrypt_def by simp
qed
have coeff_0pm1: "set ((coeffs ∘ of_qr) (m-m')) ⊆
{of_int_mod_ring (-1),0,1}"
proof -
have "poly.coeff (of_qr m) i ∈ {0,1}"
for i using m01 coeff_in_coeffs
by (metis comp_def insertCI le_degree subset_iff
zero_poly.rep_eq)
moreover have "poly.coeff (of_qr m') i ∈ {0,1}" for i
using coeffs_m' coeff_in_coeffs
by (metis comp_def insertCI le_degree subset_iff zero_poly.rep_eq)
ultimately have "poly.coeff (of_qr m - of_qr m') i ∈ {of_int_mod_ring (- 1), 0, 1}" for i
by (metis (no_types, lifting) coeff_diff diff_zero
eq_iff_diff_eq_0 insert_iff of_int_hom.hom_one of_int_minus
of_int_of_int_mod_ring singleton_iff verit_minus_simplify(3))
then have "set (coeffs (of_qr m - of_qr m')) ⊆ {of_int_mod_ring (- 1), 0, 1}"
by (simp add: coeffs_in_coeff)
then show ?thesis using m01 of_qr_diff[of m m'] by simp
qed
have "set ((coeffs ∘ of_qr) (m-m')) ⊆ {0}"
proof (rule ccontr)
assume "¬set ((coeffs ∘ of_qr) (m-m')) ⊆ {0}"
then have "∃i. poly.coeff (of_qr (m-m')) i ∈
{of_int_mod_ring (-1),1}"
using coeff_0pm1
by (smt (z3) coeff_in_coeffs comp_apply insert_iff
leading_coeff_0_iff order_refl
set_coeffs_subset_singleton_0_iff subsetD)
then have error_ge: "abs_infty_poly (to_module
(round((real_of_int q)/2)) * (m-m')) ≥
2 * round (real_of_int q / 4)"
using abs_infty_poly_ineq_pm_1 by simp
show False using error_lt error_ge by simp
qed
then show ?thesis by (simp flip: m'_def) (metis to_qr_of_qr)
qed
end
end