# Theory Kyber_spec

theory Kyber_spec
imports Main "HOL-Computational_Algebra.Computational_Algebra"
"HOL-Computational_Algebra.Polynomial_Factorial"
"Berlekamp_Zassenhaus.Poly_Mod"
"Berlekamp_Zassenhaus.Poly_Mod_Finite_Field"

begin
hide_type Matrix.vec
hide_const Matrix.vec_index

section ‹Type Class for Factorial Ring $\mathbb{Z}_q[x]/(x^n+1)$.›
text ‹The Kyber algorithms work over the quotient ring $\mathbb{Z}_q[x]/(x^n+1)$
where $q$ is a prime with $q\equiv 1 \mod 4$ and $n$ is a power of $2$.
We encode this quotient ring as a type. In order to do so, we first look at the
finite field $\mathbb{Z}_q$ implemented by ‹('a::prime_card) mod_ring›.
Then we define polynomials using the constructor ‹poly›.
For factoring out $x^n+1$, we define an equivalence relation on the polynomial ring
$\mathbb{Z}_q[x]$ via the modulo operation with modulus $x^n+1$.
Finally, we build the quotient of the equivalence relation using the construction
‹quotient_type›.›
text ‹The module $\mathbb{Z}_q[x]/(x^n+1)$ was formalized with help from Manuel Eberl.›

text ‹Modulo relation between two polynomials. ›
lemma of_int_mod_ring_eq_0_iff:
"(of_int n :: ('n :: {finite, nontriv} mod_ring)) = 0 ⟷
int (CARD('n)) dvd n"
by transfer auto

lemma of_int_mod_ring_eq_of_int_iff:
"(of_int n :: ('n :: {finite, nontriv} mod_ring)) = of_int m ⟷
[n = m] (mod (int (CARD('n))))"
by transfer (auto simp: cong_def)

definition mod_poly_rel :: "nat ⇒ int poly ⇒ int poly ⇒ bool" where
"mod_poly_rel m p q ⟷
(∀n. [poly.coeff p n = poly.coeff q n] (mod (int m)))"

lemma mod_poly_rel_altdef:
"mod_poly_rel CARD('n :: nontriv) p q ⟷
(of_int_poly p) = (of_int_poly q :: 'n mod_ring poly)"
by (auto simp: poly_eq_iff mod_poly_rel_def
of_int_mod_ring_eq_of_int_iff)

definition mod_poly_is_unit :: "nat ⇒ int poly ⇒ bool" where
"mod_poly_is_unit m p ⟷ (∃r. mod_poly_rel m (p * r) 1)"

lemma mod_poly_is_unit_altdef:
"mod_poly_is_unit CARD('n :: nontriv) p ⟷
(of_int_poly p :: 'n mod_ring poly) dvd 1"
proof
assume "mod_poly_is_unit CARD('n) p"
thus "(of_int_poly p :: 'n mod_ring poly) dvd 1"
by (auto simp: mod_poly_is_unit_def dvd_def mod_poly_rel_altdef
of_int_poly_hom.hom_mult)
next
assume "(of_int_poly p :: 'n mod_ring poly) dvd 1"
then obtain q where q: "(of_int_poly p :: 'n mod_ring poly) * q = 1"
by auto
also have "q = of_int_poly (map_poly to_int_mod_ring q)"
also have "of_int_poly p * … =
of_int_poly (p * map_poly to_int_mod_ring q)"
finally show "mod_poly_is_unit CARD('n) p"
by (auto simp: mod_poly_is_unit_def mod_poly_rel_altdef)
qed

definition mod_poly_irreducible :: "nat ⇒ int poly ⇒ bool" where
"mod_poly_irreducible m Q ⟷
¬mod_poly_rel m Q 0 ∧
¬mod_poly_is_unit m Q ∧
(∀a b. mod_poly_rel m Q (a * b) ⟶
mod_poly_is_unit m a ∨ mod_poly_is_unit m b)"

lemma of_int_poly_to_int_poly: "of_int_poly (to_int_poly p) = p"

lemma mod_poly_irreducible_altdef:
"mod_poly_irreducible CARD('n :: nontriv) p ⟷
irreducible (of_int_poly p :: 'n mod_ring poly)"
proof
assume "irreducible (of_int_poly p :: 'n mod_ring poly)"
thus "mod_poly_irreducible CARD('n) p"
by (auto simp: mod_poly_irreducible_def mod_poly_rel_altdef
mod_poly_is_unit_altdef irreducible_def of_int_poly_hom.hom_mult)
next
assume *: "mod_poly_irreducible CARD('n) p"
show "irreducible (of_int_poly p :: 'n mod_ring poly)"
unfolding irreducible_def
proof (intro conjI impI allI)
fix a b assume ab: "(of_int_poly p :: 'n mod_ring poly) = a * b"
have "of_int_poly (map_poly to_int_mod_ring a *
map_poly to_int_mod_ring b) =
of_int_poly (map_poly to_int_mod_ring a) *
(of_int_poly (map_poly to_int_mod_ring b) :: 'n mod_ring poly)"
also have "… = a * b"
also have "… = of_int_poly p"
using ab by simp
finally have "(of_int_poly p :: 'n mod_ring poly) =
of_int_poly (to_int_poly a * to_int_poly b)" ..
hence "of_int_poly (to_int_poly a) dvd (1 :: 'n mod_ring poly) ∨
of_int_poly (to_int_poly b) dvd (1 :: 'n mod_ring poly)"
using * unfolding mod_poly_irreducible_def mod_poly_rel_altdef
mod_poly_is_unit_altdef by blast
thus "(a dvd (1 :: 'n mod_ring poly)) ∨
(b dvd (1 :: 'n mod_ring poly))"
qed (use * in ‹auto simp: mod_poly_irreducible_def
mod_poly_rel_altdef mod_poly_is_unit_altdef›)
qed

text ‹Type class for quotient ring $\mathbb{Z}_q[x]/(p)$.
The polynomial p is represented as ‹qr_poly'› (an polynomial over the integers).›

class qr_spec = prime_card +
fixes qr_poly' :: "'a itself ⇒ int poly"
"¬int CARD('a) dvd lead_coeff (qr_poly' TYPE('a))"
and deg_qr'_pos : "degree (qr_poly' TYPE('a)) > 0"

text ‹‹qr_poly› is the respective polynomial in $\mathbb{Z}_q[x]$.›
definition qr_poly :: "'a :: qr_spec mod_ring poly" where
"qr_poly = of_int_poly (qr_poly' TYPE('a))"

text ‹Functions to get the degree of the polynomials to be factored out.›
definition (in qr_spec) deg_qr :: "'a itself ⇒ nat" where
"deg_qr _ = degree (qr_poly' TYPE('a))"

lemma degree_qr_poly':
"degree (qr_poly' TYPE('a :: qr_spec)) = deg_qr (TYPE('a))"

lemma degree_of_int_poly':
assumes "of_int (lead_coeff p) ≠ (0 :: 'a :: ring_1)"
shows "degree (of_int_poly p :: 'a poly) = degree p"
proof (intro antisym)
show "degree (of_int_poly p) ≤ degree p"
by (intro degree_le) (auto simp: coeff_eq_0)
show "degree (of_int_poly p :: 'a poly) ≥ degree p"
using assms by (intro le_degree) auto
qed

lemma degree_qr_poly:
"degree (qr_poly :: 'a :: qr_spec mod_ring poly) = deg_qr (TYPE('a))"
unfolding qr_poly_def
by (subst degree_of_int_poly')
(auto simp: of_int_mod_ring_eq_0_iff degree_qr_poly')

lemma deg_qr_pos : "deg_qr TYPE('a :: qr_spec) > 0"
by (metis deg_qr'_pos degree_qr_poly')

text ‹The factor polynomial is non-zero.›
lemma qr_poly_nz [simp]: "qr_poly ≠ 0"
using deg_qr_pos[where ?'a = 'a] by (auto simp flip: degree_qr_poly)

text ‹Thus, when factoring out $p$, it has no effect on the neutral element $1$.›
lemma one_mod_qr_poly [simp]:
"1 mod (qr_poly :: 'a :: qr_spec mod_ring poly) = 1"
proof -
have "2 ^ 1 ≤ (2 ^ deg_qr TYPE('a) :: nat)"
using deg_qr_pos[where ?'a = 'a]
by (intro power_increasing) auto
thus ?thesis by (metis degree_qr_poly deg_qr_pos degree_1 mod_poly_less)
qed

text ‹We define a modulo relation for polynomials modulo a polynomial $p=$‹qr_poly›.›
definition qr_rel :: "'a :: qr_spec mod_ring poly ⇒ 'a mod_ring poly ⇒ bool" where
"qr_rel P Q ⟷ [P = Q] (mod qr_poly)"

lemma equivp_qr_rel: "equivp qr_rel"
by (intro equivpI sympI reflpI transpI)
(auto simp: qr_rel_def cong_sym intro: cong_trans)

text ‹Using this equivalence relation, we can define the quotient ring as a ‹quotient_type›.›
quotient_type (overloaded) 'a qr = "'a :: qr_spec mod_ring poly" / qr_rel
by (rule equivp_qr_rel)

text ‹Defining the conversion functions.›
lift_definition to_qr :: "'a :: qr_spec mod_ring poly ⇒ 'a qr"
is "λx. (x :: 'a mod_ring poly)" .

lift_definition of_qr :: "'a qr ⇒ 'a :: qr_spec mod_ring poly"
is "λP::'a mod_ring poly. P mod qr_poly"

text ‹Simplification lemmas on conversion functions.›
lemma of_qr_to_qr: "of_qr (to_qr (x)) = x mod qr_poly"
apply (auto simp add: of_qr_def to_qr_def)
by (metis of_qr.abs_eq of_qr.rep_eq)

lemma to_qr_of_qr: "to_qr (of_qr (x)) = x"
apply (auto simp add: of_qr_def to_qr_def)
by (metis (mono_tags, lifting) Quotient3_abs_rep Quotient3_qr
Quotient3_rel cong_def qr_rel_def mod_mod_trivial)

lemma eq_to_qr: "x = y ⟹ to_qr x = to_qr y" by auto

text ‹Type class instantiation for ‹qr› (quotient ring).›
instantiation qr :: (qr_spec) comm_ring_1
begin

lift_definition zero_qr :: "'a qr" is "0" .

lift_definition one_qr :: "'a qr" is "1" .

lift_definition plus_qr :: "'a qr ⇒ 'a qr ⇒ 'a qr"
is "(+)"
unfolding qr_rel_def using cong_add by blast

lift_definition uminus_qr :: "'a qr ⇒ 'a qr"
is "uminus"
unfolding qr_rel_def  using cong_minus_minus_iff by blast

lift_definition minus_qr :: "'a qr ⇒ 'a qr ⇒ 'a qr"
is "(-)"
unfolding qr_rel_def using cong_diff by blast

lift_definition times_qr :: "'a qr ⇒ 'a qr ⇒ 'a qr"
is "(*)"
unfolding qr_rel_def using cong_mult by blast

instance
proof
show "0 ≠ (1 :: 'a qr)"
by transfer (simp add: qr_rel_def cong_def)
qed (transfer; simp add: qr_rel_def algebra_simps; fail)+

end

lemma of_qr_0 [simp]: "of_qr 0 = 0"
and of_qr_1 [simp]: "of_qr 1 = 1"
and of_qr_uminus [simp]: "of_qr (-p) = -of_qr p"
and of_qr_add [simp]: "of_qr (p + q) = of_qr p + of_qr q"
and of_qr_diff [simp]: "of_qr (p - q) = of_qr p - of_qr q"

lemma to_qr_0 [simp]: "to_qr 0 = 0"
and to_qr_1 [simp]: "to_qr 1 = 1"
and to_qr_uminus [simp]: "to_qr (-p) = -to_qr p"
and to_qr_add [simp]: "to_qr (p + q) = to_qr p + to_qr q"
and to_qr_diff [simp]: "to_qr (p - q) = to_qr p - to_qr q"
and to_qr_mult [simp]: "to_qr (p * q) = to_qr p * to_qr q"
by (transfer'; simp; fail)+

lemma to_qr_of_nat [simp]: "to_qr (of_nat n) = of_nat n"
by (induction n) auto

lemma to_qr_of_int [simp]: "to_qr (of_int n) = of_int n"
by (induction n) auto

lemma of_qr_of_nat [simp]: "of_qr (of_nat n) = of_nat n"
by (induction n) auto

lemma of_qr_of_int [simp]: "of_qr (of_int n) = of_int n"
by (induction n) auto

lemma of_qr_eq_0_iff [simp]: "of_qr p = 0 ⟷ p = 0"
by transfer (simp add: qr_rel_def cong_def)

lemma to_qr_eq_0_iff:
"to_qr p = 0 ⟷ qr_poly dvd p"
by transfer (auto simp: qr_rel_def cong_def)

text ‹Some more lemmas that will probably be useful.›

lemma to_qr_eq_iff [simp]:
"to_qr P = (to_qr Q :: 'a :: qr_spec qr) ⟷ [P = Q] (mod qr_poly)"
by transfer (auto simp: qr_rel_def)

text ‹Reduction modulo $x^n + 1$ is injective on polynomials of degree less than $n$
in particular, this means that ‹card(QR(q^n)) = q^n›. ›
lemma inj_on_to_qr:
"inj_on
(to_qr :: 'a :: qr_spec mod_ring poly ⇒ 'a qr)
{P. degree P < deg_qr TYPE('a)}"
by (intro inj_onI) (auto simp: cong_def mod_poly_less
simp flip: degree_qr_poly)

text ‹Characteristic of quotient ring is exactly q.›

lemma of_int_qr_eq_0_iff [simp]:
"of_int n = (0 :: 'a :: qr_spec qr) ⟷ int (CARD('a)) dvd n"
proof -
have "of_int n = (0 :: 'a qr) ⟷ (of_int n :: 'a mod_ring poly) = 0"
by (smt (z3) of_qr_eq_0_iff of_qr_of_int)
also have "… ⟷ (of_int n :: 'a mod_ring) = 0"
also have "… ⟷ int (CARD('a)) dvd n"
finally show ?thesis .
qed

lemma of_int_qr_eq_of_int_iff:
"of_int n = (of_int m :: 'a :: qr_spec qr) ⟷
[n = m] (mod (int (CARD('a))))"
using of_int_qr_eq_0_iff[of "n - m", where ?'a = 'a]
by (simp del: of_int_qr_eq_0_iff add: cong_iff_dvd_diff)

lemma of_nat_qr_eq_of_nat_iff:
"of_nat n = (of_nat m :: 'a :: qr_spec qr) ⟷
[n = m] (mod CARD('a))"
using of_int_qr_eq_of_int_iff[of "int n" "int m"]

lemma of_nat_qr_eq_0_iff [simp]:
"of_nat n = (0 :: 'a :: qr_spec qr) ⟷ CARD('a) dvd n"
using of_int_qr_eq_0_iff[of "int n"] by simp

section ‹Specification of Kyber›

definition to_module :: "int ⇒ 'a ::qr_spec qr" where
"to_module x = to_qr (Poly [of_int_mod_ring x ::'a mod_ring])"

text ‹Properties in the ring ‹'a qr›. A good representative has degree up to n.›
lemma deg_mod_qr_poly:
assumes "degree x < deg_qr TYPE('a::qr_spec)"
shows "x mod (qr_poly :: 'a mod_ring poly) = x"
using mod_poly_less[of x qr_poly] unfolding deg_qr_def
by (metis assms degree_qr_poly)

lemma of_qr_to_qr':
assumes "degree x < deg_qr TYPE('a::qr_spec)"
shows "of_qr (to_qr x) = (x ::'a mod_ring poly)"
using deg_mod_qr_poly[OF assms] of_qr_to_qr[of x] by simp

lemma deg_of_qr:
"degree (of_qr (x ::'a qr)) < deg_qr TYPE('a::qr_spec)"
by (metis deg_qr_pos degree_0 degree_qr_poly degree_mod_less'
qr_poly_nz of_qr.rep_eq)

lemma to_qr_smult_to_module:
"to_qr (Polynomial.smult a p) = (to_qr (Poly [a])) * (to_qr p)"
by (metis Poly.simps(1) Poly.simps(2) mult.left_neutral
mult_smult_left smult_one to_qr_mult)

lemma of_qr_to_qr_smult:
"of_qr (to_qr (Polynomial.smult a p)) =
Polynomial.smult a (of_qr (to_qr p))"

text ‹The following locale comprehends all variables used in crypto schemes over $R_q$ like
Kyber and Dilithium.›

locale module_spec =
fixes "type_a" :: "('a :: qr_spec) itself"
and "type_k" :: "('k ::finite) itself"
and n q::int and k n'::nat
assumes
n_powr_2: "n = 2 ^ n'" and
n'_gr_0: "n' > 0" and
q_gr_two: "q > 2" and
q_prime : "prime q" and
CARD_a: "int (CARD('a :: qr_spec)) = q" and
CARD_k: "int (CARD('k :: finite)) = k" and
qr_poly'_eq: "qr_poly' TYPE('a) = Polynomial.monom 1 (nat n) + 1"

begin
text ‹Some properties of the modulus q.›

lemma q_nonzero: "q ≠ 0"
using module_spec_axioms module_spec_def  by (smt (z3))

lemma q_gt_zero: "q>0"
using module_spec_axioms module_spec_def by (smt (z3))

lemma q_gt_two: "q>2"
using module_spec_axioms module_spec_def by (smt (z3))

lemma q_odd: "odd q"
using module_spec_axioms module_spec_def prime_odd_int by blast

lemma nat_q: "nat q = q"
using q_gt_zero by force

text ‹Some properties of the degree n.›

lemma n_gt_1: "n > 1"
using module_spec_axioms module_spec_def

lemma n_nonzero: "n ≠ 0"
using n_gt_1 by auto

lemma n_gt_zero: "n>0"
using n_gt_1 by auto

lemma nat_n: "nat n = n"
using n_gt_zero by force

lemma deg_qr_n:
"deg_qr TYPE('a) = n"
unfolding deg_qr_def using qr_poly'_eq n_gt_1

end

text ‹
We now define a locale for the specification parameters of Kyber as in \cite{kyber}.
The specifications use the parameters:

\begin{tabular}{r l}
$n$ & $=256 = 2^{n'}$\\
$n'$ & $= 8$\\
$q$ & $= 7681$ or $3329$\\
$k$ & $= 3$\\
\end{tabular}

Additionally, we need that $q$ is a prime with the property $q\equiv 1\mod 4$.
›

locale kyber_spec = module_spec "TYPE ('a ::qr_spec)" "TYPE ('k::finite)" +
fixes type_a :: "('a :: qr_spec) itself"
and type_k :: "('k ::finite) itself"
assumes q_mod_4: "q mod 4 = 1"
begin
end

end