Theory Efficient_Discrete_Sqrt

(*
  File:    Efficient_Discrete_Sqrt.thy
  Author:  Markus Großer, Manuel Eberl

  A reasonably efficient algorithm to compute the square root of a natural number (rounded down)
  and to test if a natural number is a perfect square.
*)
theory Efficient_Discrete_Sqrt
imports
  Complex_Main
  "HOL-Computational_Algebra.Computational_Algebra"
  "HOL-Library.Discrete_Functions"
  "HOL-Library.Tree"
  "HOL-Library.IArray"
begin
section ‹Efficient Algorithms for the Square Root on ℕ›

(*
  TODO: This could perhaps be moved somewhere else. Thre is also probably some overlap
  with Sqrt_Babylonian
*)

subsection ‹A Discrete Variant of Heron's Algorithm›

text ‹
  An algorithm for calculating the discrete square root, taken from 
  Cohen~cite"cohen2010algebraic". This algorithm is essentially a discretised variant of
  Heron's method or Newton's method specialised to the square root function.
›

lemma sqrt_eq_floor_sqrt: "floor_sqrt n = nat sqrt n"
proof -
  have "real ((nat sqrt n)2) = (real (nat sqrt n))2"
    by simp
  also have "  sqrt (real n) ^ 2"
    by (intro power_mono) auto
  also have " = real n" by simp
  finally have "(nat sqrt n)2  n"
    by (simp only: of_nat_le_iff)
  moreover have "n < (Suc (nat sqrt n))2" proof -
    have "(1 + sqrt n)2 > n"
      using floor_correct[of "sqrt n"] real_le_rsqrt[of "1 + sqrt n" n]
        of_int_less_iff[of n "(1 + sqrt n)2"] not_le
      by fastforce
    then show ?thesis
      using le_nat_floor[of "Suc (nat sqrt n)" "sqrt n"]
        of_nat_le_iff[of "(Suc (nat sqrt n))2" n] real_le_rsqrt[of _ n] not_le
      by fastforce
  qed
  ultimately show ?thesis using floor_sqrt_unique by fast
qed

fun newton_sqrt_aux :: "nat  nat  nat" where
  "newton_sqrt_aux x n =
     (let y = (x + n div x) div 2
      in if y < x then newton_sqrt_aux y n else x)"

declare newton_sqrt_aux.simps [simp del]

lemma newton_sqrt_aux_simps:
  "(x + n div x) div 2 < x  newton_sqrt_aux x n = newton_sqrt_aux ((x + n div x) div 2) n"
  "(x + n div x) div 2  x  newton_sqrt_aux x n = x"
  by (subst newton_sqrt_aux.simps; simp add: Let_def)+

lemma heron_step_real: "t > 0; n  0  (t + n/t) / 2  sqrt n"
  using arith_geo_mean_sqrt[of t "n/t"] by simp

lemma heron_step_div_eq_floored:
  "(t::nat) > 0  (t + (n::nat) div t) div 2 = nat (t + n/t) / 2"
proof -
  assume "t > 0"
  then have "(t + n/t) / 2 = (t*t + n) / (2*t)"
    by (simp add: mult_divide_mult_cancel_right[of t "t + n/t" 2, symmetric]
        algebra_simps)
  also have " = (t*t + n) div (2*t)"
    using floor_divide_of_nat_eq by blast
  also have " = (t*t + n) div t div 2"
    by (simp add: div_mult2_eq ac_simps)
  also have " = (t + n div t) div 2"
    by (simp add: 0 < t power2_eq_square)
  finally show ?thesis by simp
qed

lemma heron_step: "t > 0  (t + n div t) div 2  floor_sqrt n"
proof -
  assume "t > 0"
  have "floor_sqrt n = nat sqrt n" by (rule sqrt_eq_floor_sqrt)
  also have "  nat (t + n/t) / 2"
    using heron_step_real[of t n] t > 0 by linarith
  also have " = (t + n div t) div 2"
    using heron_step_div_eq_floored[OF t > 0] by simp
  finally show ?thesis .
qed

lemma newton_sqrt_aux_correct:
  assumes "x  floor_sqrt n"
  shows   "newton_sqrt_aux x n = floor_sqrt n"
  using assms
proof (induction x n rule: newton_sqrt_aux.induct)
  case (1 x n)
  show ?case
  proof (cases "x = floor_sqrt n")
    case True
    then have "(x ^ 2) div x  n div x" by (intro div_le_mono) simp_all
    also have "(x ^ 2) div x = x" by (simp add: power2_eq_square)
    finally have "(x + n div x) div 2  x" by linarith
    with True show ?thesis by (auto simp: newton_sqrt_aux_simps)
  next
    case False
    with "1.prems" have x_gt_sqrt: "x > floor_sqrt n" by auto
    with le_floor_sqrt_iff[of x n] have "n < x ^ 2" by simp
    have "x * (n div x)  n" using mult_div_mod_eq[of x n] by linarith
    also have " < x ^ 2" using le_floor_sqrt_iff[of x n] and x_gt_sqrt by simp
    also have " = x * x" by (simp add: power2_eq_square)
    finally have "n div x < x" by (subst (asm) mult_less_cancel1) auto
    then have step_decreasing: "(x + n div x) div 2 < x" by linarith
    with x_gt_sqrt have step_ge_sqrt: "(x + n div x) div 2  floor_sqrt n"
      by (simp add: heron_step)
    from step_decreasing have "newton_sqrt_aux x n = newton_sqrt_aux ((x + n div x) div 2) n"
      by (simp add: newton_sqrt_aux_simps)
    also have " = floor_sqrt n"
      by (intro "1.IH" step_decreasing step_ge_sqrt) simp_all
    finally show ?thesis .
  qed
qed

definition newton_sqrt :: "nat  nat" where
  "newton_sqrt n = newton_sqrt_aux n n"

declare floor_sqrt_code [code del]

theorem Discrete_sqrt_eq_newton_sqrt [code]: "floor_sqrt n = newton_sqrt n"
  unfolding newton_sqrt_def by (simp add: newton_sqrt_aux_correct floor_sqrt_le)


subsection ‹Square Testing›

text ‹
  Next, we implement an algorithm to determine whether a given natural number is a perfect square,
  as described by Cohen~cite"cohen2010algebraic". Essentially, the number first determines whether
  the number is a square. Essentially
›

definition q11 :: "nat set"
  where "q11 = {0, 1, 3, 4, 5, 9}"
definition q63 :: "nat set"
  where "q63 = {0, 1, 4, 7, 9, 16, 28, 18, 22, 25, 36, 58, 46, 49, 37, 43}"
definition q64 :: "nat set"
  where "q64 = {0, 1, 4, 9, 16, 17, 25, 36, 33, 49, 41, 57}"
definition q65 :: "nat set"
  where "q65 = {0, 1, 4, 10, 14, 9, 16, 26, 30, 25, 29, 40, 56, 36, 49, 61, 35, 51, 39, 55, 64}"


definition q11_array where
  "q11_array = IArray [True,True,False,True,True,True,False,False,False,True,False]"

definition q63_array where
  "q63_array = IArray [True,True,False,False,True,False,False,True,False,True,False,False,
     False,False,False,False,True,False,True,False,False,False,True,False,False,True,False,
     False,True,False,False,False,False,False,False,False,True,True,False,False,False,False,
     False,True,False,False,True,False,False,True,False,False,False,False,False,False,False,
     False,True,False,False,False,False,False]"

definition q64_array where
  "q64_array = IArray [True,True,False,False,True,False,False,False,False,True,False,False,
     False,False,False,False,True,True,False,False,False,False,False,False,False,True,False,
     False,False,False,False,False,False,True,False,False,True,False,False,False,False,True,
     False,False,False,False,False,False,False,True,False,False,False,False,False,False,
     False,True,False,False,False,False,False,False, False]"

definition q65_array where
  "q65_array = IArray [True,True,False,False,True,False,False,False,False,True,True,False,
     False,False,True,False,True,False,False,False,False,False,False,False,False,True,True,
     False,False,True,True,False,False,False,False,True,True,False,False,True,True,False,
     False,False,False,False,False,False,False,True,False,True,False,False,False,True,True
     ,False,False,False,False,True,False,False,True,False]"

lemma sub_q11_array: "i  {..<11}  IArray.sub q11_array i  i  q11"
  by (simp add: lessThan_nat_numeral lessThan_Suc q11_def q11_array_def, elim disjE; simp)

lemma sub_q63_array: "i  {..<63}  IArray.sub q63_array i  i  q63"
  by (simp add: lessThan_nat_numeral lessThan_Suc q63_def q63_array_def, elim disjE; simp)

lemma sub_q64_array: "i  {..<64}  IArray.sub q64_array i  i  q64"
  by (simp add: lessThan_nat_numeral lessThan_Suc q64_def q64_array_def, elim disjE; simp)

lemma sub_q65_array: "i  {..<65}  IArray.sub q65_array i  i  q65"
  by (simp add: lessThan_nat_numeral lessThan_Suc q65_def q65_array_def, elim disjE; simp)


lemma in_q11_code: "x mod 11  q11  IArray.sub q11_array (x mod 11)"
  by (subst sub_q11_array) auto

lemma in_q63_code: "x mod 63  q63  IArray.sub q63_array (x mod 63)"
  by (subst sub_q63_array) auto

lemma in_q64_code: "x mod 64  q64  IArray.sub q64_array (x mod 64)"
  by (subst sub_q64_array) auto

lemma in_q65_code: "x mod 65  q65  IArray.sub q65_array (x mod 65)"
  by (subst sub_q65_array) auto


definition square_test :: "nat  bool" where
  "square_test n =
    (n mod 64  q64  (let r = n mod 45045 in
      r mod 63  q63  r mod 65  q65  r mod 11  q11  n = (floor_sqrt n)2))"

lemma square_test_code [code]:
  "square_test n =
    (IArray.sub q64_array (n mod 64)  (let r = n mod 45045 in
           IArray.sub q63_array (r mod 63)  
           IArray.sub q65_array (r mod 65) 
           IArray.sub q11_array (r mod 11)  n = (floor_sqrt n)2))"
    using in_q11_code [symmetric] in_q63_code [symmetric] 
          in_q64_code [symmetric] in_q65_code [symmetric]
  by (simp add: Let_def square_test_def)

lemma square_mod_lower: "m > 0  (q2 :: nat) mod m = a  q' < m. q'2 mod m = a"
  using mod_less_divisor mod_mod_trivial power_mod by blast

lemma q11_upto_def: "q11 = (λk. k2 mod 11) ` {..<11}"
  by (simp add: q11_def lessThan_nat_numeral lessThan_Suc insert_commute)

lemma q11_infinite_def: "q11 = (λk. k2 mod 11) ` {0..}"
  unfolding q11_upto_def image_def proof (auto, goal_cases)
  case (1 xa)
  show ?case
    using square_mod_lower[of 11 xa "xa2 mod 11"]
      ex_nat_less_eq[of 11 "λx. xa2 mod 11 = x2 mod 11"]
    by auto
qed

lemma q63_upto_def: "q63 = (λk. k2 mod 63) ` {..<63}"
  by (simp add: q63_def lessThan_nat_numeral lessThan_Suc insert_commute)

lemma q63_infinite_def: "q63 = (λk. k2 mod 63) ` {0..}"
  unfolding q63_upto_def image_def proof (auto, goal_cases)
  case (1 xa)
  show ?case
    using square_mod_lower[of 63 xa "xa2 mod 63"]
      ex_nat_less_eq[of 63 "λx. xa2 mod 63 = x2 mod 63"]
    by auto
qed

lemma q64_upto_def: "q64 = (λk. k2 mod 64) ` {..<64}"
  by (simp add: q64_def lessThan_nat_numeral lessThan_Suc insert_commute)

lemma q64_infinite_def: "q64 = (λk. k2 mod 64) ` {0..}"
  unfolding q64_upto_def image_def proof (auto, goal_cases)
  case (1 xa)
  show ?case
    using square_mod_lower[of 64 xa "xa2 mod 64"]
      ex_nat_less_eq[of 64 "λx. xa2 mod 64 = x2 mod 64"]
    by auto
qed

lemma q65_upto_def: "q65 = (λk. k2 mod 65) ` {..<65}"
  by (simp add: q65_def lessThan_nat_numeral lessThan_Suc insert_commute)

lemma q65_infinite_def: "q65 = (λk. k2 mod 65) ` {0..}"
  unfolding q65_upto_def image_def proof (auto, goal_cases)
  case (1 xa)
  show ?case
    using square_mod_lower[of 65 xa "xa2 mod 65"]
      ex_nat_less_eq[of 65 "λx. xa2 mod 65 = x2 mod 65"]
    by auto
qed

lemma square_mod_existence:
  fixes n k :: nat
  assumes "q. q2 = n"
  shows "q. n mod k = q2 mod k"
  using assms by auto

theorem square_test_correct: "square_test n  is_square n"
proof cases
  assume "is_square n"
  hence  rhs: "q. q2 = n" by (auto elim: is_nth_powerE)
  note sq_mod = square_mod_existence[OF this]
  have q64_member: "n mod 64  q64" using sq_mod[of 64]
    unfolding q64_infinite_def image_def by simp
  let ?r = "n mod 45045"
  have "11 dvd (45045::nat)" "63 dvd (45045::nat)" "65 dvd (45045::nat)" by force+
  then have mod_45045: "?r mod 11 = n mod 11" "?r mod 63 = n mod 63" "?r mod 65 = n mod 65"
    using mod_mod_cancel[of _ 45045 n] by presburger+
  then have "?r mod 11  q11" "?r mod 63  q63" "?r mod 65  q65"
    using sq_mod[of 11] sq_mod[of 63] sq_mod[of 65]
    unfolding q11_infinite_def q63_infinite_def q65_infinite_def image_def mod_45045
    by fast+
  then show ?thesis unfolding square_test_def Let_def using q64_member rhs by auto
next
  assume not_rhs: "¬is_square n"
  hence "q. q2 = n" by auto
  then have "(floor_sqrt n)2  n" by simp
  then show ?thesis unfolding square_test_def by (auto simp: is_nth_power_def)
qed


definition get_nat_sqrt :: "nat  nat option" 
  where "get_nat_sqrt n = (if is_square n then Some (floor_sqrt n) else None)"

lemma get_nat_sqrt_code [code]:
  "get_nat_sqrt n = 
    (if IArray.sub q64_array (n mod 64)  (let r = n mod 45045 in
           IArray.sub q63_array (r mod 63)  
           IArray.sub q65_array (r mod 65) 
           IArray.sub q11_array (r mod 11)) then
       (let x = floor_sqrt n in if x2 = n then Some x else None) else None)"
  unfolding get_nat_sqrt_def square_test_correct [symmetric] square_test_def
  using in_q11_code [symmetric] in_q63_code [symmetric] 
        in_q64_code [symmetric] in_q65_code [symmetric]
  by (auto split: if_splits simp: Let_def )

end