Theory Spec_Monad

(*
 * Copyright (c) 2024 Apple Inc. All rights reserved.
 *
 * SPDX-License-Identifier: BSD-2-Clause
 *)

text_raw ‹\part{AutoCorres}›

chapter ‹Spec-Monad›

theory Spec_Monad
  imports
  "Basic_Runs_To_VCG"
  "HOL-Library.Complete_Partial_Order2"
  "HOL-Library.Monad_Syntax"
  "AutoCorres_Utils"
begin

section rel_map› and rel_project›

definition rel_map :: "('a  'b)  'a  'b  bool" where
  "rel_map f r x = (x = f r)"

lemma rel_map_direct[simp]: "rel_map f a (f a)"
  by (simp add: rel_map_def)

abbreviation "rel_project  rel_map"

lemmas rel_project_def = rel_map_def

lemma rel_project_id: "rel_project id = (=)" 
   "rel_project (λv. v) = (=)"
  by (auto simp add: rel_project_def fun_eq_iff)

lemma rel_project_unit: "rel_project (λ_. ()) x y = True"
  by (simp add: rel_project_def)

lemma rel_projectI: "y = prj x  rel_project prj x y"
  by (simp add: rel_project_def)

lemma rel_project_conv: "rel_project prj x y = (y = prj x)"
  by (simp add: rel_project_def)

section ‹Misc Theorems›

declare case_unit_Unity [simp] ― ‹without this rule simplifier seems loops in unexpected ways ›

lemma abs_const_unit: "(λ(v::unit). f) = (λ(). f)"
  by auto

lemma SUP_mono'': "(x. xA  f x  g x)  (xA. f x)  (xA. g x :: _::complete_lattice)"
  by (rule SUP_mono) auto

lemma wf_nat_bound: "wf {(a, b). b < a  b  (n::nat)}"
  apply (rule wf_subset)
  apply (rule wf_measure[where f="λa. Suc n - a"])
  apply auto
  done

lemma (in complete_lattice) admissible_le:
  "ccpo.admissible Inf (≤) (λx. (y  x))"
  by (simp add: ccpo.admissibleI local.Inf_greatest)

lemma mono_const: "mono (λx. c)"
  by (simp add: mono_def)

lemma mono_lam: "(a. mono (λx. F x a))  mono F"
  by (simp add: mono_def le_fun_def)

lemma mono_app: "mono (λx. x a)"
  by (simp add: mono_def le_fun_def)

lemma all_cong_map:
  assumes f_f': "y. f (f' y) = y" and P_Q: "x. P x  Q (f x)"
  shows "(x. P x)  (y. Q y)"
  using assms by metis

lemma rel_set_refl: "(x. x  A  R x x)  rel_set R A A"
  by (auto simp: rel_set_def)

lemma rel_set_converse_iff: "rel_set R X Y  rel_set R¯¯ Y X"
  by (auto simp add: rel_set_def)

lemma rel_set_weaken:
  "(x y. x  A  y  B  P x y  Q x y)  rel_set P A B  rel_set Q A B"
  by (force simp: rel_set_def)

lemma sim_set_refl: "(x. x  X  R x x)  sim_set R X X"
  by (auto simp: sim_set_def)

section ‹Galois Connections›

lemma mono_of_Sup_cont:
  fixes f :: "'a::complete_lattice  'b::complete_lattice"
  assumes cont: "X. f (Sup X) = (SUP xX. f x)"
  assumes xy: "x  y"
  shows "f x  f y"
proof -
  have "f x  sup (f x) (f y)" by (rule sup_ge1)
  also have " = (SUP x{x, y}. f x)" by simp
  also have " = f (Sup {x, y})" by (rule cont[symmetric])
  finally show ?thesis using xy by (simp add: sup_absorb2)
qed

lemma gc_of_Sup_cont:
  fixes f :: "'a::complete_lattice  'b::complete_lattice"
  assumes cont: "X. f (Sup X) = (SUP xX. f x)"
  shows "f x  y  x  Sup {x. f x  y}"
proof safe
  assume "f x  y" then show "x  Sup {x. f x  y}"
    by (intro Sup_upper) simp
next
  assume "x  Sup {x. f x  y}"
  then have "f x  f (Sup {x. f x  y})"
    by (rule mono_of_Sup_cont[OF cont])
  also have " = (SUP x{x. f x  y}. f x)" by (rule cont)
  also have "  y" by (rule Sup_least) auto
  finally show "f x  y" .
qed

lemma mono_of_Inf_cont:
  fixes f :: "'a::complete_lattice  'b::complete_lattice"
  assumes cont: "X. f (Inf X) = (INF xX. f x)"
  assumes xy: "x  y"
  shows "f x  f y"
proof -
  have "f x = f (Inf {x, y})" using xy by (simp add: inf_absorb1)
  also have " = (INF x{x, y}. f x)" by (rule cont)
  also have " = inf (f x) (f y)" by simp
  also have "  f y" by (rule inf_le2)
  finally show ?thesis .
qed

lemma gc_of_Inf_cont:
  fixes f :: "'a::complete_lattice  'b::complete_lattice"
  assumes cont: "X. f (Inf X) = (INF xX. f x)"
  shows "Inf {y. x  f y}  y  x  f y"
proof safe
  assume "x  f y" then show "Inf {y. x  f y}  y"
    by (intro Inf_lower) simp
next
  assume *: "Inf {y. x  f y}  y"
  have "x  (INF y{y. x  f y}. f y)" by (rule Inf_greatest) auto
  also have " = f (Inf {y. x  f y})" by (rule cont[symmetric])
  also have "  f y"
    using * by (rule mono_of_Inf_cont[OF cont])
  finally show "x  f y" .
qed

lemma gfp_fusion:
  assumes f_g: "x y. g x  y  x  f y"
  assumes a: "mono a"
  assumes b: "mono b"
  assumes *: "x. f (a x) = b (f x)"
  shows "f (gfp a) = gfp b"
  apply (intro antisym gfp_upperbound)
  subgoal
    apply (subst *[symmetric])
    apply (subst gfp_fixpoint[OF a])
    ..
  apply (rule f_g[THEN iffD1])
  apply (rule gfp_upperbound)
  apply (rule f_g[THEN iffD2])
  apply (subst *)
  apply (subst gfp_fixpoint[OF b, symmetric])
  apply (rule monoD[OF b])
  apply (rule f_g[THEN iffD1])
  apply (rule order_refl)
  done

lemma lfp_fusion:
  assumes f_g: "x y. g x  y  x  f y"
  assumes a: "mono a"
  assumes b: "mono b"
  assumes *: "x. g (a x) = b (g x)"
  shows "g (lfp a) = lfp b"
  apply (intro antisym lfp_lowerbound)
  subgoal
    apply (rule f_g[THEN iffD2])
    apply (rule lfp_lowerbound)
    apply (rule f_g[THEN iffD1])
    apply (subst *)
    apply (subst (2) lfp_fixpoint[OF b, symmetric])
    apply (rule monoD[OF b])
    apply (rule f_g[THEN iffD2])
    apply (rule order_refl)
    done
  subgoal
    apply (subst *[symmetric])
    apply (subst lfp_fixpoint[OF a])
    ..
  done

section post_state› type›

datatype 'r post_state = Failure | Success "'r set"

text constFailure is supposed to model things like undefined behaviour in C. We usually have
to show the absence of constFailure for all possible executions of the program. Moreover,
it is used to model the 'result' of a non terminating computation.
›

lemma split_post_state:
  "x = (case x of Success X  Success X | Failure  Failure)"
  for x::"'a post_state"
  by (cases x) auto

instantiation post_state :: (type) order
begin

inductive less_eq_post_state :: "'a post_state  'a post_state  bool" where
  Failure_le[simp, intro]: "less_eq_post_state p Failure"
| Success_le_Success[intro]: "r  q  less_eq_post_state (Success r) (Success q)"

definition less_post_state :: "'a post_state  'a post_state  bool" where
  "less_post_state p q  p  q  ¬ q  p"

instance
proof
  fix p q r :: "'a post_state"
  show "p  p" by (cases p) auto
  show "p  q  q  r  p  r" by (blast elim: less_eq_post_state.cases)
  show "p  q  q  p  p = q" by (blast elim: less_eq_post_state.cases)
qed (fact less_post_state_def)
end

lemma Success_le_Success_iff[simp]: "Success r  Success q  r  q"
  by (auto elim: less_eq_post_state.cases)

lemma Failure_le_iff[simp]: "Failure  q  q = Failure"
  by (auto elim: less_eq_post_state.cases)

instantiation post_state :: (type) complete_lattice
begin

definition top_post_state :: "'a post_state" where
  "top_post_state = Failure"

definition bot_post_state :: "'a post_state" where
  "bot_post_state = Success {}"

definition inf_post_state :: "'a post_state  'a post_state  'a post_state" where
  "inf_post_state =
    (λFailure  id | Success res1 
      (λFailure  Success res1 | Success res2  Success (res1  res2)))"

definition sup_post_state :: "'a post_state  'a post_state  'a post_state" where
  "sup_post_state =
    (λFailure  (λ_. Failure) | Success res1 
      (λFailure  Failure | Success res2  Success (res1  res2)))"

definition Inf_post_state :: "'a post_state set  'a post_state" where
  "Inf_post_state s = (if Success -` s = {} then Failure else Success ( (Success -` s)))"

definition Sup_post_state :: "'a post_state set  'a post_state" where
  "Sup_post_state s = (if Failure  s then Failure else Success ( (Success -` s)))"

instance
proof
  fix x y z :: "'a post_state" and A :: "'a post_state set"
  show "inf x y  y" by (simp add: inf_post_state_def split: post_state.split)
  show "inf x y  x" by (simp add: inf_post_state_def split: post_state.split)
  show "x  y  x  z  x  inf y z"
    by (auto simp add: inf_post_state_def elim!: less_eq_post_state.cases)
  show "x  sup x y" by (simp add: sup_post_state_def split: post_state.split)
  show "y  sup x y" by (simp add: sup_post_state_def split: post_state.split)
  show "x  y  z  y  sup x z  y"
    by (auto simp add: sup_post_state_def elim!: less_eq_post_state.cases)
  show "x  A  Inf A  x" by (cases x) (auto simp add: Inf_post_state_def)
  show "(x. x  A  z  x)  z  Inf A" by (cases z) (force simp add: Inf_post_state_def)+
  show "x  A  x  Sup A" by (cases x) (auto simp add: Sup_post_state_def)
  show "(x. x  A  x  z)  Sup A  z" by (cases z) (force simp add: Sup_post_state_def)+
qed (simp_all add: top_post_state_def Inf_post_state_def bot_post_state_def Sup_post_state_def)
end

primrec holds_post_state :: "('a  bool)  'a post_state  bool" where
  "holds_post_state P Failure  False"
| "holds_post_state P (Success X)  (xX. P x)"

primrec holds_partial_post_state :: "('a  bool)  'a post_state  bool" where
  "holds_partial_post_state P Failure  True"
| "holds_partial_post_state P (Success X)  (xX. P x)"

inductive
  sim_post_state :: "('a  'b  bool)  'a post_state  'b post_state  bool"
  for R
where
  [simp]: "p. sim_post_state R p Failure"
| "A B. sim_set R A B  sim_post_state R (Success A) (Success B)"

inductive
  rel_post_state :: "('a  'b  bool)  'a post_state  'b post_state  bool"
  for R
where
  [simp]: "rel_post_state R Failure Failure"
| "A B. rel_set R A B  rel_post_state R (Success A) (Success B)"

primrec lift_post_state :: "('a  'b  bool)  'b post_state  'a post_state" where
  "lift_post_state R Failure = Failure"
| "lift_post_state R (Success Y) = Success {x. yY. R x y}"

primrec unlift_post_state :: "('a  'b  bool)  'a post_state  'b post_state" where
  "unlift_post_state R Failure = Failure"
| "unlift_post_state R (Success X) = Success {y. x. R x y  x  X}"

primrec map_post_state :: "('a  'b)  'a post_state  'b post_state" where
  "map_post_state f Failure  = Failure"
| "map_post_state f (Success r) = Success (f ` r)"

primrec vmap_post_state :: "('a  'b)  'b post_state  'a post_state" where
  "vmap_post_state f Failure  = Failure"
| "vmap_post_state f (Success r) = Success (f -` r)"

definition pure_post_state :: "'a  'a post_state" where
  "pure_post_state v = Success {v}"

primrec bind_post_state :: "'a post_state  ('a  'b post_state)  'b post_state" where
  "bind_post_state Failure p = Failure"
| "bind_post_state (Success res) p =  (p ` res)"

subsection ‹Order Properties›

lemma top_ne_bot[simp]: "(:: _ post_state)  "
  and bot_ne_top[simp]: "(:: _ post_state)  "
  by (auto simp: top_post_state_def bot_post_state_def)

lemma Success_ne_top[simp]:
  "Success X  " "  Success X"
  by (auto simp: top_post_state_def)

lemma Success_eq_bot_iff[simp]:
  "Success X =   X = {}" " = Success X  X = {}"
  by (auto simp: bot_post_state_def)

lemma Sup_Success: "(xX. Success (f x)) = Success (xX. f x)"
  by (auto simp: Sup_post_state_def)

lemma Sup_Success_pair: "((x, y)X. Success (f x y)) = Success ((x, y)X. f x y)"
  by (simp add: split_beta' Sup_Success)

subsection holds_post_state›

lemma holds_post_state_iff:
  "holds_post_state P p  (X. p = Success X  (xX. P x))"
  by (cases p) auto

lemma holds_post_state_weaken:
  "holds_post_state P p  (x. P x  Q x)  holds_post_state Q p"
  by (cases p; auto)

lemma holds_post_state_combine:
  "holds_post_state P p  holds_post_state Q p 
    (x. P x  Q x  R x)  holds_post_state R p"
  by (cases p; auto)

lemma holds_post_state_Ball:
  "Y  {}  p  Failure 
    holds_post_state (λx. yY. P x y) p  (yY. holds_post_state (λx. P x y) p)"
  by (cases p; auto)

lemma holds_post_state_All:
  "holds_post_state (λx. y. P x y) p  (y. holds_post_state (λx. P x y) p)"
  by (cases p; auto)

lemma holds_post_state_BexI:
  "y  Y  holds_post_state (λx. P x y) p  holds_post_state (λx. yY. P x y) p"
  by (cases p; auto)

lemma holds_post_state_conj:
  "holds_post_state (λx. P x  Q x) p  holds_post_state P p  holds_post_state Q p"
  by (cases p; auto)

lemma sim_post_state_iff:
  "sim_post_state R p q 
    (P. holds_post_state (λy. x. R x y  P x) q  holds_post_state P p)"
  apply (cases p; cases q; simp add: sim_set_def sim_post_state.simps)
  apply safe
  apply force
  apply force
  subgoal premises prems for A B
    using prems(3)[THEN spec, of "λa. bB. R a b"] prems(4)
    by auto
  done

lemma post_state_le_iff:
  "p  q  (P. holds_post_state P q  holds_post_state P p)"
  by (cases p; cases q; force simp add: subset_eq Ball_def)

lemma post_state_eq_iff:
  "p = q  (P. holds_post_state P p  holds_post_state P q)"
  by (simp add: order_eq_iff post_state_le_iff)

lemma holds_top_post_state[simp]: "¬ holds_post_state P "
  by (simp add: top_post_state_def)

lemma holds_bot_post_state[simp]: "holds_post_state P "
  by (simp add: bot_post_state_def)



lemma holds_Sup_post_state[simp]: "holds_post_state P (Sup F)  (fF. holds_post_state P f)"
  by (subst (2) split_post_state)
    (auto simp: Sup_post_state_def split: post_state.splits)

lemma holds_post_state_gfp:
  "holds_post_state P (gfp f)  (p. p  f p  holds_post_state P p)"
  by (simp add: gfp_def)

lemma holds_post_state_gfp_apply:
  "holds_post_state P (gfp f x)  (p. p  f p  holds_post_state P (p x))"
  by (simp add: gfp_def)

lemma holds_lift_post_state[simp]:
  "holds_post_state P (lift_post_state R x)  holds_post_state (λy. x. R x y  P x) x"
  by (cases x) auto

lemma holds_map_post_state[simp]:
  "holds_post_state P (map_post_state f x)  holds_post_state (λx. P (f x)) x"
  by (cases x) auto

lemma holds_vmap_post_state[simp]:
  "holds_post_state P (vmap_post_state f x)  holds_post_state (λx. y. f y = x  P y) x"
  by (cases x) auto

lemma holds_pure_post_state[simp]: "holds_post_state P (pure_post_state x)  P x"
  by (simp add: pure_post_state_def)

lemma holds_bind_post_state[simp]:
  "holds_post_state P (bind_post_state f g)  holds_post_state (λx. holds_post_state P (g x)) f"
  by (cases f) auto

lemma holds_post_state_False: "holds_post_state (λx. False) f  f = "
  by (cases f) (auto simp: bot_post_state_def)

subsection holds_post_state_partial›

lemma holds_partial_post_state_of_holds:
  "holds_post_state P p  holds_partial_post_state P p"
  by (cases p; simp)

lemma holds_partial_post_state_iff:
  "holds_partial_post_state P p  (X. p = Success X  (xX. P x))"
  by (cases p) auto

lemma holds_partial_post_state_True[simp]: "holds_partial_post_state (λx. True) p"
  by (cases p; simp)

lemma holds_partial_post_state_weaken:
  "holds_partial_post_state P p  (x. P x  Q x)  holds_partial_post_state Q p"
  by (cases p; auto)

lemma holds_partial_post_state_Ball:
  "holds_partial_post_state (λx. yY. P x y) p  (yY. holds_partial_post_state (λx. P x y) p)"
  by (cases p; auto)

lemma holds_partial_post_state_All:
  "holds_partial_post_state (λx. y. P x y) p  (y. holds_partial_post_state (λx. P x y) p)"
  by (cases p; auto)

lemma holds_partial_post_state_conj:
  "holds_partial_post_state (λx. P x  Q x) p 
    holds_partial_post_state P p  holds_partial_post_state Q p"
  by (cases p; auto)

lemma holds_partial_top_post_state[simp]: "holds_partial_post_state P "
  by (simp add: top_post_state_def)

lemma holds_partial_bot_post_state[simp]: "holds_partial_post_state P "
  by (simp add: bot_post_state_def)

lemma holds_partial_pure_post_state[simp]: "holds_partial_post_state P (pure_post_state x)  P x"
  by (simp add: pure_post_state_def)

lemma holds_partial_Sup_post_stateI:
  "(x. x  X  holds_partial_post_state P x)  holds_partial_post_state P (Sup X)"
  by (force simp: Sup_post_state_def)

lemma holds_partial_bind_post_stateI:
  "holds_partial_post_state (λx. holds_partial_post_state P (g x)) f 
    holds_partial_post_state P (bind_post_state f g)"
  by (cases f) (auto intro: holds_partial_Sup_post_stateI)

lemma holds_partial_map_post_state[simp]:
  "holds_partial_post_state P (map_post_state f x)  holds_partial_post_state (λx. P (f x)) x"
  by (cases x) auto

lemma holds_partial_vmap_post_state[simp]:
  "holds_partial_post_state P (vmap_post_state f x) 
    holds_partial_post_state (λx. y. f y = x  P y) x"
  by (cases x) auto

lemma holds_partial_bind_post_state:
  "holds_partial_post_state (λx. holds_partial_post_state P (g x)) f 
    holds_partial_post_state P (bind_post_state f g)"
  by (cases f) (auto intro: holds_partial_Sup_post_stateI)

subsection sim_post_state›

lemma sim_post_state_eq_iff_le: "sim_post_state (=) p q  p  q"
  by (simp add: post_state_le_iff sim_post_state_iff)

lemma sim_post_state_Success_Success_iff[simp]:
  "sim_post_state R (Success r) (Success q)  (ar. bq. R a b)"
  by (auto simp add: sim_set_def elim: sim_post_state.cases intro: sim_post_state.intros)

lemma sim_post_state_Success2:
  "sim_post_state R f (Success q)  holds_post_state (λa. bq. R a b) f"
  by (cases f; simp add: sim_post_state.simps sim_set_def)

lemma sim_post_state_Failure1[simp]: "sim_post_state R Failure q  q = Failure"
  by (auto elim: sim_post_state.cases intro: sim_post_state.intros)

lemma sim_post_state_top2[simp, intro]: "sim_post_state R p "
  by (simp add: top_post_state_def)

lemma sim_post_state_top1[simp]: "sim_post_state R  q  q = "
  using sim_post_state_Failure1 by (metis top_post_state_def)

lemma sim_post_state_bot2[simp, intro]: "sim_post_state R p   p = "
  by (cases p; simp add: bot_post_state_def)

lemma sim_post_state_bot1[simp, intro]: "sim_post_state R  q"
  by (cases q; simp add: bot_post_state_def)

lemma sim_post_state_le1: "sim_post_state R f' g  f  f'  sim_post_state R f g"
  by (simp add: post_state_le_iff sim_post_state_iff)

lemma sim_post_state_le2: "sim_post_state R f g  g  g'  sim_post_state R f g'"
  by (simp add: post_state_le_iff sim_post_state_iff)

lemma sim_post_state_Sup1:
  "sim_post_state R (Sup A) f  (aA. sim_post_state R a f)"
  by (auto simp add: sim_post_state_iff)

lemma sim_post_state_Sup2:
  "a  A  sim_post_state R f a  sim_post_state R f (Sup A)"
  by (auto simp add: sim_post_state_iff)

lemma sim_post_state_Sup:
  "aA. bB. sim_post_state R a b  sim_post_state R (Sup A) (Sup B)"
  by (auto simp: sim_post_state_Sup1 intro: sim_post_state_Sup2)

lemma sim_post_state_weaken:
  "sim_post_state R f g  (x y. R x y  Q x y)  sim_post_state Q f g"
  by (cases f; cases g; force)

lemma sim_post_state_trans:
  "sim_post_state R f g  sim_post_state Q g h  (x y z. R x y  Q y z  S x z) 
    sim_post_state S f h"
  by (fastforce elim!: sim_post_state.cases intro: sim_post_state.intros simp: sim_set_def)

lemma sim_post_state_refl': "holds_partial_post_state (λx. R x x) f  sim_post_state R f f"
  by (cases f; auto simp: rel_set_def)

lemma sim_post_state_refl: "(x. R x x)  sim_post_state R f f"
  by (simp add: sim_post_state_refl')

lemma sim_post_state_pure_post_state2:
  "sim_post_state F f (pure_post_state x)  holds_post_state (λy. F y x) f"
  by (cases f; simp add: pure_post_state_def)

subsection rel_post_state›

lemma rel_post_state_top[simp, intro!]: "rel_post_state R  "
  by (auto simp: top_post_state_def)

lemma rel_post_state_top_iff[simp]:
  "rel_post_state R  p  p = "
  "rel_post_state R p   p = "
  by (auto simp: top_post_state_def rel_post_state.simps)

lemma rel_post_state_Success_iff[simp]:
  "rel_post_state R (Success A) (Success B)  rel_set R A B"
  by (auto elim: rel_post_state.cases intro: rel_post_state.intros)

lemma rel_post_state_bot[simp, intro!]: "rel_post_state R  "
  by (auto simp: bot_post_state_def Lifting_Set.empty_transfer)

lemma rel_post_state_eq_sim_post_state:
  "rel_post_state R p q  sim_post_state R p q  sim_post_state R¯¯ q p"
  by (auto simp: rel_set_def sim_set_def elim!: sim_post_state.cases rel_post_state.cases)

lemma rel_post_state_weaken:
  "rel_post_state R f g  (x y. R x y  Q x y)  rel_post_state Q f g"
  by (auto intro: sim_post_state_weaken simp: rel_post_state_eq_sim_post_state)

lemma rel_post_state_eq[relator_eq]: "rel_post_state (=) = (=)"
  by (simp add: rel_post_state_eq_sim_post_state fun_eq_iff sim_post_state_eq_iff_le order_eq_iff)

lemma rel_post_state_mono[relator_mono]:
  "A  B  rel_post_state A  rel_post_state B"
  using rel_set_mono[of A B]
  by (auto simp add: le_fun_def intro!: rel_post_state.intros elim!: rel_post_state.cases)

lemma rel_post_state_refl': "holds_partial_post_state (λx. R x x) f  rel_post_state R f f"
  by (cases f; auto simp: rel_set_def)

lemma rel_post_state_refl: "(x. R x x)  rel_post_state R f f"
  by (simp add: rel_post_state_refl')

subsection lift_post_state›

lemma lift_post_state_top: "lift_post_state R  = "
  by (auto simp: top_post_state_def)

lemma lift_post_state_unlift_post_state: "lift_post_state R p  q  p  unlift_post_state R q"
  by (cases p; cases q; auto simp add: subset_eq)

lemma lift_post_state_eq_Sup: "lift_post_state R p = Sup {q. sim_post_state R q p}"
  unfolding post_state_eq_iff holds_Sup_post_state
  using holds_lift_post_state sim_post_state_iff
  by blast

lemma le_lift_post_state_iff: "q  lift_post_state R p  sim_post_state R q p"
  by (simp add: post_state_le_iff sim_post_state_iff)

lemma lift_post_state_eq[simp]: "lift_post_state (=) p = p"
  by (simp add: post_state_eq_iff)

lemma lift_post_state_comp:
  "lift_post_state R (lift_post_state Q p) = lift_post_state (R OO Q) p"
  by (cases p) auto

lemma sim_post_state_lift:
  "sim_post_state Q q (lift_post_state R p)  sim_post_state (Q OO R) q p"
  unfolding le_lift_post_state_iff[symmetric] lift_post_state_comp ..

lemma lift_post_state_Sup: "lift_post_state R (Sup F) = (SUP fF. lift_post_state R f)"
  by (simp add: post_state_eq_iff)

lemma lift_post_state_mono: "p  q  lift_post_state R p  lift_post_state R q"
  by (simp add: post_state_le_iff)

subsection map_post_state›

lemma map_post_state_top: "map_post_state f  = "
  by (auto simp: top_post_state_def)

lemma mono_map_post_state: "s1  s2  map_post_state f s1   map_post_state f s2"
  by (simp add: post_state_le_iff)

lemma map_post_state_eq_lift_post_state: "map_post_state f p = lift_post_state (λa b. a = f b) p"
  by (cases p) auto

lemma map_post_state_Sup: "map_post_state f (Sup X) = (SUP xX. map_post_state f x)"
  by (simp add: post_state_eq_iff)

lemma map_post_state_comp: "map_post_state f (map_post_state g p) = map_post_state (f  g) p"
  by (simp add: post_state_eq_iff)

lemma map_post_state_id[simp]: "map_post_state id p = p"
  by (simp add: post_state_eq_iff)

lemma map_post_state_pure[simp]: "map_post_state f (pure_post_state x) = pure_post_state (f x)"
  by (simp add: pure_post_state_def)

lemma sim_post_state_map_post_state1:
  "sim_post_state R (map_post_state f p) q  sim_post_state (λx y. R (f x) y) p q"
  by (cases p; cases q; simp)

lemma sim_post_state_map_post_state2:
  "sim_post_state R p (map_post_state f q)  sim_post_state (λx y. R x (f y)) p q"
  unfolding map_post_state_eq_lift_post_state sim_post_state_lift by (simp add: OO_def[abs_def])

lemma map_post_state_eq_top[simp]: "map_post_state f p =   p = "
  by (cases p; simp add: top_post_state_def)

lemma map_post_state_eq_bot[simp]: "map_post_state f p =   p = "
  by (cases p; simp add: bot_post_state_def)

subsection vmap_post_state›

lemma vmap_post_state_top: "vmap_post_state f  = "
  by (auto simp: top_post_state_def)

lemma vmap_post_state_Sup:
  "vmap_post_state f (Sup X) = (SUP xX. vmap_post_state f x)"
  by (simp add: post_state_eq_iff)

lemma vmap_post_state_le_iff: "(vmap_post_state f p  q) = (p   {p. vmap_post_state f p  q})"
  using vmap_post_state_Sup by (rule gc_of_Sup_cont)

lemma vmap_post_state_eq_lift_post_state: "vmap_post_state f p = lift_post_state (λa b. f a = b) p"
  by (cases p) auto

lemma vmap_post_state_comp: "vmap_post_state f (vmap_post_state g p) = vmap_post_state (g  f) p"
  apply (simp add: post_state_eq_iff)
  apply (intro allI arg_cong[where f="λP. holds_post_state P p"])
  apply auto
  done

lemma vmap_post_state_id: "vmap_post_state id p = p"
  by (simp add: post_state_eq_iff)

lemma sim_post_state_vmap_post_state2:
  "sim_post_state R p (vmap_post_state f q) 
    sim_post_state (λx y. y'. f y' = y  R x y') p q"
  by (cases p; cases q; auto) blast+

lemma vmap_post_state_le_iff_le_map_post_state:
  "map_post_state f p  q  p  vmap_post_state f q"
  by (simp flip: sim_post_state_eq_iff_le
           add: sim_post_state_map_post_state1 sim_post_state_vmap_post_state2)

subsection pure_post_state›

lemma pure_post_state_Failure[simp]: "pure_post_state v  Failure"
  by (simp add: pure_post_state_def)

lemma pure_post_state_top[simp]: "pure_post_state v  "
  by (simp add: pure_post_state_def top_post_state_def)

lemma pure_post_state_bot[simp]: "pure_post_state v  "
  by (simp add: pure_post_state_def bot_post_state_def)

lemma pure_post_state_inj[simp]: "pure_post_state v = pure_post_state w  v = w"
  by (simp add: pure_post_state_def)

lemma sim_pure_post_state_iff[simp]:
  "sim_post_state R (pure_post_state a) (pure_post_state b)  R a b"
  by (simp add: pure_post_state_def)

lemma rel_pure_post_state_iff[simp]:
  "rel_post_state R (pure_post_state a) (pure_post_state b)  R a b"
  by (simp add: rel_post_state_eq_sim_post_state)

lemma pure_post_state_le[simp]: "pure_post_state v  pure_post_state w  v = w"
  by (simp add: pure_post_state_def)

lemma Success_eq_pure_post_state[simp]: "Success X = pure_post_state x  X = {x}"
  by (auto simp add: pure_post_state_def)

lemma pure_post_state_eq_Success[simp]: "pure_post_state x = Success X  X = {x}"
  by (auto simp add: pure_post_state_def)

lemma pure_post_state_le_Success[simp]: "pure_post_state x  Success X  x  X"
  by (auto simp add: pure_post_state_def)

lemma Success_le_pure_post_state[simp]: "Success X  pure_post_state x  X  {x}"
  by (auto simp add: pure_post_state_def)

subsection bind_post_state›

lemma bind_post_state_top: "bind_post_state  g = "
  by (auto simp: top_post_state_def)

lemma bind_post_state_Sup1[simp]:
  "bind_post_state (Sup F) g = (SUP fF. bind_post_state f g)"
  by (simp add: post_state_eq_iff)

lemma bind_post_state_Sup2[simp]:
  "G  {}  f  Failure  bind_post_state f (Sup G) = (SUP gG. bind_post_state f g)"
  by (simp add: post_state_eq_iff holds_post_state_Ball)

lemma bind_post_state_sup1[simp]:
  "bind_post_state (sup f1 f2) g = sup (bind_post_state f1 g) (bind_post_state f2 g)"
  using bind_post_state_Sup1[of "{f1, f2}" g] by simp

lemma bind_post_state_sup2[simp]:
  "bind_post_state f (sup g1 g2) = sup (bind_post_state f g1) (bind_post_state f g2)"
  using bind_post_state_Sup2[of "{g1, g2}" f] by simp

lemma bind_post_state_top1[simp]: "bind_post_state  f = "
  by (simp add: post_state_eq_iff)

lemma bind_post_state_bot1[simp]: "bind_post_state  f = "
  by (simp add: post_state_eq_iff)

lemma bind_post_state_eq_top:
  "bind_post_state f g =   ¬ holds_post_state (λx. g x  ) f"
  by (cases f) (force simp add: Sup_post_state_def simp flip: top_post_state_def)+

lemma bind_post_state_eq_bot:
  "bind_post_state f g =   holds_post_state (λx. g x = ) f"
  by (cases f) (auto simp flip: top_post_state_def)

lemma lift_post_state_bind_post_state:
  "lift_post_state R (bind_post_state x g) = bind_post_state x (λv. lift_post_state R (g v))"
  by (simp add: post_state_eq_iff)

lemma vmap_post_state_bind_post_state:
  "vmap_post_state f (bind_post_state p g) = bind_post_state p (λv. vmap_post_state f (g v))"
  by (simp add: post_state_eq_iff)

lemma map_post_state_bind_post_state:
  "map_post_state f (bind_post_state x g) = bind_post_state x (λv. map_post_state f (g v))"
  by (simp add: post_state_eq_iff)

lemma bind_post_state_pure_post_state1[simp]:
  "bind_post_state (pure_post_state v) f = f v"
  by (simp add: post_state_eq_iff)

lemma bind_post_state_pure_post_state2:
  "bind_post_state p (λx. pure_post_state (f x)) = map_post_state f p"
  by (simp add: post_state_eq_iff)

lemma bind_post_state_map_post_state:
  "bind_post_state (map_post_state f p) g = bind_post_state p (λx. g (f x))"
  by (simp add: post_state_eq_iff)

lemma bind_post_state_assoc[simp]:
  "bind_post_state (bind_post_state f g) h =
    bind_post_state f (λv. bind_post_state (g v) h)"
  by (simp add: post_state_eq_iff)

lemma sim_bind_post_state':
  "sim_post_state (λx y. sim_post_state R (g x) (k y)) f h 
    sim_post_state R (bind_post_state f g) (bind_post_state h k)"
  by (cases f; cases h; simp add: sim_post_state_Sup)

lemma sim_bind_post_state_left_iff:
  "sim_post_state R (bind_post_state f g) h 
    h = Failure  holds_post_state (λx. sim_post_state R (g x) h) f"
  by (cases f; cases h) (simp_all add: sim_post_state_Sup1)

lemma sim_bind_post_state_left:
  "holds_post_state (λx. sim_post_state R (g x) h) f 
    sim_post_state R (bind_post_state f g) h"
  by (simp add: sim_bind_post_state_left_iff)

lemma sim_bind_post_state_right:
  "g    holds_partial_post_state (λx. sim_post_state R f (h x)) g 
    sim_post_state R f (bind_post_state g h) "
  by (cases g) (auto intro: sim_post_state_Sup2)

lemma sim_bind_post_state:
  "sim_post_state Q f h  (x y. Q x y  sim_post_state R (g x) (k y)) 
    sim_post_state R (bind_post_state f g) (bind_post_state h k)"
  by (rule sim_bind_post_state'[OF sim_post_state_weaken, of Q])

lemma rel_bind_post_state':
  "rel_post_state (λa b. rel_post_state R (g1 a) (g2 b)) f1 f2 
    rel_post_state R (bind_post_state f1 g1) (bind_post_state f2 g2)"
  by (auto simp: rel_post_state_eq_sim_post_state
           intro: sim_bind_post_state' sim_post_state_weaken)

lemma rel_bind_post_state:
  "rel_post_state Q f1 f2  (a b. Q a b  rel_post_state R (g1 a) (g2 b)) 
    rel_post_state R (bind_post_state f1 g1) (bind_post_state f2 g2)"
  by (rule rel_bind_post_state'[OF rel_post_state_weaken, of Q])

lemma mono_bind_post_state: "f1  f2  g1  g2  bind_post_state f1 g1  bind_post_state f2 g2"
  by (simp flip: sim_post_state_eq_iff_le add: le_fun_def sim_bind_post_state)


lemma Sup_eq_Failure[simp]: "Sup X = Failure  Failure  X"
  by (simp add: Sup_post_state_def)

lemma Failure_inf_iff: "(Failure = x  y)  (x = Failure  y = Failure)"
  by (metis Failure_le_iff le_inf_iff)

lemma Success_vimage_singleton_cancel: "Success -` {Success X} = {X}"
  by auto

lemma Success_vimage_image_cancel: "Success -` (λx. Success (f x)) ` X = f ` X"
  by auto

lemma Success_image_comp: "(λx. Success (g x)) ` X = Success ` (g ` X)"
  by auto

section exception_or_result› type›

instantiation option :: (type) default
begin

definition "default_option = None"

instance ..
end

lemma Some_ne_default[simp]:
  "Some x  default" "default  Some x"
  "default = None  True" "None = default  True"
  by (auto simp: default_option_def)

typedef (overloaded) ('a::default, 'b) exception_or_result =
  "Inl ` (UNIV - {default})  Inr ` UNIV :: ('a + 'b) set"
  by auto

setup_lifting type_definition_exception_or_result

context assumes "SORT_CONSTRAINT('e::default)"
begin

lift_definition Exception :: "'e  ('e, 'v) exception_or_result" is
  "λe. if e = default then Inr undefined else Inl e"
  by auto

lift_definition Result :: "'v  ('e, 'v) exception_or_result" is
  "λv. Inr v"
  by auto

end

lift_definition
  case_exception_or_result :: "('e::default  'a)  ('v  'a)  ('e, 'v) exception_or_result  'a"
  is case_sum .

declare [[case_translation case_exception_or_result Exception Result]]

lemma Result_eq_Result[simp]: "Result a = Result b  a = b"
  by transfer simp

lemma Exception_eq_Exception[simp]: "Exception a = Exception b  a = b"
  by transfer simp

lemma Result_eq_Exception[simp]: "Result a = Exception e  (e = default  a = undefined)"
  by transfer simp

lemma Exception_eq_Result[simp]: "Exception e = Result a  (e = default  a = undefined)"
  by (metis Result_eq_Exception)

lemma exception_or_result_cases[case_names Exception Result, cases type: exception_or_result]:
  "(e. e  default  x = Exception e  P)  (v. x = Result v  P)  P"
  by (transfer fixing: P) auto

lemma case_exception_or_result_Result[simp]:
  "(case Result v of Exception e  f e | Result w  g w) = g v"
  by transfer simp

lemma case_exception_or_result_Exception[simp]:
  "(case Exception e of Exception e  f e | Result w  g w) = (if e = default then g undefined else f e)"
  by transfer simp


text ‹
Caution: for split rules don't use syntax
term(case r of Exception e  f e | Result w  g w), as this introduces
non eta-contracted termλe. f e and termλw. g w, which don't work with the
splitter.
›
lemma exception_or_result_split:
  "P (case_exception_or_result f g r)  
    (e. e  default  r = Exception e  P (f e)) 
    (v. r = Result v  P (g v))"
  by (cases r) simp_all

lemma exception_or_result_split_asm:
  "P (case_exception_or_result f g r) 
  ¬ ((e. r = Exception e  e  default  ¬ P (f e)) 
    (v. r = Result v  ¬ P (g v)))"
  by (cases r) simp_all

lemmas exception_or_result_splits = exception_or_result_split exception_or_result_split_asm

lemma split_exception_or_result: 
  "r = (case r of Exception e  Exception e | Result v  Result v)"
  by (cases r) auto

lemma exception_or_result_nchotomy:
  "¬ ((e. e  default  x  Exception e)  (v. x  Result v))" by (cases x) auto

lemma val_split:
  fixes r::"(unit, 'a) exception_or_result"
  shows
  "P (case_exception_or_result f g r)  
    (v. r = Result v  P (g v))"
  by (cases r) simp_all

lemma val_split_asm:
  fixes r::"(unit, 'a) exception_or_result"
  shows "P (case_exception_or_result f g r) 
    ¬ (v. r = Result v  ¬ P (g v))"
  by (cases r) simp_all

lemmas val_splits[split] = val_split val_split_asm

instantiation exception_or_result :: ("{equal, default}", equal) equal begin

definition "equal_exception_or_result a b =
  (case a of
    Exception e  (case b of Exception e'  HOL.equal e e' | Result s  False)
  | Result r     (case b of Exception e  False   | Result s  HOL.equal r s)
)"
instance proof qed (simp add: equal_exception_or_result_def equal_eq
    split: exception_or_result_splits)

end

definition is_Exception:: "('e::default, 'a) exception_or_result  bool" where
  "is_Exception x  (case x of Exception e  e  default | Result _  False)"

definition is_Result:: "('e::default, 'a) exception_or_result  bool" where
  "is_Result x  (case x of Exception e  e = default | Result _  True)"

definition the_Exception::  "('e::default, 'a) exception_or_result  'e" where
  "the_Exception x  (case x of Exception e  e | Result _  default)"

definition the_Result::  "('e::default, 'a) exception_or_result  'a" where
  "the_Result x  (case x of Exception e  undefined | Result v  v)"

lemma is_Exception_simps[simp]:
  "is_Exception (Exception e) = (e  default)"
  "is_Exception (Result v) = False"
  by (auto simp add: is_Exception_def)

lemma is_Result_simps[simp]:
  "is_Result (Exception e) = (e = default)"
  "is_Result (Result v) = True"
  by (auto simp add: is_Result_def)

lemma the_Exception_simp[simp]:
  "the_Exception (Exception e) = e"
  by (auto simp add: the_Exception_def)

lemma the_Exception_Result:
  "the_Exception (Result v) = default"
  by (simp add: the_Exception_def)

definition undefined_unit::"unit  'b" where "undefined_unit x  undefined"

syntax "_Res" :: "pttrn  pttrn" ((‹open_block notation=‹prefix Res››Res _))
syntax_consts "_Res"  case_exception_or_result
translations "λRes x. b"  "CONST case_exception_or_result (CONST undefined_unit) (λx. b)"

term "λRes x. f x"
term "λRes (x, y, z). f x y z"
term "λRes (x, y, z) s. P x y s z"


lifting_update exception_or_result.lifting
lifting_forget exception_or_result.lifting

inductive rel_exception_or_result:: "('e::default  'f::default  bool)  ('a  'b  bool) 
  ('e, 'a) exception_or_result  ('f, 'b) exception_or_result  bool"
  where
Exception:
  "E e f  e  default  f  default 
  rel_exception_or_result E R (Exception e) (Exception f)" |
Result:
  "R a b  rel_exception_or_result E R (Result a) (Result b)"

lemma All_exception_or_result_cases:
  "(x. P (x::(_, _) exception_or_result)) 
    (err. err  default  P (Exception err))   (v. P (Result v))"
  by (metis exception_or_result_cases)

lemma Ball_exception_or_result_cases:
  "(xs. P (x::(_, _) exception_or_result)) 
    (err. err  default  Exception err  s  P (Exception err)) 
    (v. Result v  s  P (Result v))"
  by (metis exception_or_result_cases)

lemma Bex_exception_or_result_cases:
  "(xs. P (x::(_, _) exception_or_result)) 
    (err. err  default  Exception err  s  P (Exception err)) 
    (v. Result v  s  P (Result v))"
  by (metis exception_or_result_cases)

lemma Ex_exception_or_result_cases:
  "(x. P (x::(_, _) exception_or_result)) 
    (err. err  default  P (Exception err)) 
    (v. P (Result v))"
  by (metis exception_or_result_cases)

lemma case_distrib_exception_or_result:
  "f (case x of Exception e  E e | Result v  R v) = (case x of Exception e  f (E e) |  Result v  f (R v))"
  by (auto split: exception_or_result_split)

lemma rel_exception_or_result_Results[simp]:
  "rel_exception_or_result E R (Result a) (Result b)  R a b"
  by (simp add: rel_exception_or_result.simps)

lemma rel_exception_or_result_Exception[simp]:
  "e  default  f  default 
    rel_exception_or_result E R (Exception e) (Exception f)  E e f"
  by (simp add: rel_exception_or_result.simps)

lemma rel_exception_or_result_Result_Exception[simp]:
  "e  default  ¬ rel_exception_or_result E R (Exception e) (Result b)"
  "f  default  ¬ rel_exception_or_result E R (Result a) (Exception f)"
  by (auto simp add: rel_exception_or_result.simps)

lemma is_Exception_iff: "is_Exception x  (e. e  default  x = Exception e)"
  apply (cases x)
   apply (auto)
  done

lemma rel_exception_or_result_converse:
  "(rel_exception_or_result E R)¯¯ = rel_exception_or_result E¯¯ R¯¯"
  apply (rule ext)+
  apply (auto simp add: rel_exception_or_result.simps)
  done

lemma rel_eception_or_result_Result[simp]:
  "rel_exception_or_result E R (Result x) (Result y) = R x y"
  by (auto simp add: rel_exception_or_result.simps)

lemma rel_exception_or_result_eq_conv: "rel_exception_or_result (=) (=) = (=)"
  apply (rule ext)+
  by (auto simp add: rel_exception_or_result.simps)
    (meson exception_or_result_cases)

lemma rel_exception_or_result_sum_eq: "rel_exception_or_result (=) (=) = (=)"
  apply (rule ext)+
  apply (subst (1) split_exception_or_result)
  apply (auto simp add: rel_exception_or_result.simps split:  exception_or_result_splits)
  done

definition
  map_exception_or_result::
    "('e::default  'f::default)  ('a  'b)  ('e, 'a) exception_or_result 
      ('f, 'b) exception_or_result"
where
  "map_exception_or_result E F x =
    (case x of Exception e  Exception (E e) | Result v  Result (F v))"

lemma map_exception_or_result_Exception[simp]:
  "e  default  map_exception_or_result E F (Exception e) = Exception (E e)"
  by (simp add: map_exception_or_result_def)

lemma map_exception_or_result_Result[simp]:
  "map_exception_or_result E F (Result v) = Result (F v)"
  by (simp add: map_exception_or_result_def)

lemma map_exception_or_result_id: "map_exception_or_result id id x = x"
  by (cases x; simp)

lemma map_exception_or_result_comp:
  assumes E2: "x. x  default  E2 x  default"
  shows "map_exception_or_result E1 F1 (map_exception_or_result E2 F2 x) =
    map_exception_or_result (E1  E2) (F1  F2) x"
  by (cases x; auto dest: E2)

lemma le_bind_post_state_exception_or_result_cases[case_names Exception Result]:
  assumes
    "holds_partial_post_state (λ(x, t). e. e  default  x = Exception e  X e t  X' e t) x"
  assumes "holds_partial_post_state (λ(x, t). v. x = Result v  V v t  V' v t) x"
  shows "bind_post_state x (λ(r, t). case r of Exception e  X e t | Result v  V v t) 
    bind_post_state x (λ(r, t). case r of Exception e  X' e t| Result v  V' v t)"
  using assms
  by (cases x; force intro!: SUP_mono'' split: exception_or_result_splits prod.splits)

section spec_monad› type›

typedef (overloaded) ('e::default, 'a, 's) spec_monad =
  "UNIV :: ('s  (('e::default, 'a) exception_or_result × 's) post_state) set"
  morphisms run Spec
  by (fact UNIV_witness)

lemma run_case_prod_distrib[simp]:
  "run (case x of (r, s)  f r s) t = (case x of (r, s)  run (f r s) t)"
  by (rule prod.case_distrib)

lemma image_case_prod_distrib[simp]:
  "(λ(r, t). f r t) ` (λv. (g v, h v)) ` R = (λv. (f (g v) (h v))) ` R "
  by auto

lemma run_Spec[simp]: "run (Spec p) s = p s"
  by (simp add: Spec_inverse)

setup_lifting type_definition_spec_monad

lemma spec_monad_ext: "(s. run f s = run g s)  f = g"
  apply transfer
  apply auto
  done

lemma spec_monad_ext_iff: "f = g  (s. run f s = run g s)"
  using spec_monad_ext by auto

instantiation spec_monad :: (default, type, type) complete_lattice
begin

lift_definition
  less_eq_spec_monad :: "('e::default, 'r, 's) spec_monad  ('e, 'r, 's) spec_monad  bool"
  is "(≤)" .

lift_definition
  less_spec_monad :: "('e::default, 'r, 's) spec_monad  ('e, 'r, 's) spec_monad  bool"
  is "(<)" .

lift_definition bot_spec_monad :: "('e::default, 'r, 's) spec_monad" is "bot" .

lift_definition top_spec_monad :: "('e::default, 'r, 's) spec_monad" is "top" .

lift_definition Inf_spec_monad :: "('e::default, 'r, 's) spec_monad set  ('e, 'r, 's) spec_monad"
  is "Inf" .

lift_definition Sup_spec_monad :: "('e::default, 'r, 's) spec_monad set  ('e, 'r, 's) spec_monad"
  is "Sup" .

lift_definition
  sup_spec_monad ::
    "('e::default, 'r, 's) spec_monad  ('e, 'r, 's) spec_monad  ('e, 'r, 's) spec_monad"
  is "sup" .

lift_definition
  inf_spec_monad ::
    "('e::default, 'r, 's) spec_monad  ('e, 'r, 's) spec_monad  ('e, 'r, 's) spec_monad"
  is "inf" .

instance
  apply (standard; transfer)
  subgoal by (rule less_le_not_le)
  subgoal by (rule order_refl)
  subgoal by (rule order_trans)
  subgoal by (rule antisym)
  subgoal by (rule inf_le1)
  subgoal by (rule inf_le2)
  subgoal by (rule inf_greatest)
  subgoal by (rule sup_ge1)
  subgoal by (rule sup_ge2)
  subgoal by (rule sup_least)
  subgoal by (rule Inf_lower)
  subgoal by (rule Inf_greatest)
  subgoal by (rule Sup_upper)
  subgoal by (rule Sup_least)
  subgoal by (rule Inf_empty)
  subgoal by (rule Sup_empty)
  done
end

lift_definition fail :: "('e::default, 'a, 's) spec_monad"
  is "λs. Failure" .

lift_definition
  bind_exception_or_result :: 
    "('e::default, 'a, 's) spec_monad 
      (('e, 'a) exception_or_result  ('f::default, 'b, 's) spec_monad)  ('f, 'b, 's) spec_monad"
is
  "λf h s. bind_post_state (f s) (λ(v, t). h v t)" .

lift_definition bind_handle ::
    "('e::default, 'a, 's) spec_monad 
      ('a  ('f, 'b, 's) spec_monad)  ('e  ('f, 'b, 's) spec_monad) 
      ('f::default, 'b, 's) spec_monad"
  is "λf g h s. bind_post_state (f s) (λ(r, t). case r of Exception e  h e t | Result v  g v t)" .

lift_definition yield :: "('e, 'a) exception_or_result  ('e::default, 'a, 's) spec_monad"
  is "λr s. pure_post_state (r, s)" .

abbreviation "return r  yield (Result r)"

abbreviation "skip  return ()"

abbreviation "throw_exception_or_result e  yield (Exception e)"

lift_definition get_state :: "('e::default, 's, 's) spec_monad"
  is "λs. pure_post_state (Result s, s)" .

lift_definition set_state :: "'s  ('e::default, unit, 's) spec_monad"
  is "λt s. pure_post_state (Result (), t)" .

lift_definition map_value ::
    "(('e::default, 'a) exception_or_result  ('f::default, 'b) exception_or_result) 
      ('e, 'a, 's) spec_monad  ('f, 'b, 's) spec_monad"
  is "λf g s. map_post_state (apfst f) (g s)" .

lift_definition vmap_value ::
    "(('e::default, 'a) exception_or_result  ('f::default, 'b) exception_or_result) 
      ('f, 'b, 's) spec_monad  ('e, 'a, 's) spec_monad"
  is "λf g s. vmap_post_state (apfst f) (g s)" .

definition bind ::
    "('e::default, 'a, 's) spec_monad  ('a  ('e, 'b, 's) spec_monad)  ('e, 'b, 's) spec_monad"
where
  "bind f g = bind_handle f g throw_exception_or_result"

adhoc_overloading
  Monad_Syntax.bind bind

lift_definition lift_state ::
    "('s  't  bool)  ('e::default, 'a, 't) spec_monad  ('e, 'a, 's) spec_monad"
  is "λR p s. lift_post_state (rel_prod (=) R) (SUP t{t. R s t}. p t)" .

definition exec_concrete ::
  "('s  't)  ('e::default, 'a, 's) spec_monad  ('e, 'a, 't) spec_monad"
where
  "exec_concrete st = lift_state (λt s. t = st s)"

definition exec_abstract ::
  "('s  't)  ('e::default, 'a, 't) spec_monad  ('e, 'a, 's) spec_monad"
where
  "exec_abstract st = lift_state (λs t. t = st s)"

definition select_value :: "('e, 'a) exception_or_result set  ('e::default, 'a, 's) spec_monad" where
  "select_value R = (SUP rR. yield r)"

definition select :: "'a set  ('e::default, 'a, 's) spec_monad" where
  "select S = (SUP aS. return a)"

definition unknown :: "('e::default, 'a, 's) spec_monad" where
  "unknown  select UNIV"

definition gets :: "('s  'a)  ('e::default, 'a, 's) spec_monad" where
  "gets f = bind get_state (λs. return (f s))"

definition assert_opt :: "'a option  ('e::default, 'a, 's) spec_monad" where
  "assert_opt v = (case v of None  fail | Some v  return v)"

definition gets_the :: "('s  'a option)  ('e::default, 'a, 's) spec_monad" where
  "gets_the f = bind (gets f) assert_opt"

definition modify :: "('s  's)  ('e::default, unit, 's) spec_monad" where
  "modify f = bind get_state (λs. set_state (f s))"

definition "assume" :: "bool  ('e::default, unit, 's) spec_monad" where
  "assume P = (if P then return () else bot)"

definition assert :: "bool  ('e::default, unit, 's) spec_monad" where
  "assert P = (if P then return () else top)"

definition assuming :: "('s  bool)  ('e::default, unit, 's) spec_monad" where
  "assuming P = do {s   get_state; assume (P s)}"

definition guard:: "('s  bool)   ('e::default, unit, 's) spec_monad" where
  "guard P = do {s  get_state; assert (P s)}"

definition assume_result_and_state :: "('s  ('a × 's) set)  ('e::default, 'a, 's) spec_monad" where
  "assume_result_and_state f = do {s  get_state; (v, t)  select (f s); set_state t; return v}"

(* TODO: delete *)
lift_definition assume_outcome ::
    "('s  (('e, 'a) exception_or_result × 's) set)  ('e::default, 'a, 's) spec_monad"
  is "λf s. Success (f s)" .

definition assert_result_and_state ::
    "('s  ('a × 's) set)  ('e::default, 'a, 's) spec_monad"
where
  "assert_result_and_state f =
    do {s  get_state; assert (f s  {}); (v, t)  select (f s); set_state t; return v}"

abbreviation state_select :: "('s × 's) set  ('e::default, unit, 's) spec_monad" where
  "state_select r  assert_result_and_state (λs. {((), s'). (s, s')  r})"

definition condition ::
    "('s  bool)  ('e::default, 'a, 's) spec_monad  ('e, 'a, 's) spec_monad 
      ('e, 'a, 's) spec_monad"
where
  "condition C T F =
    bind get_state (λs. if C s then T else F)"

notation (output)
  condition  ((‹notation=‹prefix condition››condition (_)//  (_)//  (_)) [1000,1000,1000] 999)

definition "when" ::"bool  ('e::default, unit, 's) spec_monad  ('e, unit, 's) spec_monad"
  where "when c f  condition (λ_. c) f skip"

abbreviation "unless" ::"bool  ('e::default, unit, 's) spec_monad  ('e, unit, 's) spec_monad"
  where "unless c  when (¬c)"

definition on_exit' ::
    "('e::default, 'a, 's) spec_monad  ('e, unit, 's) spec_monad  ('e, 'a, 's) spec_monad"
where
  "on_exit' f c 
     bind_exception_or_result f (λr. do { c; yield r })"

definition on_exit ::
    "('e::default, 'a, 's) spec_monad  ('s × 's) set  ('e, 'a, 's) spec_monad"
where
  "on_exit f cleanup  on_exit' f (state_select cleanup)"

abbreviation guard_on_exit ::
    "('e::default, 'a, 's) spec_monad  ('s  bool)  ('s × 's) set  ('e, 'a, 's) spec_monad"
where
  "guard_on_exit f grd cleanup  on_exit' f (bind (guard grd) (λ_. state_select cleanup))"

abbreviation assume_on_exit ::
    "('e::default, 'a, 's) spec_monad  ('s  bool)  ('s × 's) set  ('e, 'a, 's) spec_monad"
where
  "assume_on_exit f grd cleanup 
    on_exit' f (bind (assuming grd) (λ_. state_select cleanup))"

lift_definition run_bind ::
    "('e::default, 'a, 't) spec_monad  't 
     (('e, 'a) exception_or_result  't  ('f::default, 'b, 's) spec_monad) 
     ('f::default, 'b, 's) spec_monad"
  is "λf t g s. bind_post_state (f t) (λ(r, t). g r t s)" .
    ― ‹construn_bind might be a more canonical building block compared to
   constlift_state. See subsection on construn_bind below.›

type_synonym ('e, 'a, 's) predicate = "('e, 'a) exception_or_result  's  bool"

lift_definition runs_to ::
    "('e::default, 'a, 's) spec_monad  's  ('e, 'a, 's) predicate  bool"
    ((‹open_block notation=‹mixfix runs_to››_/  _  _ ) [61, 1000, 0] 30) ― ‹syntax _do_block› has 62›
  is "λf s Q. holds_post_state (λ(r, t). Q r t) (f s)" .

lift_definition runs_to_partial ::
    "('e::default, 'a, 's) spec_monad  's  ('e, 'a, 's) predicate  bool"
    ((‹open_block notation=‹mixfix runs_to_partial››_/  _ ?⦃ _ ) [61, 1000, 0] 30)
  is "λf s Q. holds_partial_post_state (λ(r, t). Q r t) (f s)" .

lift_definition refines ::
  "('e::default, 'a, 's) spec_monad  ('f::default, 'b, 't) spec_monad  's  't 
    ((('e, 'a) exception_or_result × 's)  (('f, 'b) exception_or_result × 't)  bool)  bool"
  is "λf g s t R. sim_post_state R (f s) (g t)" .

lift_definition rel_spec ::
  "('e::default, 'a, 's) spec_monad  ('f::default, 'b, 't) spec_monad  's  't 
     ((('e, 'a) exception_or_result × 's)  (('f, 'b) exception_or_result × 't)  bool)
    bool"
is
  "λf g s t R. rel_post_state R (f s) (g t)" .

definition rel_spec_monad ::
    "('s  't  bool)  (('e, 'a) exception_or_result  ('f, 'b) exception_or_result  bool) 
      ('e::default, 'a, 's) spec_monad  ('f::default, 'b, 't) spec_monad  bool"
where
  "rel_spec_monad R Q f g =
    (s t. R s t  rel_post_state (rel_prod Q R) (run f s) (run g t))"

lift_definition always_progress :: "('e::default, 'a, 's) spec_monad  bool" is
  "λp. s. p s  bot" .

named_theorems run_spec_monad "Simplification rules to run a Spec monad"
named_theorems runs_to_iff "Equivalence theorems for runs to predicate"
named_theorems always_progress_intros "intro rules for always_progress predicate"

lemma runs_to_partialI: "(r x. P r x)  f  s ?⦃ P "
  by (cases "run f s") (auto simp: runs_to_partial.rep_eq)

lemma runs_to_partial_True: "f  s ?⦃ λr s. True "
  by (simp add: runs_to_partialI)

lemma runs_to_partial_conj: 
  "f  s ?⦃ P   f  s ?⦃ Q   f  s ?⦃ λr s. P r s  Q r s "
  by (cases "run f s") (auto simp: runs_to_partial.rep_eq)

lemma refines_iff':
  "refines f g s t R  (P. g  t  λr s. p t. R (p, t) (r, s)  P p t   f  s  P )"
  apply transfer
  apply (simp add: sim_post_state_iff le_fun_def split_beta' all_comm[where 'a='a])
  apply (rule all_cong_map[where f'="λP (x, y). P x y" and f="λP x y. P (x, y)"])
  apply auto
  done

lemma refines_weaken:
  "refines f g s t R  (r s q t. R (r, s) (q, t)  Q (r, s) (q, t))  refines f g s t Q"
  by transfer (auto intro: sim_post_state_weaken)

lemma refines_mono:
  "(r s q t. R (r, s) (q, t)  Q (r, s) (q, t))  refines f g s t R  refines f g s t Q"
  by (rule refines_weaken)

lemma refines_refl: "(r t. R (r, t) (r, t))  refines f f s s R"
  by transfer (auto intro: sim_post_state_refl)

lemma refines_trans:
  "refines f g s t R  refines g h t u Q 
    (r s p t q u. R (r, s) (p, t)  Q (p, t) (q, u)  S (r, s) (q, u)) 
    refines f h s u S"
  by transfer (auto intro: sim_post_state_trans)

lemma refines_trans': "refines f g s t1 R  refines g h t1 t2 Q  refines f h s t2 (R OO Q)"
  by (rule refines_trans; assumption?) auto

lemma refines_strengthen:
  "refines f g s t R  f  s ?⦃ F   g  t ?⦃ G  
    (x s y t. R (x, s) (y, t)  F x s  G y t  Q (x, s) (y, t