Theory BitByte

(*  Title:       X86 instruction semantics and basic block symbolic execution
    Authors:     Freek Verbeek, Abhijith Bharadwaj, Joshua Bockenek, Ian Roessle, Timmy Weerwag, Binoy Ravindran
    Year:        2020
    Maintainer:  Freek Verbeek (freek@vt.edu)
*)

section "Bit and byte-level theorems"

theory BitByte
  imports Main "Word_Lib.Syntax_Bundles" "Word_Lib.Bit_Shifts_Infix_Syntax" "Word_Lib.Bitwise"
begin

subsection ‹Basics›

unbundle bit_operations_syntax
unbundle bit_projection_infix_syntax

definition take_bits :: "nat  nat  'a::len word  'a::len word" (_,__› 51) ― ‹little-endian›
  where "take_bits l h w  (w >> l) AND mask (h-l)"

text @{term "take_bits l h w"} takes a subset of bits from word @{term w}, from low (inclusive) to high (exclusive).
  For example, @{term [show_types] "take_bits 2 5 (28::8word) = 7"}.
›

definition take_byte :: "nat  'a::len word  8word" ― ‹little-endian›
  where "take_byte n w  ucast (n*8,n*8+8w)"

text @{term "take_byte n w"} takes the nth byte from word @{term w}.
  For example, @{term [show_types] "take_byte 1 (42 << 8::16word) = 42"}.
›

definition overwrite :: "nat  nat  'a::len word  'a::len word  'a::len word"
  where "overwrite l h w0 w1  ((h,LENGTH('a)w0) << h) OR ((l,hw1) << l) OR (0,lw0)"

text @{term "overwrite l h w0 w1"} overwrites low (inclusive) to high (exclusive) bits in
  word @{term w0} with bits from word @{term w1}.
  For example, @{term [show_types] "(overwrite 2 4 28 227 :: 8word) = 16"}.
›


text ‹We prove some theorems about taking the nth bit/byte of an operation. These are useful to prove
      equaltiy between different words, by first applying rule @{thm word_eqI}.›

lemma bit_take_bits_iff [bit_simps]:
  (l,hw) !! n  n < LENGTH('a)  n < h - l  w !! (n + l) for w :: 'a::len word
  by (simp add: take_bits_def bit_simps ac_simps)

lemma bit_take_byte_iff [bit_simps]:
  take_byte m w !! n  n < LENGTH('a)  n < 8  w !! (n + m * 8) for w :: 'a::len word
  by (auto simp add: take_byte_def bit_simps)

lemma bit_overwrite_iff [bit_simps]:
  overwrite l h w0 w1 !! n  n < LENGTH('a) 
    (if l  n  n < h then w1 else w0) !! n
    for w0 w1 :: 'a::len word
  by (auto simp add: overwrite_def bit_simps)

lemma nth_takebits:
  fixes w :: "'a::len word"
  shows "(l,hw) !! n = (if n < LENGTH('a)  n < h - l then w !! (n + l) else False)"
  by (auto simp add: bit_simps)

lemma nth_takebyte:
  fixes w :: "'a::len word"
  shows "take_byte (n div 8) w !! (n mod 8) = (if n mod 8 < LENGTH('a) then w!!n else False )"
  by (simp add: bit_simps)

lemma nth_take_byte_overwrite:
  fixes v v' :: "'a::len word"
  shows "take_byte n (overwrite l h v v') !! i = (if i + n * 8 < l  i + n * 8  h then take_byte n v !! i else take_byte n v' !! i)"
  by (auto simp add: bit_simps dest: bit_imp_le_length)

lemma nth_bitNOT:
  fixes a :: "'a::len word"
  shows "(NOT a) !! n  (if n < LENGTH('a)  then ¬(a !! n) else False)"
  by (simp add: bit_simps)


text ‹Various simplification rules›

lemma ucast_take_bits:
  fixes w :: "'a::len word"
  assumes "h = LENGTH('b)"
      and "LENGTH('b)  LENGTH('a)"
    shows "ucast (0,hw) = (ucast w ::'b :: len word)"
  apply (rule bit_word_eqI)
  using assms
  apply (simp add: bit_simps)
  done

lemma take_bits_ucast:
  fixes w :: "'b::len word"
  assumes "h = LENGTH('b)"
  shows "(0,h (ucast w ::'a :: len word)) = (ucast w ::'a :: len word)"
  apply (rule bit_word_eqI)
  using assms
  apply (auto simp add: bit_simps dest: bit_imp_le_length)
  done

lemma take_bits_take_bits:
  fixes w :: "'a::len word"
  shows "(l,h(l',h'w)) = (if min LENGTH('a) h  h' - l' then l+l',h'w else (l+l',l'+min LENGTH('a) hw))"
  apply (rule bit_word_eqI)
  apply (simp add: bit_simps ac_simps)
  apply auto
  done

lemma take_bits_overwrite:
  shows "l,h(overwrite l h w0 w1) = l,hw1"
  apply (rule bit_word_eqI)
  apply (simp add: bit_simps ac_simps)
  apply (auto dest: bit_imp_le_length)
  done

lemma overwrite_0_take_bits_0:
  shows "overwrite 0 h (0,hw0) w1 = 0,hw1"
  apply (rule bit_word_eqI)
  apply (simp add: bit_simps ac_simps)
  done

lemma take_byte_shiftlr_256:
  fixes v :: "256 word"
  assumes "m  n"
  shows "take_byte n (v << m*8) = (if (n+1)*8  256 then take_byte (n-m) v else 0)"
  apply (rule bit_word_eqI)
  using assms
  apply (simp add: bit_simps)
  apply (simp add: algebra_simps)
  done


subsection ‹ Take\_Bits and arithmetic›

text ‹This definition is based on @{thm to_bl_plus_carry}, which formulates addition as bitwise operations using @{term xor3} and @{term carry}.›

definition bitwise_add :: "(bool × bool) list  bool  bool list"
  where "bitwise_add x c  foldr (λ(x, y) res car. xor3 x y car # res (carry x y car)) x (λ_. []) c"

lemma length_foldr_bitwise_add:
  shows "length (bitwise_add x c) = length x"
  unfolding bitwise_add_def
  by(induct x arbitrary: c) auto

text ‹This is the "heart" of the proof: bitwise addition of two appended zipped lists can be expressed as
      two consecutive bitwise additions.
      Here, I need to make the assumption that the final carry is False.
 ›
lemma bitwise_add_append:
  assumes "x = []  ¬carry (fst (last x)) (snd (last x)) True"
  shows "bitwise_add (x @ y) (x[]  c) = bitwise_add x (x[]  c) @ bitwise_add y False"
  using assms
  unfolding bitwise_add_def
  by(induct x arbitrary: c) (auto simp add: case_prod_unfold xor3_def carry_def split: if_split_asm)

lemma bitwise_add_take_append:
  shows "take (length x) (bitwise_add (x @ y) c) = bitwise_add x c"
  unfolding bitwise_add_def
  by(induct x arbitrary: c) (auto simp add: case_prod_unfold xor3_def carry_def split: if_split_asm)

lemma bitwise_add_zero:
  shows "bitwise_add (replicate n (False, False)) False = replicate n False "
  unfolding bitwise_add_def
  by(induct n) (auto simp add: xor3_def carry_def)

lemma bitwise_add_take:
  shows "take n (bitwise_add x c) = bitwise_add (take n x) c"
  unfolding bitwise_add_def
  by (induct n arbitrary: x c,auto)
     (metis append_take_drop_id bitwise_add_def bitwise_add_take_append diff_is_0_eq' length_foldr_bitwise_add length_take nat_le_linear rev_min_pm1 take_all)


lemma fst_hd_drop_zip:
  assumes "n < length x"
      and "length x = length y"
  shows "fst (hd (drop n (zip x y))) = hd (drop n x)"
  using assms
  by (induct x arbitrary: n y,auto)
     (metis (no_types, lifting) Cons_nth_drop_Suc drop_zip fst_conv length_Cons list.sel(1) zip_Cons_Cons)

lemma snd_hd_drop_zip:
  assumes "n < length x"
      and "length x = length y"
  shows "snd (hd (drop n (zip x y))) = hd (drop n y)"
  using assms
  by (induct x arbitrary: n y,auto)
     (metis (no_types, lifting) Cons_nth_drop_Suc drop_zip snd_conv length_Cons list.sel(1) zip_Cons_Cons)

text ‹
  Ucasting of @{term "a+b"} can be rewritten to taking bits of @{term a} and @{term b}.
›

lemma ucast_plus:
  fixes a b :: "'a::len word"
  assumes "LENGTH('a) > LENGTH('b)"
  shows "(ucast (a + b) ::'b::len word) = (ucast a + ucast b::'b::len word)"
proof-
  have "to_bl (ucast (a + b) ::'b::len word) = to_bl (ucast a + ucast b::'b::len word)"
    using assms
    apply (auto simp add: to_bl_ucast to_bl_plus_carry word_rep_drop length_foldr_bitwise_add drop_zip[symmetric] rev_drop bitwise_add_def simp del: foldr_replicate foldr_append)
    apply (simp only: bitwise_add_def[symmetric] length_foldr_bitwise_add)
    by (auto simp add: drop_take bitwise_add_take[symmetric] rev_take length_foldr_bitwise_add)
  thus ?thesis
    using word_bl.Rep_eqD
    by blast
qed

lemma ucast_uminus:
  fixes a b :: "'a::len word"
assumes "LENGTH('a) > LENGTH('b)"
  shows "ucast (- a) = (- ucast a :: 'b::len word)"
  apply (subst twos_complement)+
  apply (subst word_succ_p1)+
  apply (subst ucast_plus)
  apply (rule assms)
   apply simp
   apply (rule word_eqI)
   apply (auto simp add: word_size nth_ucast nth_bitNOT)
   using assms order.strict_trans
   by blast

lemma ucast_minus:
  fixes a b :: "'a::len word"
  assumes "LENGTH('a) > LENGTH('b)"
  shows "(ucast (a - b) ::'b::len word) = (ucast a - ucast b::'b::len word)"
  using ucast_plus[OF assms,of a "-b"] ucast_uminus[OF assms,of b]
  by auto

lemma to_bl_takebits:
  fixes a :: "'a::len word"
  shows "to_bl (0,ha) = replicate (LENGTH('a) - h) False @ drop (LENGTH('a) - h) (to_bl a)"
  apply (auto simp add: take_bits_def bl_word_and to_bl_mask)
  apply (rule nth_equalityI)
  by (auto simp add: min_def nth_append)


text ‹ All simplification rules that are used during symbolic execution.›
lemmas BitByte_simps = ucast_plus ucast_minus ucast_uminus take_bits_overwrite take_bits_take_bits
  ucast_take_bits overwrite_0_take_bits_0 mask_eq_exp_minus_1
  ucast_down_ucast_id is_down take_bits_ucast ucast_up_ucast_id is_up

text ‹Simplification for immediate (numeral) values.›
lemmas take_bits_numeral[simp] = take_bits_def[of _ _ "numeral n"] for n
lemmas take_bits_num0[simp] = take_bits_def[of _ _ "0"] for n
lemmas take_bits_num1[simp] = take_bits_def[of _ _ "1"] for n
lemmas overwrite_numeral_numeral[simp] = overwrite_def[of _ _ "numeral n" "numeral m"] for n m
lemmas overwrite_num0_numeral[simp] = overwrite_def[of _ _ 0 "numeral m"] for n m
lemmas overwrite_numeral_num0[simp] = overwrite_def[of _ _ "numeral m" 0] for n m
lemmas overwrite_numeral_00[simp] = overwrite_def[of _ _ 0 0]

end