Theory Karatsuba_Sqrt_Float
theory Karatsuba_Sqrt_Float
imports
Karatsuba_Sqrt
"HOL-Library.Interval_Float"
begin
subsection ‹Floating-point approximation of \<^const>‹sqrt››
definition shift_int :: "int ⇒ int ⇒ int"
where "shift_int k n = (if k ≥ 0 then n * 2 ^ nat k else n div 2 ^ (nat (-k)))"
lemma shift_int_code [code]:
"shift_int k n = (if k ≥ 0 then push_bit (nat k) n else drop_bit (nat (-k)) n)"
by (simp add: shift_int_def push_bit_eq_mult drop_bit_eq_div)
definition lb_sqrt :: "nat ⇒ float ⇒ float" where
"lb_sqrt prec x = (let n = mantissa x; e = exponent x; k = nat (2 * int prec - bitlen n);
k' = (if even k = even e then k else k + 1) in
normfloat (Float (sqrt_int_floor (shift_int k' n)) (shift_int (-1) (e - k'))))"
definition ub_sqrt :: "nat ⇒ float ⇒ float" where
"ub_sqrt prec x = (let n = mantissa x; e = exponent x; k = nat (2 * prec - bitlen n);
k' = (if even k = even e then k else k + 1) in
normfloat (Float (sqrt_int_ceiling (shift_int k' n)) (shift_int (-1) (e - k'))))"
lemma lb_sqrt: "lb_sqrt prec x ≤ sqrt x"
proof -
define n where "n = mantissa x"
define e where "e = exponent x"
define k where "k = nat (2 * int prec - bitlen n)"
define k' where "k' = (if even k = even e then k else k + 1)"
have "even (e - k')"
by (auto simp: k'_def)
define e'' where "e'' = (e - k') div 2"
have e'': "k' = e - 2 * e''"
using ‹even (e - k')› by (auto simp: e''_def)
have "real_of_float (lb_sqrt prec x) = of_int ⌊sqrt (n * 2 powi int k')⌋ * 2 powi ((e - k') div 2)"
by (simp add: lb_sqrt_def n_def e_def k_def k'_def
Let_def powr_real_of_int' shift_int_def add_ac nat_add_distrib
sqrt_int_floor_def sqrt_int_ceiling_def)
also have "… ≤ sqrt (n * 2 powi int k') * 2 powi ((e - k') div 2)"
by (intro mult_right_mono) auto
also have "… = sqrt (of_int n * 2 powi e) * (2 powi e'' / sqrt (2 powi (2 * e'')))"
unfolding e'' by (simp add: power_int_diff real_sqrt_divide)
also have "2 powi (2 * e'') = (2 powi e'' :: real) ^ 2"
by (simp add: mult.commute power_int_mult)
also have "sqrt … = 2 powi e''"
by simp
also have "real_of_int n * 2 powi e = real_of_float (Float n e)"
by (simp add: powr_real_of_int')
also have "Float n e = x"
by (simp add: n_def e_def Float_mantissa_exponent)
finally show ?thesis
by simp
qed
lemma ub_sqrt: "ub_sqrt prec x ≥ sqrt x"
proof -
define n where "n = mantissa x"
define e where "e = exponent x"
define k where "k = nat (2 * int prec - bitlen n)"
define k' where "k' = (if even k = even e then k else k + 1)"
have "even (e - k')"
by (auto simp: k'_def)
define e'' where "e'' = (e - k') div 2"
have e'': "k' = e - 2 * e''"
using ‹even (e - k')› by (auto simp: e''_def)
have "sqrt x = sqrt (Float n e)"
by (simp add: n_def e_def Float_mantissa_exponent)
also have "… = sqrt (of_int n * 2 powi e) * (2 powi e'' / sqrt (2 powi (2 * e'')))"
by (simp add: mult.commute power_int_mult powr_real_of_int')
also have "… = sqrt (of_int n * 2 powi (e - 2 * e'')) * 2 powi e''"
by (simp add: real_sqrt_divide power_int_diff)
also have "… = sqrt (of_int n * 2 powi int k') * 2 powi ((e - k') div 2)"
unfolding e'' by simp
also have "… ≤ ⌈sqrt (of_int n * 2 powi int k')⌉ * 2 powi ((e - k') div 2)"
by (intro mult_right_mono) auto
also have "… = real_of_float (ub_sqrt prec x)"
by (simp add: ub_sqrt_def n_def e_def k_def k'_def
Let_def powr_real_of_int' shift_int_def add_ac nat_add_distrib
sqrt_int_floor_def sqrt_int_ceiling_def)
finally show ?thesis .
qed
context
includes interval.lifting
begin
lift_definition sqrt_float_interval :: "nat ⇒ float interval ⇒ float interval" is
"λprec (l, u). (lb_sqrt prec l, ub_sqrt prec u)"
proof goal_cases
case (1 prec lu)
obtain l u where [simp]: "lu = (l, u)"
by (cases lu)
have "real_of_float (lb_sqrt prec l) ≤ sqrt l"
by (rule lb_sqrt)
also have "… ≤ sqrt u"
using 1 by auto
also have "… ≤ real_of_float (ub_sqrt prec u)"
by (rule ub_sqrt)
finally show ?case
by simp
qed
lemma sqrt_float_intervalI:
fixes x :: real and X :: "float interval"
assumes "x ∈ set_of (real_interval X)"
shows "sqrt x ∈ set_of (real_interval (sqrt_float_interval prec X))"
using assms
proof (transfer, goal_cases)
case (1 x lu prec)
obtain l u where [simp]: "lu = (l, u)"
by (cases lu)
from 1 have x: "real_of_float l ≤ x" "x ≤ real_of_float u"
by simp_all
have "real_of_float (lb_sqrt prec l) ≤ sqrt x"
using lb_sqrt[of prec l] x(1) by (meson dual_order.trans real_sqrt_le_iff)
moreover have "real_of_float (ub_sqrt prec u) ≥ sqrt x"
using ub_sqrt[of u prec] x(2) by (meson dual_order.trans real_sqrt_le_iff)
ultimately show ?case
by simp
qed
lemma sqrt_float_interval:
"sqrt ` set_of (real_interval X) ⊆ set_of (real_interval (sqrt_float_interval prec X))"
using sqrt_float_intervalI[of _ X] by blast
end
end