Theory Karatsuba_Sqrt

theory Karatsuba_Sqrt                                   
imports
  Complex_Main
  Karatsuba_Sqrt_Library
begin

subsection ‹Definition of an integer square root with remainder›

definition sqrt_rem :: "nat  nat" where
  "sqrt_rem n = n - floor_sqrt n ^ 2"

lemma sqrt_rem_upper_bound: "sqrt_rem n  2 * floor_sqrt n"
proof -
  define s where "s = floor_sqrt n"
  have "n < (s + 1) ^ 2"
    unfolding s_def using Suc_floor_sqrt_power2_gt[of n] by auto
  hence "n + 1  (s + 1) ^ 2"
    by linarith
  hence "n  s ^ 2 + 2 * s"
    by (simp add: algebra_simps power2_eq_square)
  thus ?thesis
    unfolding s_def sqrt_rem_def by linarith
qed

lemma of_nat_sqrt_rem:
  "(of_nat (sqrt_rem n) :: 'a :: ring_1) = of_nat n - of_nat (floor_sqrt n) ^ 2"
  by (simp add: sqrt_rem_def)

definition sqrt_rem' where "sqrt_rem' n = (floor_sqrt n, sqrt_rem n)"

lemma Discrete_sqrt_code [code]: "floor_sqrt n = fst (sqrt_rem' n)"
  by (simp add: sqrt_rem'_def)

lemma sqrt_rem_code [code]: "sqrt_rem n = snd (sqrt_rem' n)"
  by (simp add: sqrt_rem'_def)



subsection ‹Heron's method›

text ‹
  The method used here is a variant of Heron's method, which is itself essentially Newton's method 
  specialised to square roots. This is already in the AFP under the name ``Babylonian method''.
  However, that entry derives a more general version for n›-th roots and lacks some flexibility
  that is useful for us here, so we instead derive a simple version for the square root
  directly. We will use this version in the base case of the algorithm.

  The starting value must be bigger than $\lfloor \sqrt{n}\rfloor$. We simply use
  $2^{\lceil \frac{1}{2}\log_2 n \rceil}$, which is easy to compute and fairly close to
  $\sqrt{n}$ already so that the Newton iterations converge very quickly.
›

context
  fixes n :: nat
begin

function sqrt_rem_aux :: "nat  nat × nat" where
  "sqrt_rem_aux x =
     (if x2  n then (x, n - x2) else sqrt_rem_aux ((n div x + x) div 2))"
  by auto
termination proof (relation "Wellfounded.measure id")
  fix x assume x: "¬(x2  n)"
  have "n div x * x  n"
    by simp
  also from x have "n < x * x"
    by (simp add: power2_eq_square)
  finally have "n div x < x"
    using x by simp
  hence "(n div x + x) div 2 < x"
    by (subst div_less_iff_less_mult) auto
  thus "((n div x + x) div 2, x)  measure id"
    by simp
qed auto

lemmas [simp del] = sqrt_rem_aux.simps

lemma sqrt_rem_aux_code [code]:
  "sqrt_rem_aux x = (
    let x2 = x*x; r = int n - int x2
    in if r  0 then (x, nat r) else sqrt_rem_aux (drop_bit 1 (n div x + x)))"
  by (subst sqrt_rem_aux.simps)
     (auto simp: Let_def case_prod_unfold power2_eq_square nat_diff_distrib drop_bit_eq_div
           simp flip: of_nat_mult)

lemma sqrt_rem_aux_decompose: "fst (sqrt_rem_aux x) ^ 2 + snd (sqrt_rem_aux x) = n"
  by (induction x rule: sqrt_rem_aux.induct; subst (1 2) sqrt_rem_aux.simps) auto

lemma sqrt_rem_aux_correct:
  assumes "x  floor_sqrt n"
  shows   "fst (sqrt_rem_aux x) = floor_sqrt n"
  using assms
proof (induction x rule: sqrt_rem_aux.induct)
  case (1 x)
  show ?case
  proof (cases "x ^ 2  n")
    case True
    from True have "floor_sqrt n  x"
      by (simp add: le_floor_sqrtI)
    with "1.prems" show ?thesis using True
      by (subst sqrt_rem_aux.simps) auto
  next
    case False
    hence "x > 0"
      by (auto intro!: Nat.gr0I)
    have "0 < (x ^ 2 - n) ^ 2 / (4 * x ^ 2)"
      using x > 0 False by (intro divide_pos_pos) auto
    also have "(x ^ 2 - n) ^ 2 / (4 * x ^ 2) = ((n / x + x) / 2) ^ 2 - n"
      using x > 0 False by (simp add: field_simps power2_eq_square)
    finally have "n < ((n / x + x) / 2) ^ 2"
      by simp
    hence "sqrt n ^ 2 < ((n / x + x) / 2) ^ 2"
      by simp
    hence "sqrt n < (n / x + x) / 2"
      by (rule power_less_imp_less_base) auto
    hence "nat (floor (sqrt n))  nat (floor ((n / x + x) / 2))"
      by linarith
    also have "nat (floor (sqrt n)) = floor_sqrt n"
      by (simp add: floor_sqrt_conv_floor_of_sqrt)
    also have "floor ((n / x + x) / 2) = (n div x + x) div 2"
      using floor_divide_real_eq_div[of 2 "n / x + x"] by (simp add: floor_divide_of_nat_eq)
    finally have "floor_sqrt n  (n div x + x) div 2"
      by simp
    from "1.IH"[OF False this] show ?thesis
      by (subst sqrt_rem_aux.simps) (use False in auto)
  qed
qed

lemma sqrt_rem_aux_correct':
  assumes "x  floor_sqrt n"
  shows   "sqrt_rem_aux x = sqrt_rem' n"
  using sqrt_rem_aux_correct[OF assms] sqrt_rem_aux_decompose[of x]
  by (simp add: sqrt_rem'_def prod_eq_iff sqrt_rem_def)

definition sqrt_rem'_heron :: "nat × nat" where
  "sqrt_rem'_heron = sqrt_rem_aux (push_bit ((ceillog2 n + 1) div 2) 1)"

lemma sqrt_rem'_heron_correct:
  "sqrt_rem'_heron = sqrt_rem' n"
proof (cases "n = 0")
  case True
  show ?thesis unfolding sqrt_rem'_heron_def
    by (rule sqrt_rem_aux_correct') (auto simp: True)
next
  case False
  hence n: "n > 0"
    by auto
  show ?thesis unfolding sqrt_rem'_heron_def
  proof (rule sqrt_rem_aux_correct')
    have "real (floor_sqrt n)  sqrt n"
      by (simp add: floor_sqrt_conv_floor_of_sqrt)
    also have " = 2 powr log 2 (sqrt n)"
      using n by simp
    also have "log 2 (sqrt n) = log 2 n / 2"
      using n by (simp add: log_def ln_sqrt)
    also have "(2::real) powr   2 powr ((ceillog2 n + 1) div 2)"
    proof (intro powr_mono)
      have "log 2 (real n)  real (ceillog2 n)"
        by (simp add: ceillog2_ge_log n)
      also have " / 2  (ceillog2 n + 1) div 2"
        by linarith
      finally show "log 2 n / 2  (ceillog2 n + 1) div 2"
        by - simp_all
    qed auto
    also have " = real (2 ^ ((ceillog2 n + 1) div 2))"
      by (subst powr_realpow) auto
    also have "2 ^ ((ceillog2 n + 1) div 2) = push_bit ((ceillog2 n + 1) div 2) 1"
      by (simp add: push_bit_eq_mult)
    finally show "floor_sqrt n  push_bit ((ceillog2 n + 1) div 2) 1"
      by linarith
  qed
qed

end

lemmas [code] = sqrt_rem'_heron_correct [symmetric]


subsection ‹Main algorithm›

subsubsection ‹Single step›

definition splice_bit where
  "splice_bit i n x = take_bit n (drop_bit i x)"

lemma of_nat_splice_bit:
  "of_nat (splice_bit i n x) =
     splice_bit i n (of_nat x :: 'a :: linordered_euclidean_semiring_bit_operations)"
  by (simp add: splice_bit_def of_nat_take_bit of_nat_drop_bit)

definition karatsuba_sqrt_step where
  "karatsuba_sqrt_step a32 a1 a0 b =
     (let (s, r) = sqrt_rem' a32;
          (q, u) = ((r * b + a1) div (2 * s), (r * b + a1) mod (2 * s));
          s' = int (s * b + q);
          r' = int (u * b + a0) - int (q ^ 2)
      in  if r'  0 then (s', r') else (s' - 1, r' + 2 * s' - 1))"

definition karatsuba_sqrt_step' :: "nat  nat  int × int" where
  "karatsuba_sqrt_step' n k =
     (let (s, r) = map_prod int int (sqrt_rem' (drop_bit (2*k) n));
          (q, u) = divmod_int (push_bit k r + splice_bit k k n) (push_bit 1 s);
          s' = push_bit k s + q;
          r' = push_bit k u + take_bit k n - q ^ 2
      in  if r'  0 then (s', r') else (s' - 1, r' + push_bit 1 s' - 1))"

text ‹
  Note that unlike Zimmerman, we do not have any upper bound on $a_{3}$ since this bound
  turned out to be unnecessary for the correctness of the algorithm. As long as $b^4$ is not much
  smaller than $n$, there is no efficiency problem either, since the step will still strip away
  about half of the bits of $n$.

  The advantage of this is that we do not have to do the ``normalisation'' done by Zimmerman to
  ensure that at least one of the two most significant bits of $a_3$ be set.
›
lemma karatsuba_sqrt_step_correct:
  fixes a32 a1 a0 :: nat
  assumes "a1 < b" "a0 < b" "4 * a32  b ^ 2" "even b"
  defines "n  a32 * b ^ 2 + a1 * b + a0"
  shows   "karatsuba_sqrt_step a32 a1 a0 b =
             map_prod of_nat of_nat (sqrt_rem' n)"
proof -
  define s where "s = floor_sqrt a32"
  define r where "r = sqrt_rem a32"
  define q where "q = (r * b + a1) div (2 * s)"
  define u where "u = (r * b + a1) mod (2 * s)"
  define s' where "s' = int (s * b + q)"
  define r' where "r' = int (u * b + a0) - int (q ^ 2)"
  define s'' where "s'' = (if r'  0 then s' else s' - 1)"
  define r'' where "r'' = (if r'  0 then r' else r'  + 2 * s' - 1)"

  from assms have "b > 0"
    by auto
  have "s > 0"
    using assms by (auto simp: s_def intro!: Nat.gr0I)

  have "b  2 * s"
  proof -
    have "4 * (b div 2) ^ 2 = b ^ 2"
      using even b by (auto elim!: evenE simp: power2_eq_square)
    also have "  4 * a32"
      by fact
    finally have "b div 2  s"
      unfolding s_def by (subst le_floor_sqrt_iff) auto
    thus "b  2 * s"
      using even b by (elim evenE) auto
  qed

  have s'_r': "int n = s' ^ 2 + r'"
  proof -
    have *: "int a1 = int q * (2 * int s) + int u - int r * int b"
      using arg_cong[OF div_mod_decomp[of "r * b + a1" "2 * s"], of int, folded q_def u_def]
      unfolding of_nat_add of_nat_mult by linarith
    have "int n = (int s ^ 2 + int r) * int b ^ 2 + int a1 * int b + int a0"
      by (simp add: n_def s_def r_def of_nat_sqrt_rem algebra_simps power_numeral_reduce)
    also have " = s' ^ 2 + r'"
      by (simp add: power2_eq_square algebra_simps * r'_def s'_def)
    finally show "int n = s' ^ 2 + r'" .
  qed
  hence s''_r'': "int n = s'' ^ 2 + r''"
    by (simp add: s''_def r''_def power2_eq_square algebra_simps)

  have "int n < (s' + 1) ^ 2"
  proof -
    define t where "t = floor_sqrt n - s * b"
    have "s ^ 2 * b ^ 2  a32 * b ^ 2"
      unfolding s_def by (intro mult_right_mono floor_sqrt_power2_le) auto
    also have "  n"
      by (simp add: n_def)
    finally have "(s * b) ^ 2  n"
      by (simp add: power_mult_distrib)
    hence "floor_sqrt n  s * b"
      by (simp add: le_floor_sqrt_iff)
    hence sqrt_n_eq: "floor_sqrt n = s * b + t"
      unfolding t_def by simp

    have "int (2 * s * t * b) = 2 * int s * int b * int t"
      by simp
    also have "2 * int s * int b * int t  2 * int s * int b * int t + int t ^ 2"
      by simp
    also have " = int ((s * b + t) ^ 2) - (int s * int b) ^ 2"
      unfolding of_nat_power of_nat_mult of_nat_add by algebra
    also have "s * b + t = floor_sqrt n"
      by (simp add: sqrt_n_eq)
    also have "floor_sqrt n ^ 2  n"
      by simp
    also have "n - (int s * int b) ^ 2 = int (a1 * b + a0) + (int a32 - int s ^ 2) * int b ^ 2"
      unfolding n_def of_nat_add of_nat_mult of_nat_power by algebra
    also have "int a32 - int s ^ 2 = int r"
      unfolding r_def by (simp add: of_nat_sqrt_rem s_def)
    also have "a0 < b"
      by fact
    also have "int (a1 * b + b) + int r * (int b)2 = int ((a1 + 1 + r * b) * b)"
      by (simp add: algebra_simps power2_eq_square)
    finally have "2 * s * t * b < (a1 + 1 + r * b) * b"
      unfolding of_nat_less_iff by - simp_all
    hence "2 * s * t < a1 + 1 + r * b"
      using b > 0 mult_less_cancel2 by blast
    hence "2 * s * t  r * b + a1"
      by linarith
    hence "t  q"
      unfolding q_def using s > 0
      by (subst less_eq_div_iff_mult_less_eq) (auto simp: algebra_simps)
    with sqrt_n_eq have *: "floor_sqrt n  s * b + q"
      by simp

    have "n < (floor_sqrt n + 1) ^ 2"
        using Suc_floor_sqrt_power2_gt[of n] by simp
    also have "  (s * b + q + 1) ^ 2"
      by (intro power_mono add_mono *) auto
    finally have "int n < int ((s * b + q + 1) ^ 2)"
      by linarith
    thus "int n < (s' + 1) ^ 2"
      by (simp add: algebra_simps s'_def)
  qed

  have "q  r"
  proof -
    have "q  (r * b + a1) div b"
      unfolding q_def using b  2 * s b > 0 by (intro div_le_mono2)
    also have " = r"
      using b > 0 assms by simp
    finally show "q  r" .
  qed

  have "int (q ^ 2) < 2 * s'"
  proof (cases "q = 0")
    case False
    have "q ^ 2  2 * s * b"
      unfolding power2_eq_square
    proof (intro mult_mono)
      show "q  2 * s"
        using q  r sqrt_rem_upper_bound[of a32] unfolding r_def s_def by linarith
    next
      show "q  b"
      proof -
        have "r  2 * s"
          using q  r unfolding r_def s_def using sqrt_rem_upper_bound[of a32] by linarith
        hence "q  (2 * s * b + a1) div (2 * s)"
          unfolding q_def by (intro div_le_mono add_mono mult_right_mono) auto
        also have " = b + a1 div (2 * s)"
          using assms s > 0 by simp
        also have "a1 div (2 * s) = 0"
          using b  2 * s a1 < b by auto
        finally show "q  b" by simp
      qed
    qed auto
    also have "2 * s * b < 2 * (s * b + q)"
      using q  0 by (simp add: algebra_simps)
    also have "int  = 2 * s'"
      by (simp add: s'_def)
    finally show ?thesis by - simp_all
  qed (use s > 0 b > 0 in auto simp: s'_def)

  have "r''  0"
  proof (cases "r'  0")
    case False
    have "r' + 2 * s' > 0"
      unfolding r'_def using int (q ^ 2) < 2 * s' by linarith
    thus ?thesis
      unfolding r''_def by auto
  qed (auto simp: r''_def)

  have "s''  0"
    using 0  r'' unfolding r''_def s''_def s'_def by auto

  have "s'' ^ 2  int n"
  proof -
    have "s'' ^ 2 = int n - r''"
      using s''_r'' by simp
    also have "  int n"
      using r''  0 by simp
    finally show "s'' ^ 2  n" .
  qed

  have "floor_sqrt n = nat s''"
  proof (rule floor_sqrt_unique)
    show "nat s'' ^ 2  n"
      using s'' ^ 2  int n
      by (metis nat_eq_iff2 of_nat_le_of_nat_power_cancel_iff zero_eq_power2 zero_le)
  next
    have "int n < (s'' + 1) ^ 2"
    proof (cases "r'  0")
      case True
      show ?thesis
        using True int n < (s' + 1) ^ 2 by (simp add: s''_def)
    next
      case False
      have "int n < s' ^ 2"
        using False s'_r' by auto
      thus ?thesis using False by (simp add: s''_def)
    qed
    also have "(s'' + 1) ^ 2 = int (Suc (nat s'') ^ 2)"
      using s''  0 by simp
    finally show "n < Suc (nat s'') ^ 2"
      by linarith
  qed
  moreover from this have "int (sqrt_rem n) = r''"
    using s''_r'' s''  0 unfolding of_nat_sqrt_rem by auto
  hence "sqrt_rem n = nat r''"
    by linarith
  moreover have "karatsuba_sqrt_step a32 a1 a0 b = (s'', r'')"
    unfolding karatsuba_sqrt_step_def sqrt_rem'_def n_def s''_def r''_def r'_def s'_def
              r_def s_def u_def q_def Let_def case_prod_unfold 
    by (simp add: divmod_def)
  ultimately show ?thesis using r''  0 s''  0
    by (simp add: n_def sqrt_rem'_def)
qed

lemma karatsuba_sqrt_step'_correct:
  fixes k n :: nat
  assumes k: "k > 0" and bitlen: "int k  (bitlen n + 1) div 4"
  defines "a32  drop_bit (2*k) n"
  defines "a1  splice_bit k k n"
  defines "a0  take_bit k n"
  shows   "karatsuba_sqrt_step' n k = map_prod int int (sqrt_rem' n)"
proof -
  define n' where "n' = drop_bit (2*k) n"
  have less: "a0 < 2 ^ k" "a1 < 2 ^ k"
    by (auto simp: a0_def a1_def splice_bit_def)
  have mod_less: "x mod y < 2 ^ k" if "y  2 ^ k" "y > 0" for x y :: int
  proof -
    have "x mod y < y"
      using that by (intro pos_mod_bound) auto
    also have "  2 ^ k"
      using that by simp
    finally show ?thesis .
  qed

  have n_eq: "n = a32 * 2 ^ (2 * k) + a1 * 2 ^ k + a0"
  proof -
    have "n = push_bit (2*k) (drop_bit (2*k) n) + take_bit (2*k) n"
      by (simp add: bits_ident)
    also have "take_bit (2*k) n = take_bit (2*k) (push_bit k (drop_bit k n) + take_bit k n)"
      by (simp add: bits_ident)
    also have " = push_bit k (splice_bit k k n) + take_bit k n"
      by (subst bit_eq_iff)
         (auto simp: bit_take_bit_iff bit_push_bit_iff bit_disjunctive_add_iff splice_bit_def)
    also have "push_bit (2 * k) (drop_bit (2 * k) n) + (push_bit k (splice_bit k k n) + take_bit k n) = 
                 drop_bit (2 * k) n * 2 ^ (2 * k) + splice_bit k k n * 2 ^ k + take_bit k n"
      by (simp add: push_bit_eq_mult)
    finally show ?thesis by (simp add: a32_def a1_def a0_def)
  qed

  have "a32 > 0"
  proof (rule Nat.gr0I)
    assume "a32 = 0"
    hence "bitlen (int n)  2 * int k"
      by (simp add: a32_def drop_bit_eq_0_iff_nat)
    with bitlen and k > 0 show False
      by linarith
  qed

  have *: "(2 ^ k) ^ 2  4 * a32"
  proof -
    have "int ((2 ^ k) ^ 2) = (2 ^ (2 * k) :: int)"
      by (simp add: power_mult add: mult_ac)
    also have "  int (4 * a32)  bitlen (int a32 * 2 ^ 2)  2 * k + 1"
      by (subst bitlen_ge_iff_power) (auto simp: nat_add_distrib nat_mult_distrib)
    also have "bitlen (int a32 * 2 ^ 2) = bitlen a32 + 2"
      using a32 > 0 by (subst bitlen_pow2) auto
    also have "bitlen (int n)  2 * int k"
      using assms(1,2) by linarith
    hence "bitlen (int a32) = bitlen (int n) - 2 * int k"
      by (simp add: a32_def of_nat_drop_bit bitlen_drop_bit)
    also have "(int (2 * k + 1)  bitlen (int n) - 2 * int k + 2)  True"
      using assms(2) by simp
    finally show ?thesis
      unfolding of_nat_le_iff by simp
  qed

  have "n = a32 * 2 ^ (2 * k) + a1 * 2 ^ k + a0"
    by (simp add: n_eq)
  also have "map_prod int int (sqrt_rem' ) = karatsuba_sqrt_step a32 a1 a0 (2^k)"
    by (subst karatsuba_sqrt_step_correct)
       (use * less k > 0 in auto simp: mult_ac simp flip: power_mult)
  also have "karatsuba_sqrt_step a32 a1 a0 (2^k) = 
           (let (s, r) = map_prod int int (sqrt_rem' a32);
                (q, u) = ((r * 2^k + a1) div (2 * s), (r * 2^k + a1) mod (2 * s));
                s' = s * 2^k + q;
                r' = u * 2^k + a0 - q ^ 2
            in  if r'  0 then (s', r') else (s' - 1, r' + 2 * s' - 1))"
    unfolding karatsuba_sqrt_step_def
    by (simp add: case_prod_unfold Let_def divmod_def zdiv_int zmod_int)
  also have " = karatsuba_sqrt_step' n k"
    unfolding karatsuba_sqrt_step'_def karatsuba_sqrt_step_def
    by (intro Let_cong case_prod_cong arg_cong2[of _ _ _ _ "divmod"]
              arg_cong[of _ _ "map_prod int int"]
              arg_cong[of _ _ sqrt_rem'] arg_cong[of _ _ int]
              arg_cong2[of _ _ _ _ "(-) :: int  _"] refl if_cong
              arg_cong2[of _ _ _ _ Pair] arg_cong2[of _ _ _ _ "(+)"])
        (auto simp: map_prod_def sqrt_rem'_def divmod_def a32_def a1_def a0_def of_nat_splice_bit
                    of_nat_drop_bit of_nat_take_bit divmod_int_def mult_ac push_bit_eq_mult)
  finally show ?thesis ..
qed


subsubsection ‹Full algorithm›

text ‹
  Our algorithm is parameterised with a ``limb size'' and a cutoff.
  The cutoff value describes the threshold for the base case, i.e.\ the size of inputs (in bits)
  for which we fall back to Heron's method.

  The algorithm splits the input number into four parts in such a way that the bit size of the lower
  three parts is a multiple of $2^l$ (where $l$ is the limb size). This may be useful to avoid 
  unnecessary bit shifting, since one one always splits the input number exactly at limb
  boundaries. However, whether this actually helps depends on how bit shifting of
  arbitrary-precision integers is actually implemented in the runtime.

  There is only one rather weak condition on the limb size and cutoff. Which values work best
  must be determined experimentally.
›

locale karatsuba_sqrt =
  fixes cutoff limb_size :: nat
  assumes cutoff: "2 ^ (2 + limb_size)  cutoff + 2"
begin

function karatsuba_sqrt_aux :: "nat  int × int" where
  "karatsuba_sqrt_aux n = (
    let sz = bitlen n
    in  if sz  int cutoff then
          case sqrt_rem'_heron n of (s, r)  (int s, int r)
        else let 
          k = push_bit limb_size (drop_bit (2 + limb_size) (nat (bitlen n + 1)));
          (s, r) = karatsuba_sqrt_aux (drop_bit (2*k) n);
          (q, u) = divmod_int (push_bit k r + splice_bit k k n) (push_bit 1 s);
          s' = push_bit k s + q;
          r' = push_bit k u + take_bit k n - q ^ 2
      in  if r'  0 then (s', r') else (s' - 1, r' + push_bit 1 s' - 1))"          
  by auto
termination proof (relation "measure id", goal_cases)
  case (2 n x k)
  have "2 ^ (2 + limb_size)  cutoff + 2"
    using cutoff by simp
  also have "cutoff + 2 < nat (bitlen (int n) + 2)"
    using 2 by simp
  finally have "2 ^ (2 + limb_size)  nat (bitlen (int n) + 1)"
    by linarith
  hence "k > 0"
    by (auto simp: push_bit_eq_mult drop_bit_eq_div 2(3) nat_add_distrib div_greater_zero_iff)
  hence "2 ^ 0 < (2 ^ (2 * k) :: nat)"
    using 2 by (intro power_strict_increasing Nat.gr0I)
               (auto simp: div_eq_0_iff nat_add_distrib not_le)
  moreover have "n > 0"
    using 2 by (auto intro!: Nat.gr0I)
  ultimately have "drop_bit (2 * k) n < n"
    by (auto simp: drop_bit_eq_div intro!: div_less_dividend)
  thus ?case
    by simp
qed auto

lemmas [simp del] = karatsuba_sqrt_aux.simps

lemma karatsuba_sqrt_aux_correct: "karatsuba_sqrt_aux n = map_prod int int (sqrt_rem' n)"
proof (induction n rule: karatsuba_sqrt_aux.induct)
  case (1 n)
  define sz where "sz = bitlen n"
  show ?case
  proof (cases "sz  cutoff")
    case True
    thus ?thesis
       by (subst karatsuba_sqrt_aux.simps)
          (auto simp: sqrt_rem'_heron_correct sqrt_rem'_def sz_def)
  next
    case False
    define k where "k = push_bit limb_size (drop_bit (2 + limb_size) (nat (bitlen n + 1)))"
    have n_eq: "n = drop_bit (2 * k) n * (2 ^ k)2 + splice_bit k k n * 2 ^ k + take_bit k n"
    proof -
      have "n = push_bit (2*k) (drop_bit (2*k) n) + take_bit (2*k) n"
        by (simp add: bits_ident)
      also have "take_bit (2*k) n = take_bit (2*k) (push_bit k (drop_bit k n) + take_bit k n)"
        by (simp add: bits_ident)
      also have " = push_bit k (splice_bit k k n) + take_bit k n"
        by (subst bit_eq_iff)
           (auto simp: bit_take_bit_iff bit_push_bit_iff bit_disjunctive_add_iff splice_bit_def)
      also have "push_bit (2 * k) (drop_bit (2 * k) n) + (push_bit k (splice_bit k k n) + take_bit k n) = 
                   drop_bit (2 * k) n * (2 ^ k)2 + splice_bit k k n * 2 ^ k + take_bit k n"
        by (simp add: push_bit_eq_mult flip: power_mult)
      finally show ?thesis .
    qed
      
    have "karatsuba_sqrt_aux n = karatsuba_sqrt_step' n k"
      using False "1.IH"[of sz k]
      by (subst karatsuba_sqrt_aux.simps)
         (simp_all add: karatsuba_sqrt_step'_def of_nat_splice_bit
                        of_nat_take_bit of_nat_drop_bit sz_def k_def Let_def)
    also have " = map_prod int int (sqrt_rem' n)"
    proof (subst karatsuba_sqrt_step'_correct)
      have "k  nat (bitlen (int n) + 1) div 4"
        by (simp add: k_def nat_add_distrib div_mult2_eq push_bit_eq_mult drop_bit_eq_div)
      moreover have "bitlen (int n) + 1  0"
        by (auto simp: bitlen_def)
      ultimately show "int k  (bitlen (int n) + 1) div 4"
        by linarith
    next
      show "k > 0"
      proof (rule Nat.gr0I)
        assume "k = 0"
        hence "nat sz + 1 < 2 ^ nat (int limb_size + 2)"
          by (auto simp: k_def div_eq_0_iff sz_def drop_bit_eq_div nat_add_distrib bitlen_def)
        hence "sz + 1 < int (2 ^ nat (int limb_size + 2))"
          by linarith          
        also have " = int (2 ^ (2 + limb_size))"
          by (simp add: nat_add_distrib)
        also have "  int (cutoff + 2)"
          using cutoff by linarith
        finally show False
          using False by simp
      qed
    qed (use n_eq in auto)
    finally show ?thesis .
  qed
qed

definition karatsuba_sqrt where
  "karatsuba_sqrt n = (case karatsuba_sqrt_aux n of (s, r)  (nat s, nat r))"

theorem karatsuba_sqrt_correct: "karatsuba_sqrt n = sqrt_rem' n"
  by (simp add: karatsuba_sqrt_def karatsuba_sqrt_aux_correct case_prod_unfold)

end



subsubsection ‹Concrete instantiation›

text ‹
  We pick a cutoff of 1024 and a limb size of 64 as reasonable default values.
›

definition karatsuba_sqrt_default where
  "karatsuba_sqrt_default = karatsuba_sqrt.karatsuba_sqrt 1024 6"

definition karatsuba_sqrt_default_aux where
  "karatsuba_sqrt_default_aux = karatsuba_sqrt.karatsuba_sqrt_aux 1024 6"

interpretation karatsuba_sqrt_default:
  karatsuba_sqrt 1024 6
  rewrites "karatsuba_sqrt.karatsuba_sqrt 1024 6  karatsuba_sqrt_default"
       and "karatsuba_sqrt.karatsuba_sqrt_aux 1024 6  karatsuba_sqrt_default_aux"
  by unfold_locales (auto simp: nat_add_distrib karatsuba_sqrt_default_aux_def karatsuba_sqrt_default_def)

lemmas [code] = 
  karatsuba_sqrt_default.karatsuba_sqrt_aux.simps[unfolded power2_eq_square]
  karatsuba_sqrt_default.karatsuba_sqrt_def
  karatsuba_sqrt_default.karatsuba_sqrt_correct [symmetric]




subsection ‹Using constsqrt_rem to compute floors and ceilings of constsqrt

definition sqrt_nat_ceiling :: "nat  nat" where
  "sqrt_nat_ceiling n = nat (ceiling (sqrt (real n)))"

definition sqrt_int_floor :: "int  int" where
  "sqrt_int_floor n = floor (sqrt (real_of_int n))"

definition sqrt_int_ceiling :: "int  int" where
  "sqrt_int_ceiling n = ceiling (sqrt (real_of_int n))"

lemma sqrt_nat_ceiling_code [code]:
  "sqrt_nat_ceiling n = (case sqrt_rem' n of (s, r)  if r = 0 then s else s + 1)"
proof -
  have n: "(floor_sqrt n)2 + sqrt_rem n = n"
    by (auto simp: sqrt_rem_def)
  have "sqrt n = sqrt (floor_sqrt n ^ 2 + sqrt_rem n)"
    by (simp add: sqrt_rem_def)
  also have "ceiling  = floor_sqrt n + (if sqrt_rem n = 0 then 0 else 1)"
  proof (cases "sqrt_rem n = 0")
    case False
    have "n < (floor_sqrt n + 1)2"
      using Suc_floor_sqrt_power2_gt le_eq_less_or_eq by auto
    hence "real n < real ((floor_sqrt n + 1)2)"
      by linarith
    hence "sqrt (floor_sqrt n ^ 2 + sqrt_rem n)  floor_sqrt n + 1"
      by (subst n) (auto intro!: real_le_lsqrt simp flip: of_nat_add)
    moreover have "floor_sqrt n < sqrt (floor_sqrt n ^ 2 + sqrt_rem n)"
      by (rule real_less_rsqrt) (use False in auto)
    ultimately have "ceiling (sqrt (floor_sqrt n ^ 2 + sqrt_rem n)) = floor_sqrt n + 1"
      by linarith
    thus ?thesis
      using False by simp
  qed auto
  finally show ?thesis
    by (simp add: sqrt_nat_ceiling_def sqrt_rem'_def nat_add_distrib)
qed  

lemma sqrt_int_floor_code [code]:
  "sqrt_int_floor n =
     (if n  0 then int (floor_sqrt (nat n)) else -int (sqrt_nat_ceiling (nat (-n))))"
  by (auto simp: sqrt_int_floor_def sqrt_nat_ceiling_def floor_sqrt_conv_floor_of_sqrt
                 real_sqrt_minus ceiling_minus)

lemma sqrt_int_ceiling_code [code]:
  "sqrt_int_ceiling n =
     (if n  0 then int (sqrt_nat_ceiling (nat n)) else -int (floor_sqrt (nat (-n))))"
  using sqrt_int_floor_code[of "-n"]
  by (cases n "0 :: int" rule: linorder_cases)
     (auto simp: sqrt_int_ceiling_def sqrt_int_floor_def sqrt_nat_ceiling_def[of 0]
                 real_sqrt_minus floor_minus)

end