Theory StateMonad

theory StateMonad
imports Main "HOL-Library.Monad_Syntax"
begin

section "State Monad with Exceptions"

datatype ('n, 'e) result = Normal 'n | Exception 'e

type_synonym ('a, 'e, 's) state_monad = "'s  ('a × 's, 'e) result"

lemma result_cases[cases type: result]:
  fixes x :: "('a × 's, 'e) result"
  obtains (n) a s where "x = Normal (a, s)"
        | (e) e where "x = Exception e"
proof (cases x)
    case (Normal n)
    then show ?thesis
      by (metis n prod.swap_def swap_swap)
  next
    case (Exception e)
    then show ?thesis ..
qed

subsection ‹Fundamental Definitions›

fun
  return :: "'a  ('a, 'e, 's) state_monad" where
  "return a s = Normal (a, s)"

fun
  throw :: "'e  ('a, 'e, 's) state_monad" where
  "throw e s = Exception e"

fun
  bind :: "('a, 'e, 's) state_monad  ('a  ('b, 'e, 's) state_monad) 
           ('b, 'e, 's) state_monad" (infixl ">>=" 60)
  where
    "bind f g s = (case f s of
                      Normal (a, s')  g a s'
                    | Exception e  Exception e)"

adhoc_overloading Monad_Syntax.bind bind

lemma throw_left[simp]: "throw x  y = throw x" by auto

subsection ‹The Monad Laws›

text @{term return} is absorbed at the left of a @{term bind},
  applying the return value directly:›
lemma return_bind [simp]: "(return x  f) = f x"
  by auto

text @{term return} is absorbed on the right of a @{term bind}
lemma bind_return [simp]: "(m  return) = m"
proof -
  have "s. bind m return s = m s"
  proof
    fix s
    show "bind m return s = m s"
    proof (cases "m s" rule: result_cases)
      case (n a s)
      then show ?thesis by auto
    next
      case (e e)
      then show ?thesis by auto
    qed
  qed
  thus ?thesis by auto
qed

text @{term bind} is associative›
lemma bind_assoc:
  fixes m :: "('a,'e,'s) state_monad"
  fixes f :: "'a  ('b,'e,'s) state_monad"
  fixes g :: "'b  ('c,'e,'s) state_monad"
  shows "(m  f)  g  =  m  (λx. f x g)"
proof
  fix s
  show "bind (bind m f) g s = bind m (λx. bind (f x) g) s"
    by (cases "m s" rule: result_cases, simp+)
qed

subsection ‹Basic Conguruence Rules›

(*
  This lemma is needed for termination proofs.

  Function sub1 , for example, requires it.
*)
lemma monad_cong[fundef_cong]:
  fixes m1 m2 m3 m4
  assumes "m1 s = m2 s"
      and "v s'. m2 s = Normal (v, s')  m3 v s' = m4 v s'"
    shows "(bind m1 m3) s = (bind m2 m4) s"
  apply(insert assms, cases "m1 s")
  apply (metis StateMonad.bind.simps old.prod.case result.simps(5))
  by (metis bind.elims result.simps(6))

(*
  Lemma bind_case_nat_cong is required if a bind operand is a case analysis over nat.

  Function sub2, for example, requires it.
*)
lemma bind_case_nat_cong [fundef_cong]:
  assumes "x = x'" and "a. x = Suc a  f a h = f' a h"
  shows "(case x of Suc a  f a | 0  g) h = (case x' of Suc a  f' a | 0  g) h"
  by (metis assms(1) assms(2) old.nat.exhaust old.nat.simps(4) old.nat.simps(5))

lemma if_cong[fundef_cong]:
  assumes "b = b'"
    and "b'  m1 s = m1' s"
    and "¬ b'  m2 s = m2' s"
  shows "(if b then m1 else m2) s = (if b' then m1' else m2') s"
  using assms(1) assms(2) assms(3) by auto

lemma bind_case_pair_cong [fundef_cong]:
  assumes "x = x'" and "a b. x = (a,b)  f a b s = f' a b s"
  shows "(case x of (a,b)  f a b) s = (case x' of (a,b)  f' a b) s"
  by (simp add: assms(1) assms(2) prod.case_eq_if)

lemma bind_case_let_cong [fundef_cong]:
  assumes "M = N"
      and "(x. x = N  f x s = g x s)"
    shows "(Let M f) s = (Let N g) s"
  by (simp add: assms(1) assms(2))

lemma bind_case_some_cong [fundef_cong]:
  assumes "x = x'" and "a. x = Some a  f a s = f' a s" and "x = None  g s = g' s"
  shows "(case x of Some a  f a | None  g) s = (case x' of Some a  f' a | None  g') s"
  by (simp add: assms(1) assms(2) assms(3) option.case_eq_if)

subsection ‹Other functions›

text ‹
  The basic accessor functions of the state monad. get› returns
  the current state as result, does not fail, and does not change the state.
  put s› returns unit, changes the current state to s› and does not fail.
›
fun get :: "('s, 'e, 's) state_monad" where
  "get s = Normal (s, s)"

fun put :: "'s  (unit, 'e, 's) state_monad" where
  "put s _ = Normal ((), s)"

text ‹Apply a function to the current state and return the result
without changing the state.›
fun
  applyf :: "('s  'a)  ('a, 'e, 's) state_monad" where
 "applyf f = get  (λs. return (f s))"

text ‹Modify the current state using the function passed in.›
fun
  modify :: "('s  's)  (unit, 'e, 's) state_monad"
where "modify f = get  (λs::'s. put (f s))"

text ‹
  Perform a test on the current state, performing the left monad if
  the result is true or the right monad if the result is false.
›
fun
  condition :: "('s  bool)  ('a, 'e, 's) state_monad  ('a, 'e, 's) state_monad  ('a, 'e, 's) state_monad"
where
  "condition P L R s = (if (P s) then (L s) else (R s))"

notation (output)
  condition  ("(condition (_)//  (_)//  (_))" [1000,1000,1000] 1000)

lemma condition_cong[fundef_cong]:
  assumes "b s = b' s"
    and "b' s  m1 s = m1' s"
    and "s'. s' = s  ¬ b' s'  m2 s' = m2' s'"
  shows "(condition b m1 m2) s = (condition b' m1' m2') s"
  by (simp add: assms(1) assms(2) assms(3))

fun
  assert :: "'e  ('s  bool)  ('a, 'e, 's) state_monad  ('a, 'e, 's) state_monad" where
 "assert x t m = condition t (throw x) m"

notation (output)
  assert  ("(assert (_)//  (_)//  (_))" [1000,1000,1000] 1000)

lemma assert_cong[fundef_cong]:
  assumes "b s = b' s"
    and "¬ b' s  m s = m' s"
  shows "(assert x b m) s = (assert x b' m') s"
  by (simp add: assms(1) assms(2))

subsection ‹Some basic examples›

lemma "do {
        x  return 1;
        return (2::nat);
        return x
       } =
       return 1  (λx. return (2::nat)  (λ_. (return x)))" ..

lemma "do {
        x  return 1;
          return 2;
          return x
       } = return 1"
  by auto

fun sub1 :: "(unit, nat, nat) state_monad" where
    "sub1 0 = put 0 0"
  | "sub1 (Suc n) = (do {
                     x  get;
                     put x;
                     sub1
                    }) n"

fun sub2 :: "(unit, nat, nat) state_monad" where
  "sub2 s =
     (do {
        n  get;
        (case n of
          0  put 0
        | Suc n'  (do {
                  put n';
                  sub2
                }))
     }) s"

fun sub3 :: "(unit, nat, nat) state_monad" where
  "sub3 s =
     condition (λn. n=0)
       (return ())
       (do {
         n  get;
         put (n - 1);
         sub3
        }) s"

fun sub4 :: "(unit, nat, nat) state_monad" where
  "sub4 s =
     assert (0) (λn. n=0)
       (do {
         n  get;
         put (n - 1);
         sub4
        }) s"

fun sub5 :: "(unit, nat, (nat*nat)) state_monad" where
  "sub5 s =
     assert (0) (λn. fst n=0)
       (do {
         (n,m)  get;
          put (n - 1,m);          
          sub5
        }) s"
end