Theory Bit_Shifts_Infix_Syntax

(*
 * Copyright Data61, CSIRO (ABN 41 687 119 230)
 *
 * SPDX-License-Identifier: BSD-2-Clause
 *)

(* Author: Jeremy Dawson, NICTA *)

section ‹Shift operations with infix syntax›

theory Bit_Shifts_Infix_Syntax
  imports "HOL-Library.Word" More_Word
begin

context semiring_bit_operations
begin

definition shiftl :: 'a  nat  'a  (infixl "<<" 55)
  where [code_unfold]: a << n = push_bit n a

lemma bit_shiftl_iff [bit_simps]:
  bit (a << m) n  m  n  possible_bit TYPE('a) n  bit a (n - m)
  by (simp add: shiftl_def bit_simps)

definition shiftr :: 'a  nat  'a  (infixl ">>" 55)
  where [code_unfold]: a >> n = drop_bit n a

lemma bit_shiftr_eq [bit_simps]:
  bit (a >> n) = bit a  (+) n
  by (simp add: shiftr_def bit_simps)

end

definition sshiftr :: 'a::len word  nat  'a word  (infixl >>> 55)
  where [code_unfold]: w >>> n = signed_drop_bit n w

lemma bit_sshiftr_iff [bit_simps]:
  bit (w >>> m) n  bit w (if LENGTH('a) - m  n  n < LENGTH('a) then LENGTH('a) - 1 else m + n)
  for w :: 'a::len word
  by (simp add: sshiftr_def bit_simps)

context
  includes lifting_syntax
begin

lemma shiftl_word_transfer [transfer_rule]:
  (pcr_word ===> (=) ===> pcr_word) (λk n. push_bit n k) (<<)
  apply (unfold shiftl_def)
  apply transfer_prover
  done

lemma shiftr_word_transfer [transfer_rule]:
  ((pcr_word :: int  'a::len word  _) ===> (=) ===> pcr_word)
    (λk n. drop_bit n (take_bit LENGTH('a) k))
    (>>)
proof -
  have ((pcr_word :: int  'a::len word  _) ===> (=) ===> pcr_word)
    (λk n. (drop_bit n  take_bit LENGTH('a)) k)
    (>>)
    by (unfold shiftr_def) transfer_prover
  then show ?thesis
    by simp
qed

lemma sshiftr_transfer [transfer_rule]:
  ((pcr_word :: int  'a::len word  _) ===> (=) ===> pcr_word)
    (λk n. drop_bit n (signed_take_bit (LENGTH('a) - Suc 0) k))
    (>>>)
proof -
  have ((pcr_word :: int  'a::len word  _) ===> (=) ===> pcr_word)
    (λk n. (drop_bit n  signed_take_bit (LENGTH('a) - Suc 0)) k)
    (>>>)
    by (unfold sshiftr_def) transfer_prover
  then show ?thesis
    by simp
qed

end

context semiring_bit_operations
begin

lemma shiftl_0 [simp]:
  0 << n = 0
  by (simp add: shiftl_def)

lemma shiftl_of_0 [simp]:
  a << 0 = a
  by (simp add: shiftl_def)

lemma shiftl_of_Suc [simp]:
  a << Suc n = (a * 2) << n
  by (simp add: shiftl_def)

lemma shiftl_1 [simp]:
  1 << n = 2 ^ n
  by (simp add: shiftl_def)

lemma shiftl_numeral_Suc [simp]:
  numeral m << Suc n = push_bit (Suc n) (numeral m)
  by (fact shiftl_def)

lemma shiftl_numeral_numeral [simp]:
  numeral m << numeral n = push_bit (numeral n) (numeral m)
  by (fact shiftl_def)

lemma shiftr_0 [simp]:
  0 >> n = 0
  by (simp add: shiftr_def)

lemma shiftr_of_0 [simp]:
  a >> 0 = a
  by (simp add: shiftr_def)

lemma shiftr_1 [simp]:
  1 >> n = of_bool (n = 0)
  by (simp add: shiftr_def)

lemma shiftr_numeral_Suc [simp]:
  numeral m >> Suc n = drop_bit (Suc n) (numeral m)
  by (fact shiftr_def)

lemma shiftr_numeral_numeral [simp]:
  numeral m >> numeral n = drop_bit (numeral n) (numeral m)
  by (fact shiftr_def)

lemma shiftl_eq_mult:
  x << n = x * 2 ^ n
  unfolding shiftl_def by (fact push_bit_eq_mult)

lemma shiftr_eq_div:
  x >> n = x div 2 ^ n
  unfolding shiftr_def by (fact drop_bit_eq_div)

end

context ring_bit_operations
begin

context
  includes bit_operations_syntax
begin

lemma shiftl_minus_1_numeral [simp]:
  - 1 << numeral n = NOT (mask (numeral n))
  by (simp add: shiftl_def)

end

end

lemma shiftl_Suc_0 [simp]:
  Suc 0 << n = 2 ^ n
  by (simp add: shiftl_def)

lemma shiftr_Suc_0 [simp]:
  Suc 0 >> n = of_bool (n = 0)
  by (simp add: shiftr_def)

lemma sshiftr_numeral_Suc [simp]:
  numeral m >>> Suc n = signed_drop_bit (Suc n) (numeral m)
  by (fact sshiftr_def)

lemma sshiftr_numeral_numeral [simp]:
  numeral m >>> numeral n = signed_drop_bit (numeral n) (numeral m)
  by (fact sshiftr_def)

context ring_bit_operations
begin

lemma shiftl_minus_numeral_Suc [simp]:
  - numeral m << Suc n = push_bit (Suc n) (- numeral m)
  by (fact shiftl_def)

lemma shiftl_minus_numeral_numeral [simp]:
  - numeral m << numeral n = push_bit (numeral n) (- numeral m)
  by (fact shiftl_def)

lemma shiftr_minus_numeral_Suc [simp]:
  - numeral m >> Suc n = drop_bit (Suc n) (- numeral m)
  by (fact shiftr_def)

lemma shiftr_minus_numeral_numeral [simp]:
  - numeral m >> numeral n = drop_bit (numeral n) (- numeral m)
  by (fact shiftr_def)

end

lemma sshiftr_0 [simp]:
  0 >>> n = 0
  by (simp add: sshiftr_def)

lemma sshiftr_of_0 [simp]:
  w >>> 0 = w
  by (simp add: sshiftr_def)

lemma sshiftr_1 [simp]:
  (1 :: 'a::len word) >>> n = of_bool (LENGTH('a) = 1  n = 0)
  by (simp add: sshiftr_def)

lemma sshiftr_minus_numeral_Suc [simp]:
  - numeral m >>> Suc n = signed_drop_bit (Suc n) (- numeral m)
  by (fact sshiftr_def)

lemma sshiftr_minus_numeral_numeral [simp]:
  - numeral m >>> numeral n = signed_drop_bit (numeral n) (- numeral m)
  by (fact sshiftr_def)

end