Theory StateMonad

theory StateMonad
imports Main "HOL-Library.Monad_Syntax" Utils Solidity_Symbex
begin

section "State Monad with Exceptions"

datatype ('n, 'e) result = Normal (normal: 'n) | Exception (exception: 'e)

type_synonym ('a, 'e, 's) state_monad = "'s  ('a × 's, 'e) result"

lemma result_cases[cases type: result]:
  fixes x :: "('a × 's, 'e) result"
  obtains (n) a s where "x = Normal (a, s)"
        | (e) e where "x = Exception e"
proof (cases x)
    case (Normal n)
    then show ?thesis
      by (metis n prod.swap_def swap_swap)
  next
    case (Exception e)
    then show ?thesis ..
qed

subsection ‹Fundamental Definitions›

fun return :: "'a  ('a, 'e, 's) state_monad"
where "return a s = Normal (a, s)"

fun throw :: "'e  ('a, 'e, 's) state_monad"
where "throw e s = Exception e"

fun bind :: "('a, 'e, 's) state_monad  ('a  ('b, 'e, 's) state_monad)  ('b, 'e, 's) state_monad" (infixl ">>=" 60)
where "bind f g s = (case f s of
                      Normal (a, s')  g a s'
                    | Exception e  Exception e)"

adhoc_overloading Monad_Syntax.bind bind

lemma throw_left[simp]: "throw x  y = throw x" by auto

subsection ‹The Monad Laws›

text @{term return} is absorbed at the left of a @{term bind},
  applying the return value directly:›
lemma return_bind [simp]: "(return x  f) = f x"
  by auto

text @{term return} is absorbed on the right of a @{term bind}
lemma bind_return [simp]: "(m  return) = m"
proof -
  have "s. bind m return s = m s"
  proof
    fix s
    show "bind m return s = m s"
    proof (cases "m s" rule: result_cases)
      case (n a s)
      then show ?thesis by auto
    next
      case (e e)
      then show ?thesis by auto
    qed
  qed
  thus ?thesis by auto
qed

text @{term bind} is associative›
lemma bind_assoc:
  fixes m :: "('a,'e,'s) state_monad"
  fixes f :: "'a  ('b,'e,'s) state_monad"
  fixes g :: "'b  ('c,'e,'s) state_monad"
  shows "(m  f)  g  =  m  (λx. f x g)"
proof
  fix s
  show "bind (bind m f) g s = bind m (λx. bind (f x) g) s"
    by (cases "m s" rule: result_cases, simp+)
qed

subsection ‹Basic Conguruence Rules›

(*
  This lemma is needed for termination proofs.

  Function sub1 , for example, requires it.
*)
lemma monad_cong[fundef_cong]:
  fixes m1 m2 m3 m4
  assumes "m1 s = m2 s"
      and "v s'. m2 s = Normal (v, s')  m3 v s' = m4 v s'"
    shows "(bind m1 m3) s = (bind m2 m4) s"
  apply(insert assms, cases "m1 s")
  apply (metis StateMonad.bind.simps old.prod.case result.simps(5))
  by (metis bind.elims result.simps(6))

(*
  Lemma bind_case_nat_cong is required if a bind operand is a case analysis over nat.

  Function sub2, for example, requires it.
*)
lemma bind_case_nat_cong [fundef_cong]:
  assumes "x = x'" and "a. x = Suc a  f a h = f' a h"
  shows "(case x of Suc a  f a | 0  g) h = (case x' of Suc a  f' a | 0  g) h"
  by (metis assms(1) assms(2) old.nat.exhaust old.nat.simps(4) old.nat.simps(5))

lemma if_cong[fundef_cong]:
  assumes "b = b'"
    and "b'  m1 s = m1' s"
    and "¬ b'  m2 s = m2' s"
  shows "(if b then m1 else m2) s = (if b' then m1' else m2') s"
  using assms(1) assms(2) assms(3) by auto

lemma bind_case_pair_cong [fundef_cong]:
  assumes "x = x'" and "a b. x = (a,b)  f a b s = f' a b s"
  shows "(case x of (a,b)  f a b) s = (case x' of (a,b)  f' a b) s"
  by (simp add: assms(1) assms(2) prod.case_eq_if)

lemma bind_case_let_cong [fundef_cong]:
  assumes "M = N"
      and "(x. x = N  f x s = g x s)"
    shows "(Let M f) s = (Let N g) s"
  by (simp add: assms(1) assms(2))

lemma bind_case_some_cong [fundef_cong]:
  assumes "x = x'" and "a. x = Some a  f a s = f' a s" and "x = None  g s = g' s"
  shows "(case x of Some a  f a | None  g) s = (case x' of Some a  f' a | None  g') s"
  by (simp add: assms(1) assms(2) assms(3) option.case_eq_if)

lemma bind_case_bool_cong [fundef_cong]:
  assumes "x = x'" and "x = True  f s = f' s" and "x = False  g s = g' s"
  shows "(case x of True  f | False  g) s = (case x' of True  f' | False  g') s"
  using assms(1) assms(2) assms(3) by auto

subsection ‹Other functions›

text ‹
  The basic accessor functions of the state monad. get› returns
  the current state as result, does not fail, and does not change the state.
  put s› returns unit, changes the current state to s› and does not fail.
›
fun get :: "('s, 'e, 's) state_monad" where
  "get s = Normal (s, s)"

fun put :: "'s  (unit, 'e, 's) state_monad" where
  "put s _ = Normal ((), s)"

text ‹Apply a function to the current state and return the result
without changing the state.›
fun
  applyf :: "('s  'a)  ('a, 'e, 's) state_monad" where
 "applyf f = get  (λs. return (f s))"

text ‹Modify the current state using the function passed in.›
fun
  modify :: "('s  's)  (unit, 'e, 's) state_monad"
where "modify f = get  (λs::'s. put (f s))"

fun
  assert :: "'e  ('s  bool)  (unit, 'e, 's) state_monad" where
 "assert x t = (λs. if (t s) then return () s else throw x s)"

fun
  option :: "'e  ('s  'a option)  ('a, 'e, 's) state_monad" where
 "option x f = (λs. (case f s of
    Some y  return y s
  | None  throw x s))"

subsection ‹Some basic examples›

lemma "do {
        x  return 1;
        return (2::nat);
        return x
       } =
       return 1  (λx. return (2::nat)  (λ_. (return x)))" ..

lemma "do {
        x  return 1;
          return 2;
          return x
       } = return 1"
  by auto

fun sub1 :: "(unit, nat, nat) state_monad" where
    "sub1 0 = put 0 0"
  | "sub1 (Suc n) = (do {
                     x  get;
                     put x;
                     sub1
                    }) n"

fun sub2 :: "(unit, nat, nat) state_monad" where
  "sub2 s =
     (do {
        n  get;
        (case n of
          0  put 0
        | Suc n'  (do {
                  put n';
                  sub2
                }))
     }) s"

section "Hoare Logic"

named_theorems wprule

definition
  valid :: "('s  bool)  ('a,'e,'s) state_monad  
             ('a  's  bool)  
             ('e  bool)  bool" 
  ("_/ _ /(_⦄,/ _)")
where
  "P f Q⦄,E  s. P s  (case f s of
                   Normal (r,s')  Q r s'
                 | Exception e  E e)"

lemma weaken:
  assumes "Q f R⦄, E"
      and "s. P s  Q s"
    shows "P f R⦄, E"
  using assms by (simp add: valid_def)

lemma strengthen:
  assumes "P f Q⦄, E"
      and "a s. Q a s  R a s"
    shows "P f R⦄, E"
  unfolding valid_def
proof(rule allI[OF impI])
  fix s
  assume "P s"
  show "case f s of Normal (r, s')  R r s' | Exception e  E e"
  proof (cases "f s")
    case (n a s')
    then show ?thesis using assms valid_def P s
      by (metis case_prod_conv result.simps(5))
  next
    case (e e)
    then show ?thesis using assms valid_def P s
      by (metis result.simps(6))
  qed
qed

definition wp
  where "wp f P E s  (case f s of
                   Normal (r,s')  P r s'
                 | Exception e  E e)"

declare wp_def [solidity_symbex]

lemma wp_valid: assumes "s. P s  (wp f Q E s)" shows "P f Q⦄,E"
  unfolding valid_def by (metis assms wp_def)

lemma valid_wp: assumes "P f Q⦄,E" shows "s. P s  (wp f Q E s)"
  by (metis assms valid_def wp_def)

lemma put: "λs. P () x put x P⦄,E"
  using valid_def by fastforce

lemma put':
  assumes "s. P s  Q () x"
  shows "λs. P s put x Q⦄,E"
  using assms weaken[OF put, of P Q x E] by blast

lemma wpput[wprule]: 
  assumes "P () x"
  shows "wp (put x) P E s"
  unfolding wp_def using assms by simp

lemma get: "λs. P s s get P⦄,E"
  using valid_def by fastforce

lemma get':
  assumes "s. P s  Q s s"
  shows "λs. P s get Q⦄,E"
  using assms weaken[OF get] by blast

lemma wpget[wprule]: 
  assumes "P s s"
  shows "wp get P E s"
  unfolding wp_def using assms by simp

lemma return: "λs. P x s return x P⦄,E"
  using valid_def by fastforce

lemma return':
  assumes "s. P s  Q x s"
  shows "λs. P s return x Q⦄,E"
  using assms weaken[OF return, of P Q x E] by blast

lemma wpreturn[wprule]: 
  assumes "P x s"
  shows "wp (return x) P E s"
  unfolding wp_def using assms by simp

lemma bind:
  assumes "x. B x g x C⦄,E"
      and "A f B⦄,E"
    shows "A f  g C⦄,E"
  unfolding valid_def
proof (rule allI[OF impI])
  fix s assume a: "A s"
  show "case (f  g) s of Normal (r, s')  C r s' | Exception e  E e"
  proof (cases "f s" rule:result_cases)
    case nf: (n a s')
    with assms(2) a have b: "B a s'" using valid_def[where ?f=f] by auto
    then show ?thesis
    proof (cases "g a s'" rule:result_cases)
      case ng: (n a' s'')
      with assms(1) b have c: "C a' s''" using valid_def[where ?f="g a"] by fastforce
      moreover from ng nf have "(f  g) s = Normal (a', s'')" by simp
      ultimately show ?thesis by simp
    next
      case eg: (e e)
      with assms(1) b have c: "E e" using valid_def[where ?f="g a"] by fastforce
      moreover from eg nf have "(f  g) s = Exception e" by simp
      ultimately show ?thesis by simp
    qed
  next
    case (e e)
    with a assms(2) have "E e" using valid_def[where ?f=f] by auto
    moreover from e have "(f  g) s = Exception e" by simp
    ultimately show ?thesis by simp
  qed
qed

lemma wpbind[wprule]:
  assumes "wp f (λa. (wp (g a) P E)) E s"
  shows "wp (f  g) P E s"
proof (cases "f s" rule:result_cases)
  case nf: (n a s')
  then have **:"wp (g a) P E s'" using wp_def[of f "λa. wp (g a) P E"] assms by simp
  show ?thesis
  proof (cases "g a s'" rule:result_cases)
    case ng: (n a' s'')
    then have "P a' s''" using wp_def[of "g a" P] ** by simp
    moreover from nf ng have "(f  g) s = Normal (a', s'')" by simp
    ultimately show ?thesis using wp_def by fastforce
  next
    case (e e)
    then have "E e" using wp_def[of "g a" P] ** by simp
    moreover from nf e have "(f  g) s = Exception e" by simp
    ultimately show ?thesis using wp_def by fastforce
  qed
next
  case (e e)
  then have "E e" using wp_def[of f "λa. wp (g a) P E"] assms by simp
  moreover from e have "(f  g) s = Exception e" by simp
  ultimately show ?thesis using wp_def by fastforce
qed

lemma wpassert[wprule]:
  assumes "t s  wp (return ()) P E s"
      and "¬ t s  wp (throw x) P E s"
    shows "wp (assert x t) P E s"
  using assms unfolding wp_def by simp

lemma throw:
  assumes "E x"
  shows "P throw x Q⦄, E"
  using valid_def assms by fastforce

lemma wpthrow[wprule]:
  assumes "E x"
  shows "wp (throw x) P E s"
  unfolding wp_def using assms by simp

lemma applyf:
    "λs. P (f s) s applyf f λa s. P a s⦄,E"
  by (simp add: valid_def)

lemma applyf':
  assumes "s. P s  Q (f s) s"
  shows "λs. P s applyf f λa s. Q a s⦄,E"
  using assms weaken[OF applyf] by blast

lemma wpapplyf[wprule]:
  assumes "P (f s) s"
  shows "wp (applyf f) P E s"
  unfolding wp_def using assms by simp

lemma modify:
  "λs. P () (f s) modify f P⦄, E"
  apply simp
  apply (rule bind,rule allI)
  apply (rule put)
  apply (rule get)
  done

lemma modify':
  assumes "s. P s  Q () (f s)"
  shows "λs. P s modify f Q⦄, E"
  using assms weaken[OF modify, of P Q _ E] by blast

lemma wpmodify[wprule]:
  assumes "P () (f s)"
  shows "wp (modify f) P E s"
  unfolding wp_def using assms by simp

lemma wpcasenat[wprule]:
  assumes "(y=(0::nat)  wp (f y) P E s)"
      and "x. y=Suc x  wp (g x) P E s"
  shows "wp (case y::nat of 0  f y | Suc x  g x) P E s"
  by (metis assms(1) assms(2) not0_implies_Suc old.nat.simps(4) old.nat.simps(5))

lemma wpif[wprule]:
  assumes "c  wp f P E s"
      and "¬c  wp g P E s"
  shows "wp (if c then f else g) P E s"
  using assms by simp

lemma wpsome[wprule]:
  assumes "y. x = Some y  wp (f y) P E s"
      and "x = None  wp g P E s"
  shows "wp (case x of Some y  f y | None  g) P E s"
  using assms unfolding wp_def by (simp split: option.split)

lemma wpoption[wprule]:
  assumes "y. f s = Some y  wp (return y) P E s"
      and "f s = None  wp (throw x) P E s"
  shows "wp (option x f) P E s"
  using assms unfolding wp_def by (auto split:option.split)

lemma wpprod[wprule]:
  assumes "x y. a = (x,y)  wp (f x y) P E s"
  shows "wp (case a of (x, y)  f x y) P E s"
  using assms unfolding wp_def
  by (simp split: prod.split)

method wp = rule wprule; wp?
method wpvcg = rule wp_valid, wp

lemma "λs. s=5 do {
        put (5::nat);
        x  get;
        return x
       } λa s. s=5⦄,λe. False"
  by (wpvcg, simp)
end