Theory State

(*  Title:       X86 instruction semantics and basic block symbolic execution
    Authors:     Freek Verbeek, Abhijith Bharadwaj, Joshua Bockenek, Ian Roessle, Timmy Weerwag, Binoy Ravindran
    Year:        2020
    Maintainer:  Freek Verbeek (freek@vt.edu)
*)

section "Concrete state and instructions"

theory State
  imports Main Memory
begin

text ‹A state consists of registers, memory, flags and a rip. Some design considerations here:
\begin{itemize}
\item All register values are 256 bits. We could also distinguish 64 bits registers, 128 registers etc.
      That would increase complexity in proofs and datastructures. The cost of using 256 everywhere is
      that a goal typically will have some casted 64 bits values.
\item The instruction pointer RIP is a special 64-bit register outside of the normal register set.
\item Strings are used for registers and flags. We would prefer an enumerative datatype, however, that would be
      extremely slow since there are roughly 100 register names.
\end{itemize}
›

record state =
  regs  :: "string   256word"
  mem   :: "64 word  8 word"
  flags :: "string   bool"
  rip   :: "64 word"

definition real_reg :: "string  bool × string × nat × nat"
  where "real_reg reg 
  ― ‹TODO: xmm, ymm, etc.›
  case reg of
  ― ‹rip›
    ''rip''    (True,  ''rip'', 0,64)
  ― ‹rax,rbx,rcx,rdx›
  | ''rax''    (True,  ''rax'', 0,64)
  | ''eax''    (True,  ''rax'', 0,32)
  | ''ax''     (False, ''rax'', 0,16)
  | ''ah''     (False, ''rax'', 8,16)
  | ''al''     (False, ''rax'', 0,8)
  | ''rbx''    (True,  ''rbx'', 0,64)
  | ''ebx''    (True,  ''rbx'', 0,32)
  | ''bx''     (False, ''rbx'', 0,16)
  | ''bh''     (False, ''rbx'', 8,16)
  | ''bl''     (False, ''rbx'', 0,8)
  | ''rcx''    (True,  ''rcx'', 0,64)
  | ''ecx''    (True,  ''rcx'', 0,32)
  | ''cx''     (False, ''rcx'', 0,16)
  | ''ch''     (False, ''rcx'', 8,16)
  | ''cl''     (False, ''rcx'', 0,8)
  | ''rdx''    (True,  ''rdx'', 0,64)
  | ''edx''    (True,  ''rdx'', 0,32)
  | ''dx''     (False, ''rdx'', 0,16)
  | ''dh''     (False, ''rdx'', 8,16)
  | ''dl''     (False, ''rdx'', 0,8)
  ― ‹RBP, RSP›
  | ''rbp''    (True,  ''rbp'', 0,64)
  | ''ebp''    (True,  ''rbp'', 0,32)
  | ''bp''     (False, ''rbp'', 0,16)
  | ''bpl''    (False, ''rbp'', 0,8)
  | ''rsp''    (True,  ''rsp'', 0,64)
  | ''esp''    (True,  ''rsp'', 0,32)
  | ''sp''     (False, ''rsp'', 0,16)
  | ''spl''    (False, ''rsp'', 0,8)
  ― ‹ RDI, RSI, R8 to R15›
  | ''rdi''    (True,  ''rdi'', 0,64)
  | ''edi''    (True,  ''rdi'', 0,32)
  | ''di''     (False, ''rdi'', 0,16)
  | ''dil''    (False, ''rdi'', 0,8)
  | ''rsi''    (True,  ''rsi'', 0,64)
  | ''esi''    (True,  ''rsi'', 0,32)
  | ''si''     (False, ''rsi'', 0,16)
  | ''sil''    (False, ''rsi'', 0,8)
  | ''r15''    (True,  ''r15'', 0,64)
  | ''r15d''   (True,  ''r15'', 0,32)
  | ''r15w''   (False, ''r15'', 0,16)
  | ''r15b''   (False, ''r15'', 0,8)
  | ''r14''    (True,  ''r14'', 0,64)
  | ''r14d''   (True,  ''r14'', 0,32)
  | ''r14w''   (False, ''r14'', 0,16)
  | ''r14b''   (False, ''r14'', 0,8)
  | ''r13''    (True,  ''r13'', 0,64)
  | ''r13d''   (True,  ''r13'', 0,32)
  | ''r13w''   (False, ''r13'', 0,16)
  | ''r13b''   (False, ''r13'', 0,8)
  | ''r12''    (True,  ''r12'', 0,64)
  | ''r12d''   (True,  ''r12'', 0,32)
  | ''r12w''   (False, ''r12'', 0,16)
  | ''r12b''   (False, ''r12'', 0,8)
  | ''r11''    (True,  ''r11'', 0,64)
  | ''r11d''   (True,  ''r11'', 0,32)
  | ''r11w''   (False, ''r11'', 0,16)
  | ''r11b''   (False, ''r11'', 0,8)
  | ''r10''    (True,  ''r10'', 0,64)
  | ''r10d''   (True,  ''r10'', 0,32)
  | ''r10w''   (False, ''r10'', 0,16)
  | ''r10b''   (False, ''r10'', 0,8)
  | ''r9''     (True,  ''r9'' , 0,64)
  | ''r9d''    (True,  ''r9'' , 0,32)
  | ''r9w''    (False, ''r9'' , 0,16)
  | ''r9b''    (False, ''r9'' , 0,8)
  | ''r8''     (True,  ''r8'' , 0,64)
  | ''r8d''    (True,  ''r8'' , 0,32)
  | ''r8w''    (False, ''r8'' , 0,16)
  | ''r8b''    (False, ''r8'' , 0,8)
  ― ‹xmm›
  | ''xmm0''   (True, ''xmm0'' , 0,128)
  | ''xmm1''   (True, ''xmm1'' , 0,128)
  | ''xmm2''   (True, ''xmm2'' , 0,128)
  | ''xmm3''   (True, ''xmm3'' , 0,128)
  | ''xmm4''   (True, ''xmm4'' , 0,128)
  | ''xmm5''   (True, ''xmm5'' , 0,128)
  | ''xmm6''   (True, ''xmm6'' , 0,128)
  | ''xmm7''   (True, ''xmm7'' , 0,128)
  | ''xmm8''   (True, ''xmm8'' , 0,128)
  | ''xmm9''   (True, ''xmm9'' , 0,128)
  | ''xmm10''   (True, ''xmm10'' , 0,128)
  | ''xmm11''   (True, ''xmm11'' , 0,128)
  | ''xmm12''   (True, ''xmm12'' , 0,128)
  | ''xmm13''   (True, ''xmm13'' , 0,128)
  | ''xmm14''   (True, ''xmm14'' , 0,128)
  | ''xmm15''   (True, ''xmm15'' , 0,128)
  "


text ‹x86 has register aliassing. For example, register EAX is the lower 32 bits of register RAX.
      This function map register aliasses to the ``real'' register. For example:

      @{term "real_reg ''ah'' = (False, ''rax'', 8,16)"}.

      This means that register AH is the second byte (bits 8 to 16) of register RAX.
      The bool @{const False} indicates that writing to AH does not overwrite the remainder of RAX.

      @{term "real_reg ''eax'' = (True, ''rax'', 0,32)"}.

      Register EAX is the lower 4 bytes of RAX. Writing to EAX means overwritting the remainder of RAX
      with zeroes.
›

definition reg_size :: "string  nat" ― ‹in bytes›
  where "reg_size reg  let (_,_,l,h) = real_reg reg in (h - l) div 8"


text‹ We now define functions for reading and writing from state.›

definition reg_read :: "state  string  256 word"
  where "reg_read σ reg 
      if reg = ''rip'' then ucast (rip σ) else
      if reg = '''' then 0 else ― ‹happens if no base register is used in an address›
      let (_,r,l,h) = real_reg reg in
        l,h(regs σ r)"

primrec fromBool :: "bool  'a :: len word"
  where
    "fromBool True = 1"
  | "fromBool False = 0"

definition flag_read :: "state  string  256 word"
  where "flag_read σ flag  fromBool (flags σ flag)"

definition mem_read :: "state  64 word  nat  256 word"
  where "mem_read σ a si  word_rcat (read_bytes (mem σ) a si)"


text ‹Doing state-updates occur through a tiny deeply embedded language of state updates. This allows us
      to reason over state updates through theorems.›

datatype StateUpdate =
    RegUpdate string "256 word"         ― ‹Write value to register›
  | FlagUpdate "string  bool"         ― ‹Update all flags at once›
  | RipUpdate "64 word"                 ― ‹Update instruction pointer with address›
  | MemUpdate "64 word" nat "256 word"  ― ‹Write a number of bytes of a value to the address›

primrec state_update
  where
    "state_update (RegUpdate reg val)  = (λ σ . σregs := (regs σ)(reg := val))"
  | "state_update (FlagUpdate  val)    = (λ σ . σflags := val)"
  | "state_update (RipUpdate a)        = (λ σ . σrip := a)"
  | "state_update (MemUpdate a si val) = (λ σ .
        let new = (λ a' . take_byte (unat (a' - a)) val) in
         σmem := override_on (mem σ) new (region_addresses a si))"

abbreviation RegUpdateSyntax (‹_ :=r _› 30)
  where "RegUpdateSyntax reg val  RegUpdate reg val"
abbreviation MemUpdateSyntax (_,_ :=m _› 30)
  where "MemUpdateSyntax a si val  MemUpdate a si val"
abbreviation FlagUpdateSyntax (setFlags)
  where "FlagUpdateSyntax val  FlagUpdate val"
abbreviation RipUpdateSyntax (setRip)
  where "RipUpdateSyntax val  RipUpdate val"

text ‹Executes a write to a register in terms of the tiny deeply embedded language above.›
definition reg_write
  where "reg_write reg val σ 
      let (b,r,l,h)  = real_reg reg;
          curr_val   = reg_read σ r;
          new_val    = if b then val else overwrite l h curr_val val in
        state_update (RegUpdate r new_val) σ"


text ‹A datatype for operands of instructions.›

datatype Operand =
    Imm "256 word"
  | Reg string
  | Flag string
  | Mem  nat    "64 word"   string    "string"      nat
      ― ‹size   offset      base-reg  index-reg    scale›

abbreviation mem_op_no_offset_no_index :: "string  (64 word × string × string × nat)" ([_]1 40)
  where "mem_op_no_offset_no_index r  (0,r,[],0)"

abbreviation mem_op_no_index :: "64 word  string  (64 word × string × string × nat)" ([_ + _]2 40)
  where "mem_op_no_index offset r  (offset,r,[],0)"

abbreviation mem_op :: "64 word  string  string  nat  (64 word × string × string × nat)" ([_ + _ + _ * _]3 40)
  where "mem_op offset r index scale (offset,r,index,scale)"

definition ymm_ptr (YMMWORD PTR _›)
  where "YMMWORD PTR x  case x of (offset,base,index,scale)  Mem 32 offset base index scale"

definition xmm_ptr (XMMWORD PTR _›)
  where "XMMWORD PTR x  case x of (offset,base,index,scale)  Mem 16 offset base index scale"

definition qword_ptr (QWORD PTR _›)
  where "QWORD PTR x  case x of (offset,base,index,scale)  Mem 8 offset base index scale"

definition dword_ptr (DWORD PTR _›)
  where "DWORD PTR x  case x of (offset,base,index,scale)  Mem 4 offset base index scale"

definition word_ptr (WORD PTR _›)
  where "WORD PTR x  case x of (offset,base,index,scale)  Mem 2 offset base index scale"

definition byte_ptr (BYTE PTR _›)
  where "BYTE PTR x  case x of (offset,base,index,scale)  Mem 1 offset base index scale"


primrec (nonexhaustive) operand_size :: "Operand  nat" ― ‹in bytes›
  where
    "operand_size (Reg r) = reg_size r"
  | "operand_size (Mem si _ _ _ _) = si"

fun resolve_address :: "state  64 word  char list  char list  nat  64 word"
  where "resolve_address σ offset base index scale =
      (let i = ucast (reg_read σ index);
           b = ucast (reg_read σ base) in
         offset + b + of_nat scale*i)"

primrec operand_read :: "state  Operand  256 word"
  where
    "operand_read σ (Imm i)  = i"
  | "operand_read σ (Reg r)  = reg_read σ r"
  | "operand_read σ (Flag f) = flag_read σ f"
  | "operand_read σ (Mem si offset base index scale) =
      (let a = resolve_address σ offset base index scale in
        mem_read σ a si
      )"


primrec state_with_updates :: "state  StateUpdate list  state" (infixl with 66)
  where
    "σ with [] = σ"
  | "(σ with (f#fs)) = state_update f (σ with fs)"

primrec (nonexhaustive) operand_write :: "Operand  256word  state  state"
  where
    "operand_write (Reg r)  v σ = reg_write r v σ"
  | "operand_write (Mem si offset base index scale) v σ =
      (let i = ucast (reg_read σ index);
           b = ucast (reg_read σ base);
           a = offset + b + of_nat scale*i in
        σ with [a,si :=m v]
      )"





text ‹ The following theorems simplify reading from state parts after doing updates to other state parts.›

lemma regs_reg_write:
  shows "regs (σ with ((r :=r w)#updates)) r' = (if r=r' then w else regs (σ with updates) r')"
  by (induct updates arbitrary: σ, auto simp add: case_prod_unfold Let_def)

lemma regs_mem_write:
  shows "regs (σ with ((a,si :=m v)#updates)) r = regs (σ with updates) r"
  by (induct updates arbitrary: σ, auto)

lemma regs_flag_write:
  shows "regs (σ with ((setFlags v)#updates)) r = regs (σ with updates) r"
  by (induct updates arbitrary: σ, auto)

lemma regs_rip_write:
  shows "regs (σ with ((setRip a)#updates)) f = regs (σ with updates) f"
  by (auto)


lemma flag_read_reg_write:
  shows "flag_read (σ with ((r :=r w)#updates)) f = flag_read (σ with updates) f"
  by (induct updates arbitrary: σ, auto simp add: flag_read_def)

lemma flag_read_mem_write:
  shows "flag_read (σ with ((a,si :=m v)#updates)) f = flag_read (σ with updates) f"
  by (induct updates arbitrary: σ, auto simp add: flag_read_def)

lemma flag_read_flag_write:
  shows "flag_read (σ with ((setFlags v)#updates)) = fromBool o v"
  by (induct updates arbitrary: σ, auto simp add: flag_read_def)

lemma flag_read_rip_write:
  shows "flag_read (σ with ((setRip a)#updates)) f = flag_read (σ with updates) f"
  by (auto simp add: flag_read_def)

lemma mem_read_reg_write:
  shows "mem_read (σ with ((r :=r w)#updates)) a si = mem_read (σ with updates) a si"
  by (auto simp add: mem_read_def read_bytes_def)

lemma mem_read_flag_write:
  shows "mem_read (σ with ((setFlags v)#updates)) a si = mem_read (σ with updates) a si"
  by (auto simp add: mem_read_def read_bytes_def)

lemma mem_read_rip_write:
  shows "mem_read (σ with ((setRip a')#updates)) a si = mem_read (σ with updates) a si"
  by (auto simp add: mem_read_def read_bytes_def)

lemma mem_read_mem_write_alias:
  assumes "si'  si"
      and "si  2^64"
  shows "mem_read (σ with ((a,si :=m v)#updates)) a si' = 0,si'*8 v"
  using assms
  by (auto simp add: mem_read_def word_rcat_read_bytes read_bytes_override_on_enclosed[where offset=0 and offset'=0,simplified])

lemma mem_read_mem_write_separate:
  assumes "separate a si a' si'"
shows "mem_read (σ with ((a,si :=m v)#updates)) a' si' = mem_read (σ with updates) a' si'"
  using assms
  by (auto simp add: mem_read_def read_bytes_override_on_separate)

lemma mem_read_mem_write_enclosed_minus:
  assumes "offset'  offset"
    and "si'  si"
    and "unat (offset - offset') + si' < 2^64"
    and "unat offset + si'  si + unat offset'"
  shows "mem_read (σ with ((a - offset,si :=m v)#updates)) (a - offset') si' = unat (offset - offset') * 8,unat (offset - offset') * 8 + si' * 8v"
  using assms
  by (auto simp add: mem_read_def read_bytes_override_on_enclosed word_rcat_read_bytes_enclosed[of "offset - offset'" si' "a - offset" v,simplified])

lemma mem_read_mem_write_enclosed_plus:
assumes "unat offset + si'  si"
    and "si < 2 ^ 64"
  shows "mem_read (σ with ((a,si :=m v)#updates)) (offset + a) si' = unat offset * 8,(unat offset + si') * 8v"
  using assms
  apply (auto simp add: mem_read_def read_bytes_override_on_enclosed_plus)
  using word_rcat_read_bytes_enclosed[of offset si' a v]
  by auto (simp add: add.commute)

lemma mem_read_mem_write_enclosed_plus2:
assumes "unat offset + si'  si"
    and "si < 2 ^ 64"
  shows "mem_read (σ with ((a,si :=m v)#updates)) (a + offset) si' = unat offset * 8,(unat offset + si') * 8v"
  using mem_read_mem_write_enclosed_plus[OF assms]
  by (auto simp add: add.commute)

lemma mem_read_mem_write_enclosed_numeral[simp]:
assumes "unat (numeral a' - numeral a::64 word) + (numeral si'::nat)  numeral si"
    and "numeral a'  (numeral a::64 word)"
    and "numeral si < (2 ^ 64::nat)"
  shows "mem_read (σ with ((numeral a,numeral si :=m v)#updates)) (numeral a') (numeral si') = unat (numeral a' - (numeral a::64 word)) * 8,(unat (numeral a' - (numeral a::64 word)) + (numeral si')) * 8v"
proof-
  have 1: "numeral a + (numeral a' - numeral a) = (numeral a'::64 word)"
    using assms(2) by (metis add.commute diff_add_cancel)
  thus ?thesis
    using mem_read_mem_write_enclosed_plus2[of "numeral a' - numeral a" "numeral si'" "numeral si" σ "numeral a" v updates,OF assms(1,3)]
    by auto
qed

lemma mem_read_mem_write_enclosed_numeral1_[simp]:
assumes "unat (numeral a' - numeral a::64 word) + (numeral si'::nat)  Suc 0"
    and "numeral a'  (numeral a::64 word)"
  shows "mem_read (σ with ((numeral a,Suc 0 :=m v)#updates)) (numeral a') (numeral si') = unat (numeral a' - (numeral a::64 word)) * 8,(unat (numeral a' - (numeral a::64 word)) + (numeral si')) * 8v"
proof-
  have 1: "numeral a + (numeral a' - numeral a) = (numeral a'::64 word)"
    using assms(2) by (metis add.commute diff_add_cancel)
  thus ?thesis
    using mem_read_mem_write_enclosed_plus2[of "numeral a' - numeral a" "numeral si'" "Suc 0" σ "numeral a" v updates,OF assms(1)]
    by auto
qed

lemma mem_read_mem_write_enclosed_numeral_1[simp]:
assumes "unat (numeral a' - numeral a::64 word) + (Suc 0)  numeral si"
    and "numeral a'  (numeral a::64 word)"
    and "numeral si < (2 ^ 64::nat)"
  shows "mem_read (σ with ((numeral a,numeral si :=m v)#updates)) (numeral a') (Suc 0) = unat (numeral a' - (numeral a::64 word)) * 8,(unat (numeral a' - (numeral a::64 word)) + (Suc 0)) * 8v"
proof-
  have 1: "numeral a + (numeral a' - numeral a) = (numeral a'::64 word)"
    using assms(2) by (metis add.commute diff_add_cancel)
  thus ?thesis
    using mem_read_mem_write_enclosed_plus2[of "numeral a' - numeral a" "Suc 0" "numeral si" σ "numeral a" v updates,OF assms(1,3)]
    by auto
qed


lemma mem_read_mem_write_enclosed_numeral11[simp]:
assumes "unat (numeral a' - numeral a::64 word) + (Suc 0)  Suc 0"
    and "numeral a'  (numeral a::64 word)"
  shows "mem_read (σ with ((numeral a,Suc 0 :=m v)#updates)) (numeral a') (Suc 0) = unat (numeral a' - (numeral a::64 word)) * 8,(unat (numeral a' - (numeral a::64 word)) + (Suc 0)) * 8v"
proof-
  have 1: "numeral a + (numeral a' - numeral a) = (numeral a'::64 word)"
    using assms(2) by (metis add.commute diff_add_cancel)
  thus ?thesis
    using mem_read_mem_write_enclosed_plus2[of "numeral a' - numeral a" "Suc 0" "Suc 0" σ "numeral a" v updates,OF assms(1)]
    by auto
qed


lemma rip_reg_write[simp]:
  shows "rip (σ with ((r :=r v)#updates)) = rip (σ with updates)"
  by (auto simp add: case_prod_unfold Let_def)

lemma rip_flag_write[simp]:
  shows "rip (σ with ((setFlags v)#updates)) = rip (σ with updates)"
  by (auto)

lemma rip_mem_write[simp]:
  shows "rip (σ with ((a,si :=m v)#updates)) = rip (σ with updates)"
  by (auto)

lemma rip_rip_write[simp]:
 shows "rip (σ with ((setRip a)#updates)) = a"
  by (auto)


lemma with_with:
  shows "(σ with updates) with updates' = σ with (updates' @ updates)"
by (induct updates' arbitrary: σ,auto)

lemma add_state_update_to_list:
  shows "state_update upd (σ with updates) = σ with (upd#updates)"
  by auto

text ‹The updates performed to a state are ordered: memoery, registers, flags, rip.
      This function is basically insertion sort. Moreover, consecutive updates to the same register
      are removed.›

fun insert_state_update
  where
    "insert_state_update (setRip a) (setRip a'#updates) = insert_state_update (setRip a) updates"
  | "insert_state_update (setRip a) (setFlags v#updates) = setFlags v # (insert_state_update (setRip a) updates)"
  | "insert_state_update (setRip a) ((r :=r v)#updates) = (r :=r v) # (insert_state_update (setRip a) updates)"
  | "insert_state_update (setRip a) ((a',si :=m v)#updates) = (a',si :=m v) # (insert_state_update (setRip a) updates)"

  | "insert_state_update (setFlags v) (setFlags v'#updates) = insert_state_update (setFlags v) updates"
  | "insert_state_update (setFlags v) ((r :=r v')#updates) = (r :=r v') # insert_state_update (setFlags v) updates"
  | "insert_state_update (setFlags v) ((a',si :=m v')#updates) = (a',si :=m v') # insert_state_update (setFlags v) updates"

  | "insert_state_update ((r :=r v)) ((r' :=r v')#updates) = (if r = r' then insert_state_update (r :=r v) updates else (r' :=r v')#insert_state_update (r :=r v) updates)"
  | "insert_state_update ((r :=r v)) ((a',si :=m v')#updates) = (a',si :=m v') # insert_state_update (r :=r v) updates"

  | "insert_state_update upd updates = upd # updates"

fun clean
  where
    "clean [] = []"
  | "clean [upd] = [upd]"
  | "clean (upd#upd'#updates) =  insert_state_update upd (clean (upd'#updates))"

lemma insert_state_update:
  shows "σ with (insert_state_update upd updates) = σ with (upd # updates)"
  by (induct updates rule: insert_state_update.induct,auto simp add: fun_upd_twist)

lemma clean_state_updates:
  shows "σ with (clean updates) = σ with updates"
  by (induct updates rule: clean.induct,auto simp add: insert_state_update)



text ‹The set of simplification rules used during symbolic execution.›
lemmas state_simps =
      qword_ptr_def dword_ptr_def word_ptr_def byte_ptr_def reg_size_def
      reg_write_def real_reg_def reg_read_def

      regs_rip_write regs_mem_write regs_reg_write regs_flag_write
      flag_read_reg_write flag_read_mem_write flag_read_rip_write flag_read_flag_write
      mem_read_reg_write mem_read_flag_write mem_read_rip_write
      mem_read_mem_write_alias mem_read_mem_write_separate
      mem_read_mem_write_enclosed_minus mem_read_mem_write_enclosed_plus mem_read_mem_write_enclosed_plus2

      with_with add_state_update_to_list

declare state_with_updates.simps(2)[simp del]
declare state_update.simps[simp del]

end