```theory StateMonad
begin

datatype ('n, 'e) result = Normal (normal: 'n) | Exception (exception: 'e)

type_synonym ('a, 'e, 's) state_monad = "'s ⇒ ('a × 's, 'e) result"

lemma result_cases[cases type: result]:
fixes x :: "('a × 's, 'e) result"
obtains (n) a s where "x = Normal (a, s)"
| (e) e where "x = Exception e"
proof (cases x)
case (Normal n)
then show ?thesis
by (metis n prod.swap_def swap_swap)
next
case (Exception e)
then show ?thesis ..
qed

subsection ‹Fundamental Definitions›

fun return :: "'a ⇒ ('a, 'e, 's) state_monad"
where "return a s = Normal (a, s)"

fun throw :: "'e ⇒ ('a, 'e, 's) state_monad"
where "throw e s = Exception e"

fun bind :: "('a, 'e, 's) state_monad ⇒ ('a ⇒ ('b, 'e, 's) state_monad) ⇒ ('b, 'e, 's) state_monad" (infixl ">>=" 60)
where "bind f g s = (case f s of
Normal (a, s') ⇒ g a s'
| Exception e ⇒ Exception e)"

lemma throw_left[simp]: "throw x ⤜ y = throw x" by auto

text ‹@{term return} is absorbed at the left of a @{term bind},
applying the return value directly:›
lemma return_bind [simp]: "(return x ⤜ f) = f x"
by auto

text ‹@{term return} is absorbed on the right of a @{term bind}›
lemma bind_return [simp]: "(m ⤜ return) = m"
proof -
have "∀s. bind m return s = m s"
proof
fix s
show "bind m return s = m s"
proof (cases "m s" rule: result_cases)
case (n a s)
then show ?thesis by auto
next
case (e e)
then show ?thesis by auto
qed
qed
thus ?thesis by auto
qed

text ‹@{term bind} is associative›
lemma bind_assoc:
fixes f :: "'a ⇒ ('b,'e,'s) state_monad"
fixes g :: "'b ⇒ ('c,'e,'s) state_monad"
shows "(m ⤜ f) ⤜ g  =  m ⤜ (λx. f x⤜ g)"
proof
fix s
show "bind (bind m f) g s = bind m (λx. bind (f x) g) s"
by (cases "m s" rule: result_cases, simp+)
qed

subsection ‹Basic Conguruence Rules›

(*
This lemma is needed for termination proofs.

Function sub1 , for example, requires it.
*)
fixes m1 m2 m3 m4
assumes "m1 s = m2 s"
and "⋀v s'. m2 s = Normal (v, s') ⟹ m3 v s' = m4 v s'"
shows "(bind m1 m3) s = (bind m2 m4) s"
apply(insert assms, cases "m1 s")
by (metis bind.elims result.simps(6))

(*
Lemma bind_case_nat_cong is required if a bind operand is a case analysis over nat.

Function sub2, for example, requires it.
*)
lemma bind_case_nat_cong [fundef_cong]:
assumes "x = x'" and "⋀a. x = Suc a ⟹ f a h = f' a h"
shows "(case x of Suc a ⇒ f a | 0 ⇒ g) h = (case x' of Suc a ⇒ f' a | 0 ⇒ g) h"
by (metis assms(1) assms(2) old.nat.exhaust old.nat.simps(4) old.nat.simps(5))

lemma if_cong[fundef_cong]:
assumes "b = b'"
and "b' ⟹ m1 s = m1' s"
and "¬ b' ⟹ m2 s = m2' s"
shows "(if b then m1 else m2) s = (if b' then m1' else m2') s"
using assms(1) assms(2) assms(3) by auto

lemma bind_case_pair_cong [fundef_cong]:
assumes "x = x'" and "⋀a b. x = (a,b) ⟹ f a b s = f' a b s"
shows "(case x of (a,b) ⇒ f a b) s = (case x' of (a,b) ⇒ f' a b) s"
by (simp add: assms(1) assms(2) prod.case_eq_if)

lemma bind_case_let_cong [fundef_cong]:
assumes "M = N"
and "(⋀x. x = N ⟹ f x s = g x s)"
shows "(Let M f) s = (Let N g) s"

lemma bind_case_some_cong [fundef_cong]:
assumes "x = x'" and "⋀a. x = Some a ⟹ f a s = f' a s" and "x = None ⟹ g s = g' s"
shows "(case x of Some a ⇒ f a | None ⇒ g) s = (case x' of Some a ⇒ f' a | None ⇒ g') s"
by (simp add: assms(1) assms(2) assms(3) option.case_eq_if)

lemma bind_case_bool_cong [fundef_cong]:
assumes "x = x'" and "x = True ⟹ f s = f' s" and "x = False ⟹ g s = g' s"
shows "(case x of True ⇒ f | False ⇒ g) s = (case x' of True ⇒ f' | False ⇒ g') s"
using assms(1) assms(2) assms(3) by auto

subsection ‹Other functions›

text ‹
The basic accessor functions of the state monad. ‹get› returns
the current state as result, does not fail, and does not change the state.
‹put s› returns unit, changes the current state to ‹s› and does not fail.
›
fun get :: "('s, 'e, 's) state_monad" where
"get s = Normal (s, s)"

fun put :: "'s ⇒ (unit, 'e, 's) state_monad" where
"put s _ = Normal ((), s)"

text ‹Apply a function to the current state and return the result
without changing the state.›
fun
applyf :: "('s ⇒ 'a) ⇒ ('a, 'e, 's) state_monad" where
"applyf f = get ⤜ (λs. return (f s))"

text ‹Modify the current state using the function passed in.›
fun
modify :: "('s ⇒ 's) ⇒ (unit, 'e, 's) state_monad"
where "modify f = get ⤜ (λs::'s. put (f s))"

fun
assert :: "'e ⇒ ('s ⇒ bool) ⇒ (unit, 'e, 's) state_monad" where
"assert x t = (λs. if (t s) then return () s else throw x s)"

fun
option :: "'e ⇒ ('s ⇒ 'a option) ⇒ ('a, 'e, 's) state_monad" where
"option x f = (λs. (case f s of
Some y ⇒ return y s
| None ⇒ throw x s))"

subsection ‹Some basic examples›

lemma "do {
x ← return 1;
return (2::nat);
return x
} =
return 1 ⤜ (λx. return (2::nat) ⤜ (λ_. (return x)))" ..

lemma "do {
x ← return 1;
return 2;
return x
} = return 1"
by auto

fun sub1 :: "(unit, nat, nat) state_monad" where
"sub1 0 = put 0 0"
| "sub1 (Suc n) = (do {
x ← get;
put x;
sub1
}) n"

fun sub2 :: "(unit, nat, nat) state_monad" where
"sub2 s =
(do {
n ← get;
(case n of
0 ⇒ put 0
| Suc n' ⇒ (do {
put n';
sub2
}))
}) s"

section "Hoare Logic"

named_theorems wprule

definition
valid :: "('s ⇒ bool) ⇒ ('a,'e,'s) state_monad ⇒
('a ⇒ 's ⇒ bool) ⇒
('e ⇒ bool) ⇒ bool"
("⦃_⦄/ _ /(⦃_⦄,/ ⦃_⦄)")
where
"⦃P⦄ f ⦃Q⦄,⦃E⦄ ≡ ∀s. P s ⟶ (case f s of
Normal (r,s') ⇒ Q r s'
| Exception e ⇒ E e)"

lemma weaken:
assumes "⦃Q⦄ f ⦃R⦄, ⦃E⦄"
and "∀s. P s ⟶ Q s"
shows "⦃P⦄ f ⦃R⦄, ⦃E⦄"
using assms by (simp add: valid_def)

lemma strengthen:
assumes "⦃P⦄ f ⦃Q⦄, ⦃E⦄"
and "∀a s. Q a s ⟶ R a s"
shows "⦃P⦄ f ⦃R⦄, ⦃E⦄"
unfolding valid_def
proof(rule allI[OF impI])
fix s
assume "P s"
show "case f s of Normal (r, s') ⇒ R r s' | Exception e ⇒ E e"
proof (cases "f s")
case (n a s')
then show ?thesis using assms valid_def ‹P s›
by (metis case_prod_conv result.simps(5))
next
case (e e)
then show ?thesis using assms valid_def ‹P s›
by (metis result.simps(6))
qed
qed

definition wp
where "wp f P E s ≡ (case f s of
Normal (r,s') ⇒ P r s'
| Exception e ⇒ E e)"

declare wp_def [solidity_symbex]

lemma wp_valid: assumes "⋀s. P s ⟹ (wp f Q E s)" shows "⦃P⦄ f ⦃Q⦄,⦃E⦄"
unfolding valid_def by (metis assms wp_def)

lemma valid_wp: assumes "⦃P⦄ f ⦃Q⦄,⦃E⦄" shows "⋀s. P s ⟹ (wp f Q E s)"
by (metis assms valid_def wp_def)

lemma put: "⦃λs. P () x⦄ put x ⦃P⦄,⦃E⦄"
using valid_def by fastforce

lemma put':
assumes "∀s. P s ⟶ Q () x"
shows "⦃λs. P s⦄ put x ⦃Q⦄,⦃E⦄"
using assms weaken[OF put, of P Q x E] by blast

lemma wpput[wprule]:
assumes "P () x"
shows "wp (put x) P E s"
unfolding wp_def using assms by simp

lemma get: "⦃λs. P s s⦄ get ⦃P⦄,⦃E⦄"
using valid_def by fastforce

lemma get':
assumes "∀s. P s ⟶ Q s s"
shows "⦃λs. P s⦄ get ⦃Q⦄,⦃E⦄"
using assms weaken[OF get] by blast

lemma wpget[wprule]:
assumes "P s s"
shows "wp get P E s"
unfolding wp_def using assms by simp

lemma return: "⦃λs. P x s⦄ return x ⦃P⦄,⦃E⦄"
using valid_def by fastforce

lemma return':
assumes "∀s. P s ⟶ Q x s"
shows "⦃λs. P s⦄ return x ⦃Q⦄,⦃E⦄"
using assms weaken[OF return, of P Q x E] by blast

lemma wpreturn[wprule]:
assumes "P x s"
shows "wp (return x) P E s"
unfolding wp_def using assms by simp

lemma bind:
assumes "∀x. ⦃B x⦄ g x ⦃C⦄,⦃E⦄"
and "⦃A⦄ f ⦃B⦄,⦃E⦄"
shows "⦃A⦄ f ⤜ g ⦃C⦄,⦃E⦄"
unfolding valid_def
proof (rule allI[OF impI])
fix s assume a: "A s"
show "case (f ⤜ g) s of Normal (r, s') ⇒ C r s' | Exception e ⇒ E e"
proof (cases "f s" rule:result_cases)
case nf: (n a s')
with assms(2) a have b: "B a s'" using valid_def[where ?f=f] by auto
then show ?thesis
proof (cases "g a s'" rule:result_cases)
case ng: (n a' s'')
with assms(1) b have c: "C a' s''" using valid_def[where ?f="g a"] by fastforce
moreover from ng nf have "(f ⤜ g) s = Normal (a', s'')" by simp
ultimately show ?thesis by simp
next
case eg: (e e)
with assms(1) b have c: "E e" using valid_def[where ?f="g a"] by fastforce
moreover from eg nf have "(f ⤜ g) s = Exception e" by simp
ultimately show ?thesis by simp
qed
next
case (e e)
with a assms(2) have "E e" using valid_def[where ?f=f] by auto
moreover from e have "(f ⤜ g) s = Exception e" by simp
ultimately show ?thesis by simp
qed
qed

lemma wpbind[wprule]:
assumes "wp f (λa. (wp (g a) P E)) E s"
shows "wp (f ⤜ g) P E s"
proof (cases "f s" rule:result_cases)
case nf: (n a s')
then have **:"wp (g a) P E s'" using wp_def[of f "λa. wp (g a) P E"] assms by simp
show ?thesis
proof (cases "g a s'" rule:result_cases)
case ng: (n a' s'')
then have "P a' s''" using wp_def[of "g a" P] ** by simp
moreover from nf ng have "(f ⤜ g) s = Normal (a', s'')" by simp
ultimately show ?thesis using wp_def by fastforce
next
case (e e)
then have "E e" using wp_def[of "g a" P] ** by simp
moreover from nf e have "(f ⤜ g) s = Exception e" by simp
ultimately show ?thesis using wp_def by fastforce
qed
next
case (e e)
then have "E e" using wp_def[of f "λa. wp (g a) P E"] assms by simp
moreover from e have "(f ⤜ g) s = Exception e" by simp
ultimately show ?thesis using wp_def by fastforce
qed

lemma wpassert[wprule]:
assumes "t s ⟹ wp (return ()) P E s"
and "¬ t s ⟹ wp (throw x) P E s"
shows "wp (assert x t) P E s"
using assms unfolding wp_def by simp

lemma throw:
assumes "E x"
shows "⦃P⦄ throw x ⦃Q⦄, ⦃E⦄"
using valid_def assms by fastforce

lemma wpthrow[wprule]:
assumes "E x"
shows "wp (throw x) P E s"
unfolding wp_def using assms by simp

lemma applyf:
"⦃λs. P (f s) s⦄ applyf f ⦃λa s. P a s⦄,⦃E⦄"

lemma applyf':
assumes "∀s. P s ⟶ Q (f s) s"
shows "⦃λs. P s⦄ applyf f ⦃λa s. Q a s⦄,⦃E⦄"
using assms weaken[OF applyf] by blast

lemma wpapplyf[wprule]:
assumes "P (f s) s"
shows "wp (applyf f) P E s"
unfolding wp_def using assms by simp

lemma modify:
"⦃λs. P () (f s)⦄ modify f ⦃P⦄, ⦃E⦄"
apply simp
apply (rule bind,rule allI)
apply (rule put)
apply (rule get)
done

lemma modify':
assumes "∀s. P s ⟶ Q () (f s)"
shows "⦃λs. P s⦄ modify f ⦃Q⦄, ⦃E⦄"
using assms weaken[OF modify, of P Q _ E] by blast

lemma wpmodify[wprule]:
assumes "P () (f s)"
shows "wp (modify f) P E s"
unfolding wp_def using assms by simp

lemma wpcasenat[wprule]:
assumes "(y=(0::nat) ⟹ wp (f y) P E s)"
and "⋀x. y=Suc x ⟹ wp (g x) P E s"
shows "wp (case y::nat of 0 ⇒ f y | Suc x ⇒ g x) P E s"
by (metis assms(1) assms(2) not0_implies_Suc old.nat.simps(4) old.nat.simps(5))

lemma wpif[wprule]:
assumes "c ⟹ wp f P E s"
and "¬c ⟹ wp g P E s"
shows "wp (if c then f else g) P E s"
using assms by simp

lemma wpsome[wprule]:
assumes "⋀y. x = Some y ⟹ wp (f y) P E s"
and "x = None ⟹ wp g P E s"
shows "wp (case x of Some y ⇒ f y | None ⇒ g) P E s"
using assms unfolding wp_def by (simp split: option.split)

lemma wpoption[wprule]:
assumes "⋀y. f s = Some y ⟹ wp (return y) P E s"
and "f s = None ⟹ wp (throw x) P E s"
shows "wp (option x f) P E s"
using assms unfolding wp_def by (auto split:option.split)

lemma wpprod[wprule]:
assumes "⋀x y. a = (x,y) ⟹ wp (f x y) P E s"
shows "wp (case a of (x, y) ⇒ f x y) P E s"
using assms unfolding wp_def
by (simp split: prod.split)

method wp = rule wprule; wp?
method wpvcg = rule wp_valid, wp

lemma "⦃λs. s=5⦄ do {
put (5::nat);
x ← get;
return x
} ⦃λa s. s=5⦄,⦃λe. False⦄"
by (wpvcg, simp)
end
```