Theory Efficient_Discrete_Sqrt
theory Efficient_Discrete_Sqrt
imports
Complex_Main
"HOL-Computational_Algebra.Computational_Algebra"
"HOL-Library.Discrete_Functions"
"HOL-Library.Tree"
"HOL-Library.IArray"
begin
section ‹Efficient Algorithms for the Square Root on ‹ℕ››
subsection ‹A Discrete Variant of Heron's Algorithm›
text ‹
An algorithm for calculating the discrete square root, taken from
Cohen~\<^cite>‹"cohen2010algebraic"›. This algorithm is essentially a discretised variant of
Heron's method or Newton's method specialised to the square root function.
›
lemma sqrt_eq_floor_sqrt: "floor_sqrt n = nat ⌊sqrt n⌋"
proof -
have "real ((nat ⌊sqrt n⌋)⇧2) = (real (nat ⌊sqrt n⌋))⇧2"
by simp
also have "… ≤ sqrt (real n) ^ 2"
by (intro power_mono) auto
also have "… = real n" by simp
finally have "(nat ⌊sqrt n⌋)⇧2 ≤ n"
by (simp only: of_nat_le_iff)
moreover have "n < (Suc (nat ⌊sqrt n⌋))⇧2" proof -
have "(1 + ⌊sqrt n⌋)⇧2 > n"
using floor_correct[of "sqrt n"] real_le_rsqrt[of "1 + ⌊sqrt n⌋" n]
of_int_less_iff[of n "(1 + ⌊sqrt n⌋)⇧2"] not_le
by fastforce
then show ?thesis
using le_nat_floor[of "Suc (nat ⌊sqrt n⌋)" "sqrt n"]
of_nat_le_iff[of "(Suc (nat ⌊sqrt n⌋))⇧2" n] real_le_rsqrt[of _ n] not_le
by fastforce
qed
ultimately show ?thesis using floor_sqrt_unique by fast
qed
fun newton_sqrt_aux :: "nat ⇒ nat ⇒ nat" where
"newton_sqrt_aux x n =
(let y = (x + n div x) div 2
in if y < x then newton_sqrt_aux y n else x)"
declare newton_sqrt_aux.simps [simp del]
lemma newton_sqrt_aux_simps:
"(x + n div x) div 2 < x ⟹ newton_sqrt_aux x n = newton_sqrt_aux ((x + n div x) div 2) n"
"(x + n div x) div 2 ≥ x ⟹ newton_sqrt_aux x n = x"
by (subst newton_sqrt_aux.simps; simp add: Let_def)+
lemma heron_step_real: "⟦t > 0; n ≥ 0⟧ ⟹ (t + n/t) / 2 ≥ sqrt n"
using arith_geo_mean_sqrt[of t "n/t"] by simp
lemma heron_step_div_eq_floored:
"(t::nat) > 0 ⟹ (t + (n::nat) div t) div 2 = nat ⌊(t + n/t) / 2⌋"
proof -
assume "t > 0"
then have "⌊(t + n/t) / 2⌋ = ⌊(t*t + n) / (2*t)⌋"
by (simp add: mult_divide_mult_cancel_right[of t "t + n/t" 2, symmetric]
algebra_simps)
also have "… = (t*t + n) div (2*t)"
using floor_divide_of_nat_eq by blast
also have "… = (t*t + n) div t div 2"
by (simp add: div_mult2_eq ac_simps)
also have "… = (t + n div t) div 2"
by (simp add: ‹0 < t› power2_eq_square)
finally show ?thesis by simp
qed
lemma heron_step: "t > 0 ⟹ (t + n div t) div 2 ≥ floor_sqrt n"
proof -
assume "t > 0"
have "floor_sqrt n = nat ⌊sqrt n⌋" by (rule sqrt_eq_floor_sqrt)
also have "… ≤ nat ⌊(t + n/t) / 2⌋"
using heron_step_real[of t n] ‹t > 0› by linarith
also have "… = (t + n div t) div 2"
using heron_step_div_eq_floored[OF ‹t > 0›] by simp
finally show ?thesis .
qed
lemma newton_sqrt_aux_correct:
assumes "x ≥ floor_sqrt n"
shows "newton_sqrt_aux x n = floor_sqrt n"
using assms
proof (induction x n rule: newton_sqrt_aux.induct)
case (1 x n)
show ?case
proof (cases "x = floor_sqrt n")
case True
then have "(x ^ 2) div x ≤ n div x" by (intro div_le_mono) simp_all
also have "(x ^ 2) div x = x" by (simp add: power2_eq_square)
finally have "(x + n div x) div 2 ≥ x" by linarith
with True show ?thesis by (auto simp: newton_sqrt_aux_simps)
next
case False
with "1.prems" have x_gt_sqrt: "x > floor_sqrt n" by auto
with le_floor_sqrt_iff[of x n] have "n < x ^ 2" by simp
have "x * (n div x) ≤ n" using mult_div_mod_eq[of x n] by linarith
also have "… < x ^ 2" using le_floor_sqrt_iff[of x n] and x_gt_sqrt by simp
also have "… = x * x" by (simp add: power2_eq_square)
finally have "n div x < x" by (subst (asm) mult_less_cancel1) auto
then have step_decreasing: "(x + n div x) div 2 < x" by linarith
with x_gt_sqrt have step_ge_sqrt: "(x + n div x) div 2 ≥ floor_sqrt n"
by (simp add: heron_step)
from step_decreasing have "newton_sqrt_aux x n = newton_sqrt_aux ((x + n div x) div 2) n"
by (simp add: newton_sqrt_aux_simps)
also have "… = floor_sqrt n"
by (intro "1.IH" step_decreasing step_ge_sqrt) simp_all
finally show ?thesis .
qed
qed
definition newton_sqrt :: "nat ⇒ nat" where
"newton_sqrt n = newton_sqrt_aux n n"
declare floor_sqrt_code [code del]
theorem Discrete_sqrt_eq_newton_sqrt [code]: "floor_sqrt n = newton_sqrt n"
unfolding newton_sqrt_def by (simp add: newton_sqrt_aux_correct floor_sqrt_le)
subsection ‹Square Testing›
text ‹
Next, we implement an algorithm to determine whether a given natural number is a perfect square,
as described by Cohen~\<^cite>‹"cohen2010algebraic"›. Essentially, the number first determines whether
the number is a square. Essentially
›
definition q11 :: "nat set"
where "q11 = {0, 1, 3, 4, 5, 9}"
definition q63 :: "nat set"
where "q63 = {0, 1, 4, 7, 9, 16, 28, 18, 22, 25, 36, 58, 46, 49, 37, 43}"
definition q64 :: "nat set"
where "q64 = {0, 1, 4, 9, 16, 17, 25, 36, 33, 49, 41, 57}"
definition q65 :: "nat set"
where "q65 = {0, 1, 4, 10, 14, 9, 16, 26, 30, 25, 29, 40, 56, 36, 49, 61, 35, 51, 39, 55, 64}"
definition q11_array where
"q11_array = IArray [True,True,False,True,True,True,False,False,False,True,False]"
definition q63_array where
"q63_array = IArray [True,True,False,False,True,False,False,True,False,True,False,False,
False,False,False,False,True,False,True,False,False,False,True,False,False,True,False,
False,True,False,False,False,False,False,False,False,True,True,False,False,False,False,
False,True,False,False,True,False,False,True,False,False,False,False,False,False,False,
False,True,False,False,False,False,False]"
definition q64_array where
"q64_array = IArray [True,True,False,False,True,False,False,False,False,True,False,False,
False,False,False,False,True,True,False,False,False,False,False,False,False,True,False,
False,False,False,False,False,False,True,False,False,True,False,False,False,False,True,
False,False,False,False,False,False,False,True,False,False,False,False,False,False,
False,True,False,False,False,False,False,False, False]"
definition q65_array where
"q65_array = IArray [True,True,False,False,True,False,False,False,False,True,True,False,
False,False,True,False,True,False,False,False,False,False,False,False,False,True,True,
False,False,True,True,False,False,False,False,True,True,False,False,True,True,False,
False,False,False,False,False,False,False,True,False,True,False,False,False,True,True
,False,False,False,False,True,False,False,True,False]"
lemma sub_q11_array: "i ∈ {..<11} ⟹ IArray.sub q11_array i ⟷ i ∈ q11"
by (simp add: lessThan_nat_numeral lessThan_Suc q11_def q11_array_def, elim disjE; simp)
lemma sub_q63_array: "i ∈ {..<63} ⟹ IArray.sub q63_array i ⟷ i ∈ q63"
by (simp add: lessThan_nat_numeral lessThan_Suc q63_def q63_array_def, elim disjE; simp)
lemma sub_q64_array: "i ∈ {..<64} ⟹ IArray.sub q64_array i ⟷ i ∈ q64"
by (simp add: lessThan_nat_numeral lessThan_Suc q64_def q64_array_def, elim disjE; simp)
lemma sub_q65_array: "i ∈ {..<65} ⟹ IArray.sub q65_array i ⟷ i ∈ q65"
by (simp add: lessThan_nat_numeral lessThan_Suc q65_def q65_array_def, elim disjE; simp)
lemma in_q11_code: "x mod 11 ∈ q11 ⟷ IArray.sub q11_array (x mod 11)"
by (subst sub_q11_array) auto
lemma in_q63_code: "x mod 63 ∈ q63 ⟷ IArray.sub q63_array (x mod 63)"
by (subst sub_q63_array) auto
lemma in_q64_code: "x mod 64 ∈ q64 ⟷ IArray.sub q64_array (x mod 64)"
by (subst sub_q64_array) auto
lemma in_q65_code: "x mod 65 ∈ q65 ⟷ IArray.sub q65_array (x mod 65)"
by (subst sub_q65_array) auto
definition square_test :: "nat ⇒ bool" where
"square_test n =
(n mod 64 ∈ q64 ∧ (let r = n mod 45045 in
r mod 63 ∈ q63 ∧ r mod 65 ∈ q65 ∧ r mod 11 ∈ q11 ∧ n = (floor_sqrt n)⇧2))"
lemma square_test_code [code]:
"square_test n =
(IArray.sub q64_array (n mod 64) ∧ (let r = n mod 45045 in
IArray.sub q63_array (r mod 63) ∧
IArray.sub q65_array (r mod 65) ∧
IArray.sub q11_array (r mod 11) ∧ n = (floor_sqrt n)⇧2))"
using in_q11_code [symmetric] in_q63_code [symmetric]
in_q64_code [symmetric] in_q65_code [symmetric]
by (simp add: Let_def square_test_def)
lemma square_mod_lower: "m > 0 ⟹ (q⇧2 :: nat) mod m = a ⟹ ∃q' < m. q'⇧2 mod m = a"
using mod_less_divisor mod_mod_trivial power_mod by blast
lemma q11_upto_def: "q11 = (λk. k⇧2 mod 11) ` {..<11}"
by (simp add: q11_def lessThan_nat_numeral lessThan_Suc insert_commute)
lemma q11_infinite_def: "q11 = (λk. k⇧2 mod 11) ` {0..}"
unfolding q11_upto_def image_def proof (auto, goal_cases)
case (1 xa)
show ?case
using square_mod_lower[of 11 xa "xa⇧2 mod 11"]
ex_nat_less_eq[of 11 "λx. xa⇧2 mod 11 = x⇧2 mod 11"]
by auto
qed
lemma q63_upto_def: "q63 = (λk. k⇧2 mod 63) ` {..<63}"
by (simp add: q63_def lessThan_nat_numeral lessThan_Suc insert_commute)
lemma q63_infinite_def: "q63 = (λk. k⇧2 mod 63) ` {0..}"
unfolding q63_upto_def image_def proof (auto, goal_cases)
case (1 xa)
show ?case
using square_mod_lower[of 63 xa "xa⇧2 mod 63"]
ex_nat_less_eq[of 63 "λx. xa⇧2 mod 63 = x⇧2 mod 63"]
by auto
qed
lemma q64_upto_def: "q64 = (λk. k⇧2 mod 64) ` {..<64}"
by (simp add: q64_def lessThan_nat_numeral lessThan_Suc insert_commute)
lemma q64_infinite_def: "q64 = (λk. k⇧2 mod 64) ` {0..}"
unfolding q64_upto_def image_def proof (auto, goal_cases)
case (1 xa)
show ?case
using square_mod_lower[of 64 xa "xa⇧2 mod 64"]
ex_nat_less_eq[of 64 "λx. xa⇧2 mod 64 = x⇧2 mod 64"]
by auto
qed
lemma q65_upto_def: "q65 = (λk. k⇧2 mod 65) ` {..<65}"
by (simp add: q65_def lessThan_nat_numeral lessThan_Suc insert_commute)
lemma q65_infinite_def: "q65 = (λk. k⇧2 mod 65) ` {0..}"
unfolding q65_upto_def image_def proof (auto, goal_cases)
case (1 xa)
show ?case
using square_mod_lower[of 65 xa "xa⇧2 mod 65"]
ex_nat_less_eq[of 65 "λx. xa⇧2 mod 65 = x⇧2 mod 65"]
by auto
qed
lemma square_mod_existence:
fixes n k :: nat
assumes "∃q. q⇧2 = n"
shows "∃q. n mod k = q⇧2 mod k"
using assms by auto
theorem square_test_correct: "square_test n ⟷ is_square n"
proof cases
assume "is_square n"
hence rhs: "∃q. q⇧2 = n" by (auto elim: is_nth_powerE)
note sq_mod = square_mod_existence[OF this]
have q64_member: "n mod 64 ∈ q64" using sq_mod[of 64]
unfolding q64_infinite_def image_def by simp
let ?r = "n mod 45045"
have "11 dvd (45045::nat)" "63 dvd (45045::nat)" "65 dvd (45045::nat)" by force+
then have mod_45045: "?r mod 11 = n mod 11" "?r mod 63 = n mod 63" "?r mod 65 = n mod 65"
using mod_mod_cancel[of _ 45045 n] by presburger+
then have "?r mod 11 ∈ q11" "?r mod 63 ∈ q63" "?r mod 65 ∈ q65"
using sq_mod[of 11] sq_mod[of 63] sq_mod[of 65]
unfolding q11_infinite_def q63_infinite_def q65_infinite_def image_def mod_45045
by fast+
then show ?thesis unfolding square_test_def Let_def using q64_member rhs by auto
next
assume not_rhs: "¬is_square n"
hence "∄q. q⇧2 = n" by auto
then have "(floor_sqrt n)⇧2 ≠ n" by simp
then show ?thesis unfolding square_test_def by (auto simp: is_nth_power_def)
qed
definition get_nat_sqrt :: "nat ⇒ nat option"
where "get_nat_sqrt n = (if is_square n then Some (floor_sqrt n) else None)"
lemma get_nat_sqrt_code [code]:
"get_nat_sqrt n =
(if IArray.sub q64_array (n mod 64) ∧ (let r = n mod 45045 in
IArray.sub q63_array (r mod 63) ∧
IArray.sub q65_array (r mod 65) ∧
IArray.sub q11_array (r mod 11)) then
(let x = floor_sqrt n in if x⇧2 = n then Some x else None) else None)"
unfolding get_nat_sqrt_def square_test_correct [symmetric] square_test_def
using in_q11_code [symmetric] in_q63_code [symmetric]
in_q64_code [symmetric] in_q65_code [symmetric]
by (auto split: if_splits simp: Let_def )
end