Theory Word_Lib.Bit_Comprehension_Int

(*
 * Copyright Brian Huffman, PSU; Jeremy Dawson and Gerwin Klein, NICTA
 *
 * SPDX-License-Identifier: BSD-2-Clause
 *)

section ‹Comprehension syntax for int›

theory Bit_Comprehension_Int
  imports
    Bit_Comprehension
begin

instantiation int :: bit_comprehension
begin

definition
  set_bits f = (
      if n. mn. f m = f n then
      let n = LEAST n. mn. f m = f n
      in signed_take_bit n (horner_sum of_bool 2 (map f [0..<Suc n]))
     else 0 :: int)

instance proof
  fix k :: int
  from int_bit_bound [of k]
  obtain n where *: m. n  m  bit k m  bit k n
    and **: n > 0  bit k (n - 1)  bit k n
    by blast
  have l: (LEAST q. mq. bit k m  bit k q) = n
  proof (rule Least_equality)
    show "mn. bit k m = bit k n"
      using * by blast
    show "y. my. bit k m = bit k y  n  y"
      by (metis "**" One_nat_def Suc_pred le_cases le0 neq0_conv not_less_eq_eq)
  qed
  have "signed_take_bit n (take_bit (Suc n) k) = k"
    apply (rule bit_eqI)
    by (metis "*" bit_signed_take_bit_iff bit_take_bit_iff leI lessI less_SucI min.absorb4 min.order_iff)
  then show set_bits (bit k) = k
    unfolding * set_bits_int_def horner_sum_bit_eq_take_bit l
    using "*" by auto
qed

end

lemma int_set_bits_K_False [simp]: "(BITS _. False) = (0 :: int)"
  by (simp add: set_bits_int_def)

lemma int_set_bits_K_True [simp]: "(BITS _. True) = (-1 :: int)"
  by (simp add: set_bits_int_def)

lemma set_bits_code [code]:
  "set_bits = Code.abort (STR ''set_bits is unsupported on type int'') (λ_. set_bits :: _  int)"
  by simp

lemma set_bits_int_unfold':
  set_bits f =
    (if n. n'n. ¬ f n' then
      let n = LEAST n. n'n. ¬ f n'
      in horner_sum of_bool 2 (map f [0..<n])
     else if n. n'n. f n' then
      let n = LEAST n. n'n. f n'
      in signed_take_bit n (horner_sum of_bool 2 (map f [0..<n] @ [True]))
     else 0 :: int)
proof (cases n. mn. f m  f n)
  case True
  then obtain q where q: mq. f m  f q
    by blast
  define n where n = (LEAST n. mn. f m  f n)
  have mn. f m  f n
    unfolding n_def
    using q by (rule LeastI [of _ q])
  then have n: m. n  m  f m  f n
    by blast
  from n_def have n_eq: (LEAST q. mq. f m  f n) = n
    by (smt (verit, best) Least_le mn. f m = f n dual_order.antisym wellorder_Least_lemma(1))
  show ?thesis
  proof (cases f n)
    case False
    with n have *: n. n'n. ¬ f n'
      by blast
    have **: (LEAST n. n'n. ¬ f n') = n
      using False n_eq by simp
    from * False show ?thesis
      unfolding set_bits_int_def n_def [symmetric] **
      by (auto simp add: take_bit_horner_sum_bit_eq bit_horner_sum_bit_iff take_map
          signed_take_bit_def set_bits_int_def horner_sum_bit_eq_take_bit simp del: upt.upt_Suc)
  next
    case True
    with n obtain *: n. n'n. f n' ¬ (n. n'n. ¬ f n')
      by (metis linorder_linear)
    have **: (LEAST n. n'n. f n') = n
      using True n_eq by simp
    from * True show ?thesis
      unfolding set_bits_int_def n_def [symmetric] **
      by (auto simp add: take_bit_horner_sum_bit_eq
      bit_horner_sum_bit_iff take_map
      signed_take_bit_def set_bits_int_def
      horner_sum_bit_eq_take_bit nth_append simp del: upt.upt_Suc)
  qed
next
  case False
  then show ?thesis
    by (auto simp add: set_bits_int_def)
qed

inductive wf_set_bits_int :: "(nat  bool)  bool"
  for f :: "nat  bool"
where
  zeros: "n'  n. ¬ f n'  wf_set_bits_int f"
| ones: "n'  n. f n'  wf_set_bits_int f"

lemma wf_set_bits_int_simps: "wf_set_bits_int f  (n. (n'n. ¬ f n')  (n'n. f n'))"
by(auto simp add: wf_set_bits_int.simps)

lemma wf_set_bits_int_const [simp]: "wf_set_bits_int (λ_. b)"
by(cases b)(auto intro: wf_set_bits_int.intros)

lemma wf_set_bits_int_fun_upd [simp]:
  "wf_set_bits_int (f(n := b))  wf_set_bits_int f" (is "?lhs  ?rhs")
proof
  assume ?lhs
  then obtain n'
    where "(n''n'. ¬ (f(n := b)) n'')  (n''n'. (f(n := b)) n'')"
    by(auto simp add: wf_set_bits_int_simps)
  hence "(n''max (Suc n) n'. ¬ f n'')  (n''max (Suc n) n'. f n'')" by auto
  thus ?rhs by(auto simp only: wf_set_bits_int_simps)
next
  assume ?rhs
  then obtain n' where "(n''n'. ¬ f n'')  (n''n'. f n'')" (is "?wf f n'")
    by(auto simp add: wf_set_bits_int_simps)
  hence "?wf (f(n := b)) (max (Suc n) n')" by auto
  thus ?lhs by(auto simp only: wf_set_bits_int_simps)
qed

lemma wf_set_bits_int_Suc [simp]:
  "wf_set_bits_int (λn. f (Suc n))  wf_set_bits_int f" (is "?lhs  ?rhs")
by(auto simp add: wf_set_bits_int_simps intro: le_SucI dest: Suc_le_D)

context
  fixes f
  assumes wff: "wf_set_bits_int f"
begin

lemma int_set_bits_unfold_BIT:
  "set_bits f = of_bool (f 0) + (2 :: int) * set_bits (f  Suc)"
using wff proof cases
  case (zeros n)
  show ?thesis
  proof(cases "n. ¬ f n")
    case True
    hence "f = (λ_. False)" by auto
    thus ?thesis using True by(simp add: o_def)
  next
    case False
    then obtain n' where "f n'" by blast
    with zeros have "(LEAST n. n'n. ¬ f n') = Suc (LEAST n. n'Suc n. ¬ f n')"
      by(auto intro: Least_Suc)
    also have "(λn. n'Suc n. ¬ f n') = (λn. n'n. ¬ f (Suc n'))" by(auto dest: Suc_le_D)
    also from zeros have "n'n. ¬ f (Suc n')" by auto
    ultimately show ?thesis using zeros
      apply (simp (no_asm_simp) add: set_bits_int_unfold' exI
        del: upt.upt_Suc flip: map_map split del: if_split)
      apply (simp only: map_Suc_upt upt_conv_Cons)
      apply simp
      done
  qed
next
  case (ones n)
  show ?thesis
  proof(cases "n. f n")
    case True
    hence "f = (λ_. True)" by auto
    thus ?thesis using True by(simp add: o_def)
  next
    case False
    then obtain n' where "¬ f n'" by blast
    with ones have "(LEAST n. n'n. f n') = Suc (LEAST n. n'Suc n. f n')"
      by(auto intro: Least_Suc)
    also have "(λn. n'Suc n. f n') = (λn. n'n. f (Suc n'))" by(auto dest: Suc_le_D)
    also from ones have "n'n. f (Suc n')" by auto
    moreover from ones have "(n. n'n. ¬ f n') = False"
      by(auto intro!: exI[where x="max n m" for n m] simp add: max_def split: if_split_asm)
    moreover hence "(n. n'n. ¬ f (Suc n')) = False"
      by(auto elim: allE[where x="Suc n" for n] dest: Suc_le_D)
    ultimately show ?thesis using ones
      apply (simp (no_asm_simp) add: set_bits_int_unfold' exI split del: if_split)
      apply (auto simp add: Let_def hd_map map_tl[symmetric] map_map[symmetric] map_Suc_upt upt_conv_Cons signed_take_bit_Suc
        not_le simp del: map_map)
      done
  qed
qed

lemma bin_last_set_bits [simp]:
  "odd (set_bits f :: int) = f 0"
  by (subst int_set_bits_unfold_BIT) simp_all

lemma bin_rest_set_bits [simp]:
  "set_bits f div (2 :: int) = set_bits (f  Suc)"
  by (subst int_set_bits_unfold_BIT) simp_all

lemma bin_nth_set_bits [simp]:
  "bit (set_bits f :: int) m  f m"
using wff proof (induction m arbitrary: f)
  case 0
  then show ?case
    by (simp add: Bit_Comprehension_Int.bin_last_set_bits bit_0)
next
  case Suc
  from Suc.IH [of "f  Suc"] Suc.prems show ?case
    by (simp add: Bit_Comprehension_Int.bin_rest_set_bits comp_def bit_Suc)
qed

end

end