Theory BitByte
section "Bit and byte-level theorems"
theory BitByte
imports Main "Word_Lib.Syntax_Bundles" "Word_Lib.Bit_Shifts_Infix_Syntax" "Word_Lib.Bitwise"
begin
subsection ‹Basics›
unbundle bit_operations_syntax
unbundle bit_projection_infix_syntax
definition take_bits :: "nat ⇒ nat ⇒ 'a::len word ⇒ 'a::len word" (‹⟨_,_⟩_› 51)
where "take_bits l h w ≡ (w >> l) AND mask (h-l)"
text ‹
@{term "take_bits l h w"} takes a subset of bits from word @{term w}, from low (inclusive) to high (exclusive).
For example, @{term [show_types] "take_bits 2 5 (28::8word) = 7"}.
›
definition take_byte :: "nat ⇒ 'a::len word ⇒ 8word"
where "take_byte n w ≡ ucast (⟨n*8,n*8+8⟩w)"
text ‹
@{term "take_byte n w"} takes the nth byte from word @{term w}.
For example, @{term [show_types] "take_byte 1 (42 << 8::16word) = 42"}.
›
definition overwrite :: "nat ⇒ nat ⇒ 'a::len word ⇒ 'a::len word ⇒ 'a::len word"
where "overwrite l h w0 w1 ≡ ((⟨h,LENGTH('a)⟩w0) << h) OR ((⟨l,h⟩w1) << l) OR (⟨0,l⟩w0)"
text ‹
@{term "overwrite l h w0 w1"} overwrites low (inclusive) to high (exclusive) bits in
word @{term w0} with bits from word @{term w1}.
For example, @{term [show_types] "(overwrite 2 4 28 227 :: 8word) = 16"}.
›
text ‹We prove some theorems about taking the nth bit/byte of an operation. These are useful to prove
equaltiy between different words, by first applying rule @{thm word_eqI}.›
lemma bit_take_bits_iff [bit_simps]:
‹(⟨l,h⟩w) !! n ⟷ n < LENGTH('a) ∧ n < h - l ∧ w !! (n + l)› for w :: ‹'a::len word›
by (simp add: take_bits_def bit_simps ac_simps)
lemma bit_take_byte_iff [bit_simps]:
‹take_byte m w !! n ⟷ n < LENGTH('a) ∧ n < 8 ∧ w !! (n + m * 8)› for w :: ‹'a::len word›
by (auto simp add: take_byte_def bit_simps)
lemma bit_overwrite_iff [bit_simps]:
‹overwrite l h w0 w1 !! n ⟷ n < LENGTH('a) ∧
(if l ≤ n ∧ n < h then w1 else w0) !! n›
for w0 w1 :: ‹'a::len word›
by (auto simp add: overwrite_def bit_simps)
lemma nth_takebits:
fixes w :: "'a::len word"
shows "(⟨l,h⟩w) !! n = (if n < LENGTH('a) ∧ n < h - l then w !! (n + l) else False)"
by (auto simp add: bit_simps)
lemma nth_takebyte:
fixes w :: "'a::len word"
shows "take_byte (n div 8) w !! (n mod 8) = (if n mod 8 < LENGTH('a) then w!!n else False )"
by (simp add: bit_simps)
lemma nth_take_byte_overwrite:
fixes v v' :: "'a::len word"
shows "take_byte n (overwrite l h v v') !! i = (if i + n * 8 < l ∨ i + n * 8 ≥ h then take_byte n v !! i else take_byte n v' !! i)"
by (auto simp add: bit_simps dest: bit_imp_le_length)
lemma nth_bitNOT:
fixes a :: "'a::len word"
shows "(NOT a) !! n ⟷ (if n < LENGTH('a) then ¬(a !! n) else False)"
by (simp add: bit_simps)
text ‹Various simplification rules›
lemma ucast_take_bits:
fixes w :: "'a::len word"
assumes "h = LENGTH('b)"
and "LENGTH('b) ≤ LENGTH('a)"
shows "ucast (⟨0,h⟩w) = (ucast w ::'b :: len word)"
apply (rule bit_word_eqI)
using assms
apply (simp add: bit_simps)
done
lemma take_bits_ucast:
fixes w :: "'b::len word"
assumes "h = LENGTH('b)"
shows "(⟨0,h⟩ (ucast w ::'a :: len word)) = (ucast w ::'a :: len word)"
apply (rule bit_word_eqI)
using assms
apply (auto simp add: bit_simps dest: bit_imp_le_length)
done
lemma take_bits_take_bits:
fixes w :: "'a::len word"
shows "(⟨l,h⟩(⟨l',h'⟩w)) = (if min LENGTH('a) h ≥ h' - l' then ⟨l+l',h'⟩w else (⟨l+l',l'+min LENGTH('a) h⟩w))"
apply (rule bit_word_eqI)
apply (simp add: bit_simps ac_simps)
apply auto
done
lemma take_bits_overwrite:
shows "⟨l,h⟩(overwrite l h w0 w1) = ⟨l,h⟩w1"
apply (rule bit_word_eqI)
apply (simp add: bit_simps ac_simps)
apply (auto dest: bit_imp_le_length)
done
lemma overwrite_0_take_bits_0:
shows "overwrite 0 h (⟨0,h⟩w0) w1 = ⟨0,h⟩w1"
apply (rule bit_word_eqI)
apply (simp add: bit_simps ac_simps)
done
lemma take_byte_shiftlr_256:
fixes v :: "256 word"
assumes "m ≤ n"
shows "take_byte n (v << m*8) = (if (n+1)*8 ≤ 256 then take_byte (n-m) v else 0)"
apply (rule bit_word_eqI)
using assms
apply (simp add: bit_simps)
apply (simp add: algebra_simps)
done
subsection ‹ Take\_Bits and arithmetic›
text ‹This definition is based on @{thm to_bl_plus_carry}, which formulates addition as bitwise operations using @{term xor3} and @{term carry}.›
definition bitwise_add :: "(bool × bool) list ⇒ bool ⇒ bool list"
where "bitwise_add x c ≡ foldr (λ(x, y) res car. xor3 x y car # res (carry x y car)) x (λ_. []) c"
lemma length_foldr_bitwise_add:
shows "length (bitwise_add x c) = length x"
unfolding bitwise_add_def
by(induct x arbitrary: c) auto
text ‹This is the "heart" of the proof: bitwise addition of two appended zipped lists can be expressed as
two consecutive bitwise additions.
Here, I need to make the assumption that the final carry is False.
›
lemma bitwise_add_append:
assumes "x = [] ∨ ¬carry (fst (last x)) (snd (last x)) True"
shows "bitwise_add (x @ y) (x≠[] ∧ c) = bitwise_add x (x≠[] ∧ c) @ bitwise_add y False"
using assms
unfolding bitwise_add_def
by(induct x arbitrary: c) (auto simp add: case_prod_unfold xor3_def carry_def split: if_split_asm)
lemma bitwise_add_take_append:
shows "take (length x) (bitwise_add (x @ y) c) = bitwise_add x c"
unfolding bitwise_add_def
by(induct x arbitrary: c) (auto simp add: case_prod_unfold xor3_def carry_def split: if_split_asm)
lemma bitwise_add_zero:
shows "bitwise_add (replicate n (False, False)) False = replicate n False "
unfolding bitwise_add_def
by(induct n) (auto simp add: xor3_def carry_def)
lemma bitwise_add_take:
shows "take n (bitwise_add x c) = bitwise_add (take n x) c"
unfolding bitwise_add_def
by (induct n arbitrary: x c,auto)
(metis append_take_drop_id bitwise_add_def bitwise_add_take_append diff_is_0_eq' length_foldr_bitwise_add length_take nat_le_linear rev_min_pm1 take_all)
lemma fst_hd_drop_zip:
assumes "n < length x"
and "length x = length y"
shows "fst (hd (drop n (zip x y))) = hd (drop n x)"
using assms
by (induct x arbitrary: n y,auto)
(metis (no_types, lifting) Cons_nth_drop_Suc drop_zip fst_conv length_Cons list.sel(1) zip_Cons_Cons)
lemma snd_hd_drop_zip:
assumes "n < length x"
and "length x = length y"
shows "snd (hd (drop n (zip x y))) = hd (drop n y)"
using assms
by (induct x arbitrary: n y,auto)
(metis (no_types, lifting) Cons_nth_drop_Suc drop_zip snd_conv length_Cons list.sel(1) zip_Cons_Cons)
text ‹
Ucasting of @{term "a+b"} can be rewritten to taking bits of @{term a} and @{term b}.
›
lemma ucast_plus:
fixes a b :: "'a::len word"
assumes "LENGTH('a) > LENGTH('b)"
shows "(ucast (a + b) ::'b::len word) = (ucast a + ucast b::'b::len word)"
proof-
have "to_bl (ucast (a + b) ::'b::len word) = to_bl (ucast a + ucast b::'b::len word)"
using assms
apply (auto simp add: to_bl_ucast to_bl_plus_carry word_rep_drop length_foldr_bitwise_add drop_zip[symmetric] rev_drop bitwise_add_def simp del: foldr_replicate foldr_append)
apply (simp only: bitwise_add_def[symmetric] length_foldr_bitwise_add)
by (auto simp add: drop_take bitwise_add_take[symmetric] rev_take length_foldr_bitwise_add)
thus ?thesis
using word_bl.Rep_eqD
by blast
qed
lemma ucast_uminus:
fixes a b :: "'a::len word"
assumes "LENGTH('a) > LENGTH('b)"
shows "ucast (- a) = (- ucast a :: 'b::len word)"
apply (subst twos_complement)+
apply (subst word_succ_p1)+
apply (subst ucast_plus)
apply (rule assms)
apply simp
apply (rule word_eqI)
apply (auto simp add: word_size nth_ucast nth_bitNOT)
using assms order.strict_trans
by blast
lemma ucast_minus:
fixes a b :: "'a::len word"
assumes "LENGTH('a) > LENGTH('b)"
shows "(ucast (a - b) ::'b::len word) = (ucast a - ucast b::'b::len word)"
using ucast_plus[OF assms,of a "-b"] ucast_uminus[OF assms,of b]
by auto
lemma to_bl_takebits:
fixes a :: "'a::len word"
shows "to_bl (⟨0,h⟩a) = replicate (LENGTH('a) - h) False @ drop (LENGTH('a) - h) (to_bl a)"
apply (auto simp add: take_bits_def bl_word_and to_bl_mask)
apply (rule nth_equalityI)
by (auto simp add: min_def nth_append)
text ‹ All simplification rules that are used during symbolic execution.›
lemmas BitByte_simps = ucast_plus ucast_minus ucast_uminus take_bits_overwrite take_bits_take_bits
ucast_take_bits overwrite_0_take_bits_0 mask_eq_exp_minus_1
ucast_down_ucast_id is_down take_bits_ucast ucast_up_ucast_id is_up
text ‹Simplification for immediate (numeral) values.›
lemmas take_bits_numeral[simp] = take_bits_def[of _ _ "numeral n"] for n
lemmas take_bits_num0[simp] = take_bits_def[of _ _ "0"] for n
lemmas take_bits_num1[simp] = take_bits_def[of _ _ "1"] for n
lemmas overwrite_numeral_numeral[simp] = overwrite_def[of _ _ "numeral n" "numeral m"] for n m
lemmas overwrite_num0_numeral[simp] = overwrite_def[of _ _ 0 "numeral m"] for n m
lemmas overwrite_numeral_num0[simp] = overwrite_def[of _ _ "numeral m" 0] for n m
lemmas overwrite_numeral_00[simp] = overwrite_def[of _ _ 0 0]
end