Theory Word_Lib.Signed_Division_Word

(*
 * Copyright Data61, CSIRO (ABN 41 687 119 230)
 *
 * SPDX-License-Identifier: BSD-2-Clause
Proofs tidied by LCP, 2024-09
 *)

section ‹Signed division on word›

theory Signed_Division_Word
  imports "HOL-Library.Signed_Division" "HOL-Library.Word"
begin

text ‹
  The following specification of division follows ISO C99, which in turn adopted the typical
  behavior of hardware modern in the beginning of the 1990ies.
  The underlying integer division is named ``T-division'' in \cite{leijen01}.
›

instantiation word :: (len) signed_division
begin

lift_definition signed_divide_word :: 'a::len word  'a word  'a word
  is λk l. signed_take_bit (LENGTH('a) - Suc 0) k sdiv signed_take_bit (LENGTH('a) - Suc 0) l
  by (simp flip: signed_take_bit_decr_length_iff)

lift_definition signed_modulo_word :: 'a::len word  'a word  'a word
  is λk l. signed_take_bit (LENGTH('a) - Suc 0) k smod signed_take_bit (LENGTH('a) - Suc 0) l
  by (simp flip: signed_take_bit_decr_length_iff)

lemma sdiv_word_def:
  v sdiv w = word_of_int (sint v sdiv sint w)
  for v w :: 'a::len word
  by transfer simp

lemma smod_word_def:
  v smod w = word_of_int (sint v smod sint w)
  for v w :: 'a::len word
  by transfer simp

instance proof
  fix v w :: 'a word
  have sint v sdiv sint w * sint w + sint v smod sint w = sint v
    by (fact sdiv_mult_smod_eq)
  then have word_of_int (sint v sdiv sint w * sint w + sint v smod sint w) = (word_of_int (sint v) :: 'a word)
    by simp
  then show v sdiv w * w + v smod w = v
    by (simp add: sdiv_word_def smod_word_def)
qed

end

lemma signed_divide_word_code [code]: contributor ‹Andreas Lochbihler›
  v sdiv w =
   (let v' = sint v; w' = sint w;
        negative = (v' < 0)  (w' < 0);
        result = ¦v'¦ div ¦w'¦
    in word_of_int (if negative then - result else result))
  for v w :: 'a::len word
  by (simp add: sdiv_word_def signed_divide_int_def sgn_if)

lemma signed_modulo_word_code [code]: contributor ‹Andreas Lochbihler›
  v smod w =
   (let v' = sint v; w' = sint w;
        negative = (v' < 0);
        result = ¦v'¦ mod ¦w'¦
    in word_of_int (if negative then - result else result))
  for v w :: 'a::len word
  by (simp add: smod_word_def signed_modulo_int_def sgn_if)

lemma sdiv_smod_id:
  (a sdiv b) * b + (a smod b) = a
  for a b :: 'a::len word
  by (fact sdiv_mult_smod_eq)

lemma signed_div_arith:
    "sint ((a::('a::len) word) sdiv b) = signed_take_bit (LENGTH('a) - 1) (sint a sdiv sint b)"
  by (simp add: sdiv_word_def sint_sbintrunc')

lemma signed_mod_arith:
    "sint ((a::('a::len) word) smod b) = signed_take_bit (LENGTH('a) - 1) (sint a smod sint b)"
  by (simp add: sint_sbintrunc' smod_word_def)

lemma word_sdiv_div0 [simp]:
    "(a :: ('a::len) word) sdiv 0 = 0"
  by (auto simp: sdiv_word_def signed_divide_int_def sgn_if)

lemma smod_word_zero [simp]:
  w smod 0 = w for w :: 'a::len word
  by transfer (simp add: take_bit_signed_take_bit)

lemma word_sdiv_div1 [simp]:
    "(a :: ('a::len) word) sdiv 1 = a"
proof -
  have "sint (- (1::'a word)) = - 1"
    by simp
  then show ?thesis
    by (metis int_sdiv_simps(1) mult_1 mult_minus_left scast_eq scast_id 
        sdiv_minus_eq sdiv_word_def signed_1 wi_hom_neg)
qed

lemma smod_word_one [simp]:
  w smod 1 = 0 for w :: 'a::len word
  by (simp add: smod_word_def signed_modulo_int_def)

lemma word_sdiv_div_minus1 [simp]:
    "(a :: ('a::len) word) sdiv -1 = -a"
  by (simp add: sdiv_word_def)

lemma smod_word_minus_one [simp]:
  w smod - 1 = 0 for w :: 'a::len word
  by (simp add: smod_word_def signed_modulo_int_def)

lemma one_sdiv_word_eq [simp]:
  1 sdiv w = of_bool (w = 1  w = - 1) * w for w :: 'a::len word
proof (cases 1 < ¦sint w¦)
  case True
  then show ?thesis
    by (auto simp add: sdiv_word_def signed_divide_int_def split: if_splits)
next
  case False
  then have ¦sint w¦  1
    by simp
  then have sint w  {- 1, 0, 1}
    by auto
  then have (word_of_int (sint w) :: 'a::len word)  word_of_int ` {- 1, 0, 1}
    by blast
  then have w  {- 1, 0, 1}
    by simp
  then show ?thesis by auto
qed

lemma one_smod_word_eq [simp]:
  1 smod w = 1 - of_bool (w = 1  w = - 1) for w :: 'a::len word
  using sdiv_smod_id [of 1 w] by auto

lemma minus_one_sdiv_word_eq [simp]:
  - 1 sdiv w = - (1 sdiv w) for w :: 'a::len word
  by (metis (mono_tags, opaque_lifting) minus_sdiv_eq of_int_minus sdiv_word_def signed_1 sint_n1  
      word_sdiv_div1 word_sdiv_div_minus1)

lemma minus_one_smod_word_eq [simp]:
  - 1 smod w = - (1 smod w) for w :: 'a::len word
  using sdiv_smod_id [of - 1 w] by auto

lemma smod_word_alt_def:
  "(a :: ('a::len) word) smod b = a - (a sdiv b) * b"
  by (simp add: minus_sdiv_mult_eq_smod)

lemmas sdiv_word_numeral_numeral [simp] =
  sdiv_word_def [of numeral a numeral b, simplified sint_sbintrunc sint_sbintrunc_neg]
  for a b
lemmas sdiv_word_minus_numeral_numeral [simp] =
  sdiv_word_def [of - numeral a numeral b, simplified sint_sbintrunc sint_sbintrunc_neg]
  for a b
lemmas sdiv_word_numeral_minus_numeral [simp] =
  sdiv_word_def [of numeral a - numeral b, simplified sint_sbintrunc sint_sbintrunc_neg]
  for a b
lemmas sdiv_word_minus_numeral_minus_numeral [simp] =
  sdiv_word_def [of - numeral a - numeral b, simplified sint_sbintrunc sint_sbintrunc_neg]
  for a b

lemmas smod_word_numeral_numeral [simp] =
  smod_word_def [of numeral a numeral b, simplified sint_sbintrunc sint_sbintrunc_neg]
  for a b
lemmas smod_word_minus_numeral_numeral [simp] =
  smod_word_def [of - numeral a numeral b, simplified sint_sbintrunc sint_sbintrunc_neg]
  for a b
lemmas smod_word_numeral_minus_numeral [simp] =
  smod_word_def [of numeral a - numeral b, simplified sint_sbintrunc sint_sbintrunc_neg]
  for a b
lemmas smod_word_minus_numeral_minus_numeral [simp] =
  smod_word_def [of - numeral a - numeral b, simplified sint_sbintrunc sint_sbintrunc_neg]
  for a b

lemmas word_sdiv_0 = word_sdiv_div0

lemma sdiv_word_min:
    "- (2 ^ (size a - 1))  sint (a :: ('a::len) word) sdiv sint (b :: ('a::len) word)"
  by (smt (verit, ccfv_SIG) atLeastAtMost_iff sdiv_int_range sint_ge sint_lt wsst_TYs(3))

lemma sdiv_word_max:
  sint a sdiv sint b  2 ^ (size a - Suc 0)
  for a b :: 'a::len word
proof (cases sint a = 0  sint b = 0  sgn (sint a)  sgn (sint b))
  case True then show ?thesis
  proof -
    have "i j::int. i sdiv j  ¦i¦"
      by (meson atLeastAtMost_iff sdiv_int_range)
    then show ?thesis
      by (smt (verit) sint_range_size)
  qed
next
  case False
  then have ¦sint a¦ div ¦sint b¦  ¦sint a¦
    by (subst nat_le_eq_zle [symmetric]) (simp_all add: div_abs_eq_div_nat)
  also have ¦sint a¦  2 ^ (size a - Suc 0)
    using sint_range_size [of a] by auto
  finally show ?thesis
    using False by (simp add: signed_divide_int_def)
qed

lemmas word_sdiv_numerals_lhs = sdiv_word_def[where v="numeral x" for x]
    sdiv_word_def[where v=0] sdiv_word_def[where v=1]

lemmas word_sdiv_numerals = word_sdiv_numerals_lhs[where w="numeral y" for y]
    word_sdiv_numerals_lhs[where w=0] word_sdiv_numerals_lhs[where w=1]

lemma smod_word_mod_0:
  "x smod (0 :: ('a::len) word) = x"
  by (fact smod_word_zero)

lemma smod_word_0_mod [simp]:
  "0 smod (x :: ('a::len) word) = 0"
  by (clarsimp simp: smod_word_def)

lemma smod_word_max:
  "sint (a::'a word) smod sint (b::'a word) < 2 ^ (LENGTH('a::len) - Suc 0)"
proof (cases sint b = 0  LENGTH('a) = 0)
  case True
  then show ?thesis
    by (force simp: sint_less)
next
  case False
  then show ?thesis
    by (smt (verit) sint_greater_eq sint_less smod_int_compares)
qed

lemma smod_word_min:
  "- (2 ^ (LENGTH('a::len) - Suc 0))  sint (a::'a word) smod sint (b::'a word)"
  by (smt (verit) sint_greater_eq sint_less smod_int_compares smod_int_mod_0)

lemmas word_smod_numerals_lhs = smod_word_def[where v="numeral x" for x]
    smod_word_def[where v=0] smod_word_def[where v=1]

lemmas word_smod_numerals = word_smod_numerals_lhs[where w="numeral y" for y]
    word_smod_numerals_lhs[where w=0] word_smod_numerals_lhs[where w=1]

end