Theory CRYSTALS-Kyber.Mod_Plus_Minus

theory Mod_Plus_Minus

imports Kyber_spec

begin

lemma odd_half_floor:
  real_of_int x / 2 = (x - 1) div 2 if odd x
  using that by (rule oddE) simp

section ‹Re-centered Modulo Operation›
text ‹To define the compress and decompress functions, 
  we need some special form of modulo. It returns the 
  representation of the equivalence class in (-q div 2, q div 2]›.
  Using these representatives, we ensure that the norm of the 
  representative is as small as possible.›

definition mod_plus_minus :: "int  int  int" 
  (infixl mod+- 70) where
"m mod+- b = 
  (if m mod b > b/2 then m mod b - b else m mod b)"

text ‹Range of the (re-centered) modulo operation›
 
lemma mod_range: "b>0  (a::int) mod (b::int)  {0..b-1}"
using range_mod by auto

lemma mod_rangeE: 
  assumes "(a::int){0..<b}"
  shows "a = a mod b"
using assms by auto


lemma half_mod_odd:
  assumes "b > 0" "odd b" "real_of_int b / 2 < y mod b" 
  shows "- real_of_int b / 2  y mod b - b"
    "y mod b - b  real_of_int b / 2"
proof -
  from odd_half_floor [of b]
  show "- real_of_int b / 2  y mod b - b"
    using assms by linarith
  then have "y mod b  b + real_of_int b / 2"
    by (smt (verit) b > 0 pos_mod_bound)
  then show "y mod b - b  real_of_int b / 2"
    by auto
qed

lemma half_mod:
assumes "b>0"
shows "- real_of_int b / 2  y mod b"
using assms
by (smt (verit, best) floor_less_zero half_gt_zero mod_int_pos_iff of_int_pos)

lemma mod_plus_minus_range_odd: 
  assumes "b>0" "odd b"
  shows "y mod+- b  {-b/2..b/2}"
unfolding mod_plus_minus_def by (auto simp add: half_mod_odd[OF assms] half_mod[OF assms(1)])

lemma odd_smaller_b:
  assumes "odd b" 
  shows " real_of_int b / 2  +  real_of_int b / 2  < b"
using assms 
by (smt (z3) floor_divide_of_int_eq odd_two_times_div_two_succ 
  of_int_hom.hom_add of_int_hom.hom_one)
 

lemma mod_plus_minus_rangeE_neg:
  assumes "y  {-real_of_int b/2..real_of_int b/2}"
          "odd b" "b > 0"
           "real_of_int b / 2 < y mod b"
  shows "y = y mod b - b"
proof -
  have "y  {-real_of_int b/2..<0}" using assms
  by (meson atLeastAtMost_iff atLeastLessThan_iff linorder_not_le order_trans zmod_le_nonneg_dividend)
  then have "y  {-b..<0}" using assms(2-3)
  by (metis atLeastLessThan_iff floor_divide_of_int_eq int_div_less_self linorder_linear 
    linorder_not_le neg_le_iff_le numeral_code(1) numeral_le_iff of_int_numeral order_trans 
    semiring_norm(69))
  then have "y mod b = y + b" 
  by (smt (verit) atLeastLessThan_iff mod_add_self2 mod_pos_pos_trivial)
  then show ?thesis by auto
qed

lemma mod_plus_minus_rangeE_pos:
  assumes "y  {-real_of_int b/2..real_of_int b/2}"
          "odd b" "b > 0"
          "real_of_int b / 2  y mod b"
  shows "y = y mod b"
proof -
  have "y  {0..real_of_int b/2}" 
  proof (rule ccontr)
    assume "y  {0..real_of_int b / 2} "
    then have "y  {-real_of_int b/2..<0}" using assms(1) by auto
    then have "y  {-b..<0}" using assms(2-3)
    by (metis atLeastLessThan_iff floor_divide_of_int_eq int_div_less_self linorder_linear 
      linorder_not_le neg_le_iff_le numeral_code(1) numeral_le_iff of_int_numeral order_trans 
      semiring_norm(69))
    then have "y mod b = y + b" 
      by (smt (verit) atLeastLessThan_iff mod_add_self2 mod_pos_pos_trivial)
    then have "y mod b  b - real_of_int b/2" using assms(1) by auto
    then have "y mod b > real_of_int b/2"
      using assms(2) odd_smaller_b by fastforce
    then show False using assms(4) by auto
  qed
  then have "y  {0..<b}" using assms(2-3)
  by (metis atLeastAtMost_iff atLeastLessThan_empty atLeastLessThan_iff floor_divide_of_int_eq 
    int_div_less_self linorder_not_le numeral_One numeral_less_iff of_int_numeral semiring_norm(76))
  then show ?thesis by auto
qed

lemma mod_plus_minus_rangeE:
  assumes "y  {-real_of_int b/2..real_of_int b/2}"
          "odd b" "b > 0"
  shows "y = y mod+- b"
unfolding mod_plus_minus_def 
using mod_plus_minus_rangeE_pos[OF assms] mod_plus_minus_rangeE_neg[OF assms]
by auto

text ‹Image of $0$.›

lemma mod_plus_minus_zero:
  assumes "x mod+- b = 0"
  shows "x mod b = 0"
using assms unfolding mod_plus_minus_def 
by (metis eq_iff_diff_eq_0 mod_mod_trivial mod_self)

lemma mod_plus_minus_zero':
  assumes "b>0" "odd b"
  shows "0 mod+- b = (0::int)" 
using assms(1) mod_plus_minus_def by force

text mod+-› with negative values.›

lemma neg_mod_plus_minus:
  assumes "odd b"
          "b>0"
  shows "(- x) mod+- b = - (x mod+- b)"
proof -
  obtain k :: int where k_def: "(-x) mod+- b = (-x)+ k* b" 
  using mod_plus_minus_def
  proof -
    assume a1: "k. - x mod+- b = - x + k * b  thesis"
    have "i. i mod b + - (x + i) = - x mod+- b" 
    by (smt (verit, del_insts) mod_add_self1 mod_plus_minus_def)
    then show ?thesis
      using a1 by (metis (no_types) diff_add_cancel diff_diff_add 
      diff_minus_eq_add minus_diff_eq minus_mult_div_eq_mod 
      mult.commute mult_minus_left)
  qed
  then have "(-x) mod+- b = -(x - k*b)" using k_def by auto
  also have " = - ((x-k*b) mod+- b)"
  proof -
    have range_xkb:"x - k * b  
      {- real_of_int b / 2..real_of_int b / 2}"
      using k_def mod_plus_minus_range_odd[OF assms(2) assms(1)]
      by (smt (verit, ccfv_SIG) atLeastAtMost_iff)
    have "x - k*b = (x - k*b) mod+- b" 
      using mod_plus_minus_rangeE[OF range_xkb assms] by auto
    then show ?thesis by auto
  qed
  also have "-((x - k*b) mod+- b) = -(x mod+- b)" 
    unfolding mod_plus_minus_def 
    by (smt (verit, best) mod_mult_self1)
  finally show ?thesis by auto
qed

text ‹Representative with mod+-›

lemma mod_plus_minus_rep_ex:
"k. x = k*b + x mod+- b"
unfolding mod_plus_minus_def 
by (split if_splits)(metis add.right_neutral add_diff_eq div_mod_decomp_int 
  eq_iff_diff_eq_0 mod_add_self2)


lemma mod_plus_minus_rep: 
  obtains k where "x = k*b + x mod+- b"
using mod_plus_minus_rep_ex by auto

text ‹Multiplication in mod+-›

lemma mod_plus_minus_mult: 
  "s*x mod+- q = (s mod+- q) * (x mod+- q) mod+- q"
unfolding mod_plus_minus_def 
by (smt (verit, ccfv_threshold) minus_mod_self2 mod_mult_left_eq mod_mult_right_eq)
end