Theory Sqrt_Babylonian

(*  Title:       Computing Square Roots using the Babylonian Method
    Author:      René Thiemann       <rene.thiemann@uibk.ac.at>
    Maintainer:  René Thiemann
    License:     LGPL
*)

(*
Copyright 2009-2014 René Thiemann

This file is part of IsaFoR/CeTA.

IsaFoR/CeTA is free software: you can redistribute it and/or modify it under the
terms of the GNU Lesser General Public License as published by the Free Software
Foundation, either version 3 of the License, or (at your option) any later
version.

IsaFoR/CeTA is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public License along
with IsaFoR/CeTA. If not, see <http://www.gnu.org/licenses/>.
*)

theory Sqrt_Babylonian
imports 
  Sqrt_Babylonian_Auxiliary
  NthRoot_Impl
begin

section ‹Executable algorithms for square roots›

text ‹
  This theory provides executable algorithms for computing square-roots of numbers which
  are all based on the Babylonian method (which is also known as Heron's method or Newton's method).
  
  For integers / naturals / rationals precise algorithms are given, i.e., here $sqrt\ x$ delivers
  a list of all integers / naturals / rationals $y$ where $y^2 = x$. 
  To this end, the Babylonian method has been adapted by using integer-divisions.

  In addition to the precise algorithms, we also provide approximation algorithms. One works for 
  arbitrary linear ordered fields, where some number $y$ is computed such that
  @{term "abs(y^2 - x) < ε"}. Moreover, for the naturals, integers, and rationals we provide algorithms to compute
  @{term "floor (sqrt x)"} and @{term "ceiling (sqrt x)"} which are all based
  on the underlying algorithm that is used to compute the precise square-roots on integers, if these 
  exist.

  The major motivation for developing the precise algorithms was given by \ceta{} cite"CeTA",
  a tool for certifiying termination proofs. Here, non-linear equations of the form
  $(a_1x_1 + \dots a_nx_n)^2 = p$ had to be solved over the integers, where $p$ is a concrete polynomial.
  For example, for the equation $(ax + by)^2 = 4x^2 - 12xy + 9y^2$ one easily figures out that
  $a^2 = 4, b^2 = 9$, and $ab = -6$, which results in a possible solution $a = \sqrt 4 = 2, b = - \sqrt 9 = -3$.
›

subsection ‹The Babylonian method›

text ‹
The Babylonian method for computing $\sqrt n$ iteratively computes 
\[
x_{i+1} = \frac{\frac n{x_i} + x_i}2
\]
until $x_i^2 \approx n$. Note that if $x_0^2 \geq n$, then for all $i$ we have both
$x_i^2 \geq n$ and $x_i \geq x_{i+1}$. 
›

subsection ‹The Babylonian method using integer division›
text ‹
  First, the algorithm is developed for the non-negative integers.
  Here, the division operation $\frac xy$ is replaced by @{term "x div y = of_int x / of_int y"}.
  Note that replacing @{term "of_int x / of_int y"} by @{term "of_int x / of_int y"} would lead to non-termination
  in the following algorithm.

  We explicititly develop the algorithm on the integers and not on the naturals, as the calculations
  on the integers have been much easier. For example, $y - x + x = y$ on the integers, which would require
  the side-condition $y \geq x$ for the naturals. These conditions will make the reasoning much more tedious---as
  we have experienced in an earlier state of this development where everything was based on naturals.

  Since the elements
  $x_0, x_1, x_2,\dots$ are monotone decreasing, in the main algorithm we abort as soon as $x_i^2 \leq n$.›


text ‹\textbf{Since in the meantime, all of these algorithms have been generalized to arbitrary
  $p$-th roots in @{theory Sqrt_Babylonian.NthRoot_Impl}, we just instantiate the general algorithms by $p = 2$ and then provide 
  specialized code equations which are more efficient than the general purpose algorithms.}›

definition sqrt_int_main' :: "int  int  int × bool" where
  [simp]: "sqrt_int_main' x n = root_int_main' 1 1 2 x n"

lemma sqrt_int_main'_code[code]: "sqrt_int_main' x n = (let x2 = x * x in if x2  n then (x, x2 = n)
    else sqrt_int_main' ((n div x + x) div 2) n)"
  using root_int_main'.simps[of 1 1 2 x n]
  unfolding Let_def by auto

definition sqrt_int_main :: "int  int × bool" where
  [simp]: "sqrt_int_main x = root_int_main 2 x"

lemma sqrt_int_main_code[code]: "sqrt_int_main x = sqrt_int_main' (start_value x 2) x"
  by (simp add: root_int_main_def Let_def)

definition sqrt_int :: "int  int list" where
  "sqrt_int x = root_int 2 x"

lemma sqrt_int_code[code]: "sqrt_int x = (if x < 0 then [] else case sqrt_int_main x of (y,True)  if y = 0 then [0] else [y,-y] | _  [])"
proof -
  interpret fixed_root 2 1 by (unfold_locales, auto)
  obtain b y where res: "root_int_main 2 x = (b,y)" by force
  show ?thesis
    unfolding sqrt_int_def root_int_def Let_def
    using root_int_main[OF _ res]
    using res
    by simp
qed

lemma sqrt_int[simp]: "set (sqrt_int x) = {y. y * y = x}"
  unfolding sqrt_int_def by (simp add: power2_eq_square)


lemma sqrt_int_pos: assumes res: "sqrt_int x = Cons s ms"
  shows "s  0"
proof -
  note res = res[unfolded sqrt_int_code Let_def, simplified]
  from res have x0: "x  0" by (cases ?thesis, auto)
  obtain ss b where call: "sqrt_int_main x = (ss,b)" by force
  from res[unfolded call] x0 have "ss = s" 
    by (cases b, cases "ss = 0", auto)
  from root_int_main(1)[OF x0 call[unfolded this sqrt_int_main_def]]
  show ?thesis .
qed

definition [simp]: "sqrt_int_floor_pos x = root_int_floor_pos 2 x"

lemma sqrt_int_floor_pos_code[code]: "sqrt_int_floor_pos x = fst (sqrt_int_main x)"
  by (simp add: root_int_floor_pos_def)

lemma sqrt_int_floor_pos: assumes x: "x  0" 
  shows "sqrt_int_floor_pos x =  sqrt (of_int x) "
  using root_int_floor_pos[OF x, of 2] by (simp add: sqrt_def)

definition [simp]: "sqrt_int_ceiling_pos x = root_int_ceiling_pos 2 x"

lemma sqrt_int_ceiling_pos_code[code]: "sqrt_int_ceiling_pos x = (case sqrt_int_main x of (y,b)  if b then y else y + 1)"
  by (simp add: root_int_ceiling_pos_def)

lemma sqrt_int_ceiling_pos: assumes x: "x  0" 
  shows "sqrt_int_ceiling_pos x =  sqrt (of_int x) "
  using root_int_ceiling_pos[OF x, of 2] by (simp add: sqrt_def)

definition "sqrt_int_floor x = root_int_floor 2 x"

lemma sqrt_int_floor_code[code]: "sqrt_int_floor x = (if x  0 then sqrt_int_floor_pos x else - sqrt_int_ceiling_pos (- x))"
  unfolding sqrt_int_floor_def root_int_floor_def by simp

lemma sqrt_int_floor[simp]: "sqrt_int_floor x =  sqrt (of_int x) "
  by (simp add: sqrt_int_floor_def sqrt_def)

definition "sqrt_int_ceiling x = root_int_ceiling 2 x"

lemma sqrt_int_ceiling_code[code]: "sqrt_int_ceiling x = (if x  0 then sqrt_int_ceiling_pos x else - sqrt_int_floor_pos (- x))"
  unfolding sqrt_int_ceiling_def root_int_ceiling_def by simp

lemma sqrt_int_ceiling[simp]: "sqrt_int_ceiling x =  sqrt (of_int x) "
  by (simp add: sqrt_int_ceiling_def sqrt_def)

lemma sqrt_int_ceiling_bound: "0  x  x  (sqrt_int_ceiling x)^2"
  unfolding sqrt_int_ceiling using le_of_int_ceiling sqrt_le_D
  by (metis of_int_power_le_of_int_cancel_iff)


subsection ‹Square roots for the naturals›


definition sqrt_nat :: "nat  nat list"
  where "sqrt_nat x = root_nat 2 x"
 
lemma sqrt_nat_code[code]: "sqrt_nat x  map nat (take 1 (sqrt_int (int x)))"
  unfolding sqrt_nat_def root_nat_def sqrt_int_def by simp

lemma sqrt_nat[simp]: "set (sqrt_nat x) = { y. y * y = x}" 
  unfolding sqrt_nat_def using root_nat[of 2 x] by (simp add: power2_eq_square)

definition sqrt_nat_floor :: "nat  int" where
  "sqrt_nat_floor x = root_nat_floor 2 x"

lemma sqrt_nat_floor_code[code]: "sqrt_nat_floor x = sqrt_int_floor_pos (int x)"
  unfolding sqrt_nat_floor_def root_nat_floor_def by simp

lemma sqrt_nat_floor[simp]: "sqrt_nat_floor x =  sqrt (real x) "
  unfolding sqrt_nat_floor_def by (simp add: sqrt_def)

definition sqrt_nat_ceiling :: "nat  int" where
  "sqrt_nat_ceiling x = root_nat_ceiling 2 x"

lemma sqrt_nat_ceiling_code[code]: "sqrt_nat_ceiling x = sqrt_int_ceiling_pos (int x)"
  unfolding sqrt_nat_ceiling_def root_nat_ceiling_def by simp

lemma sqrt_nat_ceiling[simp]: "sqrt_nat_ceiling x =  sqrt (real x) "
  unfolding sqrt_nat_ceiling_def by (simp add: sqrt_def)

subsection ‹Square roots for the rationals›

definition sqrt_rat :: "rat  rat list" where
  "sqrt_rat x = root_rat 2 x"

lemma sqrt_rat_code[code]: "sqrt_rat x = (case quotient_of x of (z,n)  (case sqrt_int n of 
    []  [] 
  | sn # xs  map (λ sz. of_int sz / of_int sn) (sqrt_int z)))"
proof -
  obtain z n where q: "quotient_of x = (z,n)" by force
  show ?thesis
  unfolding sqrt_rat_def root_rat_def q split sqrt_int_def
  by (cases "root_int 2 n", auto)
qed

lemma sqrt_rat[simp]: "set (sqrt_rat x) = { y. y * y = x}"
  unfolding sqrt_rat_def using root_rat[of 2 x]
  by (simp add: power2_eq_square)

lemma sqrt_rat_pos: assumes sqrt: "sqrt_rat x = Cons s ms" 
  shows "s  0"
proof -
  obtain z n where q: "quotient_of x = (z,n)" by force
  note sqrt = sqrt[unfolded sqrt_rat_code q, simplified]
  let ?sz = "sqrt_int z"
  let ?sn = "sqrt_int n"
  from q have n: "n > 0" by (rule quotient_of_denom_pos)
  from sqrt obtain sz mz where sz: "?sz = sz # mz" by (cases ?sn, auto)
  from sqrt obtain sn mn where sn: "?sn = sn # mn" by (cases ?sn, auto)
  from sqrt_int_pos[OF sz] sqrt_int_pos[OF sn] have pos: "0  sz" "0  sn" by auto
  from sqrt sz sn have s: "s = of_int sz / of_int sn" by auto
  show ?thesis unfolding s using pos
    by (metis of_int_0_le_iff zero_le_divide_iff)
qed

definition sqrt_rat_floor :: "rat  int" where
  "sqrt_rat_floor x = root_rat_floor 2 x"

lemma sqrt_rat_floor_code[code]: "sqrt_rat_floor x = (case quotient_of x of (a,b)  sqrt_int_floor (a * b) div b)"
  unfolding sqrt_rat_floor_def root_rat_floor_def by (simp add: sqrt_def)

lemma sqrt_rat_floor[simp]: "sqrt_rat_floor x =  sqrt (of_rat x) "
  unfolding sqrt_rat_floor_def by (simp add: sqrt_def)

definition sqrt_rat_ceiling :: "rat  int" where
  "sqrt_rat_ceiling x = root_rat_ceiling 2 x"

lemma sqrt_rat_ceiling_code[code]: "sqrt_rat_ceiling x = - (sqrt_rat_floor (-x))"
  unfolding sqrt_rat_ceiling_def sqrt_rat_floor_def root_rat_ceiling_def by simp

lemma sqrt_rat_ceiling: "sqrt_rat_ceiling x =  sqrt (of_rat x) "
  unfolding sqrt_rat_ceiling_def by (simp add: sqrt_def)

lemma sqr_rat_of_int: assumes x: "x * x = rat_of_int i"
  shows " j :: int. j * j = i"
proof -
  from x have mem: "x  set (sqrt_rat (rat_of_int i))" by simp
  from x have "rat_of_int i  0" by (metis zero_le_square)
  hence *: "quotient_of (rat_of_int i) = (i,1)" by (metis quotient_of_int)
  have 1: "sqrt_int 1 = [1,-1]" by code_simp
  from mem sqrt_rat_code * split 1 
  have x: "x  rat_of_int ` {y. y * y = i}" by auto
  thus ?thesis by auto
qed

subsection ‹Approximating square roots›

text ‹
  The difference to the previous algorithms is that now we abort, once the distance is below
  $\epsilon$.  
  Moreover, here we use standard division and not integer division.
  This part is not yet generalized by @{theory Sqrt_Babylonian.NthRoot_Impl}.

  We first provide the executable version without guard @{term "x > 0"} as partial function,
  and afterwards prove termination and soundness for a similar algorithm that is defined within the upcoming
locale.
›

partial_function (tailrec) sqrt_approx_main_impl :: "'a :: linordered_field  'a  'a  'a" where 
  [code]: "sqrt_approx_main_impl ε n x = (if x * x - n < ε then x else sqrt_approx_main_impl ε n 
    ((n / x + x) / 2))"

text ‹We setup a locale where we ensure that we have standard assumptions: positive $\epsilon$ and
  positive $n$. We require sort @{term floor_ceiling}, since @{term " x "} is used for the termination
  argument.›
locale sqrt_approximation = 
  fixes ε :: "'a :: {linordered_field,floor_ceiling}"
  and n :: 'a
  assumes ε : "ε > 0"
  and n: "n > 0"
begin

function sqrt_approx_main :: "'a  'a" where 
  "sqrt_approx_main x = (if x > 0 then (if x * x - n < ε then x else sqrt_approx_main 
    ((n / x + x) / 2)) else 0)"
    by pat_completeness auto

text ‹Termination essentially is a proof of convergence. Here, one complication is the fact
  that the limit is not always defined. E.g., if @{typ "'a"} is @{typ rat} then there is no
  square root of 2. Therefore, the error-rate $\frac x{\sqrt n} - 1$ is not expressible. 
  Instead we use the expression $\frac{x^2}n - 1$ as error-rate which
  does not require any square-root operation.›
termination
proof -
  define er where "er x = (x * x / n - 1)" for x
  define c where "c = 2 * n / ε"
  define m where "m x = nat  c * er x " for x
  have c: "c > 0" unfolding c_def using n ε by auto
  show ?thesis
  proof
    show "wf (measures [m])" by simp
  next
    fix x 
    assume x: "0 < x" and xe: "¬ x * x - n < ε"
    define y where "y = (n / x + x) / 2"    
    show "((n / x + x) / 2,x)  measures [m]" unfolding y_def[symmetric]
    proof (rule measures_less)
      from n have inv_n: "1 / n > 0" by auto
      from xe have "x * x - n  ε" by simp
      from this[unfolded mult_le_cancel_left_pos[OF inv_n, of ε, symmetric]]
      have erxen: "er x  ε / n" unfolding er_def using n by (simp add: field_simps)
      have en: "ε / n > 0" and ne: "n / ε > 0" using ε n by auto
      from en erxen have erx: "er x > 0" by linarith
      have pos: "er x * 4 + er x * (er x * 4) > 0" using erx
        by (auto intro: add_pos_nonneg)
      have "er y = 1 / 4 * (n / (x * x) - 2  + x * x / n)" unfolding er_def y_def using x n
        by (simp add: field_simps)
      also have " = 1 / 4 * er x * er x / (1 + er x)" unfolding er_def using x n
        by (simp add: field_simps)
      finally have "er y = 1 / 4 * er x * er x / (1 + er x)" .
      also have " < 1 / 4 * (1 + er x) * er x / (1 + er x)" using erx erx pos
        by (auto simp: field_simps)
      also have " = er x / 4" using erx by (simp add: field_simps)
      finally have er_y_x: "er y  er x / 4" by linarith
      from erxen have "c * er x  2" unfolding c_def mult_le_cancel_left_pos[OF ne, of _ "er x", symmetric]
        using n ε by (auto simp: field_simps)
      hence pos: "c * er x > 0" "c * er x  2" by auto
      show "m y < m x" unfolding m_def nat_mono_iff[OF pos(1)]
      proof -      
        have "c * er y  c * (er x / 4)"
          by (rule floor_mono, unfold mult_le_cancel_left_pos[OF c], rule er_y_x)
        also have " < c * er x / 4 + 1" by auto
        also have "  c * er x"
          by (rule floor_mono, insert pos(2), simp add: field_simps)
        finally show "c * er y < c * er x" .
      qed
    qed
  qed
qed

text ‹Once termination is proven, it is easy to show equivalence of 
  @{const sqrt_approx_main_impl} and @{const sqrt_approx_main}.›
lemma sqrt_approx_main_impl: "x > 0  sqrt_approx_main_impl ε n x = sqrt_approx_main x"
proof (induct x rule: sqrt_approx_main.induct)
  case (1 x)
  hence x: "x > 0" by auto
  hence nx: "0 < (n / x + x) / 2" using n by (auto intro: pos_add_strict)
  note simps = sqrt_approx_main_impl.simps[of _ _ x] sqrt_approx_main.simps[of x]
  show ?case 
  proof (cases "x * x - n < ε")
    case True
    thus ?thesis unfolding simps using x by auto
  next
    case False
    show ?thesis using 1(1)[OF x False nx] unfolding simps using x False by auto
  qed
qed

text ‹Also soundness is not complicated.›

lemma sqrt_approx_main_sound: assumes x: "x > 0" and xx: "x * x > n"
  shows "sqrt_approx_main x * sqrt_approx_main x > n  sqrt_approx_main x * sqrt_approx_main x - n < ε"
  using assms
proof (induct x rule: sqrt_approx_main.induct)
  case (1 x)
  from 1 have x:  "x > 0" "(x > 0) = True" by auto
  note simp = sqrt_approx_main.simps[of x, unfolded x if_True]
  show ?case
  proof (cases "x * x - n < ε")
    case True
    with 1 show ?thesis unfolding simp by simp
  next
    case False
    let ?y = "(n / x + x) / 2"
    from False simp have simp: "sqrt_approx_main x = sqrt_approx_main ?y" by simp
    from n x have y: "?y > 0" by (auto intro: pos_add_strict)
    note IH = 1(1)[OF x(1) False y]
    from x have x4: "4 * x * x > 0" by (auto intro: mult_sign_intros)
    show ?thesis unfolding simp
    proof (rule IH)
      show "n < ?y * ?y"
        unfolding mult_less_cancel_left_pos[OF x4, of n, symmetric]
      proof -
        have id: "4 * x * x * (?y * ?y) = 4 * x * x * n + (n - x * x) * (n - x * x)" using x(1)
          by (simp add: field_simps)
        from 1(3) have "x * x - n > 0" by auto
        from mult_pos_pos[OF this this]
        show "4 * x * x * n < 4 * x * x * (?y * ?y)" unfolding id 
          by (simp add: field_simps)
      qed
    qed
  qed
qed   

end

text ‹It remains to assemble everything into one algorithm.›

definition sqrt_approx :: "'a :: {linordered_field,floor_ceiling}  'a  'a" where
  "sqrt_approx ε x  if ε > 0 then (if x = 0 then 0 else let xpos = abs x in sqrt_approx_main_impl ε xpos (xpos + 1)) else 0"


lemma sqrt_approx: assumes ε: "ε > 0"
  shows "¦sqrt_approx ε x * sqrt_approx ε x - ¦x¦¦ < ε"
proof (cases "x = 0")
  case True
  with ε show ?thesis unfolding sqrt_approx_def by auto
next
  case False
  let ?x = "¦x¦" 
  let ?sqrti = "sqrt_approx_main_impl ε ?x (?x + 1)"
  let ?sqrt = "sqrt_approximation.sqrt_approx_main ε ?x (?x + 1)"
  define sqrt where "sqrt = ?sqrt"
  from False have x: "?x > 0" "?x + 1 > 0" by auto
  interpret sqrt_approximation ε ?x
    by (unfold_locales, insert x ε, auto)
  from False ε have "sqrt_approx ε x = ?sqrti" unfolding sqrt_approx_def by (simp add: Let_def)
  also have "?sqrti = ?sqrt"
    by (rule sqrt_approx_main_impl, auto)
  finally have id: "sqrt_approx ε x = sqrt" unfolding sqrt_def .
  have sqrt: "sqrt * sqrt > ?x  sqrt * sqrt - ?x < ε" unfolding sqrt_def
    by (rule sqrt_approx_main_sound[OF x(2)], insert x mult_pos_pos[OF x(1) x(1)], auto simp: field_simps)
  show ?thesis unfolding id using sqrt by auto
qed

subsection ‹Some tests›

text ‹Testing executabity and show that sqrt 2 is irrational›
lemma "¬ ( i :: rat. i * i = 2)"
proof -
  have "set (sqrt_rat 2) = {}" by eval
  thus ?thesis by simp
qed

text ‹Testing speed›
lemma "¬ ( i :: int. i * i = 1234567890123456789012345678901234567890)"
proof -
  have "set (sqrt_int 1234567890123456789012345678901234567890) = {}" by eval
  thus ?thesis by simp
qed

text ‹The following test›

value "let ε = 1 / 100000000 :: rat; s = sqrt_approx ε 2 in (s, s * s - 2, ¦s * s - 2¦ < ε)"

text ‹results in (1.4142135623731116, 4.738200762148612e-14, True).›
 
end