Theory Spec_Monad

(*
 * Copyright (c) 2024 Apple Inc. All rights reserved.
 *
 * SPDX-License-Identifier: BSD-2-Clause
 *)

text_raw ‹\part{AutoCorres}›

chapter ‹Spec-Monad›

theory Spec_Monad
  imports
  "Basic_Runs_To_VCG"
  "HOL-Library.Complete_Partial_Order2"
  "HOL-Library.Monad_Syntax"
  "AutoCorres_Utils"
begin

section rel_map› and rel_project›

definition rel_map :: "('a  'b)  'a  'b  bool" where
  "rel_map f r x = (x = f r)"

lemma rel_map_direct[simp]: "rel_map f a (f a)"
  by (simp add: rel_map_def)

abbreviation "rel_project  rel_map"

lemmas rel_project_def = rel_map_def

lemma rel_project_id: "rel_project id = (=)" 
   "rel_project (λv. v) = (=)"
  by (auto simp add: rel_project_def fun_eq_iff)

lemma rel_project_unit: "rel_project (λ_. ()) x y = True"
  by (simp add: rel_project_def)

lemma rel_projectI: "y = prj x  rel_project prj x y"
  by (simp add: rel_project_def)

lemma rel_project_conv: "rel_project prj x y = (y = prj x)"
  by (simp add: rel_project_def)

section ‹Misc Theorems›

declare case_unit_Unity [simp] ― ‹without this rule simplifier seems loops in unexpected ways ›

lemma abs_const_unit: "(λ(v::unit). f) = (λ(). f)"
  by auto

lemma SUP_mono'': "(x. xA  f x  g x)  (xA. f x)  (xA. g x :: _::complete_lattice)"
  by (rule SUP_mono) auto

lemma wf_nat_bound: "wf {(a, b). b < a  b  (n::nat)}"
  apply (rule wf_subset)
  apply (rule wf_measure[where f="λa. Suc n - a"])
  apply auto
  done

lemma (in complete_lattice) admissible_le:
  "ccpo.admissible Inf (≤) (λx. (y  x))"
  by (simp add: ccpo.admissibleI local.Inf_greatest)

lemma mono_const: "mono (λx. c)"
  by (simp add: mono_def)

lemma mono_lam: "(a. mono (λx. F x a))  mono F"
  by (simp add: mono_def le_fun_def)

lemma mono_app: "mono (λx. x a)"
  by (simp add: mono_def le_fun_def)

lemma all_cong_map:
  assumes f_f': "y. f (f' y) = y" and P_Q: "x. P x  Q (f x)"
  shows "(x. P x)  (y. Q y)"
  using assms by metis

lemma rel_set_refl: "(x. x  A  R x x)  rel_set R A A"
  by (auto simp: rel_set_def)

lemma rel_set_converse_iff: "rel_set R X Y  rel_set R¯¯ Y X"
  by (auto simp add: rel_set_def)

lemma rel_set_weaken:
  "(x y. x  A  y  B  P x y  Q x y)  rel_set P A B  rel_set Q A B"
  by (force simp: rel_set_def)

lemma sim_set_refl: "(x. x  X  R x x)  sim_set R X X"
  by (auto simp: sim_set_def)

section ‹Galois Connections›

lemma mono_of_Sup_cont:
  fixes f :: "'a::complete_lattice  'b::complete_lattice"
  assumes cont: "X. f (Sup X) = (SUP xX. f x)"
  assumes xy: "x  y"
  shows "f x  f y"
proof -
  have "f x  sup (f x) (f y)" by (rule sup_ge1)
  also have " = (SUP x{x, y}. f x)" by simp
  also have " = f (Sup {x, y})" by (rule cont[symmetric])
  finally show ?thesis using xy by (simp add: sup_absorb2)
qed

lemma gc_of_Sup_cont:
  fixes f :: "'a::complete_lattice  'b::complete_lattice"
  assumes cont: "X. f (Sup X) = (SUP xX. f x)"
  shows "f x  y  x  Sup {x. f x  y}"
proof safe
  assume "f x  y" then show "x  Sup {x. f x  y}"
    by (intro Sup_upper) simp
next
  assume "x  Sup {x. f x  y}"
  then have "f x  f (Sup {x. f x  y})"
    by (rule mono_of_Sup_cont[OF cont])
  also have " = (SUP x{x. f x  y}. f x)" by (rule cont)
  also have "  y" by (rule Sup_least) auto
  finally show "f x  y" .
qed

lemma mono_of_Inf_cont:
  fixes f :: "'a::complete_lattice  'b::complete_lattice"
  assumes cont: "X. f (Inf X) = (INF xX. f x)"
  assumes xy: "x  y"
  shows "f x  f y"
proof -
  have "f x = f (Inf {x, y})" using xy by (simp add: inf_absorb1)
  also have " = (INF x{x, y}. f x)" by (rule cont)
  also have " = inf (f x) (f y)" by simp
  also have "  f y" by (rule inf_le2)
  finally show ?thesis .
qed

lemma gc_of_Inf_cont:
  fixes f :: "'a::complete_lattice  'b::complete_lattice"
  assumes cont: "X. f (Inf X) = (INF xX. f x)"
  shows "Inf {y. x  f y}  y  x  f y"
proof safe
  assume "x  f y" then show "Inf {y. x  f y}  y"
    by (intro Inf_lower) simp
next
  assume *: "Inf {y. x  f y}  y"
  have "x  (INF y{y. x  f y}. f y)" by (rule Inf_greatest) auto
  also have " = f (Inf {y. x  f y})" by (rule cont[symmetric])
  also have "  f y"
    using * by (rule mono_of_Inf_cont[OF cont])
  finally show "x  f y" .
qed

lemma gfp_fusion:
  assumes f_g: "x y. g x  y  x  f y"
  assumes a: "mono a"
  assumes b: "mono b"
  assumes *: "x. f (a x) = b (f x)"
  shows "f (gfp a) = gfp b"
  apply (intro antisym gfp_upperbound)
  subgoal
    apply (subst *[symmetric])
    apply (subst gfp_fixpoint[OF a])
    ..
  apply (rule f_g[THEN iffD1])
  apply (rule gfp_upperbound)
  apply (rule f_g[THEN iffD2])
  apply (subst *)
  apply (subst gfp_fixpoint[OF b, symmetric])
  apply (rule monoD[OF b])
  apply (rule f_g[THEN iffD1])
  apply (rule order_refl)
  done

lemma lfp_fusion:
  assumes f_g: "x y. g x  y  x  f y"
  assumes a: "mono a"
  assumes b: "mono b"
  assumes *: "x. g (a x) = b (g x)"
  shows "g (lfp a) = lfp b"
  apply (intro antisym lfp_lowerbound)
  subgoal
    apply (rule f_g[THEN iffD2])
    apply (rule lfp_lowerbound)
    apply (rule f_g[THEN iffD1])
    apply (subst *)
    apply (subst (2) lfp_fixpoint[OF b, symmetric])
    apply (rule monoD[OF b])
    apply (rule f_g[THEN iffD2])
    apply (rule order_refl)
    done
  subgoal
    apply (subst *[symmetric])
    apply (subst lfp_fixpoint[OF a])
    ..
  done

section post_state› type›

datatype 'r post_state = Failure | Success "'r set"

text constFailure is supposed to model things like undefined behaviour in C. We usually have
to show the absence of constFailure for all possible executions of the program. Moreover,
it is used to model the 'result' of a non terminating computation.
›

lemma split_post_state:
  "x = (case x of Success X  Success X | Failure  Failure)"
  for x::"'a post_state"
  by (cases x) auto

instantiation post_state :: (type) order
begin

inductive less_eq_post_state :: "'a post_state  'a post_state  bool" where
  Failure_le[simp, intro]: "less_eq_post_state p Failure"
| Success_le_Success[intro]: "r  q  less_eq_post_state (Success r) (Success q)"

definition less_post_state :: "'a post_state  'a post_state  bool" where
  "less_post_state p q  p  q  ¬ q  p"

instance
proof
  fix p q r :: "'a post_state"
  show "p  p" by (cases p) auto
  show "p  q  q  r  p  r" by (blast elim: less_eq_post_state.cases)
  show "p  q  q  p  p = q" by (blast elim: less_eq_post_state.cases)
qed (fact less_post_state_def)
end

lemma Success_le_Success_iff[simp]: "Success r  Success q  r  q"
  by (auto elim: less_eq_post_state.cases)

lemma Failure_le_iff[simp]: "Failure  q  q = Failure"
  by (auto elim: less_eq_post_state.cases)

instantiation post_state :: (type) complete_lattice
begin

definition top_post_state :: "'a post_state" where
  "top_post_state = Failure"

definition bot_post_state :: "'a post_state" where
  "bot_post_state = Success {}"

definition inf_post_state :: "'a post_state  'a post_state  'a post_state" where
  "inf_post_state =
    (λFailure  id | Success res1 
      (λFailure  Success res1 | Success res2  Success (res1  res2)))"

definition sup_post_state :: "'a post_state  'a post_state  'a post_state" where
  "sup_post_state =
    (λFailure  (λ_. Failure) | Success res1 
      (λFailure  Failure | Success res2  Success (res1  res2)))"

definition Inf_post_state :: "'a post_state set  'a post_state" where
  "Inf_post_state s = (if Success -` s = {} then Failure else Success ( (Success -` s)))"

definition Sup_post_state :: "'a post_state set  'a post_state" where
  "Sup_post_state s = (if Failure  s then Failure else Success ( (Success -` s)))"

instance
proof
  fix x y z :: "'a post_state" and A :: "'a post_state set"
  show "inf x y  y" by (simp add: inf_post_state_def split: post_state.split)
  show "inf x y  x" by (simp add: inf_post_state_def split: post_state.split)
  show "x  y  x  z  x  inf y z"
    by (auto simp add: inf_post_state_def elim!: less_eq_post_state.cases)
  show "x  sup x y" by (simp add: sup_post_state_def split: post_state.split)
  show "y  sup x y" by (simp add: sup_post_state_def split: post_state.split)
  show "x  y  z  y  sup x z  y"
    by (auto simp add: sup_post_state_def elim!: less_eq_post_state.cases)
  show "x  A  Inf A  x" by (cases x) (auto simp add: Inf_post_state_def)
  show "(x. x  A  z  x)  z  Inf A" by (cases z) (force simp add: Inf_post_state_def)+
  show "x  A  x  Sup A" by (cases x) (auto simp add: Sup_post_state_def)
  show "(x. x  A  x  z)  Sup A  z" by (cases z) (force simp add: Sup_post_state_def)+
qed (simp_all add: top_post_state_def Inf_post_state_def bot_post_state_def Sup_post_state_def)
end

primrec holds_post_state :: "('a  bool)  'a post_state  bool" where
  "holds_post_state P Failure  False"
| "holds_post_state P (Success X)  (xX. P x)"

primrec holds_partial_post_state :: "('a  bool)  'a post_state  bool" where
  "holds_partial_post_state P Failure  True"
| "holds_partial_post_state P (Success X)  (xX. P x)"

inductive
  sim_post_state :: "('a  'b  bool)  'a post_state  'b post_state  bool"
  for R
where
  [simp]: "p. sim_post_state R p Failure"
| "A B. sim_set R A B  sim_post_state R (Success A) (Success B)"

inductive
  rel_post_state :: "('a  'b  bool)  'a post_state  'b post_state  bool"
  for R
where
  [simp]: "rel_post_state R Failure Failure"
| "A B. rel_set R A B  rel_post_state R (Success A) (Success B)"

primrec lift_post_state :: "('a  'b  bool)  'b post_state  'a post_state" where
  "lift_post_state R Failure = Failure"
| "lift_post_state R (Success Y) = Success {x. yY. R x y}"

primrec unlift_post_state :: "('a  'b  bool)  'a post_state  'b post_state" where
  "unlift_post_state R Failure = Failure"
| "unlift_post_state R (Success X) = Success {y. x. R x y  x  X}"

primrec map_post_state :: "('a  'b)  'a post_state  'b post_state" where
  "map_post_state f Failure  = Failure"
| "map_post_state f (Success r) = Success (f ` r)"

primrec vmap_post_state :: "('a  'b)  'b post_state  'a post_state" where
  "vmap_post_state f Failure  = Failure"
| "vmap_post_state f (Success r) = Success (f -` r)"

definition pure_post_state :: "'a  'a post_state" where
  "pure_post_state v = Success {v}"

primrec bind_post_state :: "'a post_state  ('a  'b post_state)  'b post_state" where
  "bind_post_state Failure p = Failure"
| "bind_post_state (Success res) p =  (p ` res)"

subsection ‹Order Properties›

lemma top_ne_bot[simp]: "(:: _ post_state)  "
  and bot_ne_top[simp]: "(:: _ post_state)  "
  by (auto simp: top_post_state_def bot_post_state_def)

lemma Success_ne_top[simp]:
  "Success X  " "  Success X"
  by (auto simp: top_post_state_def)

lemma Success_eq_bot_iff[simp]:
  "Success X =   X = {}" " = Success X  X = {}"
  by (auto simp: bot_post_state_def)

lemma Sup_Success: "(xX. Success (f x)) = Success (xX. f x)"
  by (auto simp: Sup_post_state_def)

lemma Sup_Success_pair: "((x, y)X. Success (f x y)) = Success ((x, y)X. f x y)"
  by (simp add: split_beta' Sup_Success)

subsection holds_post_state›

lemma holds_post_state_iff:
  "holds_post_state P p  (X. p = Success X  (xX. P x))"
  by (cases p) auto

lemma holds_post_state_weaken:
  "holds_post_state P p  (x. P x  Q x)  holds_post_state Q p"
  by (cases p; auto)

lemma holds_post_state_combine:
  "holds_post_state P p  holds_post_state Q p 
    (x. P x  Q x  R x)  holds_post_state R p"
  by (cases p; auto)

lemma holds_post_state_Ball:
  "Y  {}  p  Failure 
    holds_post_state (λx. yY. P x y) p  (yY. holds_post_state (λx. P x y) p)"
  by (cases p; auto)

lemma holds_post_state_All:
  "holds_post_state (λx. y. P x y) p  (y. holds_post_state (λx. P x y) p)"
  by (cases p; auto)

lemma holds_post_state_BexI:
  "y  Y  holds_post_state (λx. P x y) p  holds_post_state (λx. yY. P x y) p"
  by (cases p; auto)

lemma holds_post_state_conj:
  "holds_post_state (λx. P x  Q x) p  holds_post_state P p  holds_post_state Q p"
  by (cases p; auto)

lemma sim_post_state_iff:
  "sim_post_state R p q 
    (P. holds_post_state (λy. x. R x y  P x) q  holds_post_state P p)"
  apply (cases p; cases q; simp add: sim_set_def sim_post_state.simps)
  apply safe
  apply force
  apply force
  subgoal premises prems for A B
    using prems(3)[THEN spec, of "λa. bB. R a b"] prems(4)
    by auto
  done

lemma post_state_le_iff:
  "p  q  (P. holds_post_state P q  holds_post_state P p)"
  by (cases p; cases q; force simp add: subset_eq Ball_def)

lemma post_state_eq_iff:
  "p = q  (P. holds_post_state P p  holds_post_state P q)"
  by (simp add: order_eq_iff post_state_le_iff)

lemma holds_top_post_state[simp]: "¬ holds_post_state P "
  by (simp add: top_post_state_def)

lemma holds_bot_post_state[simp]: "holds_post_state P "
  by (simp add: bot_post_state_def)



lemma holds_Sup_post_state[simp]: "holds_post_state P (Sup F)  (fF. holds_post_state P f)"
  by (subst (2) split_post_state)
    (auto simp: Sup_post_state_def split: post_state.splits)

lemma holds_post_state_gfp:
  "holds_post_state P (gfp f)  (p. p  f p  holds_post_state P p)"
  by (simp add: gfp_def)

lemma holds_post_state_gfp_apply:
  "holds_post_state P (gfp f x)  (p. p  f p  holds_post_state P (p x))"
  by (simp add: gfp_def)

lemma holds_lift_post_state[simp]:
  "holds_post_state P (lift_post_state R x)  holds_post_state (λy. x. R x y  P x) x"
  by (cases x) auto

lemma holds_map_post_state[simp]:
  "holds_post_state P (map_post_state f x)  holds_post_state (λx. P (f x)) x"
  by (cases x) auto

lemma holds_vmap_post_state[simp]:
  "holds_post_state P (vmap_post_state f x)  holds_post_state (λx. y. f y = x  P y) x"
  by (cases x) auto

lemma holds_pure_post_state[simp]: "holds_post_state P (pure_post_state x)  P x"
  by (simp add: pure_post_state_def)

lemma holds_bind_post_state[simp]:
  "holds_post_state P (bind_post_state f g)  holds_post_state (λx. holds_post_state P (g x)) f"
  by (cases f) auto

lemma holds_post_state_False: "holds_post_state (λx. False) f  f = "
  by (cases f) (auto simp: bot_post_state_def)

subsection holds_post_state_partial›

lemma holds_partial_post_state_of_holds:
  "holds_post_state P p  holds_partial_post_state P p"
  by (cases p; simp)

lemma holds_partial_post_state_iff:
  "holds_partial_post_state P p  (X. p = Success X  (xX. P x))"
  by (cases p) auto

lemma holds_partial_post_state_True[simp]: "holds_partial_post_state (λx. True) p"
  by (cases p; simp)

lemma holds_partial_post_state_weaken:
  "holds_partial_post_state P p  (x. P x  Q x)  holds_partial_post_state Q p"
  by (cases p; auto)

lemma holds_partial_post_state_Ball:
  "holds_partial_post_state (λx. yY. P x y) p  (yY. holds_partial_post_state (λx. P x y) p)"
  by (cases p; auto)

lemma holds_partial_post_state_All:
  "holds_partial_post_state (λx. y. P x y) p  (y. holds_partial_post_state (λx. P x y) p)"
  by (cases p; auto)

lemma holds_partial_post_state_conj:
  "holds_partial_post_state (λx. P x  Q x) p 
    holds_partial_post_state P p  holds_partial_post_state Q p"
  by (cases p; auto)

lemma holds_partial_top_post_state[simp]: "holds_partial_post_state P "
  by (simp add: top_post_state_def)

lemma holds_partial_bot_post_state[simp]: "holds_partial_post_state P "
  by (simp add: bot_post_state_def)

lemma holds_partial_pure_post_state[simp]: "holds_partial_post_state P (pure_post_state x)  P x"
  by (simp add: pure_post_state_def)

lemma holds_partial_Sup_post_stateI:
  "(x. x  X  holds_partial_post_state P x)  holds_partial_post_state P (Sup X)"
  by (force simp: Sup_post_state_def)

lemma holds_partial_bind_post_stateI:
  "holds_partial_post_state (λx. holds_partial_post_state P (g x)) f 
    holds_partial_post_state P (bind_post_state f g)"
  by (cases f) (auto intro: holds_partial_Sup_post_stateI)

lemma holds_partial_map_post_state[simp]:
  "holds_partial_post_state P (map_post_state f x)  holds_partial_post_state (λx. P (f x)) x"
  by (cases x) auto

lemma holds_partial_vmap_post_state[simp]:
  "holds_partial_post_state P (vmap_post_state f x) 
    holds_partial_post_state (λx. y. f y = x  P y) x"
  by (cases x) auto

lemma holds_partial_bind_post_state:
  "holds_partial_post_state (λx. holds_partial_post_state P (g x)) f 
    holds_partial_post_state P (bind_post_state f g)"
  by (cases f) (auto intro: holds_partial_Sup_post_stateI)

subsection sim_post_state›

lemma sim_post_state_eq_iff_le: "sim_post_state (=) p q  p  q"
  by (simp add: post_state_le_iff sim_post_state_iff)

lemma sim_post_state_Success_Success_iff[simp]:
  "sim_post_state R (Success r) (Success q)  (ar. bq. R a b)"
  by (auto simp add: sim_set_def elim: sim_post_state.cases intro: sim_post_state.intros)

lemma sim_post_state_Success2:
  "sim_post_state R f (Success q)  holds_post_state (λa. bq. R a b) f"
  by (cases f; simp add: sim_post_state.simps sim_set_def)

lemma sim_post_state_Failure1[simp]: "sim_post_state R Failure q  q = Failure"
  by (auto elim: sim_post_state.cases intro: sim_post_state.intros)

lemma sim_post_state_top2[simp, intro]: "sim_post_state R p "
  by (simp add: top_post_state_def)

lemma sim_post_state_top1[simp]: "sim_post_state R  q  q = "
  using sim_post_state_Failure1 by (metis top_post_state_def)

lemma sim_post_state_bot2[simp, intro]: "sim_post_state R p   p = "
  by (cases p; simp add: bot_post_state_def)

lemma sim_post_state_bot1[simp, intro]: "sim_post_state R  q"
  by (cases q; simp add: bot_post_state_def)

lemma sim_post_state_le1: "sim_post_state R f' g  f  f'  sim_post_state R f g"
  by (simp add: post_state_le_iff sim_post_state_iff)

lemma sim_post_state_le2: "sim_post_state R f g  g  g'  sim_post_state R f g'"
  by (simp add: post_state_le_iff sim_post_state_iff)

lemma sim_post_state_Sup1:
  "sim_post_state R (Sup A) f  (aA. sim_post_state R a f)"
  by (auto simp add: sim_post_state_iff)

lemma sim_post_state_Sup2:
  "a  A  sim_post_state R f a  sim_post_state R f (Sup A)"
  by (auto simp add: sim_post_state_iff)

lemma sim_post_state_Sup:
  "aA. bB. sim_post_state R a b  sim_post_state R (Sup A) (Sup B)"
  by (auto simp: sim_post_state_Sup1 intro: sim_post_state_Sup2)

lemma sim_post_state_weaken:
  "sim_post_state R f g  (x y. R x y  Q x y)  sim_post_state Q f g"
  by (cases f; cases g; force)

lemma sim_post_state_trans:
  "sim_post_state R f g  sim_post_state Q g h  (x y z. R x y  Q y z  S x z) 
    sim_post_state S f h"
  by (fastforce elim!: sim_post_state.cases intro: sim_post_state.intros simp: sim_set_def)

lemma sim_post_state_refl': "holds_partial_post_state (λx. R x x) f  sim_post_state R f f"
  by (cases f; auto simp: rel_set_def)

lemma sim_post_state_refl: "(x. R x x)  sim_post_state R f f"
  by (simp add: sim_post_state_refl')

lemma sim_post_state_pure_post_state2:
  "sim_post_state F f (pure_post_state x)  holds_post_state (λy. F y x) f"
  by (cases f; simp add: pure_post_state_def)

subsection rel_post_state›

lemma rel_post_state_top[simp, intro!]: "rel_post_state R  "
  by (auto simp: top_post_state_def)

lemma rel_post_state_top_iff[simp]:
  "rel_post_state R  p  p = "
  "rel_post_state R p   p = "
  by (auto simp: top_post_state_def rel_post_state.simps)

lemma rel_post_state_Success_iff[simp]:
  "rel_post_state R (Success A) (Success B)  rel_set R A B"
  by (auto elim: rel_post_state.cases intro: rel_post_state.intros)

lemma rel_post_state_bot[simp, intro!]: "rel_post_state R  "
  by (auto simp: bot_post_state_def Lifting_Set.empty_transfer)

lemma rel_post_state_eq_sim_post_state:
  "rel_post_state R p q  sim_post_state R p q  sim_post_state R¯¯ q p"
  by (auto simp: rel_set_def sim_set_def elim!: sim_post_state.cases rel_post_state.cases)

lemma rel_post_state_weaken:
  "rel_post_state R f g  (x y. R x y  Q x y)  rel_post_state Q f g"
  by (auto intro: sim_post_state_weaken simp: rel_post_state_eq_sim_post_state)

lemma rel_post_state_eq[relator_eq]: "rel_post_state (=) = (=)"
  by (simp add: rel_post_state_eq_sim_post_state fun_eq_iff sim_post_state_eq_iff_le order_eq_iff)

lemma rel_post_state_mono[relator_mono]:
  "A  B  rel_post_state A  rel_post_state B"
  using rel_set_mono[of A B]
  by (auto simp add: le_fun_def intro!: rel_post_state.intros elim!: rel_post_state.cases)

lemma rel_post_state_refl': "holds_partial_post_state (λx. R x x) f  rel_post_state R f f"
  by (cases f; auto simp: rel_set_def)

lemma rel_post_state_refl: "(x. R x x)  rel_post_state R f f"
  by (simp add: rel_post_state_refl')

subsection lift_post_state›

lemma lift_post_state_top: "lift_post_state R  = "
  by (auto simp: top_post_state_def)

lemma lift_post_state_unlift_post_state: "lift_post_state R p  q  p  unlift_post_state R q"
  by (cases p; cases q; auto simp add: subset_eq)

lemma lift_post_state_eq_Sup: "lift_post_state R p = Sup {q. sim_post_state R q p}"
  unfolding post_state_eq_iff holds_Sup_post_state
  using holds_lift_post_state sim_post_state_iff
  by blast

lemma le_lift_post_state_iff: "q  lift_post_state R p  sim_post_state R q p"
  by (simp add: post_state_le_iff sim_post_state_iff)

lemma lift_post_state_eq[simp]: "lift_post_state (=) p = p"
  by (simp add: post_state_eq_iff)

lemma lift_post_state_comp:
  "lift_post_state R (lift_post_state Q p) = lift_post_state (R OO Q) p"
  by (cases p) auto

lemma sim_post_state_lift:
  "sim_post_state Q q (lift_post_state R p)  sim_post_state (Q OO R) q p"
  unfolding le_lift_post_state_iff[symmetric] lift_post_state_comp ..

lemma lift_post_state_Sup: "lift_post_state R (Sup F) = (SUP fF. lift_post_state R f)"
  by (simp add: post_state_eq_iff)

lemma lift_post_state_mono: "p  q  lift_post_state R p  lift_post_state R q"
  by (simp add: post_state_le_iff)

subsection map_post_state›

lemma map_post_state_top: "map_post_state f  = "
  by (auto simp: top_post_state_def)

lemma mono_map_post_state: "s1  s2  map_post_state f s1   map_post_state f s2"
  by (simp add: post_state_le_iff)

lemma map_post_state_eq_lift_post_state: "map_post_state f p = lift_post_state (λa b. a = f b) p"
  by (cases p) auto

lemma map_post_state_Sup: "map_post_state f (Sup X) = (SUP xX. map_post_state f x)"
  by (simp add: post_state_eq_iff)

lemma map_post_state_comp: "map_post_state f (map_post_state g p) = map_post_state (f  g) p"
  by (simp add: post_state_eq_iff)

lemma map_post_state_id[simp]: "map_post_state id p = p"
  by (simp add: post_state_eq_iff)

lemma map_post_state_pure[simp]: "map_post_state f (pure_post_state x) = pure_post_state (f x)"
  by (simp add: pure_post_state_def)

lemma sim_post_state_map_post_state1:
  "sim_post_state R (map_post_state f p) q  sim_post_state (λx y. R (f x) y) p q"
  by (cases p; cases q; simp)

lemma sim_post_state_map_post_state2:
  "sim_post_state R p (map_post_state f q)  sim_post_state (λx y. R x (f y)) p q"
  unfolding map_post_state_eq_lift_post_state sim_post_state_lift by (simp add: OO_def[abs_def])

lemma map_post_state_eq_top[simp]: "map_post_state f p =   p = "
  by (cases p; simp add: top_post_state_def)

lemma map_post_state_eq_bot[simp]: "map_post_state f p =   p = "
  by (cases p; simp add: bot_post_state_def)

subsection vmap_post_state›

lemma vmap_post_state_top: "vmap_post_state f  = "
  by (auto simp: top_post_state_def)

lemma vmap_post_state_Sup:
  "vmap_post_state f (Sup X) = (SUP xX. vmap_post_state f x)"
  by (simp add: post_state_eq_iff)

lemma vmap_post_state_le_iff: "(vmap_post_state f p  q) = (p   {p. vmap_post_state f p  q})"
  using vmap_post_state_Sup by (rule gc_of_Sup_cont)

lemma vmap_post_state_eq_lift_post_state: "vmap_post_state f p = lift_post_state (λa b. f a = b) p"
  by (cases p) auto

lemma vmap_post_state_comp: "vmap_post_state f (vmap_post_state g p) = vmap_post_state (g  f) p"
  apply (simp add: post_state_eq_iff)
  apply (intro allI arg_cong[where f="λP. holds_post_state P p"])
  apply auto
  done

lemma vmap_post_state_id: "vmap_post_state id p = p"
  by (simp add: post_state_eq_iff)

lemma sim_post_state_vmap_post_state2:
  "sim_post_state R p (vmap_post_state f q) 
    sim_post_state (λx y. y'. f y' = y  R x y') p q"
  by (cases p; cases q; auto) blast+

lemma vmap_post_state_le_iff_le_map_post_state:
  "map_post_state f p  q  p  vmap_post_state f q"
  by (simp flip: sim_post_state_eq_iff_le
           add: sim_post_state_map_post_state1 sim_post_state_vmap_post_state2)

subsection pure_post_state›

lemma pure_post_state_Failure[simp]: "pure_post_state v  Failure"
  by (simp add: pure_post_state_def)

lemma pure_post_state_top[simp]: "pure_post_state v  "
  by (simp add: pure_post_state_def top_post_state_def)

lemma pure_post_state_bot[simp]: "pure_post_state v  "
  by (simp add: pure_post_state_def bot_post_state_def)

lemma pure_post_state_inj[simp]: "pure_post_state v = pure_post_state w  v = w"
  by (simp add: pure_post_state_def)

lemma sim_pure_post_state_iff[simp]:
  "sim_post_state R (pure_post_state a) (pure_post_state b)  R a b"
  by (simp add: pure_post_state_def)

lemma rel_pure_post_state_iff[simp]:
  "rel_post_state R (pure_post_state a) (pure_post_state b)  R a b"
  by (simp add: rel_post_state_eq_sim_post_state)

lemma pure_post_state_le[simp]: "pure_post_state v  pure_post_state w  v = w"
  by (simp add: pure_post_state_def)

lemma Success_eq_pure_post_state[simp]: "Success X = pure_post_state x  X = {x}"
  by (auto simp add: pure_post_state_def)

lemma pure_post_state_eq_Success[simp]: "pure_post_state x = Success X  X = {x}"
  by (auto simp add: pure_post_state_def)

lemma pure_post_state_le_Success[simp]: "pure_post_state x  Success X  x  X"
  by (auto simp add: pure_post_state_def)

lemma Success_le_pure_post_state[simp]: "Success X  pure_post_state x  X  {x}"
  by (auto simp add: pure_post_state_def)

subsection bind_post_state›

lemma bind_post_state_top: "bind_post_state  g = "
  by (auto simp: top_post_state_def)

lemma bind_post_state_Sup1[simp]:
  "bind_post_state (Sup F) g = (SUP fF. bind_post_state f g)"
  by (simp add: post_state_eq_iff)

lemma bind_post_state_Sup2[simp]:
  "G  {}  f  Failure  bind_post_state f (Sup G) = (SUP gG. bind_post_state f g)"
  by (simp add: post_state_eq_iff holds_post_state_Ball)

lemma bind_post_state_sup1[simp]:
  "bind_post_state (sup f1 f2) g = sup (bind_post_state f1 g) (bind_post_state f2 g)"
  using bind_post_state_Sup1[of "{f1, f2}" g] by simp

lemma bind_post_state_sup2[simp]:
  "bind_post_state f (sup g1 g2) = sup (bind_post_state f g1) (bind_post_state f g2)"
  using bind_post_state_Sup2[of "{g1, g2}" f] by simp

lemma bind_post_state_top1[simp]: "bind_post_state  f = "
  by (simp add: post_state_eq_iff)

lemma bind_post_state_bot1[simp]: "bind_post_state  f = "
  by (simp add: post_state_eq_iff)

lemma bind_post_state_eq_top:
  "bind_post_state f g =   ¬ holds_post_state (λx. g x  ) f"
  by (cases f) (force simp add: Sup_post_state_def simp flip: top_post_state_def)+

lemma bind_post_state_eq_bot:
  "bind_post_state f g =   holds_post_state (λx. g x = ) f"
  by (cases f) (auto simp flip: top_post_state_def)

lemma lift_post_state_bind_post_state:
  "lift_post_state R (bind_post_state x g) = bind_post_state x (λv. lift_post_state R (g v))"
  by (simp add: post_state_eq_iff)

lemma vmap_post_state_bind_post_state:
  "vmap_post_state f (bind_post_state p g) = bind_post_state p (λv. vmap_post_state f (g v))"
  by (simp add: post_state_eq_iff)

lemma map_post_state_bind_post_state:
  "map_post_state f (bind_post_state x g) = bind_post_state x (λv. map_post_state f (g v))"
  by (simp add: post_state_eq_iff)

lemma bind_post_state_pure_post_state1[simp]:
  "bind_post_state (pure_post_state v) f = f v"
  by (simp add: post_state_eq_iff)

lemma bind_post_state_pure_post_state2:
  "bind_post_state p (λx. pure_post_state (f x)) = map_post_state f p"
  by (simp add: post_state_eq_iff)

lemma bind_post_state_map_post_state:
  "bind_post_state (map_post_state f p) g = bind_post_state p (λx. g (f x))"
  by (simp add: post_state_eq_iff)

lemma bind_post_state_assoc[simp]:
  "bind_post_state (bind_post_state f g) h =
    bind_post_state f (λv. bind_post_state (g v) h)"
  by (simp add: post_state_eq_iff)

lemma sim_bind_post_state':
  "sim_post_state (λx y. sim_post_state R (g x) (k y)) f h 
    sim_post_state R (bind_post_state f g) (bind_post_state h k)"
  by (cases f; cases h; simp add: sim_post_state_Sup)

lemma sim_bind_post_state_left_iff:
  "sim_post_state R (bind_post_state f g) h 
    h = Failure  holds_post_state (λx. sim_post_state R (g x) h) f"
  by (cases f; cases h) (simp_all add: sim_post_state_Sup1)

lemma sim_bind_post_state_left:
  "holds_post_state (λx. sim_post_state R (g x) h) f 
    sim_post_state R (bind_post_state f g) h"
  by (simp add: sim_bind_post_state_left_iff)

lemma sim_bind_post_state_right:
  "g    holds_partial_post_state (λx. sim_post_state R f (h x)) g 
    sim_post_state R f (bind_post_state g h) "
  by (cases g) (auto intro: sim_post_state_Sup2)

lemma sim_bind_post_state:
  "sim_post_state Q f h  (x y. Q x y  sim_post_state R (g x) (k y)) 
    sim_post_state R (bind_post_state f g) (bind_post_state h k)"
  by (rule sim_bind_post_state'[OF sim_post_state_weaken, of Q])

lemma rel_bind_post_state':
  "rel_post_state (λa b. rel_post_state R (g1 a) (g2 b)) f1 f2 
    rel_post_state R (bind_post_state f1 g1) (bind_post_state f2 g2)"
  by (auto simp: rel_post_state_eq_sim_post_state
           intro: sim_bind_post_state' sim_post_state_weaken)

lemma rel_bind_post_state:
  "rel_post_state Q f1 f2  (a b. Q a b  rel_post_state R (g1 a) (g2 b)) 
    rel_post_state R (bind_post_state f1 g1) (bind_post_state f2 g2)"
  by (rule rel_bind_post_state'[OF rel_post_state_weaken, of Q])

lemma mono_bind_post_state: "f1  f2  g1  g2  bind_post_state f1 g1  bind_post_state f2 g2"
  by (simp flip: sim_post_state_eq_iff_le add: le_fun_def sim_bind_post_state)


lemma Sup_eq_Failure[simp]: "Sup X = Failure  Failure  X"
  by (simp add: Sup_post_state_def)

lemma Failure_inf_iff: "(Failure = x  y)  (x = Failure  y = Failure)"
  by (metis Failure_le_iff le_inf_iff)

lemma Success_vimage_singleton_cancel: "Success -` {Success X} = {X}"
  by auto

lemma Success_vimage_image_cancel: "Success -` (λx. Success (f x)) ` X = f ` X"
  by auto

lemma Success_image_comp: "(λx. Success (g x)) ` X = Success ` (g ` X)"
  by auto

section exception_or_result› type›

instantiation option :: (type) default
begin

definition "default_option = None"

instance ..
end

lemma Some_ne_default[simp]:
  "Some x  default" "default  Some x"
  "default = None  True" "None = default  True"
  by (auto simp: default_option_def)

typedef (overloaded) ('a::default, 'b) exception_or_result =
  "Inl ` (UNIV - {default})  Inr ` UNIV :: ('a + 'b) set"
  by auto

setup_lifting type_definition_exception_or_result

context assumes "SORT_CONSTRAINT('e::default)"
begin

lift_definition Exception :: "'e  ('e, 'v) exception_or_result" is
  "λe. if e = default then Inr undefined else Inl e"
  by auto

lift_definition Result :: "'v  ('e, 'v) exception_or_result" is
  "λv. Inr v"
  by auto

end

lift_definition
  case_exception_or_result :: "('e::default  'a)  ('v  'a)  ('e, 'v) exception_or_result  'a"
  is case_sum .

declare [[case_translation case_exception_or_result Exception Result]]

lemma Result_eq_Result[simp]: "Result a = Result b  a = b"
  by transfer simp

lemma Exception_eq_Exception[simp]: "Exception a = Exception b  a = b"
  by transfer simp

lemma Result_eq_Exception[simp]: "Result a = Exception e  (e = default  a = undefined)"
  by transfer simp

lemma Exception_eq_Result[simp]: "Exception e = Result a  (e = default  a = undefined)"
  by (metis Result_eq_Exception)

lemma exception_or_result_cases[case_names Exception Result, cases type: exception_or_result]:
  "(e. e  default  x = Exception e  P)  (v. x = Result v  P)  P"
  by (transfer fixing: P) auto

lemma case_exception_or_result_Result[simp]:
  "(case Result v of Exception e  f e | Result w  g w) = g v"
  by transfer simp

lemma case_exception_or_result_Exception[simp]:
  "(case Exception e of Exception e  f e | Result w  g w) = (if e = default then g undefined else f e)"
  by transfer simp


text ‹
Caution: for split rules don't use syntax
term(case r of Exception e  f e | Result w  g w), as this introduces
non eta-contracted termλe. f e and termλw. g w, which don't work with the
splitter.
›
lemma exception_or_result_split:
  "P (case_exception_or_result f g r)  
    (e. e  default  r = Exception e  P (f e)) 
    (v. r = Result v  P (g v))"
  by (cases r) simp_all

lemma exception_or_result_split_asm:
  "P (case_exception_or_result f g r) 
  ¬ ((e. r = Exception e  e  default  ¬ P (f e)) 
    (v. r = Result v  ¬ P (g v)))"
  by (cases r) simp_all

lemmas exception_or_result_splits = exception_or_result_split exception_or_result_split_asm

lemma split_exception_or_result: 
  "r = (case r of Exception e  Exception e | Result v  Result v)"
  by (cases r) auto

lemma exception_or_result_nchotomy:
  "¬ ((e. e  default  x  Exception e)  (v. x  Result v))" by (cases x) auto

lemma val_split:
  fixes r::"(unit, 'a) exception_or_result"
  shows
  "P (case_exception_or_result f g r)  
    (v. r = Result v  P (g v))"
  by (cases r) simp_all

lemma val_split_asm:
  fixes r::"(unit, 'a) exception_or_result"
  shows "P (case_exception_or_result f g r) 
    ¬ (v. r = Result v  ¬ P (g v))"
  by (cases r) simp_all

lemmas val_splits[split] = val_split val_split_asm

instantiation exception_or_result :: ("{equal, default}", equal) equal begin

definition "equal_exception_or_result a b =
  (case a of
    Exception e  (case b of Exception e'  HOL.equal e e' | Result s  False)
  | Result r     (case b of Exception e  False   | Result s  HOL.equal r s)
)"
instance proof qed (simp add: equal_exception_or_result_def equal_eq
    split: exception_or_result_splits)

end

definition is_Exception:: "('e::default, 'a) exception_or_result  bool" where
  "is_Exception x  (case x of Exception e  e  default | Result _  False)"

definition is_Result:: "('e::default, 'a) exception_or_result  bool" where
  "is_Result x  (case x of Exception e  e = default | Result _  True)"

definition the_Exception::  "('e::default, 'a) exception_or_result  'e" where
  "the_Exception x  (case x of Exception e  e | Result _  default)"

definition the_Result::  "('e::default, 'a) exception_or_result  'a" where
  "the_Result x  (case x of Exception e  undefined | Result v  v)"

lemma is_Exception_simps[simp]:
  "is_Exception (Exception e) = (e  default)"
  "is_Exception (Result v) = False"
  by (auto simp add: is_Exception_def)

lemma is_Result_simps[simp]:
  "is_Result (Exception e) = (e = default)"
  "is_Result (Result v) = True"
  by (auto simp add: is_Result_def)

lemma the_Exception_simp[simp]:
  "the_Exception (Exception e) = e"
  by (auto simp add: the_Exception_def)

lemma the_Exception_Result:
  "the_Exception (Result v) = default"
  by (simp add: the_Exception_def)

definition undefined_unit::"unit  'b" where "undefined_unit x  undefined"

syntax "_Res" :: "pttrn  pttrn" ((‹open_block notation=‹prefix Res››Res _))
syntax_consts "_Res"  case_exception_or_result
translations "λRes x. b"  "CONST case_exception_or_result (CONST undefined_unit) (λx. b)"

term "λRes x. f x"
term "λRes (x, y, z). f x y z"
term "λRes (x, y, z) s. P x y s z"


lifting_update exception_or_result.lifting
lifting_forget exception_or_result.lifting

inductive rel_exception_or_result:: "('e::default  'f::default  bool)  ('a  'b  bool) 
  ('e, 'a) exception_or_result  ('f, 'b) exception_or_result  bool"
  where
Exception:
  "E e f  e  default  f  default 
  rel_exception_or_result E R (Exception e) (Exception f)" |
Result:
  "R a b  rel_exception_or_result E R (Result a) (Result b)"

lemma All_exception_or_result_cases:
  "(x. P (x::(_, _) exception_or_result)) 
    (err. err  default  P (Exception err))   (v. P (Result v))"
  by (metis exception_or_result_cases)

lemma Ball_exception_or_result_cases:
  "(xs. P (x::(_, _) exception_or_result)) 
    (err. err  default  Exception err  s  P (Exception err)) 
    (v. Result v  s  P (Result v))"
  by (metis exception_or_result_cases)

lemma Bex_exception_or_result_cases:
  "(xs. P (x::(_, _) exception_or_result)) 
    (err. err  default  Exception err  s  P (Exception err)) 
    (v. Result v  s  P (Result v))"
  by (metis exception_or_result_cases)

lemma Ex_exception_or_result_cases:
  "(x. P (x::(_, _) exception_or_result)) 
    (err. err  default  P (Exception err)) 
    (v. P (Result v))"
  by (metis exception_or_result_cases)

lemma case_distrib_exception_or_result:
  "f (case x of Exception e  E e | Result v  R v) = (case x of Exception e  f (E e) |  Result v  f (R v))"
  by (auto split: exception_or_result_split)

lemma rel_exception_or_result_Results[simp]:
  "rel_exception_or_result E R (Result a) (Result b)  R a b"
  by (simp add: rel_exception_or_result.simps)

lemma rel_exception_or_result_Exception[simp]:
  "e  default  f  default 
    rel_exception_or_result E R (Exception e) (Exception f)  E e f"
  by (simp add: rel_exception_or_result.simps)

lemma rel_exception_or_result_Result_Exception[simp]:
  "e  default  ¬ rel_exception_or_result E R (Exception e) (Result b)"
  "f  default  ¬ rel_exception_or_result E R (Result a) (Exception f)"
  by (auto simp add: rel_exception_or_result.simps)

lemma is_Exception_iff: "is_Exception x  (e. e  default  x = Exception e)"
  apply (cases x)
   apply (auto)
  done

lemma rel_exception_or_result_converse:
  "(rel_exception_or_result E R)¯¯ = rel_exception_or_result E¯¯ R¯¯"
  apply (rule ext)+
  apply (auto simp add: rel_exception_or_result.simps)
  done

lemma rel_eception_or_result_Result[simp]:
  "rel_exception_or_result E R (Result x) (Result y) = R x y"
  by (auto simp add: rel_exception_or_result.simps)

lemma rel_exception_or_result_eq_conv: "rel_exception_or_result (=) (=) = (=)"
  apply (rule ext)+
  by (auto simp add: rel_exception_or_result.simps)
    (meson exception_or_result_cases)

lemma rel_exception_or_result_sum_eq: "rel_exception_or_result (=) (=) = (=)"
  apply (rule ext)+
  apply (subst (1) split_exception_or_result)
  apply (auto simp add: rel_exception_or_result.simps split:  exception_or_result_splits)
  done

definition
  map_exception_or_result::
    "('e::default  'f::default)  ('a  'b)  ('e, 'a) exception_or_result 
      ('f, 'b) exception_or_result"
where
  "map_exception_or_result E F x =
    (case x of Exception e  Exception (E e) | Result v  Result (F v))"

lemma map_exception_or_result_Exception[simp]:
  "e  default  map_exception_or_result E F (Exception e) = Exception (E e)"
  by (simp add: map_exception_or_result_def)

lemma map_exception_or_result_Result[simp]:
  "map_exception_or_result E F (Result v) = Result (F v)"
  by (simp add: map_exception_or_result_def)

lemma map_exception_or_result_id: "map_exception_or_result id id x = x"
  by (cases x; simp)

lemma map_exception_or_result_comp:
  assumes E2: "x. x  default  E2 x  default"
  shows "map_exception_or_result E1 F1 (map_exception_or_result E2 F2 x) =
    map_exception_or_result (E1  E2) (F1  F2) x"
  by (cases x; auto dest: E2)

lemma le_bind_post_state_exception_or_result_cases[case_names Exception Result]:
  assumes
    "holds_partial_post_state (λ(x, t). e. e  default  x = Exception e  X e t  X' e t) x"
  assumes "holds_partial_post_state (λ(x, t). v. x = Result v  V v t  V' v t) x"
  shows "bind_post_state x (λ(r, t). case r of Exception e  X e t | Result v  V v t) 
    bind_post_state x (λ(r, t). case r of Exception e  X' e t| Result v  V' v t)"
  using assms
  by (cases x; force intro!: SUP_mono'' split: exception_or_result_splits prod.splits)

section spec_monad› type›

typedef (overloaded) ('e::default, 'a, 's) spec_monad =
  "UNIV :: ('s  (('e::default, 'a) exception_or_result × 's) post_state) set"
  morphisms run Spec
  by (fact UNIV_witness)

lemma run_case_prod_distrib[simp]:
  "run (case x of (r, s)  f r s) t = (case x of (r, s)  run (f r s) t)"
  by (rule prod.case_distrib)

lemma image_case_prod_distrib[simp]:
  "(λ(r, t). f r t) ` (λv. (g v, h v)) ` R = (λv. (f (g v) (h v))) ` R "
  by auto

lemma run_Spec[simp]: "run (Spec p) s = p s"
  by (simp add: Spec_inverse)

setup_lifting type_definition_spec_monad

lemma spec_monad_ext: "(s. run f s = run g s)  f = g"
  apply transfer
  apply auto
  done

lemma spec_monad_ext_iff: "f = g  (s. run f s = run g s)"
  using spec_monad_ext by auto

instantiation spec_monad :: (default, type, type) complete_lattice
begin

lift_definition
  less_eq_spec_monad :: "('e::default, 'r, 's) spec_monad  ('e, 'r, 's) spec_monad  bool"
  is "(≤)" .

lift_definition
  less_spec_monad :: "('e::default, 'r, 's) spec_monad  ('e, 'r, 's) spec_monad  bool"
  is "(<)" .

lift_definition bot_spec_monad :: "('e::default, 'r, 's) spec_monad" is "bot" .

lift_definition top_spec_monad :: "('e::default, 'r, 's) spec_monad" is "top" .

lift_definition Inf_spec_monad :: "('e::default, 'r, 's) spec_monad set  ('e, 'r, 's) spec_monad"
  is "Inf" .

lift_definition Sup_spec_monad :: "('e::default, 'r, 's) spec_monad set  ('e, 'r, 's) spec_monad"
  is "Sup" .

lift_definition
  sup_spec_monad ::
    "('e::default, 'r, 's) spec_monad  ('e, 'r, 's) spec_monad  ('e, 'r, 's) spec_monad"
  is "sup" .

lift_definition
  inf_spec_monad ::
    "('e::default, 'r, 's) spec_monad  ('e, 'r, 's) spec_monad  ('e, 'r, 's) spec_monad"
  is "inf" .

instance
  apply (standard; transfer)
  subgoal by (rule less_le_not_le)
  subgoal by (rule order_refl)
  subgoal by (rule order_trans)
  subgoal by (rule antisym)
  subgoal by (rule inf_le1)
  subgoal by (rule inf_le2)
  subgoal by (rule inf_greatest)
  subgoal by (rule sup_ge1)
  subgoal by (rule sup_ge2)
  subgoal by (rule sup_least)
  subgoal by (rule Inf_lower)
  subgoal by (rule Inf_greatest)
  subgoal by (rule Sup_upper)
  subgoal by (rule Sup_least)
  subgoal by (rule Inf_empty)
  subgoal by (rule Sup_empty)
  done
end

lift_definition fail :: "('e::default, 'a, 's) spec_monad"
  is "λs. Failure" .

lift_definition
  bind_exception_or_result :: 
    "('e::default, 'a, 's) spec_monad 
      (('e, 'a) exception_or_result  ('f::default, 'b, 's) spec_monad)  ('f, 'b, 's) spec_monad"
is
  "λf h s. bind_post_state (f s) (λ(v, t). h v t)" .

lift_definition bind_handle ::
    "('e::default, 'a, 's) spec_monad 
      ('a  ('f, 'b, 's) spec_monad)  ('e  ('f, 'b, 's) spec_monad) 
      ('f::default, 'b, 's) spec_monad"
  is "λf g h s. bind_post_state (f s) (λ(r, t). case r of Exception e  h e t | Result v  g v t)" .

lift_definition yield :: "('e, 'a) exception_or_result  ('e::default, 'a, 's) spec_monad"
  is "λr s. pure_post_state (r, s)" .

abbreviation "return r  yield (Result r)"

abbreviation "skip  return ()"

abbreviation "throw_exception_or_result e  yield (Exception e)"

lift_definition get_state :: "('e::default, 's, 's) spec_monad"
  is "λs. pure_post_state (Result s, s)" .

lift_definition set_state :: "'s  ('e::default, unit, 's) spec_monad"
  is "λt s. pure_post_state (Result (), t)" .

lift_definition map_value ::
    "(('e::default, 'a) exception_or_result  ('f::default, 'b) exception_or_result) 
      ('e, 'a, 's) spec_monad  ('f, 'b, 's) spec_monad"
  is "λf g s. map_post_state (apfst f) (g s)" .

lift_definition vmap_value ::
    "(('e::default, 'a) exception_or_result  ('f::default, 'b) exception_or_result) 
      ('f, 'b, 's) spec_monad  ('e, 'a, 's) spec_monad"
  is "λf g s. vmap_post_state (apfst f) (g s)" .

definition bind ::
    "('e::default, 'a, 's) spec_monad  ('a  ('e, 'b, 's) spec_monad)  ('e, 'b, 's) spec_monad"
where
  "bind f g = bind_handle f g throw_exception_or_result"

adhoc_overloading
  Monad_Syntax.bind bind

lift_definition lift_state ::
    "('s  't  bool)  ('e::default, 'a, 't) spec_monad  ('e, 'a, 's) spec_monad"
  is "λR p s. lift_post_state (rel_prod (=) R) (SUP t{t. R s t}. p t)" .

definition exec_concrete ::
  "('s  't)  ('e::default, 'a, 's) spec_monad  ('e, 'a, 't) spec_monad"
where
  "exec_concrete st = lift_state (λt s. t = st s)"

definition exec_abstract ::
  "('s  't)  ('e::default, 'a, 't) spec_monad  ('e, 'a, 's) spec_monad"
where
  "exec_abstract st = lift_state (λs t. t = st s)"

definition select_value :: "('e, 'a) exception_or_result set  ('e::default, 'a, 's) spec_monad" where
  "select_value R = (SUP rR. yield r)"

definition select :: "'a set  ('e::default, 'a, 's) spec_monad" where
  "select S = (SUP aS. return a)"

definition unknown :: "('e::default, 'a, 's) spec_monad" where
  "unknown  select UNIV"

definition gets :: "('s  'a)  ('e::default, 'a, 's) spec_monad" where
  "gets f = bind get_state (λs. return (f s))"

definition assert_opt :: "'a option  ('e::default, 'a, 's) spec_monad" where
  "assert_opt v = (case v of None  fail | Some v  return v)"

definition gets_the :: "('s  'a option)  ('e::default, 'a, 's) spec_monad" where
  "gets_the f = bind (gets f) assert_opt"

definition modify :: "('s  's)  ('e::default, unit, 's) spec_monad" where
  "modify f = bind get_state (λs. set_state (f s))"

definition "assume" :: "bool  ('e::default, unit, 's) spec_monad" where
  "assume P = (if P then return () else bot)"

definition assert :: "bool  ('e::default, unit, 's) spec_monad" where
  "assert P = (if P then return () else top)"

definition assuming :: "('s  bool)  ('e::default, unit, 's) spec_monad" where
  "assuming P = do {s   get_state; assume (P s)}"

definition guard:: "('s  bool)   ('e::default, unit, 's) spec_monad" where
  "guard P = do {s  get_state; assert (P s)}"

definition assume_result_and_state :: "('s  ('a × 's) set)  ('e::default, 'a, 's) spec_monad" where
  "assume_result_and_state f = do {s  get_state; (v, t)  select (f s); set_state t; return v}"

(* TODO: delete *)
lift_definition assume_outcome ::
    "('s  (('e, 'a) exception_or_result × 's) set)  ('e::default, 'a, 's) spec_monad"
  is "λf s. Success (f s)" .

definition assert_result_and_state ::
    "('s  ('a × 's) set)  ('e::default, 'a, 's) spec_monad"
where
  "assert_result_and_state f =
    do {s  get_state; assert (f s  {}); (v, t)  select (f s); set_state t; return v}"

abbreviation state_select :: "('s × 's) set  ('e::default, unit, 's) spec_monad" where
  "state_select r  assert_result_and_state (λs. {((), s'). (s, s')  r})"

definition condition ::
    "('s  bool)  ('e::default, 'a, 's) spec_monad  ('e, 'a, 's) spec_monad 
      ('e, 'a, 's) spec_monad"
where
  "condition C T F =
    bind get_state (λs. if C s then T else F)"

notation (output)
  condition  ((‹notation=‹prefix condition››condition (_)//  (_)//  (_)) [1000,1000,1000] 999)

definition "when" ::"bool  ('e::default, unit, 's) spec_monad  ('e, unit, 's) spec_monad"
  where "when c f  condition (λ_. c) f skip"

abbreviation "unless" ::"bool  ('e::default, unit, 's) spec_monad  ('e, unit, 's) spec_monad"
  where "unless c  when (¬c)"

definition on_exit' ::
    "('e::default, 'a, 's) spec_monad  ('e, unit, 's) spec_monad  ('e, 'a, 's) spec_monad"
where
  "on_exit' f c 
     bind_exception_or_result f (λr. do { c; yield r })"

definition on_exit ::
    "('e::default, 'a, 's) spec_monad  ('s × 's) set  ('e, 'a, 's) spec_monad"
where
  "on_exit f cleanup  on_exit' f (state_select cleanup)"

abbreviation guard_on_exit ::
    "('e::default, 'a, 's) spec_monad  ('s  bool)  ('s × 's) set  ('e, 'a, 's) spec_monad"
where
  "guard_on_exit f grd cleanup  on_exit' f (bind (guard grd) (λ_. state_select cleanup))"

abbreviation assume_on_exit ::
    "('e::default, 'a, 's) spec_monad  ('s  bool)  ('s × 's) set  ('e, 'a, 's) spec_monad"
where
  "assume_on_exit f grd cleanup 
    on_exit' f (bind (assuming grd) (λ_. state_select cleanup))"

lift_definition run_bind ::
    "('e::default, 'a, 't) spec_monad  't 
     (('e, 'a) exception_or_result  't  ('f::default, 'b, 's) spec_monad) 
     ('f::default, 'b, 's) spec_monad"
  is "λf t g s. bind_post_state (f t) (λ(r, t). g r t s)" .
    ― ‹construn_bind might be a more canonical building block compared to
   constlift_state. See subsection on construn_bind below.›

type_synonym ('e, 'a, 's) predicate = "('e, 'a) exception_or_result  's  bool"

lift_definition runs_to ::
    "('e::default, 'a, 's) spec_monad  's  ('e, 'a, 's) predicate  bool"
    ((‹open_block notation=‹mixfix runs_to››_/  _  _ ) [61, 1000, 0] 30) ― ‹syntax _do_block› has 62›
  is "λf s Q. holds_post_state (λ(r, t). Q r t) (f s)" .

lift_definition runs_to_partial ::
    "('e::default, 'a, 's) spec_monad  's  ('e, 'a, 's) predicate  bool"
    ((‹open_block notation=‹mixfix runs_to_partial››_/  _ ?⦃ _ ) [61, 1000, 0] 30)
  is "λf s Q. holds_partial_post_state (λ(r, t). Q r t) (f s)" .

lift_definition refines ::
  "('e::default, 'a, 's) spec_monad  ('f::default, 'b, 't) spec_monad  's  't 
    ((('e, 'a) exception_or_result × 's)  (('f, 'b) exception_or_result × 't)  bool)  bool"
  is "λf g s t R. sim_post_state R (f s) (g t)" .

lift_definition rel_spec ::
  "('e::default, 'a, 's) spec_monad  ('f::default, 'b, 't) spec_monad  's  't 
     ((('e, 'a) exception_or_result × 's)  (('f, 'b) exception_or_result × 't)  bool)
    bool"
is
  "λf g s t R. rel_post_state R (f s) (g t)" .

definition rel_spec_monad ::
    "('s  't  bool)  (('e, 'a) exception_or_result  ('f, 'b) exception_or_result  bool) 
      ('e::default, 'a, 's) spec_monad  ('f::default, 'b, 't) spec_monad  bool"
where
  "rel_spec_monad R Q f g =
    (s t. R s t  rel_post_state (rel_prod Q R) (run f s) (run g t))"

lift_definition always_progress :: "('e::default, 'a, 's) spec_monad  bool" is
  "λp. s. p s  bot" .

named_theorems run_spec_monad "Simplification rules to run a Spec monad"
named_theorems runs_to_iff "Equivalence theorems for runs to predicate"
named_theorems always_progress_intros "intro rules for always_progress predicate"

lemma runs_to_partialI: "(r x. P r x)  f  s ?⦃ P "
  by (cases "run f s") (auto simp: runs_to_partial.rep_eq)

lemma runs_to_partial_True: "f  s ?⦃ λr s. True "
  by (simp add: runs_to_partialI)

lemma runs_to_partial_conj: 
  "f  s ?⦃ P   f  s ?⦃ Q   f  s ?⦃ λr s. P r s  Q r s "
  by (cases "run f s") (auto simp: runs_to_partial.rep_eq)

lemma refines_iff':
  "refines f g s t R  (P. g  t  λr s. p t. R (p, t) (r, s)  P p t   f  s  P )"
  apply transfer
  apply (simp add: sim_post_state_iff le_fun_def split_beta' all_comm[where 'a='a])
  apply (rule all_cong_map[where f'="λP (x, y). P x y" and f="λP x y. P (x, y)"])
  apply auto
  done

lemma refines_weaken:
  "refines f g s t R  (r s q t. R (r, s) (q, t)  Q (r, s) (q, t))  refines f g s t Q"
  by transfer (auto intro: sim_post_state_weaken)

lemma refines_mono:
  "(r s q t. R (r, s) (q, t)  Q (r, s) (q, t))  refines f g s t R  refines f g s t Q"
  by (rule refines_weaken)

lemma refines_refl: "(r t. R (r, t) (r, t))  refines f f s s R"
  by transfer (auto intro: sim_post_state_refl)

lemma refines_trans:
  "refines f g s t R  refines g h t u Q 
    (r s p t q u. R (r, s) (p, t)  Q (p, t) (q, u)  S (r, s) (q, u)) 
    refines f h s u S"
  by transfer (auto intro: sim_post_state_trans)

lemma refines_trans': "refines f g s t1 R  refines g h t1 t2 Q  refines f h s t2 (R OO Q)"
  by (rule refines_trans; assumption?) auto

lemma refines_strengthen:
  "refines f g s t R  f  s ?⦃ F   g  t ?⦃ G  
    (x s y t. R (x, s) (y, t)  F x s  G y t  Q (x, s) (y, t)) 
  refines f g s t Q"
  by transfer (fastforce elim!: sim_post_state.cases simp: sim_set_def split_beta')

lemma refines_strengthen1:
  "refines f g s t R  f  s ?⦃ F  
    (x s y t. R (x, s) (y, t)  F x s  Q (x, s) (y, t)) 
  refines f g s t Q"
  by (rule refines_strengthen[OF _ _ runs_to_partial_True]; assumption?)

lemma refines_strengthen2:
  "refines f g s t R  g  t ?⦃ G  
    (x s y t. R (x, s) (y, t)  G y t  Q (x, s) (y, t)) 
  refines f g s t Q"
  by (rule refines_strengthen[OF _ runs_to_partial_True]; assumption?)

lemma refines_cong_cases:
  assumes "e s e' s'. e  default  e'  default 
    R (Exception e, s) (Exception e', s')  Q (Exception e, s) (Exception e', s')"
  assumes "e s x' s'. e  default 
    R (Exception e, s) (Result x', s')  Q (Exception e, s) (Result x', s')"
  assumes "x s e' s'. e'  default 
    R (Result x, s) (Exception e', s')  Q (Result x, s) (Exception e', s')"
  assumes "x s x' s'. R (Result x, s) (Result x', s')  Q (Result x, s) (Result x', s')"
  shows "refines f f' s s' R  refines f f' s s' Q"
  apply (clarsimp intro!: arg_cong[where f="refines f f' s s'"] simp: fun_eq_iff del: iffI)
  subgoal for v s v' s' by (cases v; cases v'; simp add: assms)
  done

lemma runs_to_partial_weaken[runs_to_vcg_weaken]:
  "f  s ?⦃Q  (r t. Q r t  Q' r t)  f  s ?⦃Q'"
  by transfer (auto intro: holds_partial_post_state_weaken)

lemma runs_to_weaken[runs_to_vcg_weaken]: "f  s Q  (r t. Q r t  Q' r t)  f  s Q'"
  by transfer (auto intro: holds_post_state_weaken)

lemma runs_to_partial_imp_runs_to_partial[mono]:
  "(a s. P a s  Q a s)  f  s ?⦃ P   f  s ?⦃ Q "
  using runs_to_partial_weaken[of f s P Q] by auto

lemma runs_to_imp_runs_to[mono]:
  "(a s. P a s  Q a s)  f  s  P   f  s  Q "
  using runs_to_weaken[of f s P Q] by auto

lemma refines_imp_refines[mono]:
  "(a s. P a s  Q a s)  refines f g s t P  refines f g s t Q"
  using refines_weaken[of f g s t P Q] by auto

lemma runs_to_cong_cases:
  assumes "e s. e  default  P (Exception e) s  Q (Exception e) s"
  assumes "x s. P (Result x) s  Q (Result x) s"
  shows "f  s  P   f  s  Q "
  apply (clarsimp intro!: arg_cong[where f="runs_to _ _"] simp: fun_eq_iff del: iffI)
  subgoal for v s by (cases v; simp add: assms)
  done

lemma le_spec_monad_le_refines_iff: "f  g  (s. refines f g s s (=))"
  by transfer (simp add: le_fun_def sim_post_state_eq_iff_le)

lemma spec_monad_le_iff: "f  g  (P (s::'a). g  s  P   f  s  P )"
  by (simp add: le_spec_monad_le_refines_iff refines_iff' all_comm[where 'a='a])

lemma spec_monad_eq_iff: "f = g  (P s. f  s  P   g  s  P )"
  by (simp add: order_eq_iff spec_monad_le_iff)

lemma spec_monad_eqI: "(P s. f  s  P   g  s  P )  f = g"
  by (simp add: spec_monad_eq_iff)

lemma runs_to_Sup_iff: "(Sup X)  s  P   (xX. x  s  P )"
  by transfer simp

lemma runs_to_lfp:
  assumes f: "mono f"
  assumes x_s: "P x s"
  assumes *: "p. (x s. P x s  p x  s  R )  (x s. P x s  f p x  s  R )"
  shows "lfp f x  s  R "
proof -
  have "x s. P x s  lfp f x  s  R "
    apply (rule lfp_ordinal_induct[OF f])
    subgoal using * by blast
    subgoal by (simp add: runs_to_Sup_iff)
    done
  with x_s show ?thesis by auto
qed

lemma runs_to_partial_alt:
  "f  s ?⦃ Q   run f s =   f  s  Q "
  by (cases "run f s"; simp add: runs_to.rep_eq runs_to_partial.rep_eq top_post_state_def)

lemma runs_to_of_runs_to_partial_runs_to':
  "f  s  P   f  s ?⦃ Q   f  s  Q "
  by (auto simp: runs_to_partial_alt runs_to.rep_eq)

lemma runs_to_partial_of_runs_to:
  "f  s  Q   f  s ?⦃ Q "
  by (auto simp: runs_to_partial_alt)

lemma runs_to_partial_tivial[simp]: "f  s ?⦃ λ_ _. True "
  by (cases "run f s"; simp add: runs_to_partial.rep_eq)

lemma le_spec_monad_le_run_iff: "f  g  (s. run f s  run g s)"
  by (simp add: le_fun_def less_eq_spec_monad.rep_eq)

lemma refines_le_run_trans: "refines f g s s1 R  run g s1  run h s2  refines f h s s2 R"
  by transfer (rule sim_post_state_le2)

lemma le_run_refines_trans: "run f s  run g s1  refines g h s1 s2 R  refines f h s s2 R"
  by transfer (rule sim_post_state_le1)

lemma refines_le_trans: "refines f g s t R  g  h  refines f h s t R"
  by (auto simp: le_spec_monad_le_run_iff refines_le_run_trans)

lemma le_refines_trans: "f  g  refines g h s t R  refines f h s t R"
  by (auto simp: le_spec_monad_le_run_iff intro: le_run_refines_trans)

lemma le_run_refines_iff: "run f s  run g t  refines f g s t (=)"
  by transfer (simp add: sim_post_state_eq_iff_le)

lemma refines_top[simp]: "refines f  s t R"
  unfolding top_spec_monad.rep_eq refines.rep_eq by simp

lemma refines_Sup1: "refines (Sup F) g s s' R  (fF. refines f g s s' R)"
  by (simp add: refines.rep_eq Sup_spec_monad.rep_eq image_image sim_post_state_Sup1)

lemma monotone_le_iff_refines:
  "monotone R (≤) F  (x y. R x y  (s. refines (F x) (F y) s s (=)))"
  by (auto simp: monotone_def le_fun_def le_run_refines_iff[symmetric] le_spec_monad_le_run_iff)

lemma refines_iff_runs_to:
  "refines f g s t R  (P. g  t  λr t'. p s'. R (p, s') (r, t')  P p s'   f  s  P )"
  by transfer (auto simp: sim_post_state_iff split_beta')

lemma runs_to_refines_weaken':
  "refines f g s t R  g  t  λr t'. p s'. R (p, s') (r, t')  P p s'   f  s  P "
  by (simp add: refines_iff_runs_to)

lemma runs_to_refines_weaken: "refines f f' s t (=)  f'  t  P   f  s  P "
  by (auto simp: runs_to_refines_weaken')

lemma refinesD_runs_to:
  assumes f_g: "refines f g s t R" and g: "g  t  P "
  shows "f  s  λr t. r' t'. R (r, t) (r', t')  P r' t' "
  by (rule runs_to_refines_weaken'[OF f_g]) (auto intro: runs_to_weaken[OF g])

lemma rel_spec_iff_refines:
  "rel_spec f g s t R  refines f g s t R  refines g f t s (R¯¯)"
  unfolding rel_spec_def by transfer (simp add: rel_post_state_eq_sim_post_state)

lemma rel_spec_symm: "rel_spec f g s t R  rel_spec g f t s R¯¯"
  by (simp add: rel_spec_iff_refines ac_simps)

lemma rel_specD_refines1: "rel_spec f g s t R  refines f g s t R"
  by (simp add: rel_spec_iff_refines)

lemma rel_spec_mono: "P  Q  rel_spec f g s t P  rel_spec f g s t Q"
  using rel_post_state_mono by (auto simp: rel_spec_def)

lemma rel_spec_eq: "rel_spec f g s t (=)  run f s = run g t"
  by (simp add: rel_spec_def rel_post_state_eq)

lemma rel_spec_eq_conv: "(s. rel_spec f g s s (=))  f = g"
  using rel_spec_eq spec_monad_ext by metis

lemma rel_spec_eqD: "(s. rel_spec f g s s (=))  f = g"
  using rel_spec_eq_conv by metis

lemma rel_spec_refl': "f  s ?⦃ λr s. R (r, s) (r, s)  rel_spec f f s s R"
  by transfer  (simp add: rel_post_state_refl' split_beta')

lemma rel_spec_refl: "(r s. R (r, s) (r, s))  rel_spec f f s s R"
  by (rule rel_spec_mono[OF _ rel_spec_eq[THEN iffD2]]) auto

lemma refines_same_runs_to_partialI:
  "f  s ?⦃λr s'. R (r, s') (r, s')  refines f f s s R"
  by transfer
     (auto intro!: sim_post_state_refl' simp: split_beta')

lemma refines_same_runs_toI:
  "f  s λr s'. R (r, s') (r, s')  refines f f s s R"
  by (rule refines_same_runs_to_partialI[OF runs_to_partial_of_runs_to])

lemma always_progress_case_prod[always_progress_intros]:
  "always_progress (f (fst p) (snd p))  always_progress (case_prod f p)"
  by (simp add: split_beta)

lemma refines_runs_to_partial_fuse:
  "refines f f' s s' Q  f  s ?⦃P 
    refines f f' s s' (λ(r,t) (r',t'). Q (r,t) (r',t')  P r t)"
  by transfer (auto elim!: sim_post_state.cases simp: sim_set_def split_beta')

lemma runs_to_refines:
  "f'  s'  P   refines f f' s s' Q  (x s y t. P x s  Q (y, t) (x, s)  R y t) 
    f  s  R "
  by transfer (force elim!: sim_post_state.cases simp: sim_set_def split_beta')

lemma runs_to_partial_subst_Res: "f  s ?⦃λr. P (the_Result r)  f  s ?⦃λRes r. P r"
  by (intro arg_cong[where f="λP. f  s ?⦃ P "]) (auto simp: fun_eq_iff the_Result_def)

lemma runs_to_subst_Res: "f  s λr. P (the_Result r)  f  s λRes r. P r"
  by (intro arg_cong[where f="λP. f  s  P "]) (auto simp: fun_eq_iff the_Result_def)

lemma runs_to_Res[simp]: "f  s λr t. v. r = Result v  P v t  f  s λRes v t. P v t"
  by (intro arg_cong[where f="λP. f  s  P "]) (auto simp: fun_eq_iff the_Result_def)

lemma runs_to_partial_Res[simp]:
  "f  s ?⦃λr t. v. r = Result v  P v t  f  s ?⦃λRes v t. P v t"
  by (intro arg_cong[where f="λP. f  s ?⦃ P "]) (auto simp: fun_eq_iff the_Result_def)

lemma rel_spec_runs_to:
  assumes f: "f  s  P " "always_progress f"
    and g: "g  t  Q " "always_progress g"
    and P: "(r s' p t'. P r s'  Q p t'  R (r, s') (p, t'))"
  shows "rel_spec f g s t R"
  using assms unfolding rel_spec_def always_progress.rep_eq runs_to.rep_eq
  by (cases "run f s"; cases "run g t"; simp add: rel_set_def split_beta')
     (metis all_not_in_conv bot_post_state_def prod.exhaust_sel)

lemma runs_to_res_independent_res: "f  s λ_. P  f  s λRes _. P"
  by (rule runs_to_subst_Res)

lemma lift_state_spec_monad_eq[simp]: "lift_state (=) p = p"
  by transfer (simp add: prod.rel_eq fun_eq_iff rel_post_state_eq)

lemma rel_post_state_Sup:
  "rel_set (λx y. rel_post_state Q (f x) (g y)) X Y  rel_post_state Q (xX. f x) (xY. g x)"
  by (force simp: rel_post_state_eq_sim_post_state rel_set_def intro!: sim_post_state_Sup)

lemma rel_set_Result_image_iff:
  "rel_set (rel_prod (λRes v1 Res v2. (P v1 v2)) R)
     ((λx. case x of (v, s)  (Result v, s)) ` Vals1)
     ((λx. case x of (v, s)  (Result v, s)) ` Vals2)
   
   rel_set (rel_prod P R) Vals1 Vals2"
  by (force simp add: rel_set_def)

lemma rel_post_state_converse_iff:
  "rel_post_state R X Y  rel_post_state R¯¯ Y X"
  by (metis conversep_conversep rel_post_state_eq_sim_post_state)

lemma runs_to_le_post_state_iff: "runs_to f s Q  run f s  Success {(r, s). Q r s}"
  by (cases "run f s") (auto simp add: runs_to.rep_eq)

lemma runs_to_partial_runs_to_iff: "runs_to_partial f s Q  (run f s = Failure  runs_to f s Q)"
  by (cases "run f s"; simp add: runs_to_partial.rep_eq runs_to.rep_eq)

lemma run_runs_to_extensionality:
  "run f s = run g s  (P. f  s P  g  s P)"
  apply transfer
  apply (simp add: post_state_eq_iff)
  apply (rule all_cong_map[where f'="λP (x, y). P x y" and f="λP x y. P (x, y)"])
  apply auto
  done

lemma le_spec_monad_runI: "(s. run f s  run g s)  f  g"
  by transfer (simp add: le_fun_def)

subsection constrel_spec_monad

lemma rel_spec_monad_iff_rel_spec:
  "rel_spec_monad R Q f g  (s t. R s t  rel_spec f g s t (rel_prod Q R))"
  unfolding rel_spec_monad_def rel_fun_def rel_spec.rep_eq ..

lemma rel_spec_monadI:
  "(s t. R s t  rel_spec f g s t (rel_prod Q R))  rel_spec_monad R Q f g"
  by (auto simp: rel_spec_monad_iff_rel_spec)

lemma rel_spec_monadD:
  "rel_spec_monad R Q f g  R s t  rel_spec f g s t (rel_prod Q R)"
  by (auto simp: rel_spec_monad_iff_rel_spec)

lemma rel_spec_monad_eq_conv: "rel_spec_monad (=) (=) = (=)"
  unfolding rel_spec_monad_def by transfer (simp add: fun_eq_iff prod.rel_eq rel_post_state_eq)

lemma rel_spec_monad_converse_iff:
  "rel_spec_monad R Q f g  rel_spec_monad R¯¯ Q¯¯ g f"
  using rel_post_state_converse_iff
  by (metis conversep.cases conversep.intros prod.rel_conversep rel_spec_monad_def)

lemma rel_spec_monad_iff_refines:
  "rel_spec_monad S R f g 
    (s t. S s t  (refines f g s t (rel_prod R S)  refines g f t s (rel_prod R¯¯ S¯¯)))"
  using rel_spec_iff_refines rel_spec_monad_iff_rel_spec
  by (metis prod.rel_conversep)

lemma rel_spec_monad_rel_exception_or_resultI:
  "rel_spec_monad R (rel_exception_or_result (=) (=)) f g  rel_spec_monad R (=) f g"
  by (simp add: rel_exception_or_result_sum_eq)

lemma runs_to_partial_runs_to_fuse: 
  assumes part: "f  s ?⦃Q"
  assumes tot: "f  s P"
  shows "f  s λr t. Q r t  P r t"
  using assms
  apply (cases "run f s")
   apply (auto simp add: runs_to_def runs_to_partial_def)
  done

section ‹VCG basic setup›

lemma runs_to_cong_pred_only:
  "P = Q  (p  s  P )  (p  s  Q )"
  by simp

lemma runs_to_cong_state_only[runs_to_vcg_cong_state_only]:
  "s = t  (p  s  Q )  (p  t  Q )"
  by simp

lemma runs_to_partial_cong_state_only[runs_to_vcg_cong_state_only]:
  "s = t  (p  s ?⦃ Q )  (p  t ?⦃ Q )"
  by simp

lemma runs_to_cong_program_only[runs_to_vcg_cong_program_only]:
  "p = q  (p  s  Q )  (q  s  Q )"
  by simp

lemma runs_to_partial_cong_program_only[runs_to_vcg_cong_program_only]:
  "p = q  (p  s ?⦃ Q )  (q  s ?⦃ Q )"
  by simp

lemma runs_to_case_conv[simp]:
  "((case (a, b) of (x, y)  f x y)  s  Q )  ((f a b)  s  Q )"
  by simp

lemma runs_to_partial_case_conv[simp]:
  "((case (a, b) of (x, y)  f x y)  s ?⦃ Q )  ((f a b)  s ?⦃ Q )"
  by simp

lemma always_progress_prod_case[always_progress_intros]:
  "always_progress (f (fst p) (snd p))  always_progress (case p of (x, y)  f x y)"
  by (auto simp add: always_progress_def split: prod.splits)

lemma runs_to_conj:
  "(f  s  λr s. P r s  Q r s)  (f  s  λr s. P r s )  (f  s  λr s. Q r s)"
  by (cases "run f s") (auto simp: runs_to.rep_eq)

lemma runs_to_all:
  "(f  s  λr s. x. P x r s )  (x. f  s  λr s. P x r s )"
  by (cases "run f s") (auto simp: runs_to.rep_eq)

lemma runs_to_imp_const:
  "(f  s  λr s. P r s  Q )  (Q  (f  s λ_ _. True ))  (¬ Q  (f  s  λr s. ¬ P r s))"
  by (cases "run f s") (auto simp: runs_to.rep_eq)

subsection res_monad› and exn_monad› Types›

type_synonym 'a val = "(unit, 'a) exception_or_result"
type_synonym ('e, 'a) xval = "('e option, 'a) exception_or_result"
type_synonym ('a, 's) res_monad = "(unit, 'a , 's) spec_monad"
type_synonym ('e, 'a, 's) exn_monad = "('e option, 'a, 's) spec_monad"

definition Exn :: "'e  ('e, 'a) xval" where
  "Exn e = Exception (Some e)"

definition
  case_xval :: "('e  'a)  ('v  'a)  ('e, 'v) xval  'a" where
  "case_xval f g x =
    (case x of
       Exception v  (case v of Some e  f e | None  undefined)
     | Result v  g v)"

declare [[case_translation case_xval Exn Result]]

inductive rel_xval:: "('e  'f  bool)  ('a  'b  bool) 
  ('e, 'a) xval  ('f, 'b) xval  bool"
  where
Exn: "E e f  rel_xval E R (Exn e) (Exn f)" | 
Result: "R a b  rel_xval E R (Result a) (Result b)"

definition map_xval :: "('e  'f)  ('a  'b)  ('e, 'a) xval  ('f, 'b) xval" where
  "map_xval f g x  case x of Exn e  Exn (f e) | Result v  Result (g v)"

lemma rel_xval_eq: "rel_xval (=) (=) = (=)"
  apply (rule ext)+
  subgoal for x y
    apply (cases x)
    apply (auto simp add: rel_xval.simps default_option_def Exn_def)
    done
  done

lemma rel_xval_rel_exception_or_result_conv: 
  "rel_xval E R = rel_exception_or_result (rel_map the OO E OO rel_map Some) R"
  apply (rule ext)+
  subgoal for x y
    apply (cases x)
    subgoal
      by (auto simp add: rel_exception_or_result.simps rel_xval.simps Exn_def default_option_def rel_map_def)
        (metis option.sel rel_map_direct relcomppI)+
    subgoal
      by (auto simp add: rel_exception_or_result.simps rel_xval.simps Exn_def default_option_def rel_map_def)
    done
  done

lemmas rel_xval_Exn = rel_xval.intros(1)
lemmas rel_xval_Result = rel_xval.intros(2)

lemma case_xval_simps[simp]:
  "case_xval f g (Exn v) = f v"
  "case_xval f g (Result e) = g e"
   by (auto simp add: case_xval_def Exn_def)

lemma case_exception_or_result_Exn[simp]:
  "case_exception_or_result f g (Exn x) = f (Some x)"
  by (simp add: Exn_def)

lemma xval_split:
  "P (case_xval f g r) 
    (e. r = Exn e  P (f e)) 
    (v. r = Result v  P (g v))"
  by (auto simp add: case_xval_def Exn_def split: exception_or_result_splits option.splits)

lemma xval_split_asm:
  "P (case_xval f g r) 
  ¬ ((e. r = Exn e  ¬ P (f e)) 
    (v. r = Result v  ¬ P (g v)))"
  by (auto simp add: case_xval_def Exn_def split: exception_or_result_splits option.splits)

lemmas xval_splits = xval_split xval_split_asm

lemma Exn_eq_Exn[simp]: "Exn x = Exn y  x = y"
  by (simp add: Exn_def)

lemma Exn_neq_Result[simp]: "Exn x = Result e  False"
  by (simp add: Exn_def)

lemma Result_neq_Exn[simp]: "Result e = Exn x  False"
  by (simp add: Exn_def)

lemma Exn_eq_Exception[simp]:
  "Exn x = Exception a  a = Some x"
  "Exception a = Exn x   a = Some x"
  by (auto simp: Exn_def)

lemma map_exception_or_result_Exn[simp]:
  "(x. f (Some x)  None)  map_exception_or_result f g (Exn x) = Exn (the (f (Some x)))"
  by (simp add: Exn_def)

lemma case_xval_Exception_Some_simp[simp]: "(case Exception (Some y) of Exn e  f e | Result v  g v) = f y "
  by (metis Exn_def case_xval_simps(1))

lemma rel_xval_simps[simp]: 
  "rel_xval E R (Exn e) (Exn f) = E e f"
  "rel_xval E R (Result v) (Result w) = R v w"
  "rel_xval E R (Exn e) (Result w) = False"
  "rel_xval E R (Result v) (Exn f) = False"
  by (auto simp add: rel_xval.simps)

lemma map_xval_simps[simp]: 
  "map_xval f g (Exn e) = Exn (f e)"
  "map_xval f g (Result v) = Result (g v)"
  by (auto simp add: map_xval_def)

lemma map_xval_Exn: "map_xval f g x = Exn y  (e. x = Exn e  y = f e)"
  by (auto simp add: map_xval_def split: xval_splits)

lemma map_xval_Result: "map_xval f g x = Result y  (v. x = Result v  y = g v)"
  by (auto simp add: map_xval_def split: xval_splits)

lemma Result_unit_eq: "(x:: unit val) = Result ()"
  by (cases x) auto

simproc_setup Result_unit_eq ("x::(unit val)") = let
  fun is_Result_unit ConstResult @{typ "unit"} @{typ "unit"} for ConstProduct_Type.Unity = true
    | is_Result_unit _ = false
  in
    K (K (fn ct =>
      if is_Result_unit (Thm.term_of ct) then NONE
      else SOME (mk_meta_eq @{thm Result_unit_eq})))
  end

lemma ex_val_Result1:
  "v1. (x::'v1 val) = Result v1"
  by (cases x) auto

lemma ex_val_Result2:
  "v1 v2. (x::('v1 * 'v2) val) = Result (v1, v2)"
  by (cases x) auto

lemma ex_val_Result3:
  "v1 v2 v3. (x::('v1 * 'v2 * 'v3) val) = Result (v1, v2, v3)"
  by (cases x) auto

lemma ex_val_Result4:
  "v1 v2 v3 v4. (x::('v1 * 'v2 * 'v3 * 'v4) val) = Result (v1, v2, v3, v4)"
  by (cases x) auto

lemma ex_val_Result5:
  "v1 v2 v3 v4 v5. (x::('v1 * 'v2 * 'v3 * 'v4 * 'v5) val) = Result (v1, v2, v3, v4, v5)"
  by (cases x) auto

lemma ex_val_Result6:
  "v1 v2 v3 v4 v5 v6. (x::('v1 * 'v2 * 'v3 * 'v4 * 'v5  * 'v6) val) = Result (v1, v2, v3, v4, v5, v6)"
  by (cases x) auto

lemma ex_val_Result7:
  "v1 v2 v3 v4 v5 v6 v7. (x::('v1 * 'v2 * 'v3 * 'v4 * 'v5  * 'v6 * 'v7) val) = Result (v1, v2, v3, v4, v5, v6, v7)"
  by (cases x) auto

lemma ex_val_Result8:
  "v1 v2 v3 v4 v5 v6 v7 v8. (x::('v1 * 'v2 * 'v3 * 'v4 * 'v5  * 'v6 * 'v7 * 'v8) val) = Result (v1, v2, v3, v4, v5, v6, v7, v8)"
  by (cases x) auto

lemma ex_val_Result9:
  "v1 v2 v3 v4 v5 v6 v7 v8 v9. (x::('v1 * 'v2 * 'v3 * 'v4 * 'v5  * 'v6 * 'v7 * 'v8 * 'v9) val) = Result (v1, v2, v3, v4, v5, v6, v7, v8, v9)"
  by (cases x) auto

lemma ex_val_Result10:
  "v1 v2 v3 v4 v5 v6 v7 v8 v9 v10. (x::('v1 * 'v2 * 'v3 * 'v4 * 'v5  * 'v6 * 'v7 * 'v8 * 'v9 * 'v10) val) = Result (v1, v2, v3, v4, v5, v6, v7, v8, v9, v10)"
  by (cases x) auto


lemma ex_val_Result11:
  "v1 v2 v3 v4 v5 v6 v7 v8 v9 v10 v11. (x::('v1 * 'v2 * 'v3 * 'v4 * 'v5  * 'v6 * 'v7 * 'v8 * 'v9 * 'v10 * 'v11) val) = Result (v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)"
  by (cases x) auto


lemmas ex_val_Result[simp] =
  ex_val_Result1
  ex_val_Result2
  ex_val_Result3
  ex_val_Result4
  ex_val_Result5
  ex_val_Result6
  ex_val_Result7
  ex_val_Result8
  ex_val_Result9
  ex_val_Result10
  ex_val_Result11

lemma all_val_imp_iff: "(v. (r::'a val) = Result v  P)  P"
  by (cases r) auto

definition map_exn :: "('e  'f)  ('e, 'a) xval  ('f, 'a) xval" where
  "map_exn f x =
    (case x of
       Exn e  Exn (f e)
     | Result v  Result v)"

lemma map_exn_simps[simp]:
  "map_exn f (Exn e) = Exn (f e)"
  "map_exn f (Result v) = Result v"
  "map_exn f x = Result v  x = Result v"
  by (auto simp add: map_exn_def split: xval_splits)

lemma map_exn_id[simp]: "map_exn (λx. x) = (λx. x)"
  by (auto simp add: map_exn_def split: xval_splits)

definition unnest_exn :: "('e + 'a, 'a) xval  ('e, 'a) xval" where
  "unnest_exn x =
     (case x of
        Exn e  (case e of Inl l  Exn l | Inr r  Result r)
      | Result v  Result v)"

lemma unnest_exn_simps[simp]:
  "unnest_exn (Exn (Inl l)) = Exn l"
  "unnest_exn (Exn (Inr r)) = Result r"
  "unnest_exn (Result v) = Result v"
   by (auto simp add: unnest_exn_def split: xval_splits)

lemma unnest_exn_eq_simps[simp]:
  "unnest_exn (Result r) = Result r'  r = r'"
  "unnest_exn (Result e)  Exn e"
  "unnest_exn (Exn e) = Result r  e = Inr r"
  "unnest_exn (Exn e) = Exn e'  e = Inl e'"
  by (auto simp: unnest_exn_def split: xval_splits sum.splits)

definition the_Exn :: "('e, 'a) xval  'e" where
  "the_Exn x  (case x of Exn e  e | Result _  undefined)"

abbreviation the_Res :: "'a val  'a" where "the_Res  the_Result"

lemma is_Exception_val[simp]: "is_Exception (x::'a val) = False"
  by (auto simp add: is_Exception_def)

lemma is_Exception_Exn[simp]: "is_Exception (Exn x)"
  by (simp add: Exn_def)

lemma is_Result_val[simp]: "is_Result (v::'a val)"
  by (auto simp add: is_Result_def)

lemma the_Exception_Exn[simp]: "the_Exception (Exn e) = Some e"
  by (auto simp add: the_Exception_def)

lemma the_Exn_Exn[simp]:
  "the_Exn (Exn e) = e"
  "the_Exn (Exception (Some e)) = e"
   by (auto simp add: the_Exn_def case_xval_def)

lemma the_Result_simp[simp]:
  "the_Result (Result v) = v"
  by (auto simp add: the_Result_def)

lemma Result_the_Result_val[simp]: "Result (the_Result (x::'a val)) = x"
  by (auto simp add: the_Result_def)

lemma rel_exception_or_result_val_apply:
  fixes x::"'a val"
  fixes y ::"'b val"
  shows "rel_exception_or_result E R x y  rel_exception_or_result E' R x y"
  by (auto simp add: rel_exception_or_result.simps)

lemma rel_exception_or_result_val:
  shows "((rel_exception_or_result E R)::'a val  'b val  bool) = rel_exception_or_result E' R"
  apply (rule ext)+
  apply (auto simp add: rel_exception_or_result.simps)
  done

lemma rel_exception_or_result_Res_val:
  "(λRes x Res y. R x y) = rel_exception_or_result (λ_ _. False) R"
  apply (rule ext)+
  apply auto
  done

definition unite:: "('a, 'a) xval  'a val" where
  "unite x = (case x of Exn v  Result v | Result v  Result v)"

lemma unite_simps[simp]:
  "unite (Exn v) = Result v"
  "unite (Result v) = Result v"
  by (auto simp add: unite_def)

lemma val_exhaust: "(v. (x::'a val) = Result v  P)  P"
  by (cases x) auto

lemma val_iff: "P (x::'a val)  (v. x = Result v  P (Result v))"
  by (cases x) auto

lemma val_nchotomy: "(x::'a val). v. x = Result v"
  by (meson val_exhaust)

lemma "(x::'a val. PROP P x)  PROP P (Result (v::'a))"
  by (rule meta_spec)

lemma val_Result_the_Result_conv: "(x::'a val) = Result (the_Result x)"
  by (cases x) (auto simp add: the_Result_def)

lemma split_val_all[simp]: "(x::'a val. PROP P x)  (v::'a. PROP P (Result v))"
proof
  fix v::'a
  assume "x::'a val. PROP P x"
  then show "PROP P (Result v)" .
next
  fix x::"'a val"
  assume "v::'a. PROP P (Result v)"
  hence "PROP P (Result (the_Result x))" .
  thus "PROP P x"
    apply (subst val_Result_the_Result_conv)
    apply (assumption)
    done
qed

lemma split_val_Ex[simp]: "((x::'a val). P x)  ((v::'a). P (Result v))"
  by (auto)

lemma split_val_All[simp]: "((x::'a val). P x)  ((v::'a). P (Result v))"
  by (auto)

definition to_xval :: "('e + 'a)  ('e, 'a) xval" where
 "to_xval x = (case x of Inl e  Exn e | Inr v  Result v)"

lemma to_xval_simps[simp]:
  "to_xval (Inl e) = Exn e"
  "to_xval (Inr v) = Result v"
   by (auto simp add: to_xval_def)

lemma to_xval_Result_iff[simp]: "to_xval x = Result v  x = Inr v"
  apply (cases x)
   apply (auto simp add: to_xval_def default_option_def)
  done

lemma to_xval_Exn_iff[simp]: 
 "to_xval x = Exn v 
    (x  = Inl v)"
  apply (cases x)
   apply (auto simp add: to_xval_def default_option_def Exn_def)
  done

lemma to_xval_Exception_iff[simp]: 
 "to_xval x = Exception v 
    ((v = None  x = Inr undefined)  (e. v = Some e  x  = Inl e))"
  apply (cases x)
   apply (auto simp add: to_xval_def default_option_def Exn_def)
  done


definition from_xval :: "('e, 'a) xval  ('e + 'a)" where
 "from_xval x = (case x of Exn e  Inl e | Result v  Inr v)"

lemma from_xval_simps[simp]:
  "from_xval (Exn e) = Inl e"
  "from_xval (Result v) = Inr v"
  by (auto simp add: from_xval_def)

lemma from_xval_Inr_iff[simp]: "from_xval x = Inr v  x = Result v"
  apply (cases x)
   apply (auto simp add: from_xval_def split: xval_splits)
  done

lemma from_xval_Inl_iff[simp]: "from_xval x = Inl e  x = Exn e"
  apply (cases x)
   apply (auto simp add: from_xval_def split: xval_splits)
  done

lemma to_xval_from_xval[simp]: "to_xval (from_xval x) = x"
  apply (cases x)
   apply (auto simp add: Exn_def default_option_def)
  done

lemma from_xval_to_xval[simp]: "from_xval (to_xval x) = x"
  apply (cases x)
   apply (auto simp add: Exn_def default_option_def)
  done

lemma rel_map_to_xval_Exn_iff[simp]:  "rel_map to_xval y (Exn e)  y = Inl e"
  by (auto simp add: rel_map_def to_xval_def split: sum.splits)

lemma rel_map_to_xval_Inr_iff[simp]:  "rel_map to_xval (Inr r) x  x = Result r"
  by (auto simp add: rel_map_def)

lemma rel_map_to_xval_Inl_iff[simp]:  "rel_map to_xval (Inl l) x  x = Exn l"
  by (auto simp add: rel_map_def)

lemma rel_map_to_xval_Result_iff[simp]:  "rel_map to_xval y (Result r)  y = Inr r"
  by (auto simp add: rel_map_def to_xval_def split: sum.splits)

lemma rel_map_from_xval_Exn_iff[simp]:  "rel_map from_xval (Exn l) x  x = Inl l"
  by (auto simp add: rel_map_def)

lemma rel_map_from_xval_Inl_iff[simp]:  "rel_map from_xval y (Inl e)  y = Exn e"
  by (auto simp add: rel_map_def from_xval_def split: xval_splits)

lemma rel_map_from_xval_Result_iff[simp]:  "rel_map from_xval (Result r) x  x = Inr r"
  by (auto simp add: rel_map_def)

lemma rel_map_from_xval_Inr_iff[simp]:  "rel_map from_xval y (Inr r)  y = Result r"
  by (auto simp add: rel_map_def from_xval_def split: xval_splits)

section res_monad› and exn_monad› functions›

abbreviation "throw e  yield (Exn e)"

definition "try" :: "('e + 'a, 'a, 's) exn_monad  ('e, 'a, 's) exn_monad" where
  "try = map_value unnest_exn"

definition "finally" :: "('a, 'a, 's) exn_monad  ('a, 's) res_monad" where
  "finally = map_value unite"

definition
  catch :: "('e, 'a, 's) exn_monad 
            ('e  ( 'f::default, 'a, 's) spec_monad) 
            ('f::default, 'a, 's) spec_monad" (infix <catch> 10)
where
  "f <catch> handler  bind_handle f return (handler o the)"

abbreviation bind_finally ::
  "('e, 'a, 's) exn_monad  (('e, 'a) xval  ('b, 's) res_monad)  ('b, 's) res_monad"
where
  "bind_finally  bind_exception_or_result"

definition ignoreE ::"('e, 'a, 's) exn_monad  ('a, 's) res_monad" where
  "ignoreE f = catch f (λ_. bot)"

definition liftE :: "('a, 's) res_monad  ('e, 'a, 's) exn_monad" where
  "liftE = map_value (map_exception_or_result (λx. undefined) id)"

definition check:: "'e  ('s  'a  bool)  ('e, 'a, 's) exn_monad" where
  "check e p =
    condition (λs. x. p s x)
    (do { s  get_state; select {x. p s x} })
    (throw e)"

abbreviation check' :: "'e  ('s  bool)  ('e, unit, 's) exn_monad" where
  "check' e p  check e (λs _. p s)"

section "Monad operations"

subsection consttop

declare top_spec_monad.rep_eq[run_spec_monad, simp]

lemma always_progress_top[always_progress_intros]: "always_progress "
  by transfer simp

lemma runs_to_top[simp]: "  s  Q   False"
  by transfer simp

lemma runs_to_partial_top[simp]: "  s ?⦃ Q   True"
  by transfer simp

subsection constbot

declare bot_spec_monad.rep_eq[run_spec_monad, simp]

lemma always_progress_bot_iff[iff]: "always_progress   False"
  by transfer simp

lemma runs_to_bot[simp]: "  s  Q   True"
  by transfer simp

lemma runs_to_partial_bot[simp]: "  s ?⦃ Q   True"
  by transfer simp

subsection constfail

lemma always_progress_fail[always_progress_intros]: "always_progress fail"
  by transfer (simp add: bot_post_state_def)

lemma run_fail[run_spec_monad, simp]: "run fail s = "
  by transfer (simp add: top_post_state_def)

lemma runs_to_fail[simp]: "fail  s  R   False"
  by transfer simp

lemma runs_to_partial_fail[simp]: "fail  s ?⦃ R   True"
  by transfer simp

lemma refines_fail[simp]: "refines f fail s t R"
  by (simp add: refines.rep_eq)

lemma rel_spec_fail[simp]: "rel_spec fail fail s t R"
  by (simp add: rel_spec_def)

subsection constyield

lemma run_yield[run_spec_monad, simp]: "run (yield r) s = pure_post_state (r, s)"
  by transfer simp

lemma always_progress_yield[always_progress_intros]: "always_progress (yield r)"
  by (simp add: always_progress_def)

lemma yield_inj[simp]: "yield x = yield y  x = y"
  by transfer (auto simp: fun_eq_iff)

lemma runs_to_yield_iff[simp]: "((yield r)  s  Q )  Q r s"
  by transfer simp

lemma runs_to_yield[runs_to_vcg]: "Q r s  yield r  s  Q "
  by simp

lemma runs_to_partial_yield_iff[simp]: "((yield r)  s ?⦃ Q )  Q r s"
  by transfer simp

lemma runs_to_partial_yield[runs_to_vcg]: "Q r s  yield r  s ?⦃ Q "
  by simp

lemma refines_yield_iff[simp]: "refines (yield r) (yield r') s s' R  R (r, s) (r', s')"
  by transfer simp

lemma refines_yield: "R (a, s) (b, t)  refines (yield a) (yield b) s t R"
  by simp

lemma refines_yield_right_iff:
  "refines f (yield e) s t R  (f  s  λr s'. R (r, s') (e, t) )"
  by (auto simp: refines.rep_eq runs_to.rep_eq sim_post_state_pure_post_state2 split_beta')

lemma rel_spec_yield: "R (a, s) (b, t)  rel_spec (yield a) (yield b) s t R"
  by (simp add: rel_spec_def)

lemma rel_spec_yield_iff[simp]: "rel_spec (yield x) (yield y) s t Q  Q (x, s) (y, t)"
  by (simp add: rel_spec.rep_eq)

lemma rel_spec_monad_yield: "Q x y  rel_spec_monad R Q (yield x) (yield y)"
  by (auto simp add: rel_spec_monad_iff_rel_spec)

subsection constthrow_exception_or_result

lemma throw_exception_or_result_bind[simp]:
  "e  default  throw_exception_or_result e >>= f = throw_exception_or_result e"
  unfolding bind_def by transfer auto

subsection constthrow

lemma throw_bind[simp]: "throw e >>= f = throw e"
  by (simp add: Exn_def)

subsection constget_state

lemma always_progress_get_state[always_progress_intros]: "always_progress get_state"
  by transfer simp

lemma run_get_state[run_spec_monad, simp]: "run get_state s = pure_post_state (Result s, s)"
  by transfer simp

lemma runs_to_get_state[runs_to_vcg]: "get_state  s λr t. r = Result s  t = s"
  by transfer simp

lemma runs_to_get_state_iff[runs_to_iff]: "get_state  s Q  (Q (Result s) s)"
  by transfer simp

lemma runs_to_partial_get_state[runs_to_vcg]: "get_state  s ?⦃λr t. r = Result s  t = s"
  by transfer simp

lemma refines_get_state: "R (Result s, s) (Result t, t)  refines get_state get_state s t R"
  by transfer simp

lemma rel_spec_get_state: "R (Result s, s) (Result t, t)  rel_spec get_state get_state s t R"
  by (auto simp: rel_spec_iff_refines intro: refines_get_state)

subsection constset_state

lemma always_progress_set_state[always_progress_intros]: "always_progress (set_state t)"
  by transfer simp

lemma set_state_inj[simp]: "set_state x = set_state y  x = y"
  by transfer (simp add: fun_eq_iff)

lemma run_set_state[run_spec_monad, simp]: "run (set_state t) s = pure_post_state (Result (), t)"
  by transfer simp

lemma runs_to_set_state[runs_to_vcg]: "set_state t  s λr t'. t' = t"
  by transfer simp

lemma runs_to_set_state_iff[runs_to_iff]: "set_state t  s Q  Q (Result ()) t"
  by transfer simp

lemma runs_to_partial_set_state[runs_to_vcg]: "set_state t  s ?⦃λr t'. t' = t"
  by transfer simp

lemma refines_set_state:
  "R (Result (), s') (Result (), t')  refines (set_state s') (set_state t') s t R"
  by transfer simp

lemma rel_spec_set_state:
  "R (Result (), s') (Result (), t')  rel_spec (set_state s') (set_state t') s t R"
  by (auto simp add: rel_spec_iff_refines intro: refines_set_state)

subsection constselect

lemma always_progress_select[always_progress_intros]: "S  {}  always_progress (select S)"
  unfolding select_def by transfer simp

lemma runs_to_select[runs_to_vcg]: "(x. x  S  Q (Result x) s)  select S  s  Q "
  unfolding select_def by transfer simp

lemma runs_to_select_iff[runs_to_iff]: "select S  s Q  (xS. Q (Result x) s)"
  unfolding select_def by transfer auto

lemma runs_to_partial_select[runs_to_vcg]: "(x. x  S  Q (Result x) s)  select S  s ?⦃ Q "
  using runs_to_select by (rule runs_to_partial_of_runs_to)

lemma run_select[run_spec_monad, simp]: "run (select S) s = Success ((λv. (Result v, s)) ` S)"
  unfolding select_def by transfer (auto simp add: image_image pure_post_state_def Sup_Success)

lemma refines_select:
  "(x. x  P  xaQ. R (Result x, s) (Result xa, t))  refines (select P) (select Q) s t R"
  unfolding select_def by transfer (auto intro!: sim_post_state_Sup)

lemma rel_spec_select:
  "rel_set (λa b. R (Result a, s) (Result b, t)) P Q  rel_spec (select P) (select Q) s t R"
  by (auto simp add: rel_spec_def rel_set_def)

subsection constunknown

lemma runs_to_unknown[runs_to_vcg]: "(x. Q (Result x) s)  unknown  s  Q "
  by (simp add: unknown_def runs_to_select_iff)

lemma runs_to_unknown_iff[runs_to_iff]: "unknown  s  Q   (x. Q (Result x) s)"
  by (simp add: unknown_def runs_to_select_iff)

lemma runs_to_partial_unknown[runs_to_vcg]: "(x. Q (Result x) s)  unknown  s ?⦃ Q "
  using runs_to_unknown by (rule runs_to_partial_of_runs_to)

lemma run_unknown[run_spec_monad, simp]: "run unknown s = Success ((λv. (Result v, s)) ` UNIV)"
  unfolding unknown_def by simp

lemma always_progress_unknown[always_progress_intros]: "always_progress unknown"
  unfolding unknown_def by (simp add: always_progress_intros)

subsection constlift_state

lemma run_lift_state[run_spec_monad]:
  "run (lift_state R f) s = lift_post_state (rel_prod (=) R) (SUP t{t. R s t}. run f t)"
  by transfer standard

lemma runs_to_lift_state_iff[runs_to_iff]:
  "(lift_state R f)  s  Q   (s'. R s s'  f  s' λr t'. t. R t t'  Q r t)"
  by (simp add: runs_to.rep_eq lift_state.rep_eq rel_prod_sel split_beta')

lemma runs_to_lift_state[runs_to_vcg]:
  "(s'. R s s'  f  s' λr t'. t. R t t'  Q r t)  lift_state R f  s  Q "
  by (simp add: runs_to_lift_state_iff)

lemma runs_to_partial_lift_state[runs_to_vcg]:
  "(s'. R s s'  f  s' ?⦃λr t'.  t. R t t'  Q r t)  lift_state R f  s ?⦃ Q "
  by (cases "s'. R s s'  run f s'  Failure")
     (auto simp: runs_to_partial_alt run_lift_state lift_post_state_Sup image_image
                 image_iff runs_to_lift_state_iff top_post_state_def)

lemma mono_lift_state: "f  f'  lift_state R f  lift_state R f'"
  by transfer (auto simp: le_fun_def intro!: lift_post_state_mono SUP_mono)

subsection ‹constexec_concrete›

lemma run_exec_concrete[run_spec_monad]:
  "run (exec_concrete st f) s = map_post_state (apsnd st) ( (run f ` st -` {s}))"
  by (auto simp add: exec_concrete_def lift_state_def map_post_state_eq_lift_post_state fun_eq_iff
           intro!: arg_cong2[where f=lift_post_state] SUP_cong)

lemma runs_to_exec_concrete_iff[runs_to_iff]:
  "exec_concrete st f  s  Q   (t. s = st t  f  t λr t.  Q r (st t))"
  by (auto simp: runs_to.rep_eq run_exec_concrete split_beta')

lemma runs_to_exec_concrete[runs_to_vcg]:
  "(t. s = st t  f  t λr t.  Q r (st t))  exec_concrete st f  s  Q "
  by (simp add: runs_to_exec_concrete_iff)

lemma runs_to_partial_exec_concrete[runs_to_vcg]:
  "(t. s = st t  f  t ?⦃λr t.  Q r (st t))  exec_concrete st f  s ?⦃ Q "
  by (simp add: exec_concrete_def runs_to_partial_lift_state)

lemma mono_exec_concrete: "f  f'  exec_concrete st f  exec_concrete st f'"
  unfolding exec_concrete_def by (rule mono_lift_state)

lemma monotone_exec_concrete_le[partial_function_mono]:
  "monotone Q (≤) f  monotone Q (≤) (λf'. exec_concrete st (f f'))"
  using mono_exec_concrete
  by (fastforce simp add: monotone_def le_fun_def)

lemma monotone_exec_concrete_ge[partial_function_mono]:
  "monotone Q (≥) f  monotone Q (≥) (λf'. exec_concrete st (f f'))"
  using mono_exec_concrete
  by (fastforce simp add: monotone_def le_fun_def)

subsection ‹constexec_abstract›

lemma run_exec_abstract[run_spec_monad]:
  "run (exec_abstract st f) s = vmap_post_state (apsnd st) (run f (st s))"
  by (auto simp add: exec_abstract_def vmap_post_state_eq_lift_post_state lift_state_def fun_eq_iff
           intro!: arg_cong2[where f=lift_post_state])

lemma runs_to_exec_abstract_iff[runs_to_iff]:
  "exec_abstract st f  s  Q   f  (st s) λr t. s'. t = st s'  Q r s' "
  by (simp add: runs_to.rep_eq run_exec_abstract split_beta' prod_eq_iff eq_commute)

lemma runs_to_exec_abstract[runs_to_vcg]:
  "f  (st s) λr t. s'. t = st s'  Q r s'  exec_abstract st f  s  Q "
  by (simp add: runs_to_exec_abstract_iff)

lemma runs_to_partial_exec_abstract[runs_to_vcg]:
  "f  (st s) ?⦃λr t. s'. t = st s'  Q r s'  exec_abstract st f  s ?⦃ Q "
  by (simp add: runs_to_partial.rep_eq run_exec_abstract split_beta' prod_eq_iff eq_commute)

lemma mono_exec_abstract: "f  f'  exec_abstract st f  exec_abstract st f'"
  unfolding exec_abstract_def by (rule mono_lift_state)

lemma monotone_exec_abstract_le[partial_function_mono]:
  "monotone Q (≤) f  monotone Q (≤) (λf'. exec_abstract st (f f'))"
  by (simp add: monotone_def mono_exec_abstract)

lemma monotone_exec_abstract_ge[partial_function_mono]:
  "monotone Q (≥) f  monotone Q (≥) (λf'. exec_abstract st (f f'))"
  by (simp add: monotone_def mono_exec_abstract)

subsection constbind_exception_or_result

lemma runs_to_bind_exception_or_result_iff[runs_to_iff]:
  "bind_exception_or_result f g  s  Q   f  s  λr t. g r  t  Q "
  by transfer (simp add: split_beta')

lemma runs_to_bind_exception_or_result[runs_to_vcg]:
  "f  s  λr t. g r  t  Q   bind_exception_or_result f g  s  Q "
  by (auto simp: runs_to_bind_exception_or_result_iff)

lemma runs_to_partial_bind_exception_or_result[runs_to_vcg]:
  "f  s ?⦃ λr t. g r  t ?⦃ Q   bind_exception_or_result f g  s ?⦃ Q "
  by transfer (simp add: split_beta' holds_partial_bind_post_state)

lemma refines_bind_exception_or_result:
  "refines f f' s s' (λ(r, t) (r', t'). refines (g r) (g' r') t t' R) 
    refines (bind_exception_or_result f g) (bind_exception_or_result f' g') s s' R"
  by transfer (auto intro: sim_bind_post_state)

lemma refines_bind_exception_or_result':
  assumes f: "refines f f' s s' Q"
  assumes g: "r t r' t'. Q (r, t) (r', t')  refines (g r) (g' r') t t' R"
  shows "refines (bind_exception_or_result f g) (bind_exception_or_result f' g') s s' R"
  by (auto intro: refines_bind_exception_or_result refines_mono[OF _ f] g)

lemma mono_bind_exception_or_result:
  "f  f'  g  g'  bind_exception_or_result f g  bind_exception_or_result f' g'"
  unfolding le_fun_def
  by transfer (auto simp add: le_fun_def intro!: mono_bind_post_state)

lemma monotone_bind_exception_or_result_le[partial_function_mono]:
  "monotone R (≤) (λf'. f f')  (v. monotone R (≤) (λf'. g f' v)) 
    monotone R (≤) (λf'. bind_exception_or_result (f f') (g f'))"
  by (simp add: monotone_def) (metis le_fun_def mono_bind_exception_or_result)

lemma monotone_bind_exception_or_result_ge[partial_function_mono]:
  "monotone R (≥) (λf'. f f')  (v. monotone R (≥) (λf'. g f' v)) 
    monotone R (≥) (λf'. bind_exception_or_result (f f') (g f'))"
  by (simp add: monotone_def) (metis le_fun_def mono_bind_exception_or_result)

lemma run_bind_exception_or_result_cong:
  assumes *: "run f s = run f' s"
  assumes **: "f  s ?⦃ λx s'. run (g x) s' = run (g' x) s' "
  shows "run (bind_exception_or_result f g) s = run (bind_exception_or_result f' g') s"
  using assms
  by (cases "run f' s")
     (auto simp: bind_exception_or_result_def runs_to_partial.rep_eq
           intro!: SUP_cong split: exception_or_result_splits)

subsection constbind_handle

lemma bind_handle_eq:
  "bind_handle f g h =
    bind_exception_or_result f (λr. case r of Exception e  h e | Result v  g v)"
  unfolding spec_monad_ext_iff bind_handle.rep_eq bind_exception_or_result.rep_eq
  by (intro allI arg_cong[where f="bind_post_state _"] ext)
     (auto split: exception_or_result_split)

lemma runs_to_bind_handle_iff[runs_to_iff]:
  "bind_handle f g h  s  Q   f  s  λr t.
    (v. r = Result v  g v  t  Q ) 
    (e. r = Exception e  e  default  h e  t  Q )"
  by transfer
     (auto intro!: arg_cong[where f="λp. holds_post_state p _"] simp: fun_eq_iff
           split: prod.splits exception_or_result_splits)

lemma runs_to_bind_handle[runs_to_vcg]:
  "f  s  λr t. (v. r = Result v  g v  t  Q ) 
                 (e. r = Exception e  e  default  h e  t  Q ) 
  bind_handle f g h  s  Q "
  by (auto simp: runs_to_bind_handle_iff)

lemma runs_to_bind_handle_exeption_monad[runs_to_vcg]:
  fixes f :: "('e, 'a, 's) exn_monad"
  assumes f:
    "f  s  λr t. (v. r = Result v  (g v  t  Q )) 
                   (e. r = Exception (Some e)  (h (Some e)  t  Q ))"
  shows "bind_handle f g h  s  Q "
  by (rule runs_to_bind_handle[OF runs_to_weaken[OF f]]) (auto simp: default_option_def)

lemma runs_to_bind_handle_res_monad[runs_to_vcg]:
  fixes f :: "('a, 's) res_monad"
  assumes f: "f  s  λr t. (v. r = Result v  (g v  t  Q ))"
  shows "bind_handle f g h  s  Q "
  by (rule runs_to_bind_handle[OF runs_to_weaken[OF f]]) (auto simp: default_option_def)

lemma runs_to_partial_bind_handle[runs_to_vcg]:
  "f  s ?⦃ λr t. (v. r = Result v  g v  t ?⦃ Q ) 
                 (e. r = Exception e  e  default  h e  t ?⦃ Q ) 
  bind_handle f g h  s ?⦃ Q "
  by transfer
     (auto intro!: holds_partial_bind_post_stateI intro: holds_partial_post_state_weaken
           split: exception_or_result_splits prod.splits)

lemma runs_to_partial_bind_handle_exeption_monad[runs_to_vcg]:
  fixes f :: "('e, 'a, 's) exn_monad"
  assumes f: "f  s ?⦃ λr t. (v. r = Result v  (g v  t ?⦃ Q )) 
    (e. r = Exception (Some e)  (h (Some e)  t ?⦃ Q ))"
  shows "bind_handle f g h  s ?⦃ Q "
  by (rule runs_to_partial_bind_handle[OF runs_to_partial_weaken[OF f]])
     (auto simp: default_option_def)

lemma runs_to_partial_bind_handle_res_monad[runs_to_vcg]:
  fixes f :: "('a, 's) res_monad"
  assumes f: "f  s ?⦃ λr t. (v. r = Result v  (g v  t ?⦃ Q ))"
  shows "bind_handle f g h  s ?⦃ Q "
  by (rule runs_to_partial_bind_handle[OF runs_to_partial_weaken[OF f]])
     (auto simp: default_option_def)

lemma mono_bind_handle:
  "f  f'  g  g'  h  h'  bind_handle f g h  bind_handle f' g' h'"
  unfolding le_fun_def
  by transfer
     (auto simp add: le_fun_def intro!: mono_bind_post_state split: exception_or_result_splits)

lemma monotone_bind_handle_le[partial_function_mono]:
  "monotone R (≤) (λf'. f f')  (v. monotone R (≤) (λf'. g f' v)) 
    (e. monotone R (≤) (λf'. h f' e)) 
    monotone R (≤) (λf'. bind_handle (f f') (g f') (h f'))"
  by (simp add: monotone_def) (metis le_fun_def mono_bind_handle)

lemma monotone_bind_handle_ge[partial_function_mono]:
  "monotone R (≥) (λf'. f f')  (v. monotone R (≥) (λf'. g f' v)) 
    (e. monotone R (≥) (λf'. h f' e)) 
    monotone R (≥) (λf'. bind_handle (f f') (g f') (h f'))"
  by (simp add: monotone_def) (metis le_fun_def mono_bind_handle)

lemma run_bind_handle[run_spec_monad]:
  "run (bind_handle f g h) s = bind_post_state (run f s)
    (λ(r, t). case r of Exception e  run (h e) t | Result v  run (g v) t)"
  by transfer simp

lemma always_progress_bind_handle[always_progress_intros]:
  "always_progress f  (v. always_progress (g v))  (e. always_progress (h e))
    always_progress (bind_handle f g h)"
  by (auto simp: always_progress.rep_eq run_bind_handle bind_post_state_eq_bot prod_eq_iff
                 exception_or_result_nchotomy holds_post_state_False
           split: exception_or_result_splits prod.splits)

lemma refines_bind_handle':
  "refines f f' s s' (λ(r, t) (r', t').
    (e e'. e  default  e'  default  r = Exception e  r' = Exception e' 
      refines (h e) (h' e') t t' R) 
    (e x'. e  default  r = Exception e  r' = Result x' 
      refines (h e) (g' x') t t' R) 
    (x e'. e'  default  r = Result x  r' = Exception e' 
      refines (g x) (h' e') t t' R) 
    (x x'. r = Result x  r' = Result x' 
      refines (g x) (g' x') t t' R)) 
  refines (bind_handle f g h) (bind_handle f' g' h') s s' R"
  apply transfer
  apply (auto intro!: sim_bind_post_state')[1]
  apply (rule sim_post_state_weaken, assumption)
  apply (auto split: exception_or_result_splits prod.splits)
  done

lemma refines_bind_handle_bind_handle:
  assumes f: "refines f f' s s' Q"
  assumes ll: "e e' t t'. Q (Exception e, t) (Exception e', t') 
    e  default  e'  default 
    refines (h e) (h' e') t t' R"
  assumes lr: "e v' t t'. Q (Exception e, t) (Result v', t')  e  default 
    refines (h e) (g' v') t t' R"
  assumes rl: "v e' t t'. Q (Result v, t) (Exception e', t')  e'  default 
    refines (g v) (h' e') t t' R"
  assumes rr: "v v' t t'. Q (Result v, t) (Result v', t') 
    refines (g v) (g' v') t t' R"
  shows "refines (bind_handle f g h) (bind_handle f' g' h') s s' R"
  apply (rule refines_bind_handle')
  apply (rule refines_weaken[OF f])
  apply (auto split: exception_or_result_splits prod.splits intro: ll lr rl rr)
  done

lemma refines_bind_handle_bind_handle_exn:
  assumes f: "refines f f' s s' Q"
  assumes ll: "e e' t t'. Q (Exn e, t) (Exn e', t') 
    refines (h (Some e)) (h' (Some e')) t t' R"
  assumes lr: "e v' t t'. Q (Exn e, t) (Result v', t') 
    refines (h (Some e)) (g' v') t t' R"
  assumes rl: "v e' t t'. Q (Result v, t) (Exn e', t') 
    refines (g v) (h' (Some e')) t t' R"
  assumes rr: "v v' t t'. Q (Result v, t) (Result v', t') 
    refines (g v) (g' v') t t' R"
  shows "refines (bind_handle f g h) (bind_handle f' g' h') s s' R"
  apply (rule refines_bind_handle')
  apply (rule refines_weaken[OF f])
  using ll lr rl rr
  apply (auto split: exception_or_result_splits prod.splits simp: Exn_def default_option_def)
  done

lemma bind_handle_return_spec_monad[simp]: "bind_handle (return v) g h = g v"
  by transfer simp

lemma bind_handle_throw_spec_monad[simp]:
  "v  default  bind_handle (throw_exception_or_result v) g h = h v"
  by transfer simp

lemma bind_handle_bind_handle_spec_monad:
  "bind_handle (bind_handle f g1 h1) g2 h2 =
    bind_handle f
      (λv. bind_handle (g1 v) g2 h2)
      (λe. bind_handle (h1 e) g2 h2)"
  apply transfer
  apply (auto simp: fun_eq_iff intro!: arg_cong[where f="bind_post_state _"])[1]
  subgoal for g1 h1 g2 h2 r v by (cases r) simp_all
  done

lemma mono_bind_handle_spec_monad:
  "mono f  (v. mono (λx. g x v))  (e. mono (λx. h x e)) 
    mono (λx. bind_handle (f x) (g x) (h x))"
  unfolding mono_def
  apply transfer
  apply (auto simp: le_fun_def intro!: mono_bind_post_state)[1]
  subgoal for f g h x y r q by (cases r) simp_all
  done

lemma rel_spec_bind_handle:
  "rel_spec f f' s s' (λ(r, t) (r', t').
    (e e'. e  default  e'  default  r = Exception e  r' = Exception e' 
      rel_spec (h e) (h' e') t t' R) 
    (e x'. e  default  r = Exception e  r' = Result x' 
      rel_spec (h e) (g' x') t t' R) 
    (x e'. e'  default  r = Result x  r' = Exception e' 
      rel_spec (g x) (h' e') t t' R) 
    (x x'. r = Result x  r' = Result x' 
      rel_spec (g x) (g' x') t t' R)) 
  rel_spec (bind_handle f g h) (bind_handle f' g' h') s s' R"
  by (auto simp: rel_spec_iff_refines intro!: refines_bind_handle' intro: refines_weaken)

lemma bind_finally_bind_handle_conv: "bind_finally f g = bind_handle f (λv. g (Result v)) (λe. g (Exception e))"
  apply (rule spec_monad_eqI)
  apply (auto simp add: runs_to_iff Exn_def [symmetric] default_option_def elim!: runs_to_weaken)[1]
  using exception_or_result_nchotomy by force

subsection constbind

lemma run_bind[run_spec_monad]: "run (bind f g) s =
    bind_post_state (run f s) (λ(r, t). case r of
      Exception e  pure_post_state (Exception e, t)
    | Result v  run (g v) t)"
  by (auto simp add: bind_def run_bind_handle )

lemma run_bind_eq_top_iff:
  "run (bind f g) s =   ¬ (f  s  λx s. a. x = Result a  run (g a) s   )"
  by (simp add: run_bind bind_post_state_eq_top runs_to.rep_eq prod_eq_iff split_beta'
           split: exception_or_result_splits prod.splits)

lemma always_progress_bind[always_progress_intros]:
  "always_progress f  (v. always_progress (g v))  always_progress (bind f g)"
  by (simp add: always_progress_intros bind_def)

lemma run_bind_cong:
  assumes *: "run f s = run f' s"
  assumes **: "f  s ?⦃ λx s'. v. x = (Result v)  run (g v) s' = run (g' v) s' "
  shows "run (bind f g) s = run (bind f' g') s"
  using assms
  by (cases "run f' s")
     (auto simp: run_bind runs_to_partial.rep_eq
           intro!: SUP_cong split: exception_or_result_splits)

lemma runs_to_bind_iff[runs_to_iff]:
  "bind f g  s  Q   f  s  λr t.
    (v. r = Result v  g v  t  Q ) 
    (e. r = Exception e  e  default  Q (Exception e) t)"
  by (simp add: bind_def runs_to_bind_handle_iff)

lemma runs_to_bind[runs_to_vcg]:
  "f  s  λr t. (v. r = Result v  g v  t  Q ) 
                 (e. r = Exception e  e  default  Q (Exception e) t) 
      bind f g  s  Q "
  by (simp add: runs_to_bind_iff)

lemma runs_to_bind_exception[runs_to_vcg]:
  fixes f :: "('e, 'a, 's) exn_monad"
  assumes [runs_to_vcg]: "f  s  λr t. (v. r = Result v  g v  t  Q ) 
                 (e. r = Exn e  Q (Exn e) t)"
  shows "bind f g  s  Q "
  supply runs_to_bind[runs_to_vcg]
  by runs_to_vcg (auto simp: Exn_def default_option_def)

lemma runs_to_bind_res[runs_to_vcg]:
  fixes f :: "('a, 's) res_monad"
  assumes [runs_to_vcg]: "f  s  λRes v t. g v  t  Q "
  shows "bind f g  s  Q "
  supply runs_to_bind[runs_to_vcg]
  by runs_to_vcg

lemma runs_to_partial_bind[runs_to_vcg]:
  "f  s ?⦃ λr t. (v. r = Result v  g v  t ?⦃ Q ) 
                 (e. r = Exception e  e  default  Q (Exception e) t) 
      bind f g  s ?⦃ Q "
  apply (simp add: bind_def)
  apply (rule runs_to_partial_bind_handle)
  apply simp
  done

lemma runs_to_partial_bind_exeption_monad[runs_to_vcg]:
  fixes f :: "('e, 'a, 's) exn_monad"
  assumes [runs_to_vcg]: "f  s ?⦃ λr t. (v. r = Result v  g v  t ?⦃ Q ) 
    (e. r = Exn e  Q (Exn e) t)"
  shows "bind f g  s ?⦃ Q "
  by runs_to_vcg (auto simp add: default_option_def Exn_def)

lemma runs_to_partial_bind_res_monad[runs_to_vcg]:
  fixes f :: "('a, 's) res_monad"
  assumes [runs_to_vcg]: "f  s ?⦃ λRes v t. g v  t ?⦃ Q "
  shows "bind f g  s ?⦃ Q "
  by runs_to_vcg

lemma bind_return[simp]: "(bind m return) = m"
  apply (clarsimp simp: spec_monad_eq_iff runs_to_bind_iff fun_eq_iff
                  intro!: runs_to_cong_pred_only)
  subgoal for P r t
    by (cases r) auto
  done

lemma bind_skip[simp]: "bind m (λx. skip) = m"
  using bind_return[of m] by simp

lemma return_bind[simp]: "bind (return x) f = f x"
  unfolding bind_def by transfer simp

lemma bind_assoc: "bind (bind f g) h = bind f (λx. bind (g x) h)"
  apply (rule spec_monad_ext)
  apply (auto simp: split_beta' run_bind intro!: arg_cong[where f="bind_post_state _"]
              split: exception_or_result_splits)
  done

lemma mono_bind: "f  f'  g  g'  bind f g  bind f' g'"
  unfolding bind_def
  by (auto intro: mono_bind_handle)

lemma mono_bind_spec_monad:
  "mono f  (v. mono (λx. g x v))  mono (λx. bind (f x) (g x))"
  unfolding bind_def
  by (intro mono_bind_handle_spec_monad mono_const) auto

lemma monotone_bind_le[partial_function_mono]:
  "monotone R (≤) (λf'. f f')  (v. monotone R (≤) (λf'. g f' v))
   monotone R (≤) (λf'. bind (f f') (g f'))"
  unfolding bind_def
  by (auto intro: monotone_bind_handle_le)

lemma monotone_bind_ge[partial_function_mono]:
  "monotone R (≥) (λf'. f f')  (v. monotone R (≥) (λf'. g f' v))
   monotone R (≥) (λf'. bind (f f') (g f'))"
  unfolding bind_def
  by (auto intro: monotone_bind_handle_ge)

lemma refines_bind':
  assumes f: "refines f f' s s' (λ(x, t) (x', t').
    (e. e  default  x = Exception e 
      ((e'. e'  default  x' = Exception e'  R (Exception e, t) (Exception e', t')) 
      (v'. x' = Result v'  refines (throw_exception_or_result e) (g' v') t t' R))) 
    (v. x = Result v 
      ((e'. e'  default  x' = Exception e' 
        refines (g v) (throw_exception_or_result e') t t' R) 
      (v'. x' = Result v'  refines (g v) (g' v') t t' R))))"
  shows "refines (bind f g) (bind f' g') s s' R"
  unfolding bind_def
  by (rule refines_bind_handle'[OF refines_weaken, OF f]) auto

lemma refines_bind:
  assumes f: "refines f f' s s' Q"
  assumes ll: "e e' t t'.
    Q (Exception e, t) (Exception e', t')  e  default  e'  default 
    R (Exception e, t) (Exception e', t')"
  assumes lr: "e v' t t'. Q (Exception e, t) (Result v', t')  e  default 
    refines (yield (Exception e)) (g' v') t t' R"
  assumes rl: "v e' t t'. Q (Result v, t) (Exception e', t')  e'  default 
    refines (g v) (yield (Exception e')) t t' R"
  assumes rr: "v v' t t'. Q (Result v, t) (Result v', t') 
    refines (g v) (g' v') t t' R"
  shows "refines (bind f g) (bind f' g') s s' R"
  by (rule refines_bind'[OF refines_weaken[OF f]])
     (auto simp: ll lr rl rr)

lemma refines_bind_bind_exn:
  assumes f: "refines f f' s s' Q"
  assumes ll: "e e' t t'. Q (Exn e, t) (Exn e', t')  R (Exn e, t) (Exn e', t')"
  assumes lr: "e v' t t'. Q (Exn e, t) (Result v', t')  refines (throw e) (g' v') t t' R"
  assumes rl: "v e' t t'. Q (Result v, t) (Exn e', t')  refines (g v) (throw e') t t' R"
  assumes rr: "v v' t t'. Q (Result v, t) (Result v', t')  refines (g v) (g' v') t t' R"
  shows "refines (f  g) (f'  g') s s' R"
  using ll lr rl rr
  by (intro refines_bind[OF f])
     (auto simp: Exn_def default_option_def)

lemma refines_bind_res:
  assumes f: "refines f f' s s' (λ(Res r, t) (Res r', t'). refines (g r) (g' r') t t' R) "
  shows "refines ((bind f g)::('a,'s) res_monad) ((bind f' g')::('b, 't) res_monad) s s' R"
  by (rule refines_bind[OF f]) auto

lemma refines_bind_res':
  assumes f: "refines f f' s s' Q"
  assumes g: "r t r' t'. Q (Result r, t) (Result r', t')  refines (g r) (g' r') t t' R"
  shows "refines ((bind f g)::('a,'s) res_monad) ((bind f' g')::('b, 't) res_monad) s s' R"
  by (auto intro!: refines_bind_res refines_weaken[OF f] g)

lemma refines_bind_bind_exn_wp: 
  assumes f: "refines f f' s s' (λ(r, t) (r',t'). 
     (case r of 
        Exn e  (case r' of Exn e'  R (Exn e, t) (Exn e', t') | Result v'  refines (throw e) (g' v') t t' R)
      | Result v  (case r' of Exn e'  refines (g v) (throw e') t t' R | Result v'  refines (g v) (g' v') t t' R)))"
   shows "refines (f  g) (f'  g') s s' R"
  apply (rule refines_bind')
  apply (rule refines_weaken [OF f])
  apply (auto simp add: Exn_def default_option_def split: xval_splits )
  done

lemma rel_spec_bind_res:
  "rel_spec f f' s s' (λ(Res r, t) (Res r', t'). rel_spec (g r) (g' r') t t' R) 
    rel_spec ((bind f g)::('a,'s) res_monad) ((bind f' g')::('b, 't) res_monad) s s' R"
  unfolding rel_spec_iff_refines
  by (safe intro!: refines_bind_res; (rule refines_weaken, assumption, simp))

lemma rel_spec_bind_res':
  assumes f: "rel_spec f f' s s' Q"
  assumes g: "r t r' t'. Q (Result r, t) (Result r', t')  rel_spec (g r) (g' r') t t' R"
  shows "rel_spec ((bind f g)::('a,'s) res_monad) ((bind f' g')::('b, 't) res_monad) s s' R"
  by (auto intro!: rel_spec_bind_res rel_spec_mono[OF _ f] g)

lemma rel_spec_bind_bind:
  assumes f: "rel_spec f f' s s' Q"
  assumes ll: "e e' t t'.
    Q (Exception e, t) (Exception e', t')  e  default  e'  default 
    R (Exception e, t) (Exception e', t')"
  assumes lr: "e v' t t'. Q (Exception e, t) (Result v', t')  e  default 
    rel_spec (yield (Exception e)) (g' v') t t' R"
  assumes rl: "v e' t t'. Q (Result v, t) (Exception e', t')  e'  default 
    rel_spec (g v) (yield (Exception e')) t t' R"
  assumes rr: "v v' t t'. Q (Result v, t) (Result v', t') 
    rel_spec (g v) (g' v') t t' R"
  shows "rel_spec (bind f g) (bind f' g') s s' R"
  using assms unfolding rel_spec_iff_refines
  apply (intro conjI)
  subgoal by (rule refines_bind[where Q=Q]) auto
  subgoal by (rule refines_bind[where Q="Q¯¯"]) auto
  done

lemma rel_spec_bind_exn:
  assumes f: "rel_spec f f' s s' Q"
  assumes ll: "e e' t t'. Q (Exn e, t) (Exn e', t')  R (Exn e, t) (Exn e', t')"
  assumes lr: "e v' t t'. Q (Exn e, t) (Result v', t')  rel_spec (throw e) (g' v') t t' R"
  assumes rl: "v e' t t'. Q (Result v, t) (Exn e', t')  rel_spec (g v) (throw e') t t' R"
  assumes rr: "v v' t t'. Q (Result v, t) (Result v', t')  rel_spec (g v) (g' v') t t' R"
  shows "rel_spec (bind f g) (bind f' g') s s' R"
  using ll lr rl rr
  by (intro rel_spec_bind_bind[OF f])
     (auto simp: Exn_def default_option_def)

lemma refines_bind_right:
  assumes f: "refines f f' s s' Q"
  assumes ll: "e e' t t'.
    Q (Exception e, t) (Exception e', t')  e  default  e'  default 
    R (Exception e, t) (Exception e', t')"
  assumes lr: "e v' t t'. Q (Exception e, t) (Result v', t')  e  default 
    refines (yield (Exception e)) (g' v') t t' R"
  assumes rl: "v e' t t'. Q (Result v, t) (Exception e', t')  e'  default 
    R (Result v, t) (Exception e', t')"
  assumes rr: "v v' t t'. Q (Result v, t) (Result v', t') 
    refines (return v) (g' v') t t' R"
  shows "refines f (bind f' g') s s' R"
proof -
  have "refines (bind f return) (bind f' g') s s' R"
    using ll lr rl rr by (intro refines_bind[OF f]) auto
  then show ?thesis by simp
qed

lemma refines_bind_left:
  assumes f: "refines f f' s s' Q"
  assumes ll: "e e' t t'.
    Q (Exception e, t) (Exception e', t')  e  default  e'  default 
    R (Exception e, t) (Exception e', t')"
  assumes lr: "e v' t t'. Q (Exception e, t) (Result v', t')  e  default 
    R (Exception e, t) (Result v', t')"
  assumes rl: "v e' t t'. Q (Result v, t) (Exception e', t')  e'  default 
    refines (g v) (yield (Exception e')) t t' R"
  assumes rr: "v v' t t'. Q (Result v, t) (Result v', t') 
    refines (g v) (return v') t t' R"
  shows "refines (bind f g) f' s s' R"
proof -
  have "refines (bind f g) (bind f' return) s s' R"
    using ll lr rl rr by (intro refines_bind[OF f]) auto
  then show ?thesis by simp
qed

lemma rel_spec_bind_right:
  assumes f: "rel_spec f f' s s' Q"
  assumes ll: "e e' t t'.
    Q (Exception e, t) (Exception e', t')  e  default  e'  default 
    R (Exception e, t) (Exception e', t')"
  assumes lr: "e v' t t'. Q (Exception e, t) (Result v', t')  e  default 
    rel_spec (yield (Exception e)) (g' v') t t' R"
  assumes rl: "v e' t t'. Q (Result v, t) (Exception e', t')  e'  default 
    R (Result v, t) (Exception e', t')"
  assumes rr: "v v' t t'. Q (Result v, t) (Result v', t') 
    rel_spec (return v) (g' v') t t' R"
  shows "rel_spec f (bind f' g') s s' R"
  using assms unfolding rel_spec_iff_refines
  apply (intro conjI)
  subgoal by (rule refines_bind_right[where Q=Q]) auto
  subgoal by (rule refines_bind_left[where Q ="Q¯¯"]) auto
  done

lemma refines_bind_handle_left':
  "f  s  λv s'. (r. v = Result r  refines (g r) k s' t R) 
    (e. e  default  v = Exception e  refines (h e) k s' t R)  
  refines (bind_handle f g h) k s t R"
  by (auto simp: runs_to.rep_eq refines.rep_eq run_bind_handle split_beta' ac_simps
           intro!: sim_bind_post_state_left split: prod.splits exception_or_result_splits)

lemma refines_bind_left_res:
  "f  s  λRes r s'. refines (g r) h s' t R   refines (f >>= g) h s t R"
  unfolding bind_def by (rule refines_bind_handle_left') simp

lemma refines_bind_left_exn:
  "f  s  λr s'. (a. r = Result a  refines (g a) h s' t R) 
    (e. r = Exn e  refines (throw e) h s' t R)  
    refines (f >>= g) h s t R"
  unfolding bind_def
  by (rule refines_bind_handle_left')
     (auto simp add: Exn_def default_option_def imp_ex)

lemma runs_to_partial_bind1:
  "(r s. (g r)  s ?⦃P)  ((f >>= g)::('a, 's) res_monad)  s ?⦃P"
  apply (rule runs_to_partial_bind)
  apply (rule runs_to_partial_weaken[OF runs_to_partial_True])
  apply auto
  done

lemma unknown_bind_const[simp]: "unknown >>= (λx. f) = f"
  by (rule spec_monad_ext) (simp add: run_bind)

lemma bind_cong_left:
  fixes f::"('e::default, 'a, 's) spec_monad"
  shows "(r. g r = g' r)  (f >>= g) = (f >>= g')"
  by (rule spec_monad_ext) (simp add: run_bind)

lemma bind_cong_right:
  fixes f::"('e::default, 'a, 's) spec_monad"
  shows "f = f'  (f >>= g) = (f' >>= g)"
  by (rule spec_monad_ext) (simp add: run_bind)

lemma rel_spec_bind_res'':
  fixes f::"('a, 's) res_monad"
  shows "f  s ?⦃ λRes r t. rel_spec (g r) (g' r) t t R  rel_spec (f >>= g) (f >>= g') s s R"
  by (intro rel_spec_bind_res rel_spec_refl') (auto simp: split_beta')

lemma rel_spec_monad_bind_rel_exception_or_result:
  assumes mn: "rel_spec_monad R (rel_exception_or_result E P) m n"
    and fg: "rel_fun P (rel_spec_monad R (rel_exception_or_result E Q)) f g"
  shows "rel_spec_monad R (rel_exception_or_result E Q) (m >>= f) (n >>= g)"
  by (intro rel_spec_monadI rel_spec_bind_bind[OF rel_spec_monadD[OF mn]])
     (simp_all add: fg[THEN rel_funD, THEN rel_spec_monadD])

lemma rel_spec_monad_bind_rel_exception_or_result':
  "(x. rel_spec_monad (=) (rel_exception_or_result (=) Q) (f x) (g x)) 
    rel_spec_monad (=) (rel_exception_or_result (=) Q) (m >>= f) (m >>= g)"
  apply (rule rel_spec_monad_bind_rel_exception_or_result[where P="(=)"])
  apply (auto simp add: rel_exception_or_result_eq_conv rel_spec_monad_eq_conv)
  done

lemma rel_spec_monad_bind:
  "rel_spec_monad R (λRes v1 Res v2. P v1 v2) m n  rel_fun P (rel_spec_monad R Q) f g 
    rel_spec_monad R Q (m >>= f) (n >>= g)"
  apply (clarsimp simp add: rel_spec_monad_iff_rel_spec[abs_def] rel_fun_def
                  intro!: rel_spec_bind_res)
  subgoal premises prems for s t
    by (rule rel_spec_mono[OF _ prems(1)[rule_format, OF prems(3)]])
       (clarsimp simp: le_fun_def prems(2)[rule_format])
  done

lemma rel_spec_monad_bind_left:
  assumes mn: "rel_spec_monad R (λRes v1 Res v2. P v1 v2) m n"
    and fg: "rel_fun P (rel_spec_monad R Q) f return"
  shows "rel_spec_monad R Q (m >>= f) n"
proof -
  from mn fg
  have "rel_spec_monad R Q (m >>= f) (n >>= return)"
    by (rule rel_spec_monad_bind)
  then show ?thesis
    by (simp)
qed

lemma rel_spec_monad_bind':
  fixes m:: "('a, 's) res_monad"
  shows "(x. rel_spec_monad (=) Q (f x) (g x))  rel_spec_monad (=) Q (m >>= f) (m >>= g)"
  by (rule rel_spec_monad_bind[where R="(=)" and P="(=)"])
     (auto simp: rel_fun_def rel_spec_monad_iff_rel_spec intro!: rel_spec_refl)

lemma run_bind_cong_simple:
  "run f s = run f' s  (v s. run (g v) s = run (g' v) s) 
    run (bind f g) s = run (bind f' g') s"
  by (simp add: run_bind)

lemma run_bind_cong_simple_same_head:  "(v s. run (g v) s = run (g' v) s) 
  run (bind f g) s = run (bind f g') s"
  by (rule run_bind_cong_simple) auto

lemma refines_bind_same:
  "refines (f >>= g) (f >>= g') s s R" if "f  s λRes y t. refines (g y) (g' y) t t R "
  apply (rule refines_bind_res)
  apply (rule refines_same_runs_toI)
  apply (rule runs_to_weaken[OF that])
  apply auto
  done

lemma refines_bind_handle_right_runs_to_partialI:
  "g  t ?⦃ λr t'. (e. e  default  r = Exception e  refines f (k e) s t' R) 
     (a. r = Result a  refines f (h a) s t' R)   always_progress g 
  refines f (bind_handle g h k) s t R"
  by transfer
     (auto simp: prod_eq_iff split_beta'
           intro!: sim_bind_post_state_right
           split: exception_or_result_splits prod.splits)

lemma refines_bind_right_runs_to_partialI:
  "g  t ?⦃ λr t'. (e. e  default  r = Exception e 
      refines f (throw_exception_or_result e) s t' R) 
     (a. r = Result a  refines f (h a) s t' R)   always_progress g 
  refines f (bind g h) s t R"
  unfolding bind_def by (rule refines_bind_handle_right_runs_to_partialI)

lemma refines_bind_right_runs_toI:
  "g  t  λRes r t'. refines f (h r) s t' R   always_progress g 
    refines f (g >>= h) s t R"
  unfolding bind_def
  by (rule refines_bind_handle_right_runs_to_partialI)
     (auto intro: runs_to_partial_of_runs_to)

lemma refines_bind_left_refine:
  "refines (f >>= g) h s t R"
  if "refines f f' s s (=)" "refines (f' >>= g) h s t R"
  by (rule refines_trans[OF refines_bind[OF that(1)] that(2), where R="(=)"])
     (auto simp add: refines_refl)

lemma refines_bind_right_single:
  assumes x: "pure_post_state (Result x, u)  run g t" and h: "refines f (h x) s u R"
  shows "refines f (g  h) s t R"
  apply (rule refines_le_run_trans[OF h])
  apply (simp add: run_bind)
  apply (rule order_trans[OF _ mono_bind_post_state, OF _ x order_refl])
  apply simp
  done

lemma return_let_bind: " (return (let v = f' in (g' v))) = do {v <- return f'; return (g' v)}"
  apply (rule spec_monad_eqI)
  apply (auto simp add: runs_to_iff)
  done

lemma rel_spec_monad_rel_xval_bind:
  assumes f_f': "rel_spec_monad S (rel_xval L P) f f'"
  assumes Res_Res: "v v'. P v v'  rel_spec_monad S (rel_xval L R) (g v) (g' v')"
  shows "rel_spec_monad S (rel_xval L R) (f  g) (f'  g')"
  apply (intro rel_spec_monadI rel_spec_bind_exn)
  apply (rule f_f'[THEN rel_spec_monadD], assumption)
  using Res_Res[THEN rel_spec_monadD]
  apply auto
  done

lemma rel_spec_monad_rel_xval_same_bind:
  assumes f_f': "rel_spec_monad S (rel_xval L R) f f'"
  assumes Res_Res: "v v'. R v v'   rel_spec_monad S (rel_xval L R) (g v) (g' v')"
  shows "rel_spec_monad S (rel_xval L R) (f  g) (f'  g')"
  using assms by (rule rel_spec_monad_rel_xval_bind)

lemma rel_spec_monad_rel_xval_result_eq_bind:
  assumes f_f': "rel_spec_monad S (rel_xval L (=)) f f'"
  assumes Res_Res: "v. rel_spec_monad S (rel_xval L (=)) (g v) (g' v)"
  shows "rel_spec_monad S (rel_xval L (=)) (f  g) (f'  g')"
  apply (rule rel_spec_monad_rel_xval_bind [OF f_f'])
  subgoal using Res_Res by auto
  done

lemma rel_spec_monad_fail: "rel_spec_monad Q R fail fail"
  by (auto simp add: rel_spec_monad_def rel_set_def)

lemma bind_fail[simp]: "fail  X = fail"
  apply (rule spec_monad_ext)
  apply (simp add: run_bind)
  done

subsection constassert

lemma assert_simps[simp]:
  "assert True = return ()"
  "assert False = top"
   by (auto simp add: assert_def)

lemma always_progress_assert[always_progress_intros]: "always_progress (assert P)"
  by (simp add: always_progress_def assert_def)

lemma run_assert[run_spec_monad]:
  "run (assert P) s = (if P then pure_post_state (Result (), s) else )"
  by (simp add: assert_def)

lemma runs_to_assert_iff[simp]: "assert P  s  Q   P  Q (Result ()) s"
  by (simp add: runs_to.rep_eq run_assert)

lemma runs_to_assert[runs_to_vcg]: "P  Q (Result ()) s  assert P  s  Q "
  by simp

lemma runs_to_partial_assert_iff[simp]: "assert P  s ?⦃ Q   (P  Q (Result ()) s)"
  by (simp add: runs_to_partial.rep_eq run_assert)

lemma refines_top_iff[simp]: "refines  g s t R  run g t = "
  by transfer auto

lemma refines_assert: 
  "refines (assert P) (assert Q) s t R  (Q  P  R (Result (), s) (Result (), t))"
  by (simp add: assert_def)

subsection "constassume"

lemma assume_simps[simp]:
  "assume True = return ()"
  "assume False = bot"
   by (auto simp add: assume_def)

lemma always_progress_assume[always_progress_intros]: "P  always_progress (assume P)"
  by (simp add: always_progress_def assume_def)

lemma run_assume[run_spec_monad]:
  "run (assume P) s = (if P then pure_post_state (Result (), s) else )"
  by (simp add: assume_def)

lemma run_assume_simps[run_spec_monad, simp]:
  "P  run (assume P) s = pure_post_state (Result (), s)"
  "¬ P  run (assume P) s = "
  by (simp_all add: run_assume)

lemma runs_to_assume_iff[simp]: "assume P  s  Q   (P  Q (Result ()) s)"
  by (simp add: runs_to.rep_eq run_assume)

lemma runs_to_partial_assume_iff[simp]: "assume P  s ?⦃ Q   (P  Q (Result ()) s)"
  by (simp add: runs_to_partial.rep_eq run_assume)

lemma runs_to_assume[runs_to_vcg]: "(P  Q (Result ()) s)  assume P  s  Q "
  by simp

subsection constassume_outcome

lemma run_assume_outcome[run_spec_monad, simp]: "run (assume_outcome f) s = Success (f s)"
  apply transfer
  apply simp
  done

lemma always_progress_assume_outcome[always_progress_intros]:
  "(s. f s  {})  always_progress (assume_outcome f)"
  by (auto simp add: always_progress_def bot_post_state_def)

lemma runs_to_assume_outcome[runs_to_vcg]:
  "(assume_outcome f)  s λr t. (r, t)  f s"
  by (simp add: runs_to.rep_eq)

lemma runs_to_assume_outcome_iff[runs_to_iff]:
  "(assume_outcome f)  s Q  ((r, t)  f s. Q r t)"
  by (auto simp add: runs_to.rep_eq)

lemma runs_to_partial_assume_outcome[runs_to_vcg]:
  "(assume_outcome f)  s ?⦃λr t. (r, t)  f s"
  by (simp add: runs_to_partial.rep_eq)

lemma assume_outcome_elementary:
  "assume_outcome f = do {s  get_state; (r, t)  select (f s); set_state t; yield r}"
  apply (rule spec_monad_ext)
  apply (simp add: run_bind Sup_Success pure_post_state_def)
  done

subsection constassume_result_and_state

lemma run_assume_result_and_state[run_spec_monad, simp]:
  "run (assume_result_and_state f) s = Success ((λ(v, t). (Result v, t)) `  f s)"
  by (auto simp add: assume_result_and_state_def run_bind Sup_Success_pair pure_post_state_def)

lemma always_progress_assume_result_and_state[always_progress_intros]:
  "(s. f s  {})  always_progress (assume_result_and_state f)"
  by (simp add: always_progress_def)

lemma runs_to_assume_result_and_state[runs_to_vcg]:
  "assume_result_and_state f  s λr t. v. r = Result v  (v, t)  f s"
  by (auto simp add: runs_to.rep_eq)

lemma runs_to_assume_result_and_state_iff[runs_to_iff]:
  "assume_result_and_state f  s Q  ((v, t)  f s. Q (Result v) t)"
  by (auto simp add: runs_to.rep_eq)

lemma runs_to_partial_assume_result_and_state[runs_to_vcg]:
  "assume_result_and_state f  s ?⦃λr t. v. r = Result v  (v, t)  f s"
  by (auto simp add: runs_to_partial.rep_eq)

lemma refines_assume_result_and_state_right:
  "f  s λr s'. r' t'.  (r', t')  g t  R (r, s') (Result r', t')  
  refines f (assume_result_and_state g) s t R"
  by (simp add: refines.rep_eq runs_to.rep_eq sim_post_state_Success2 split_beta' Bex_def)

lemma refines_assume_result_and_state:
  "sim_set R (P s) (Q t) 
    refines (assume_result_and_state P) (assume_result_and_state Q) s t
     (λ(Res v, s) (Res w, t). R (v, s) (w, t))"
  by (force simp: refines.rep_eq sim_set_def)

lemma refines_assume_result_and_state_iff:
  "refines (assume_result_and_state A) (assume_result_and_state B) s t Q  
        sim_set (λ(v, s') (w, t'). Q (Result v, s') (Result w, t')) (A s) (B t) "
  apply (simp add: refines.rep_eq, intro iffI)
  subgoal
    by (fastforce simp add: sim_set_def)
  subgoal
    by (force simp add: sim_set_def)
  done

subsection constgets

lemma run_gets[run_spec_monad, simp]:"run (gets f) s = pure_post_state (Result (f s), s)"
  by (simp add: gets_def run_bind)

lemma always_progress_gets[always_progress_intros]: "always_progress (gets f)"
  by (simp add: always_progress_def)

lemma runs_to_gets[runs_to_vcg]: "gets f  s λr t. r = Result (f s)  t = s"
  by (simp add: runs_to.rep_eq)

lemma runs_to_gets_iff[runs_to_iff]: "gets f  s Q  Q (Result (f s)) s"
  by (simp add: runs_to.rep_eq)

lemma runs_to_partial_gets[runs_to_vcg]: "gets f  s ?⦃λr t. r = Result (f s)  t = s"
  by (simp add: runs_to_partial.rep_eq)

lemma refines_gets: "R (Result (f s), s) (Result (g t), t)  refines (gets f) (gets g) s t R"
  by (auto simp add: refines.rep_eq)

lemma rel_spec_gets: "R (Result (f s), s) (Result (g t), t)  rel_spec (gets f) (gets g) s t R"
  by (auto simp add: rel_spec_iff_refines intro: refines_gets)

lemma runs_to_always_progress_to_gets:
  "(s. f  s λr t. t = s  r = Result (v s))  always_progress f  f = gets v"
  apply (clarsimp simp: spec_monad_ext_iff runs_to.rep_eq always_progress.rep_eq)
  subgoal premises prems for s
    using prems[rule_format, of s]
    by (cases "run f s"; auto simp add: pure_post_state_def Ball_def)
  done

lemma gets_let_bind: "(gets (λs. let v = f' s in (g' v s))) = do {v <- gets f'; gets (g' v)}"
  apply (rule spec_monad_eqI)
  apply (auto simp add: runs_to_iff)
  done

subsection constassert_result_and_state

lemma run_assert_result_and_state[run_spec_monad]:
  "run (assert_result_and_state f) s =
    (if f s = {} then  else Success ((λ(v, t). (Result v, t)) `  f s))"
  by (auto simp add: assert_result_and_state_def run_bind pure_post_state_def Sup_Success_pair)

lemma always_progress_assert_result_and_state[always_progress_intros]:
  "always_progress (assert_result_and_state f)"
  by (auto simp add: always_progress_def run_assert_result_and_state)

lemma runs_to_assert_result_and_state_iff[runs_to_iff]:
  "assert_result_and_state f  s Q  (f s  {}  ((v,t)  f s. Q (Result v) t))"
  by (auto simp add: runs_to.rep_eq run_assert_result_and_state)

lemma runs_to_assert_result_and_state[runs_to_vcg]:
  "f s  {}  assert_result_and_state f  s λr t. v. r = Result v  (v, t)  f s"
  by (simp add: runs_to_assert_result_and_state_iff)

lemma runs_to_partial_assert_result_and_state[runs_to_vcg]:
  "assert_result_and_state f  s ?⦃λr t. v. r = Result v  (v, t)  f s"
  by (auto simp add: runs_to_partial.rep_eq run_assert_result_and_state)

lemma runs_to_state_select[runs_to_vcg]:
  "t. (s, t)  R  state_select R  s  λr t. r = Result ()  (s, t)  R "
  by (auto intro!: runs_to_assert_result_and_state[THEN runs_to_weaken])

lemma runs_to_partial_state_select[runs_to_vcg]:
  "state_select R  s ?⦃ λr t. r = Result ()  (s, t)  R "
  by (simp add: runs_to_partial.rep_eq run_assert_result_and_state)

lemma refines_assert_result_and_state:
  assumes sim: "(r' s'. (r', s')  f s  (v' t'. (v', t')  g t  R (Result r', s') (Result v', t')))" 
  assumes emp: "f s = {}  g t = {}"
  shows "refines (assert_result_and_state f) (assert_result_and_state g) s t R"
  using sim emp by (fastforce simp add: refines_iff_runs_to runs_to_iff)

lemma refines_state_select: 
  assumes sim: "(s'. (s, s')  f  (t'. (t, t')  g  R (Result (), s') (Result (), t')))" 
  assumes emp: "s'. (s, s')  f  t'. (t, t')  g"
  shows "refines (state_select f) (state_select g) s t R"
  using sim emp by (intro refines_assert_result_and_state) auto

subsection "constassuming"

lemma run_assuming[run_spec_monad]:
  "run (assuming g) s = (if g s then pure_post_state (Result (), s) else )"
  by (simp add: assuming_def run_bind)

lemma always_progress_assuming[always_progress_intros]: "always_progress (assuming g)  (s. g s)"
  by (auto simp add: always_progress_def run_assuming)

lemma runs_to_assuming[runs_to_vcg]: "assuming g  s λr t. g s  r = Result ()  t = s"
  by (simp add: runs_to.rep_eq run_assuming)

lemma runs_to_assuming_iff[runs_to_iff]: "assuming g  s Q  (g s  Q (Result ()) s)"
  by (auto simp add: runs_to.rep_eq run_assuming)

lemma runs_to_partial_assuming[runs_to_vcg]:
  "assuming g  s ?⦃λr t. g s  r = Result ()  t = s  g s"
  by (simp add: runs_to_partial.rep_eq run_assuming)

lemma assuming_state_assume:
  "assuming P = assume_result_and_state (λs. (if P s then {((), s)} else {}))"
  by (simp add: spec_monad_eq_iff runs_to_iff)

lemma assuming_True[simp]: "assuming (λs. True) = skip"
  by (simp add: spec_monad_eq_iff runs_to_iff)

lemma refines_assuming:
  "(P s  Q t)  (P s  Q t  R (Result (), s) (Result (), t)) 
    refines (assuming P) (assuming Q) s t R"
  by (auto simp add: refines.rep_eq run_assuming)

lemma rel_spec_assuming:
  "(Q t  P s)  (P s  Q t  R (Result (), s) (Result (), t)) 
    rel_spec (assuming P) (assuming Q) s t R"
  by (auto simp add: rel_spec_def run_assuming rel_set_def)

lemma refines_bind_assuming_right:
  "P t  (P t  refines f (g ()) s t R)  refines f (assuming P  g) s t R"
  by (simp add: refines.rep_eq run_assuming run_bind)

lemma refines_bind_assuming_left:
  "(P s  refines (f ()) g s t R)  refines (assuming P >>= f) g s t R"
  by (simp add: refines.rep_eq run_assuming run_bind)

subsection "constguard"

lemma run_guard[run_spec_monad]:
  "run (guard g) s = (if g s then pure_post_state (Result (), s) else )"
  by (simp add: guard_def run_bind)

lemma always_progress_guard[always_progress_intros]: "always_progress (guard g)"
  by (auto simp add: always_progress_def run_guard)

lemma runs_to_guard[runs_to_vcg]: "g s  guard g  s λr t. r = Result ()  t = s"
  by (simp add: runs_to.rep_eq run_guard)

lemma runs_to_guard_iff[runs_to_iff]: "guard g  s Q  (g s  Q (Result ()) s)"
  by (auto simp add: runs_to.rep_eq run_guard)

lemma runs_to_partial_guard[runs_to_vcg]: "guard g  s ?⦃λr t. r = Result ()  t = s  g s"
  by (simp add: runs_to_partial.rep_eq run_guard)

lemma refines_guard:
  "(Q t  P s)  (P s  Q t  R (Result (), s) (Result (), t)) 
    refines (guard P) (guard Q) s t R"
  by (auto simp add: refines.rep_eq run_guard)

lemma rel_spec_guard:
  "(Q t  P s)  (P s  Q t  R (Result (), s) (Result (), t)) 
    rel_spec (guard P) (guard Q) s t R"
  by (auto simp add: rel_spec_def run_guard rel_set_def)

lemma refines_bind_guard_right:
  "refines f (guard P  g) s t R" if "P t  refines f (g ()) s t R"
  using that
  by (auto simp: refines.rep_eq run_guard run_bind)

lemma guard_False_fail: "guard (λ_. False) = fail"
  by (simp add: spec_monad_ext run_guard)

lemma rel_spec_monad_bind_guard:
  shows "(x. rel_spec_monad (=) Q (f x) (g x)) 
    rel_spec_monad (=) Q (guard P >>= f) (guard P >>= g)"
  by (auto simp add: rel_spec_monad_def run_bind run_guard)

lemma runs_to_guard_bind_iff: "((guard P >>= f)  s  Q )  P s  ((f ())  s  Q )"
  by (simp add: runs_to_iff)

lemma refines_bind_guard_right_iff:
  "refines f (guard P  g) s t R  (P t  refines f (g ()) s t R)"
  by (auto simp: refines.rep_eq run_guard run_bind)

lemma refines_bind_guard_right_end:
  assumes f_g: "refines f g s t R"
  shows "refines f (do {res <- g; guard G; return res}) s t 
            (λ(r, s) (q, t). R (r, s) (q, t)  
                (case q of Exception e  True | Result _  G t))"
  apply (subst bind_return[symmetric])
  apply (rule refines_bind[OF f_g])
  apply (auto simp: refines_bind_guard_right_iff)
  done

lemma refines_bind_guard_right_end':
  assumes f_g: "refines f g s t R"
  shows "refines f (do {res <- g; guard (G res); return res}) s t 
            (λ(r, s) (q, t). R (r, s) (q, t)  
                (case q of Exception e  True | Result v  G v t))"
  apply (subst bind_return[symmetric])
  apply (rule refines_bind[OF f_g])
  apply (auto simp: refines_bind_guard_right_iff)
  done

subsection "constassert_opt"

lemma run_assert_opt[run_spec_monad, simp]:
  "run (assert_opt x) s = (case x of Some v  pure_post_state (Result v, s) | None  )"
  by (simp add: assert_opt_def run_bind split: option.splits)

lemma always_progress_assert_opt[always_progress_intros]: "always_progress (assert_opt x)"
  by (simp add: always_progress_intros assert_opt_def split: option.splits)

lemma runs_to_assert_opt[runs_to_vcg]: "(v. x = Some v  Q (Result v) s)  assert_opt x  s Q"
  by (auto simp add: runs_to.rep_eq)

lemma runs_to_assert_opt_iff[runs_to_iff]:
  "assert_opt x  s Q  (v. x = Some v  Q (Result v) s)"
  by (auto simp add: runs_to.rep_eq split: option.split)

lemma runs_to_partial_assert_opt[runs_to_vcg]:
  "(v. x = Some v  Q (Result v) s)  assert_opt x  s ?⦃Q"
  by (auto simp add: runs_to_partial.rep_eq split: option.split)

subsection "constgets_the"

lemma run_gets_the':
  "run (gets_the f) s = (case f s of Some v  pure_post_state (Result v, s) | None  )"
  by (simp add: gets_the_def run_bind top_post_state_def pure_post_state_def split: option.split)

lemma run_gets_the[run_spec_monad, simp]:
  "run (gets_the f) s = (case (f s) of Some v  pure_post_state (Result v, s) | None  )"
  by (simp add: gets_the_def run_bind)

lemma always_progress_gets_the[always_progress_intros]: "always_progress (gets_the f)"
  by (simp add: always_progress_intros gets_the_def split: option.splits)

lemma runs_to_gets_the[runs_to_vcg]: "(v. f s = Some v  Q (Result v) s)  gets_the f  s Q"
  by (auto simp add: runs_to.rep_eq)

lemma runs_to_gets_the_iff[runs_to_iff]:
  "gets_the f  s Q  (v. f s = Some v  Q (Result v) s)"
  by (auto simp add: runs_to.rep_eq split: option.split)

lemma runs_to_partial_gets_the[runs_to_vcg]:
  "(v. f s = Some v  Q (Result v) s)  gets_the f  s ?⦃Q"
  by (auto simp add: runs_to_partial.rep_eq split: option.split)

subsection constmodify

lemma run_modify[run_spec_monad, simp]: "run (modify f) s = pure_post_state (Result (), f s)"
  by (simp add: modify_def run_bind)

lemma always_progress_modifiy[always_progress_intros]: "always_progress (modify f)"
  by (simp add: modify_def always_progress_intros)

lemma runs_to_modify[runs_to_vcg]: "modify f  s  λr t. r = Result ()  t = f s "
  by (simp add: runs_to.rep_eq)

lemma runs_to_modify_res[runs_to_vcg]: "((modify f)::(unit, 's) res_monad)  s  λr t. t = f s "
  by (simp add: runs_to.rep_eq)

lemma runs_to_modify_iff[runs_to_iff]: "modify f  s Q  Q (Result ()) (f s)"
  by (simp add: runs_to.rep_eq)

lemma runs_to_partial_modify[runs_to_vcg]: "modify f  s ?⦃ λr t. r = Result ()  t = f s "
  by (simp add: runs_to_partial.rep_eq)

lemma runs_to_partial_modify_res[runs_to_vcg]:
  "((modify f)::(unit, 's) res_monad)  s ?⦃ λr t. t = f s "
  by (simp add: runs_to_partial.rep_eq)

lemma refines_modify:
  "R (Result (), f s) (Result (), g t)  refines (modify f) (modify g) s t R"
  by (auto simp add: refines.rep_eq)

lemma rel_spec_modify:
  "R (Result (), f s) (Result (), g t)  rel_spec (modify f) (modify g) s t R"
  by (auto simp add: rel_spec_iff_refines intro: refines_modify)

subsection ‹condition›

lemma run_condition[run_spec_monad]: "run (condition c f g) s = (if c s then run f s else run g s)"
  by (simp add: condition_def run_bind)

lemma run_condition_True[run_spec_monad, simp]: "c s  run (condition c f g) s = run f s"
  by (simp add: run_condition)

lemma run_condition_False[run_spec_monad, simp]: "¬c s  run (condition c f g) s = run g s"
  by (simp add: run_condition)

lemma always_progress_condition[always_progress_intros]:
  "always_progress f  always_progress g  always_progress (condition c f g)"
  by (auto simp add: always_progress_def run_condition)

lemma condition_swap: "(condition C A B) = (condition (λs. ¬ C s) B A)"
  by (rule spec_monad_ext) (clarsimp simp add: run_condition)

lemma condition_fail_rhs: "(condition C X fail) = (guard C >>= (λ_. X))"
  by (rule spec_monad_ext) (simp add: run_bind run_guard run_condition)

lemma condition_fail_lhs: "(condition C fail X) = (guard (λs. ¬ C s) >>= (λ_. X))"
  by (metis condition_fail_rhs condition_swap)

lemma condition_bind_fail[simp]:
  "(condition C A B >>= (λ_. fail)) = condition C (A >>= (λ_. fail)) (B >>= (λ_. fail))"
  apply (rule spec_monad_ext)
  apply (clarsimp simp add: run_condition run_bind)
  done

lemma condition_True[simp]: "condition (λ_. True) f g = f"
  apply (rule spec_monad_ext)
  apply (clarsimp simp add: run_condition run_bind)
  done

lemma condition_False[simp]: "condition (λ_. False) f g = g"
  apply (rule spec_monad_ext)
  apply (clarsimp simp add: run_condition run_bind)
  done

lemma le_condition_runI:
  "(s. c s  run h s  run f s)  (s. ¬ c s  run h s  run g s)
   h  condition c f g"
  by (simp add: le_spec_monad_runI run_condition)

lemma mono_condition_spec_monad:
  "mono T  mono F  mono (λx. condition C (F x) (T x))"
  by (auto simp: condition_def intro!: mono_bind_spec_monad mono_const)

lemma mono_condition: "f  f'  g  g'  condition c f g  condition c f' g'"
  by (simp add: le_fun_def less_eq_spec_monad.rep_eq run_condition)

lemma monotone_condition_le[partial_function_mono]:
  "monotone R (≤) (λf'. f f')  (monotone R (≤) (λf'. g f'))
   monotone R (≤) (λf'. condition c (f f') (g f'))"
  by (auto simp add: monotone_def intro!: mono_condition)

lemma monotone_condition_ge[partial_function_mono]:
  "monotone R (≥) (λf'. f f')  (monotone R (≥) (λf'. g f'))
   monotone R (≥) (λf'. condition c (f f') (g f'))"
  by (auto simp add: monotone_def intro!: mono_condition)

lemma runs_to_condition[runs_to_vcg]:
  "(c s  f  s  Q )  (¬ c s  g  s  Q )  condition c f g  s  Q "
  by (simp add: runs_to.rep_eq run_condition)

lemma runs_to_condition_iff[runs_to_iff]:
  "condition c f g  s  Q   (if c s then f  s  Q  else g  s  Q )"
  by (simp add: runs_to.rep_eq run_condition)

lemma runs_to_partial_condition[runs_to_vcg]:
  "(c s  f  s ?⦃ Q )  (¬ c s  g  s ?⦃ Q )  condition c f g  s ?⦃ Q "
  by (simp add: runs_to_partial.rep_eq run_condition)

lemma refines_condition_iff:
  assumes "c' s'  c s"
  shows "refines (condition c f g) (condition c' f' g') s s' R 
   (if c' s' then refines f f' s s' R else refines g g' s s' R)"
  using assms
  by (auto simp add: refines.rep_eq run_condition)

lemma refines_condition:
  "P s  P' s' 
    (P s  P' s'  refines f f' s s' R) 
    (¬ P s  ¬ P' s'  refines g g' s s' R) 
    refines (condition P f g) (condition P' f' g') s s' R"
  using refines_condition_iff
  by metis

lemma refines_condition_TrueI:
  assumes "c' s' = c s" and "c' s'" "refines f f' s s' R"
  shows "refines (condition c f g) (condition c' f' g') s s' R"
  by (simp add:  refines_condition_iff[where c'=c' and c=c and s'=s' and s=s, OF assms(1)] assms(2, 3))

lemma refines_condition_FalseI:
  assumes "c' s' = c s" and "¬c' s'" "refines g g' s s' R"
  shows "refines (condition c f g) (condition c' f' g') s s' R"
  by (simp add:  refines_condition_iff[where c'=c' and c=c and s'=s' and s=s, OF assms(1)] assms(2, 3))

lemma refines_condition_bind_left:
  "refines (condition C T F  X) Y s t R 
    (C s  refines (T  X) Y s t R)  (¬ C s  refines (F  X) Y s t R)"
  by (simp add: refines.rep_eq run_bind run_condition)

lemma refines_condition_bind_right:
  "refines X (condition C T F  Y) s t R 
    (C t  refines X (T  Y) s t R)  (¬ C t  refines X (F  Y) s t R)"
  by (simp add: refines.rep_eq run_bind run_condition)

lemma rel_spec_condition_iff:
  assumes "c' s'  c s"
  shows "rel_spec (condition c f g) (condition c' f' g') s s' R 
   (if c' s' then rel_spec f f' s s' R else rel_spec g g' s s' R)"
  using assms
  by (auto simp add: rel_spec_def run_condition)

lemma rel_spec_condition:
  "P s  P' s' 
    (P s  P' s'  rel_spec f f' s s' R) 
    (¬ P s  ¬ P' s'  rel_spec g g' s s' R) 
    rel_spec (condition P f g) (condition P' f' g') s s' R"
  using rel_spec_condition_iff
  by metis

lemma rel_spec_condition_TrueI:
  assumes "c' s' = c s" and "c' s'" "rel_spec f f' s s' R"
  shows "rel_spec (condition c f g) (condition c' f' g') s s' R"
  by (simp add:  rel_spec_condition_iff[where c'=c' and c=c and s'=s' and s=s, OF assms(1)] assms(2, 3))

lemma rel_spec_condition_FalseI:
  assumes "c' s' = c s" and "¬c' s'" "rel_spec g g' s s' R"
  shows "rel_spec (condition c f g) (condition c' f' g') s s' R"
  by (simp add:  rel_spec_condition_iff[where c'=c' and c=c and s'=s' and s=s, OF assms(1)] assms(2, 3))

lemma refines_condition_left:
  "(P s  refines f h s t R)  (¬ P s  refines g h s t R) 
    refines (condition P f g) h s t R"
  by (simp add: refines.rep_eq run_condition)

lemma rel_spec_condition_left:
  "(P s  rel_spec f h s t R)  (¬ P s  rel_spec g h s t R) 
    rel_spec (condition P f g) h s t R"
  by (auto simp add: rel_spec_def run_condition)

lemma refines_condition_true:
  "P t  refines f g s t R  refines f (condition P g h) s t R"
  by (simp add: refines.rep_eq run_condition)

lemma rel_spec_condition_true:
  "P t  rel_spec f g s t R 
    rel_spec f (condition P g h) s t R"
  by (auto simp add: rel_spec_def run_condition)

lemma refines_condition_false:
  "¬ P t  refines f h s t R 
    refines f (condition P g h) s t R"
  by (simp add: refines.rep_eq run_condition)

lemma rel_spec_condition_false:
  "¬ P t  rel_spec f h s t R 
    rel_spec f (condition P g h) s t R"
  by (auto simp add: rel_spec_def run_condition)

lemma condition_bind:
  "(condition P f g >>= h) = condition P (f >>= h) (g >>= h)"
  by (simp add: spec_monad_ext_iff run_condition run_bind)

lemma rel_spec_monad_condition:
  assumes "rel_fun R (=) P P'"
    and "rel_spec_monad R Q f f'"
    and "rel_spec_monad R Q g g'"
  shows "rel_spec_monad R Q (condition P f g) (condition P' f' g')"
  using assms
  by (auto simp add: rel_spec_monad_def run_condition rel_fun_def)

lemma rel_spec_monad_condition_const:
  "P  P'  (P  rel_spec_monad R Q f f') 
    (¬ P  rel_spec_monad R Q g g') 
  rel_spec_monad R Q (condition (λ_. P) f g) (condition (λ_. P') f' g')"
  by (cases P) (auto simp add: condition_def)

subsection "constwhen"

lemma run_when[run_spec_monad]:
  "run (when c f) s = (if c then run f s else pure_post_state (Result (), s))"
  unfolding when_def by simp

lemma always_progress_when[always_progress_intros]:
  "always_progress f  always_progress (when c f)"
  unfolding when_def by (simp add: always_progress_intros)

lemma runs_to_when[runs_to_vcg]:
  "(c  f  s  Q )  (¬ c  Q (Result ()) s)  when c f  s  Q "
  by (auto simp add: runs_to.rep_eq run_when)

lemma runs_to_when_iff[runs_to_iff]: "
  (when c f)  s  Q   (if c then f  s  Q  else Q (Result ()) s)"
  by (auto simp add: runs_to.rep_eq run_when)

lemma runs_to_partial_when[runs_to_vcg]:
  "(c  f  s ?⦃ Q )  (¬ c  Q (Result ()) s)  when c f  s ?⦃ Q "
  by (auto simp add: runs_to_partial.rep_eq run_when)

lemma mono_when: "f  f'  when c f  when c f'"
  unfolding when_def by (simp add: mono_condition)

lemma monotone_when_le[partial_function_mono]:
  "monotone R (≤) (λf'. f f')
   monotone R (≤) (λf'. when c (f f'))"
  unfolding when_def by (simp add: monotone_condition_le)

lemma monotone_when_ge[partial_function_mono]:
  "monotone R (≥) (λf'. f f')
   monotone R (≥) (λf'. when c (f f'))"
  unfolding when_def by (simp add: monotone_condition_ge)

lemma when_True[simp]: "when True f = f"
  apply (rule spec_monad_ext)
  apply (simp add: run_when)
  done

lemma when_False[simp]: "when False f = return ()"
  apply (rule spec_monad_ext)
  apply (simp add: run_when)
  done

subsection ‹While›

context fixes C :: "'a  's  bool" and B :: "'a  ('e::default, 'a, 's) spec_monad"
begin

definition whileLoop :: "'a  ('e, 'a, 's) spec_monad" where
  "whileLoop =
    gfp (λW a. condition (C a) (bind (B a) W) (return a))"
  ― ‹Collapses to constFailure in case of any non terminating computation.›

definition whileLoop_finite :: "'a  ('e, 'a, 's) spec_monad" where
  "whileLoop_finite =
    lfp (λW a. condition (C a) (bind (B a) W) (return a))"
  ― ‹Does not collapse to constFailure in presence of a non terminating computation.
    constFailure can still occur when the body fails in some iteration. It
    captures the outcomes of all terminating and thus finite computations.›

inductive whileLoop_terminates :: "'a  's  bool" where
  step: "a s. (C a s  B a  s ?⦃ λv s. a. v = Result a  whileLoop_terminates a s ) 
    whileLoop_terminates a s"
  ― ‹ This is weaker than run (whileLoop a) s ≠ ⊤›: as it uses partial correctness ›

lemma mono_whileLoop_functional:
  "mono (λW a. condition (C a) (bind (B a) W) (return a))"
  by (intro mono_app mono_lam mono_const mono_bind_spec_monad mono_condition_spec_monad)

lemma whileLoop_unroll:
  "whileLoop a =
    condition (C a) (bind (B a) whileLoop) (return a)"
  unfolding whileLoop_def
  by (subst gfp_unfold[OF mono_whileLoop_functional]) simp

lemma whileLoop_finite_unfold:
  "whileLoop_finite a =
    condition (C a) (bind (B a) whileLoop_finite) (return a)"
  unfolding whileLoop_finite_def
  by (subst lfp_unfold[OF mono_whileLoop_functional]) simp

lemma whileLoop_ne_Failure:
  "(C a s  B a  s  λx s. a. x = Result a  run (whileLoop a) s  Failure ) 
    run (whileLoop a) s  Failure"
  by (subst whileLoop_unroll)
     (simp add: run_condition run_bind_eq_top_iff flip: top_post_state_def)

lemma whileLoop_ne_top_induct[consumes 1, case_names step]:
  assumes a_s: "run (whileLoop a) s  "
    and step: "a s. (C a s  B a  s  λx s. a. x = Result a  P a s )  P a s"
  shows "P a s"
proof -
  have "(λa. Spec (λs. if P a s then  else ))  whileLoop"
    unfolding whileLoop_def
    by (intro gfp_upperbound)
       (auto intro: step runs_to_weaken
         simp: le_fun_def less_eq_spec_monad.rep_eq run_condition top_unique run_bind_eq_top_iff)
  with a_s show ?thesis
    by (auto simp add: le_fun_def less_eq_spec_monad.rep_eq top_unique top_post_state_def
             split: if_splits)
qed

lemma runs_to_whileLoop:
  assumes R: "wf R"
  assumes *: "I (Result a) s"
  assumes P_Result: "a s. ¬ C a s  I (Result a) s  P (Result a) s"
  assumes P_Exception: "a s. a  default  I (Exception a) s  P (Exception a) s"
  assumes B: "a s. C a s  I (Result a) s 
    B a   s λr t. I r t  (b. r = Result b  ((b, t), (a, s))  R)"
  shows "whileLoop a  s P"
proof (use R * in induction x"(a, s)" arbitrary: a s)
  case (less a s)
  show ?case
    by (auto simp: P_Result P_Exception whileLoop_unroll[of a]
                   runs_to_condition_iff runs_to_bind_iff less
             intro!: runs_to_weaken[OF B])
qed

lemma runs_to_whileLoop_finite:
  assumes *: "I (Result a) s"
  assumes P_Result: "a s. ¬ C a s  I (Result a) s  P (Result a) s"
  assumes P_Exception: "a s. a  default  I (Exception a) s  P (Exception a) s"
  assumes B: "a s. C a s  I (Result a) s  B a  s  I "
  shows "whileLoop_finite a  s P"
proof -
  have "whileLoop_finite a  s
     λx s. I x s  (a. x = Result a  C a s  P (Result a) s) "
    unfolding whileLoop_finite_def
    apply (rule runs_to_lfp[OF mono_whileLoop_functional, of "λa. I (Result a)", OF *])
    subgoal premises prems for W x s
      using prems(2)
      by (intro runs_to_condition runs_to_bind runs_to_yield_iff
                runs_to_weaken[OF B] runs_to_weaken[OF prems(1)] conjI allI impI)
         simp_all
    done
  then show ?thesis
    apply (rule runs_to_weaken)
    subgoal for r s
      by (cases r; cases "a. r = Result a  C a s"; simp add: P_Result P_Exception)
    done
qed

lemma runs_to_partial_whileLoop_finite:
  assumes *: "I (Result a) s"
  assumes B: "a s. C a s  I (Result a) s  (B a)  s ?⦃ I "
  assumes P_Result: "a s. ¬ C a s  I (Result a) s  P (Result a) s"
  assumes P_Exception: "a s. a  default  I (Exception a) s  P (Exception a) s"
  shows "whileLoop_finite a  s ?⦃ P "
proof -
  have "whileLoop_finite a  s P" if **: "run (whileLoop_finite a) s  "
    apply (rule runs_to_whileLoop_finite[where
          I="λx s. I x s  (a. x = Result a  run (whileLoop_finite a) s  )"])
       apply (auto simp: * ** P_Result P_Exception)[3]
    subgoal for a s using B[of a s]
      by (subst (asm) whileLoop_finite_unfold)
        (auto simp: runs_to_conj run_bind_eq_top_iff  dest: runs_to_of_runs_to_partial_runs_to')
    done
  then show ?thesis
    by (auto simp: runs_to_partial_alt top_post_state_def)
qed

lemma whileLoop_finite_eq_whileLoop_of_whileLoop_terminates:
  assumes "whileLoop_terminates a s"
  shows "run (whileLoop_finite a) s = run (whileLoop a) s"
proof (use assms in induction)
  case (step a s)
  show ?case
    apply (subst whileLoop_unroll)
    apply (subst whileLoop_finite_unfold)
    apply (auto simp: run_condition intro!: run_bind_cong[OF refl runs_to_partial_weaken, OF step])
    done
qed

lemma whileLoop_terminates_of_succeeds:
  "run (whileLoop a) s    whileLoop_terminates a s"
  by (induction rule: whileLoop_ne_top_induct)
     (auto intro: whileLoop_terminates.step runs_to_partial_of_runs_to)

lemma whileLoop_finite_eq_whileLoop:
  "run (whileLoop a) s    run (whileLoop_finite a) s = run (whileLoop a) s"
  by (rule whileLoop_finite_eq_whileLoop_of_whileLoop_terminates[OF
    whileLoop_terminates_of_succeeds])

lemma runs_to_whileLoop_of_runs_to_whileLoop_finite_if_terminates:
  "whileLoop_terminates i s  whileLoop_finite i  s Q  whileLoop i  s Q"
  unfolding runs_to.rep_eq  whileLoop_finite_eq_whileLoop_of_whileLoop_terminates .

lemma whileLoop_finite_le_whileLoop: "whileLoop_finite a  whileLoop a"
  using le_fun_def lfp_le_gfp mono_whileLoop_functional
  unfolding whileLoop_finite_def whileLoop_def
  by fastforce

lemma runs_to_partial_whileLoop_finite_whileLoop:
  "whileLoop_finite i  s ?⦃Q  whileLoop i  s ?⦃Q"
  by (cases "run (whileLoop i) s = ")
     (auto simp:
       runs_to.rep_eq runs_to_partial_alt whileLoop_finite_eq_whileLoop top_post_state_def)

lemma runs_to_partial_whileLoop:
  assumes "I (Result a) s"
  assumes "a s. ¬ C a s  I (Result a) s  P (Result a) s"
  assumes "a s. a  default  I (Exception a) s  P (Exception a) s"
  assumes "a s. C a s  I (Result a) s  (B a)  s ?⦃ I "
  shows "whileLoop a  s ?⦃ P "
  using assms(1,4,2,3)
  by (rule runs_to_partial_whileLoop_finite_whileLoop[OF runs_to_partial_whileLoop_finite])

lemma always_progress_whileLoop[always_progress_intros]:
  assumes B: "(v. always_progress (B v))"
  shows "always_progress (whileLoop a)"
  unfolding always_progress.rep_eq
proof
  fix s have "run (whileLoop a) s  " if *: "run (whileLoop a) s  "
  proof (use * in induction rule: whileLoop_ne_top_induct)
    case (step a s) then show ?case
      apply (subst whileLoop_unroll)
      apply (simp add: run_condition run_bind bind_post_state_eq_bot prod_eq_iff runs_to.rep_eq
                  split: prod.splits exception_or_result_splits)
      apply clarsimp
      subgoal premises prems
        using holds_post_state_combine[OF prems(1,3), where R="λx. False"] B
        by (auto simp: holds_post_state_False exception_or_result_nchotomy always_progress.rep_eq)
      done
  qed
  then show "run (whileLoop a) s  "
    by (cases "run (whileLoop a) s = Failure") (auto simp: bot_post_state_def)
qed

lemma runs_to_partial_whileLoop_cond_false:
  "(whileLoop I)  s ?⦃ λr t. a. r = Result a  ¬ C a t "
  by (rule runs_to_partial_whileLoop[of "λ_ _. True"]) simp_all

end

context
  fixes R
  fixes C :: "'a  's  bool" and B :: "'a  ('e::default, 'a, 's) spec_monad"
    and C' :: "'b  't  bool" and B' :: "'b  ('f::default, 'b, 't) spec_monad"
  assumes C: "x s x' s'. R (Result x, s) (Result x', s')  C x s  C' x' s'"
  assumes B: "x s x' s'. R (Result x, s) (Result x', s')  C x s  C' x' s' 
    refines (B x) (B' x') s s' R"
  assumes R: "v s v' s'. R (v, s) (v', s')  (x. v = Result x)  (x'. v' = Result x')"
begin

lemma refines_whileLoop_finite_strong:
  assumes x_x': "R (Result x, s) (Result x', s')"
  shows "refines (whileLoop_finite C B x) (whileLoop_finite C' B' x') s s'
    (λ(r, s) (r', s'). (v. r = Result v  (v'. r' = Result v'  ¬ C v s  ¬ C' v' s')) 
     R (r, s) (r', s'))"
    (is "refines _ _ _ _ ?R")
proof -
  let ?P = "λp. x s x' s'. R (Result x, s) (Result x', s') 
    refines (p x) (whileLoop_finite C' B' x') s s' ?R"

  have "?P (whileLoop_finite C B)"
    unfolding whileLoop_finite_def[of C B]
  proof (rule lfp_ordinal_induct[OF mono_whileLoop_functional], intro allI impI)
    fix S x s x' s' assume S: "?P S" and x_x': "R (Result x, s) (Result x', s')"
    from x_x' show
      "refines (condition (C x) (B x  S) (return x)) (whileLoop_finite C' B' x') s s' ?R"
      by (subst whileLoop_finite_unfold)
         (auto dest: R simp: C[OF x_x'] refines_condition_iff
               intro!: refines_bind'[OF refines_mono[OF _ B[OF x_x']]] S[rule_format])
  qed (simp add: refines_Sup1)
  with x_x' show ?thesis by simp
qed

lemma refines_whileLoop_finite:
  assumes x_x': "R (Result x, s) (Result x', s')"
  shows "refines (whileLoop_finite C B x) (whileLoop_finite C' B' x') s s' R"
  by (rule refines_mono[OF _ refines_whileLoop_finite_strong, OF _ x_x']) simp

lemma whileLoop_succeeds_terminates_of_refines:
  assumes "run (whileLoop C' B' x') s'  "
  shows "R (Result x, s) (Result x', s')  run (whileLoop C B x) s  Failure"
proof (use assms in induction arbitrary: x s rule: whileLoop_ne_top_induct)
  case (step a' s' a s)

  show ?case
  proof (rule whileLoop_ne_Failure)
    assume "C a s"
    with C[of a s a' s'] step.prems have "C' a' s'" by simp

    show "B a  s  λx s. a. x = Result a  run (whileLoop C B a) s  Failure "
      using step.IH[OF C' a' s'] B[OF step.prems C a s C' a' s']
      by (rule runs_to_refines) (metis R)
  qed
qed

lemma refines_whileLoop_strong:
  assumes x_x': "R (Result x, s) (Result x', s')"
  shows "refines (whileLoop C B x) (whileLoop C' B' x') s s'
      (λ(r, s) (r', s'). (v. r = Result v  (v'. r' = Result v'  ¬ C v s  ¬ C' v' s')) 
     R (r, s) (r', s'))"
  apply (cases "run (whileLoop C' B' x') s' = Failure")
  subgoal by (simp add: refines.rep_eq )
  subgoal
    using
      whileLoop_succeeds_terminates_of_refines[OF _ x_x']
      refines_whileLoop_finite_strong[OF x_x']
    by (simp add: refines.rep_eq whileLoop_finite_eq_whileLoop top_post_state_def)
  done

lemma refines_whileLoop:
  assumes x_x': "R (Result x, s) (Result x', s')"
  shows "refines (whileLoop C B x) (whileLoop C' B' x') s s' R"
  by (rule refines_mono[OF _ refines_whileLoop_strong, OF _ x_x']) simp

end

lemma runs_to_whileLoop_res':
  assumes R: "wf R"
  assumes *: "I a s"
  assumes P_Result: "a s. ¬ C a s  I a s  P (Result a) s"
  assumes B: "a s. C a s  I a s 
    B a   s λr t. (b. r = Result b  I b t  ((b, t), (a, s))  R)"
  shows "(whileLoop C B a::('a, 's) res_monad)  s P"
  apply (rule runs_to_whileLoop[OF R, where I = "λException _  (λ_. False) | Result v  I v"])
  subgoal using * by auto
  subgoal using P_Result by auto
  subgoal by auto
  subgoal for a b by (rule B[of a b, THEN runs_to_weaken]) auto
  done

lemma runs_to_whileLoop_res:
  assumes B: "a s. C a s  I a s 
    B a   s λRes r t. I r t  ((r, t), (a, s))  R"
  assumes P_Result: "a s. I a s  ¬ C a s  P a s"
  assumes R: "wf R"
  assumes *: "I a s"
  shows "(whileLoop C B a::('a, 's) res_monad)  s λRes r. P r"
  apply (rule runs_to_whileLoop_res' [where R = R and I = I and B = B,  OF R *])
  using B P_Result by auto

lemma runs_to_whileLoop_variant_res:
  assumes I: "r s c. I r s c 
    C r s  (B r)  s  λRes q t. c'. I q t c'  (c', c)  R "
  assumes Q: "r s c. I r s c  ¬ C r s  Q r s"
  assumes R: "wf R"
  shows "I r s c  (whileLoop C B r::('a, 's) res_monad)  s λRes r. Q r "
proof (induction arbitrary: r s rule: wf_induct[OF R])
  case (1 c) show ?case
    apply (subst whileLoop_unroll)
    supply I[where c=c, runs_to_vcg]
    apply (runs_to_vcg)
    subgoal by (simp add: 1)
    subgoal using 1 by simp
    subgoal using Q 1 by blast
    done
qed

lemma runs_to_whileLoop_inc_res:
  assumes *: "i. i < M  (B (F i))  (S i)  λRes r' t'. r' = (F (Suc i))  t' = (S (Suc i)) "
      and [simp]: "i. i  M  C (F i) (S i)  i < M"
      and [simp]: "i = F 0" "si = S 0" "t = F M" "st = S M"
    shows "(whileLoop C B i::('a, 's) res_monad)  si  λRes r' t'. r' = t  t' = st "
  apply (rule runs_to_whileLoop_variant_res[where I="λr t i. r = F i  t = S i  i  M" and c=0,
    OF _ _ wf_nat_bound[of M]])
  apply simp_all
  apply (rule runs_to_weaken[OF *])
   apply auto
  done

lemma runs_to_whileLoop_dec_res:
  assumes *: "i::nat. i > 0  i  M 
        (B (F i))  (S i)  λRes r' t'. r' = (F (i - 1))  t' = (S (i - 1)) "
      and [simp]: "i. C (F i) (S i)  i > 0"
      and [simp]: "i = F M" "si = S M" "t = F 0" "st = S 0"
    shows "(whileLoop C B i::('a, 's) res_monad)  si  λRes r' t'. r' = t  t' = st "
  apply (rule runs_to_whileLoop_variant_res[where I="λr t i. r = F i  t = S i  i  M" and c=M,
        OF _ _ wf_measure[of "λx. x"]])
    apply (fastforce intro!: runs_to_weaken[OF *])+
  done


lemma runs_to_whileLoop_exn:
  assumes R: "wf R"
  assumes *: "I (Result a) s"
  assumes P_Result: "a s. ¬ C a s  I (Result a) s  P (Result a) s"
  assumes P_Exn: "a s. I (Exn a) s  P (Exn a) s"
  assumes B: "a s. C a s  I (Result a) s 
    B a   s λr t. I r t  (b. r = Result b  ((b, t), (a, s))  R)"
  shows "whileLoop C B a  s P"
  apply (rule runs_to_whileLoop[where I = I, OF R *])
  subgoal using P_Result by auto
  subgoal using P_Exn by (auto simp add: default_option_def Exn_def)
  subgoal using B by blast
  done

lemma runs_to_whileLoop_exn':
  assumes B: "a s. I (Result a) s  C a s 
    B a   s λr t. I r t  (b. r = Result b  ((b, t), (a, s))  R)"
  assumes P_Result: "a s.  I (Result a) s   ¬ C a s  P (Result a) s"
  assumes P_Exn: "a s. I (Exn a) s  P (Exn a) s"
  assumes R: "wf R"
  assumes *: "I (Result a) s"
  shows "whileLoop C B a  s P"
  by (rule runs_to_whileLoop_exn[where R=R and I=I and B=B and C=C and P=P,
        OF R * P_Result P_Exn B])

lemma runs_to_partial_whileLoop_res:
  assumes *: "I a s"
  assumes P_Result: "a s. ¬ C a s  I a s  P (Result a) s"
  assumes B: "a s. C a s  I a s 
    (B a)  s ?⦃λr t. (b. r = Result b  I b t)"
  shows "(whileLoop C B a::('a, 's) res_monad)  s ?⦃P"
  apply (rule runs_to_partial_whileLoop[where I = "λException _  (λ_. False) | Result v  I v"])
  subgoal using * by auto
  subgoal using P_Result by auto
  subgoal by auto
  subgoal by (rule B[THEN runs_to_partial_weaken]) auto
  done

lemma runs_to_partial_whileLoop_exn:
  assumes *: "I (Result a) s"
  assumes P_Result: "a s. ¬ C a s  I (Result a) s  P (Result a) s"
  assumes P_Exn: "a s. I (Exn a) s  P (Exn a) s"
  assumes B: "a s. C a s  I (Result a) s 
    (B a)  s ?⦃λr t. I r t"
  shows "whileLoop C B a  s ?⦃P"
  apply (rule runs_to_partial_whileLoop[where I = I, OF *])
  subgoal using P_Result by auto
  subgoal using P_Exn by (auto simp add: default_option_def Exn_def)
  subgoal using B by blast
  done

notation (output)
  whileLoop  ((‹notation=‹prefix whileLoop››whileLoop (_)//  (_)) [1000, 1000] 1000)

lemma whileLoop_mono: "b  b'  whileLoop c b i  whileLoop c b' i"
  unfolding whileLoop_def
  by (simp add: gfp_mono le_funD le_funI mono_bind mono_condition)

lemma whileLoop_finite_mono: "b  b'  whileLoop_finite c b i  whileLoop_finite c b' i"
  unfolding whileLoop_finite_def
  by (simp add: le_funD le_funI lfp_mono mono_bind mono_condition)

lemma monotone_whileLoop_le[partial_function_mono]:
  "(x. monotone R (≤) (λf. b f x))  monotone R (≤) (λf. whileLoop c (b f) i)"
  using whileLoop_mono unfolding monotone_def
  by (metis le_fun_def)

lemma monotone_whileLoop_ge[partial_function_mono]:
  "(x. monotone R (≥) (λf. b f x))  monotone R (≥) (λf. whileLoop c (b f) i)"
  using whileLoop_mono unfolding monotone_def
  by (metis le_fun_def)

lemma monotone_whileLoop_finite_le[partial_function_mono]:
  "(x. monotone R (≤) (λf. b f x))  monotone R (≤) (λf. whileLoop_finite c (b f) i)"
  using whileLoop_finite_mono unfolding monotone_def
  by (metis le_fun_def)

lemma monotone_whileLoop_finite_ge[partial_function_mono]:
  "(x. monotone R (≥) (λf. b f x))  monotone R (≥) (λf. whileLoop_finite c (b f) i)"
  using whileLoop_finite_mono unfolding monotone_def
  by (metis le_fun_def)

lemma run_whileLoop_le_invariant_cong:
  assumes I: "I (Result i) s"
  assumes invariant: "r s. C r s  I (Result r) s  B r  s ?⦃I"
  assumes C: "r s. I (Result r) s  C r s = C' r s"
  assumes B: "r s. C r s  I (Result r) s  run (B r) s = run (B' r) s"
  shows "run (whileLoop C B i) s  run (whileLoop C' B' i) s"
  unfolding le_run_refines_iff
  apply (rule refines_mono[of "λa b. a = b  I (fst a) (snd a)"])
  apply simp
  apply (rule refines_whileLoop)
  subgoal using C by auto
  subgoal for x s y t using invariant[of x s] B[of x t]
    by (cases "run (B' y) t")
       (auto simp add: refines.rep_eq runs_to_partial_alt runs_to.rep_eq split_beta')
  subgoal by simp
  subgoal using I by simp
  done

lemma run_whileLoop_invariant_cong:
  assumes I: "I (Result i) s"
  assumes invariant: "r s. C r s  I (Result r) s  B r  s ?⦃I"
  assumes C: "r s. I (Result r) s  C r s = C' r s"
  assumes B: "r s. C r s  I (Result r) s  run (B r) s = run (B' r) s"
  shows "run (whileLoop C B i) s = run (whileLoop C' B' i) s"
proof (rule antisym)
  from I invariant C B
  show "run (whileLoop C B i) s  run (whileLoop C' B' i) s"
    by (rule run_whileLoop_le_invariant_cong)
next
  from B invariant have invariant_sym: "r s. C r s  I (Result r) s  B' r  s ?⦃I"
    by (simp add: runs_to_partial.rep_eq)
  show "run (whileLoop C' B' i) s  run (whileLoop C B i) s"
    apply (rule run_whileLoop_le_invariant_cong
      [where I=I, OF I invariant_sym C[symmetric] B[symmetric]])
    apply (auto simp add: C)
    done
qed

lemma whileLoop_cong:
  assumes C: "r s. C r s = C' r s"
  assumes B: "r s. C r s  run (B r) s = run (B' r) s"
  shows "whileLoop C B = whileLoop C' B'"
proof (rule ext)
  fix i
  show "whileLoop C B i = whileLoop C' B' i"
  proof (rule spec_monad_ext)
    fix s

    show "run (whileLoop C B i) s = run (whileLoop C' B' i) s"
      by (rule run_whileLoop_invariant_cong[where I = "λ_ _. True"])
         (simp_all add: C B)
  qed
qed

lemma refines_whileLoop':
  assumes C: "a s b t. R (Result a, s) (Result b, t)  C a s  C' b t"
    and B:
      "a s b t. R (Result a, s) (Result b, t)  C a s  C' b t  refines (B a) (B' b) s t R"
    and I: "R (Result I, s) (Result I', s')"
    and R:
      "r s r' s'. R (r, s) (r', s')  rel_exception_or_result (λ_ _ . True) (λ_ _ . True) r r'"
  shows "refines (whileLoop C B I) (whileLoop C' B' I') s s'
    (λ(r, s) (r', s'). (v. r = Result v  (v'. r' = Result v'  ¬ C v s  ¬ C' v' s')) 
     R (r, s) (r', s'))"
  apply (rule refines_whileLoop_strong)
  subgoal using C by simp
  subgoal using B by simp
  subgoal for v s v' s' using R[of v s v' s'] by (auto elim!: rel_exception_or_result.cases)
  subgoal using I by simp
  done

lemma rel_spec_whileLoop':
  assumes C: "a s b t. R (Result a, s) (Result b, t)  C a s  C' b t"
    and B: "a s b t. R (Result a, s) (Result b, t)  C a s  C' b t  rel_spec (B a) (B' b) s t R"
    and I: "R (Result I, s) (Result I', s')"
    and R: "r s r' s'. R (r, s) (r', s')  rel_exception_or_result (λ_ _ . True) (λ_ _ . True) r r'"
  shows "rel_spec (whileLoop C B I) (whileLoop C' B' I') s s'
    (λ(r, s) (r', s'). (v. r = Result v  (v'. r' = Result v'  ¬ C v s  ¬ C' v' s')) 
     R (r, s) (r', s'))"
  using C B I R
  unfolding rel_spec_iff_refines
  apply (intro conjI)
  subgoal
    by (intro refines_whileLoop') auto
  subgoal
    apply (intro refines_mono[OF _ refines_whileLoop'[where R="R¯¯"]])
    subgoal by simp (smt (verit) Exception_eq_Result rel_exception_or_result.cases)
    subgoal by auto
    subgoal by auto
    subgoal by auto
    subgoal by simp (smt (verit, best) rel_exception_or_result.simps)
    done
  done

lemma rel_spec_whileLoop:
  assumes C: "a s b t. R (Result a, s) (Result b, t)  C a s  C' b t"
    and B: "a s b t. R (Result a, s) (Result b, t)  C a s  C' b t  rel_spec (B a) (B' b) s t R"
    and I: "R (Result I, s) (Result I', s')"
    and R: "r s r' s'. R (r, s) (r', s')  rel_exception_or_result (λ_ _ . True) (λ_ _ . True) r r'"
  shows "rel_spec (whileLoop C B I) (whileLoop C' B' I') s s' R"
  by (rule rel_spec_mono[OF _ rel_spec_whileLoop'[OF assms]]) auto

lemma do_whileLoop_combine:
  "do { x1  body x0; whileLoop P body x1 } =
    do {
      (b, y)  whileLoop (λ(b, x) s. b  P x s) (λ(b, x). do { y  body x; return (True, y) })
        (False, x0);
      return y
    }"
  apply (subst (2) whileLoop_unroll)
  apply (simp add: bind_assoc)
  apply (rule arg_cong[where f="bind _", OF ext])
  apply (subst bind_return[symmetric])
  apply (rule rel_spec_eqD)
  apply (rule rel_spec_bind_bind[where
    Q="rel_prod (rel_exception_or_result (=) (λx (b, y). b = True  x = y)) (=)"])
  subgoal for x s
    apply (rule rel_spec_whileLoop)
    subgoal by auto
    subgoal 
      apply (simp split: prod.splits) [1]
      apply (subst bind_return[symmetric])
      apply clarify
      apply (rule rel_spec_bind_bind[where Q="(=)"])
          apply (simp_all add: rel_spec_refl)
      done
    by (auto simp: rel_exception_or_result.simps split: prod.splits)
     apply (simp_all split: prod.splits)
  done

lemma rel_spec_whileLoop_res:
  assumes C: "a s b t. R (Result a, s) (Result b, t)  C a s  C' b t"
    and B: "a s b t. R (Result a, s) (Result b, t)  C a s  C' b t  rel_spec (B a) (B' b) s t R"
    and I: "R (Result I, s) (Result I', s')"
  shows "rel_spec ((whileLoop C B I)::('a, 's) res_monad) ((whileLoop C' B' I')::('b, 't) res_monad) s s' R"
  by (rule rel_spec_whileLoop[OF C B I]) auto

lemma rel_spec_monad_whileLoop:
  assumes init: "R I I'"
  assumes cond: "x y. R x y  (rel_fun S (=)) (C x) (C' y)"
  assumes body: "x y. R x y  rel_spec_monad S (rel_exception_or_result E R) (B x) (B' y)"
  shows "rel_spec_monad S (rel_exception_or_result E R) (whileLoop C B I) (whileLoop C' B' I')"
  apply (clarsimp simp add: rel_spec_monad_iff_rel_spec)
  apply (rule rel_spec_whileLoop)
  subgoal for s t x s' y t' using cond by (simp add: rel_fun_def)
  subgoal for s t x s' y t' using body[of x y] by (simp add: rel_spec_monad_iff_rel_spec)
  subgoal using init by simp
  subgoal by (auto elim!: rel_exception_or_result.cases)
  done

lemma rel_spec_monad_whileLoop_res':
  assumes init: "R I I'"
  assumes cond: "x y. R x y  (rel_fun S (=)) (C x) (C' y)"
  assumes body: "x y. R x y  rel_spec_monad S (λRes x Res y. R x y) (B x) (B' y)"
  shows "rel_spec_monad S (λRes x Res y. R x y)
            ((whileLoop C B I)::('a, 's) res_monad)
            ((whileLoop C' B' I')::('b, 't) res_monad)"
  using assms
  by (auto intro: rel_spec_monad_whileLoop simp add: rel_exception_or_result_Res_val)

lemma rel_spec_monad_whileLoop_res:
  assumes init: "R I I'"
  assumes cond: "rel_fun R (rel_fun S (=)) C C'"
  assumes body: "rel_fun R (rel_spec_monad S (λRes x Res y. R x y)) B B'"
  shows "rel_spec_monad S (λRes x Res y. R x y)
            ((whileLoop C B I)::('a, 's) res_monad)
            ((whileLoop C' B' I')::('b, 't) res_monad)"
  using assms
  by (auto intro: rel_spec_monad_whileLoop_res' simp add: rel_fun_def)

lemma runs_to_whileLoop_finite_exn:
  assumes B: "r s. I (Result r) s  C r s  (B r)  s  I "
  assumes Qr: "r s. I (Result r) s  ¬ C r s  Q (Result r) s"
  assumes Ql: "e s. I (Exn e) s  Q (Exn e) s"
  assumes I: "I (Result r) s"
  shows "(whileLoop_finite C B r)  s  Q "
  using assms by (intro runs_to_whileLoop_finite[of I]) (auto simp: Exn_def default_option_def)

lemma runs_to_whileLoop_finite_res:
  assumes B: "r s. I r s  C r s  (B r)  s λRes r. I r "
  assumes Q: "r s. I r s  ¬ C r s  Q r s"
  assumes I: "I r s"
  shows "((whileLoop_finite C B r)::('a, 's) res_monad)  s λRes r. Q r "
  using assms by (intro runs_to_whileLoop_finite[of "λRes r. I r"]) simp_all

lemma runs_to_whileLoop_eq_whileLoop_finite:
  "run (whileLoop C B r) s   
    (whileLoop C B r)  s  Q   (whileLoop_finite C B r)  s  Q "
  by (simp add: whileLoop_finite_eq_whileLoop runs_to.rep_eq)

lemma whileLoop_finite_cond_fail:
    "¬ C r s  (run (whileLoop_finite C B r) s) = (run (return r) s)"
  apply (subst whileLoop_finite_unfold)
  apply simp
  done

lemma runs_to_whileLoop_finite_cond_fail:
  "¬ C r s  (whileLoop_finite C B r)  s  Q   (return r)  s  Q "
  apply (subst whileLoop_finite_unfold)
  apply (simp add: runs_to.rep_eq)
  done

lemma runs_to_whileLoop_cond_fail:
  "¬ C r s  (whileLoop C B r)  s  Q   (return r)  s  Q "
  apply (subst whileLoop_unroll)
  apply (simp add: runs_to.rep_eq)
  done

lemma whileLoop_finite_unfold':
  "(whileLoop_finite C B r) =
     ((condition (C r) (B r) (return r)) >>= (whileLoop_finite C B))"
  apply (rule spec_monad_ext)
  apply (subst (1) whileLoop_finite_unfold)
  subgoal for s
    apply (cases "C r s")
    subgoal by (simp add: run_bind)
    subgoal by (simp add: whileLoop_finite_cond_fail run_bind)
    done
  done

lemma runs_to_whileLoop_unroll:
  assumes "¬ C r s  P r s"
  assumes [runs_to_vcg]: "C r s  (B r)  s  λRes r t. ((whileLoop C B r)  t λRes r. P r ) "
  shows "(whileLoop C B r)  s λRes r. P r "
  using assms(1)
  by (subst whileLoop_unroll) runs_to_vcg

lemma runs_to_partial_whileLoop_unroll:
  assumes "¬ C r s  P r s"
  assumes "C r s  (B r)  s ?⦃ λRes r t. ((whileLoop C B r)  t ?⦃λRes r. P r ) "
  shows "(whileLoop C B r)  s ?⦃λRes r. P r "
  using assms
  by (subst whileLoop_unroll) runs_to_vcg


lemma runs_to_whileLoop_unroll_exn:
  assumes "¬ C r s  P (Result r) s"
  assumes [runs_to_vcg]: "C r s  (B r)  s  λr t.
    (a. r = Result a  ((whileLoop C B a)  t  P )) 
    (e. r = Exn e  P (Exn e) t) "
  shows "(whileLoop C B r)  s λr s'. P r s' "
  using assms(1)
  by (subst whileLoop_unroll) (runs_to_vcg, auto simp: Exn_def default_option_def)


lemma runs_to_partial_whileLoop_unroll_exn:
  assumes "¬ C r s  P (Result r) s"
  assumes "C r s  (B r)  s ?⦃ λr t.
    (a. r = Result a  ((whileLoop C B a)  t ?⦃ P )) 
    (e. r = Exn e  P (Exn e) t) "
  shows "(whileLoop C B r)  s ?⦃λr s'. P r s' "
  using assms
  by (subst whileLoop_unroll) runs_to_vcg

lemma refines_whileLoop_exn:
  assumes C: "a s b t. R (Result a, s) (Result b, t)  C a s  C' b t"
    and B: "a s b t. R (Result a, s) (Result b, t)  C a s  C' b t  refines (B a) (B' b) s t R"
    and I: "R (Result I, s) (Result I', s')"
    and R1: "a s b t. R (Exn a, s) (Result b, t)  False"
    and R2: "a s b t. R (Result a, s) (Exn b, t)  False"
  shows "refines (whileLoop C B I) (whileLoop C' B' I') s s' R"
proof -
  have ref: "refines (whileLoop C B I) (whileLoop C' B' I') s s'
         (λ(r, s) (r', s'). (v. r = Result v  (v'. r' = Result v'  ¬ C v s  ¬ C' v' s'))  
          R (r, s) (r', s'))"

    apply (rule refines_whileLoop' [OF C])
    subgoal by assumption
    subgoal by (rule B)
    subgoal by (rule I)
    subgoal using R1 R2
      by (auto simp add: rel_exception_or_result.simps Exn_def default_option_def)
        (metis default_option_def exception_or_result_cases not_None_eq)+
    done
  show ?thesis
    apply (rule refines_mono [OF _ ref])
    apply (auto)
    done
qed

lemma refines_whileLoop'':
  assumes C: "a s b t. R (Result a, s) (Result b, t)  C a s  C' b t"
    and B: "a s b t. R (Result a, s) (Result b, t)  C a s  C' b t  refines (B a) (B' b) s t R"
    and I: "R (Result I, s) (Result I', s')"
    and R: "r s r' s'. R (r, s) (r', s')  rel_exception_or_result (λ_ _ . True) (λ_ _ . True) r r'"
  shows "refines (whileLoop C B I) (whileLoop C' B' I') s s' R"
proof -
  have "refines (whileLoop C B I) (whileLoop C' B' I') s s'
    (λ(r, s) (r', s'). (v. r = Result v  (v'. r' = Result v'  ¬ C v s  ¬ C' v' s'))  
     R (r, s) (r', s'))"
    by (rule refines_whileLoop' [OF C B I R])
  then show ?thesis
    apply (rule refines_weaken)
    apply auto
    done
qed

lemma rel_spec_monad_whileLoop_exn:
  assumes init: "R I I'"
  assumes cond: "x y. R x y  (rel_fun S (=)) (C x) (C' y)"
  assumes body: "x y. R x y  rel_spec_monad S (rel_xval E R) (B x) (B' y)"
  shows "rel_spec_monad S (rel_xval E R) (whileLoop C B I) (whileLoop C' B' I')"
  unfolding rel_xval_rel_exception_or_result_conv
  by (rule rel_spec_monad_whileLoop [OF init cond body [simplified rel_xval_rel_exception_or_result_conv]])

lemma refines_whileLoop_guard_right: 
  assumes "x s x' s'. R (Result x, s) (Result x', s')  G' x' s'  C x s = C' x' s'"
  assumes "x s x' s'. R (Result x, s) (Result x', s')  C x s  C' x' s'  G' x' s'  refines (B x) (B' x') s s' R" 
  assumes "v s v' s'. R (v, s) (v', s')  (x. v = Result x) = (x'. v' = Result x')"
  assumes "R (Result x, s) (Result x', s')"
  assumes "G s' = G' x' s'"
  shows "refines (whileLoop C B x) (guard G  (λ_. whileLoop C' (λr. do {res <- B' r; guard (G' res); return res}) x')) s s' R"
  apply (rule refines_bind_guard_right)
  apply (rule refines_mono [where R="λ(r, s) (q, t). R (r, s) (q, t)  
                (case q of Exception e  True | Result v  G' v t)"])
  subgoal by simp
  subgoal
    apply (rule refines_whileLoop)
    subgoal using assms(1) by auto
    subgoal apply (rule  refines_bind_guard_right_end')
      using assms(2)
      by auto
    subgoal using assms(3) by (auto split: exception_or_result_splits)
    subgoal using assms(4,5) by auto
    done
  done

subsection constmap_value

lemma map_value_lift_state: "map_value f (lift_state R g) = lift_state R (map_value f g)"
  apply transfer
  apply (simp add: map_post_state_eq_lift_post_state lift_post_state_comp
      lift_post_state_Sup image_image OO_def prod_eq_iff rel_prod.simps)
  apply (simp add: ac_simps)
  done

lemma run_map_value[run_spec_monad]: "run (map_value f g) s =
  map_post_state (λ(v, s). (f v, s)) (run g s)"
  by transfer (auto simp add: apfst_def map_prod_def map_post_state_def
      split: post_state.splits prod.splits)

lemma always_progress_map_value[always_progress_intros]:
  "always_progress g  always_progress (map_value f g)"
  by (simp add: always_progress_def run_map_value)

lemma runs_to_map_value_iff[runs_to_iff]: "map_value f g  s  Q   g  s λr t. Q (f r) t"
  by (auto simp add: runs_to.rep_eq run_map_value split_beta')

lemma runs_to_partial_map_value_iff[runs_to_iff]:
  "map_value f g  s ?⦃ Q   g  s ?⦃λr t. Q (f r) t"
  by (auto simp add: runs_to_partial.rep_eq run_map_value split_beta')

lemma runs_to_map_value[runs_to_vcg]: "g  s λr t. Q (f r) t  map_value f g  s  Q "
  by (simp add: runs_to_map_value_iff)

lemma runs_to_partial_map_value[runs_to_vcg]:
  "g  s ?⦃λr t. Q (f r) t  map_value f g  s ?⦃ Q "
  by (auto simp add: runs_to_partial.rep_eq run_map_value split_beta')

lemma mono_map_value: "g  g'  map_value f g  map_value f g'"
  by transfer (simp add: le_fun_def mono_map_post_state)

lemma monotone_map_value_le[partial_function_mono]:
  "monotone R (≤) (λf'. g f')  monotone R (≤) (λf'. map_value f (g f'))"
  by (auto simp: monotone_def mono_map_value)

lemma monotone_map_value_ge[partial_function_mono]:
  "monotone R (≥) (λf'. g f')  monotone R (≥) (λf'. map_value f (g f'))"
  by (auto simp: monotone_def mono_map_value)

lemma map_value_fail[simp]: "map_value f fail = fail"
  by transfer simp

lemma map_value_map_exn_gets[simp]: "map_value (map_exn emb) (gets x) = gets x"
  by (simp add: spec_monad_ext_iff run_map_value)

lemma refines_map_value_right_iff:
  "refines f (map_value m g) s t R  refines f g s t (λ(x, s) (y, t). R (x, s) (m y, t))"
  by transfer (simp add: sim_post_state_map_post_state2 split_beta' apfst_def map_prod_def)

lemma refines_map_value_left_iff:
  "refines (map_value m f) g s t R  refines f g s t (λ(x, s) (y, t). R (m x, s) (y, t))"
  by transfer (simp add: sim_post_state_map_post_state1 split_beta' apfst_def map_prod_def)

lemma rel_spec_map_value_right_iff:
  "rel_spec f (map_value m g) s t R  rel_spec f g s t (λ(x, s) (y, t). R (x, s) (m y, t))"
  by (simp add: rel_spec_iff_refines refines_map_value_right_iff refines_map_value_left_iff
                conversep.simps[abs_def] split_beta')

lemma rel_spec_map_value_left_iff:
  "rel_spec (map_value m f) g s t R  rel_spec f g s t (λ(x, s) (y, t). R (m x, s) (y, t))"
  by (simp add: rel_spec_iff_refines refines_map_value_right_iff refines_map_value_left_iff
                conversep.simps[abs_def] split_beta')

lemma refines_map_value_right:
  "refines f (map_value m f) s s (λ(x, s) (y, t). y = m x  s = t)"
  by (simp add: refines_map_value_right_iff refines_refl)

lemma refines_map_value: 
  assumes "refines f f' s t Q"
  assumes "r s' w t'. Q (r, s') (w, t')  R (g r, s') (g' w, t')"
  shows "refines (map_value g f) (map_value g' f') s t R"
  apply (simp add: refines_map_value_left_iff refines_map_value_right_iff)
  apply (rule refines_weaken [OF assms(1)])
  using assms(2) by auto

lemma map_value_id[simp]: "map_value (λx. x) = (λx. x)"
  apply (simp add: fun_eq_iff, intro allI iffI)
  apply (rule spec_monad_eqI)
  apply (rule runs_to_iff)
  done

subsection constliftE

lemma run_liftE[run_spec_monad]:
  "run (liftE f) s =
    map_post_state (λ(v, s). (map_exception_or_result (λx. undefined) id v, s)) (run f s)"
  by (simp add: run_map_value liftE_def)

lemma always_progress_liftE[always_progress_intros]:
  "always_progress f  always_progress (liftE f)"
  by (simp add: liftE_def always_progress_intros)

lemma runs_to_liftE_iff:
  "liftE f  s  Q   f  s λr t. Q (map_exception_or_result (λx. undefined) id r) t"
  by (simp add: liftE_def runs_to_map_value_iff)

lemma runs_to_liftE_iff_Res[runs_to_iff]:
  "liftE f  s  Q   f  s  λRes r. Q (Result r)"
  by (auto intro!: runs_to_cong_pred_only simp: runs_to_liftE_iff)

lemma runs_to_liftE':
  "f  s λr t. Q (map_exception_or_result (λx. undefined) id r) t  liftE f  s  Q "
  by (simp add: runs_to_liftE_iff)

lemma runs_to_liftE[runs_to_vcg]: "f  s  λRes r. Q (Result r)   liftE f  s  Q "
  by (simp add: runs_to_liftE_iff_Res)

lemma refines_liftE_left_iff:
  "refines (liftE f) g s t R 
    refines f g s t (λ(x, s') (y, t'). v. x = Result v  R (Result v, s') (y, t'))"
  by (auto simp add: liftE_def refines_map_value_left_iff intro!: refines_cong_cases del: iffI)

lemma refines_liftE_right_iff:
  "refines f (liftE g) s t R 
    refines f g s t (λ(x, s') (y, t'). v. y = Result v  R (x, s') (Result v, t'))"
  by (auto simp add: liftE_def refines_map_value_right_iff intro!: refines_cong_cases del: iffI)

lemma rel_spec_liftE:
  "rel_spec (liftE f) g s t R 
    rel_spec f g s t (λ(x, s') (y, t'). v. x = Result v  R (Result v, s') (y, t'))"
  by (simp add: rel_spec_iff_refines refines_liftE_right_iff refines_liftE_left_iff
                conversep.simps[abs_def] split_beta')

lemma rel_spec_monad_bind_liftE:
  "rel_spec_monad R (λRes v1 Res v2. P v1 v2) m n  rel_fun P (rel_spec_monad R Q) f g 
    rel_spec_monad R Q ((liftE m) >>= f) ((liftE n) >>= g)"
  apply (auto intro!: rel_bind_post_state' rel_post_state_refl
    simp: rel_spec_monad_def run_bind run_liftE bind_post_state_map_post_state rel_fun_def)[1]
  subgoal premises prems for s t
    by (rule rel_post_state_weaken[OF prems(1)[rule_format, OF prems(3)]])
       (auto simp add: rel_prod.simps prems(2))
  done

lemma rel_spec_monad_bind_liftE':
  "(x. rel_spec_monad (=) Q (f x) (g x))  rel_spec_monad (=) Q (liftE m >>= f) (liftE m >>= g)"
  by (auto simp add: rel_spec_monad_def run_bind run_liftE bind_post_state_map_post_state
           intro!: rel_bind_post_state' rel_post_state_refl)

lemma runs_to_partial_liftE':
  "f  s ?⦃λr t. Q (map_exception_or_result (λx. undefined) id r) t  liftE f  s ?⦃ Q "
  by (simp add: runs_to_partial.rep_eq run_liftE split_beta')

lemma runs_to_partial_liftE[runs_to_vcg]:
  assumes [runs_to_vcg]: "f  s ?⦃ λRes r. Q (Result r)" shows "liftE f  s ?⦃ Q "
  supply runs_to_partial_liftE'[runs_to_vcg] by runs_to_vcg

lemma mono_lifE: "f  f'  liftE f   liftE f'"
  unfolding liftE_def by (rule mono_map_value)

lemma monotone_liftE_le[partial_function_mono]:
  "monotone R (≤) (λf'. f f')  monotone R (≤) (λf'. liftE (f f'))"
  unfolding liftE_def by (rule monotone_map_value_le)

lemma monotone_liftE_ge[partial_function_mono]:
  "monotone R (≥) (λf'. f f')  monotone R (≥) (λf'. liftE (f f'))"
  unfolding liftE_def by (rule monotone_map_value_ge)

lemma bind_handle_liftE: "bind_handle (liftE f) g h = bind f g"
  by (simp add: spec_monad_eq_iff runs_to_iff)

lemma liftE_top[simp]: "liftE  = "
  by (rule spec_monad_ext) (simp add: run_liftE)

lemma liftE_bot[simp]: "liftE bot = bot"
  by (rule spec_monad_ext) (simp add: run_liftE)

lemma liftE_fail[simp]: "liftE fail = fail"
  by (rule spec_monad_ext) (simp add: run_liftE)

lemma liftE_return[simp]: "liftE (return x) = return x"
  by (rule spec_monad_ext) (simp add: run_liftE)

lemma liftE_yield_Exception[simp]:
  "liftE (yield (Exception x)) = skip"
  by (rule spec_monad_ext) (simp add: run_liftE)

lemma liftE_throw_exception_or_result[simp]:
  "liftE (throw_exception_or_result x) = return undefined"
  by (rule spec_monad_ext) (simp add: run_liftE map_exception_or_result_def)

lemma liftE_get_state[simp]: "liftE (get_state) = get_state"
  by (rule spec_monad_ext) (simp add: run_liftE)

lemma liftE_set_state[simp]: "liftE (set_state s) = set_state s"
  by (rule spec_monad_ext) (simp add: run_liftE)

lemma liftE_select[simp]: "liftE (select S) = select S"
  by (rule spec_monad_ext) (simp add: run_liftE)

lemma liftE_unknown[simp]: "liftE (unknown) = unknown"
  by (rule spec_monad_ext) (simp add: run_liftE)

lemma liftE_lift_state: "liftE (lift_state R f) = lift_state R (liftE f)"
  unfolding liftE_def by (rule map_value_lift_state)

lemma liftE_exec_concrete: "liftE (exec_concrete st f) = exec_concrete st (liftE f)"
  unfolding exec_concrete_def by (rule liftE_lift_state)

lemma liftE_exec_abstract: "liftE (exec_abstract st f) = exec_abstract st (liftE f)"
  unfolding exec_abstract_def by (rule liftE_lift_state)

lemma liftE_assert[simp]: "liftE (assert P) = assert P"
  by (rule spec_monad_ext) (simp add: run_liftE run_assert)

lemma liftE_assume[simp]: "liftE (assume P) = assume P"
  by (rule spec_monad_ext) (simp add: run_liftE run_assume)

lemma liftE_gets[simp]: "liftE (gets f) = gets f"
  by (rule spec_monad_ext) (simp add: run_liftE)

lemma liftE_guard[simp]: "liftE (guard P) = guard P"
  by (rule spec_monad_ext) (simp add: run_liftE run_guard)

lemma liftE_assert_opt[simp]: "liftE (assert_opt v) = assert_opt v"
  by (rule spec_monad_ext) (auto simp add: run_liftE split: option.splits)

lemma liftE_gets_the[simp]: "liftE (gets_the f) = gets_the f"
  by (rule spec_monad_ext) (auto simp add: run_liftE split: option.splits)

lemma liftE_modify[simp]: "liftE (modify f) = modify f"
  by (rule spec_monad_ext) (simp add: run_liftE)

lemma liftE_bind: "liftE x >>= (λa. liftE (y a)) = liftE (x >>= y)"
  by (simp add: spec_monad_eq_iff runs_to_iff)

lemma bindE_liftE_skip: "liftE (f  (λy. skip))  g = liftE f  (λ_. g ())"
  by (auto simp add: spec_monad_eq_iff runs_to_iff intro!: arg_cong[where f="runs_to f _"])

lemma liftE_state_select[simp]: "liftE (state_select f) = state_select f"
  apply (rule spec_monad_eqI)
  apply (auto simp add: runs_to_iff)
  done

lemma liftE_assume_result_and_state[simp]:
  "liftE (assume_result_and_state f) = assume_result_and_state f"
  apply (rule spec_monad_eqI)
  apply (auto simp add: runs_to_iff)
  done

lemma map_value_map_exn_liftE [simp]:
  "map_value (map_exn emb) (liftE f) = liftE f"
  by (simp add: spec_monad_eq_iff runs_to_iff)

lemma liftE_condition: "liftE (condition c f g) = condition c (liftE f) (liftE g)"
  by (simp add: spec_monad_eq_iff runs_to_iff)

lemma liftE_whileLoop: "liftE (whileLoop C B I) = whileLoop C (λr. liftE (B r)) I"
  apply (rule rel_spec_eqD)
  apply (subst rel_spec_liftE)
  apply (rule rel_spec_whileLoop)
  subgoal by simp
  apply (subst rel_spec_symm)
  apply (subst rel_spec_liftE)
  apply (simp add: rel_spec_refl)
  apply auto
  done

subsection consttry

lemma run_try[run_spec_monad]:
  "run (try f) s = map_post_state (λ(v, s). (unnest_exn v, s)) (run f s)"
  by (simp add: try_def run_map_value)

lemma always_progress_try[always_progress_intros]: "always_progress f  always_progress (try f)"
  by (simp add: try_def always_progress_intros)

lemma runs_to_try[runs_to_vcg]: "f  s λr t. Q (unnest_exn r) t  try f  s  Q "
  by (simp add: try_def runs_to_map_value)

lemma runs_to_try_iff[runs_to_iff]: "try f  s  Q   f  s λr t. Q (unnest_exn r) t"
  by (auto simp add: try_def runs_to_iff)

lemma runs_to_partial_try[runs_to_vcg]: "f  s ?⦃λr t. Q (unnest_exn r) t  try f  s ?⦃ Q "
  by (simp add: try_def runs_to_partial_map_value)

lemma mono_try: "f  f'  try f  try f'"
  unfolding try_def
  by (rule mono_map_value)

lemma monotone_try_le[partial_function_mono]:
  "monotone R (≤) (λf'. f f')  monotone R (≤) (λf'. try (f f'))"
  unfolding try_def by (rule monotone_map_value_le)

lemma monotone_try_ge[partial_function_mono]:
  "monotone R (≥) (λf'. f f')  monotone R (≥) (λf'. try (f f'))"
  unfolding try_def by (rule monotone_map_value_ge)

lemma refines_try_right:
  "refines f (try f) s s (λ(x, s) (y, t). y = unnest_exn x  s = t)"
  by (auto simp add: try_def refines_map_value_right)

subsection constfinally

lemma run_finally[run_spec_monad]:
  "run (finally f) s = map_post_state (λ(v, s). (unite v, s)) (run f s)"
  by (simp add: finally_def run_map_value)

lemma always_progress_finally[always_progress_intros]:
  "always_progress f  always_progress (finally f)"
  by (simp add: finally_def always_progress_intros)

lemma runs_to_finally_iff[runs_to_iff]:
  "finally f  s  Q  
    f  s λr t. (v. r = Result v  Q (Result v) t)  (v. r = Exn v  Q (Result v) t)"
  by (auto simp: finally_def runs_to_map_value_iff unite_def
           intro!: arg_cong[where f="runs_to _ _"] split: xval_splits)

lemma runs_to_finally': "f  s λr t. Q (unite r) t  finally f  s  Q "
  by (simp add: finally_def runs_to_map_value)

lemma runs_to_finally[runs_to_vcg]:
  "f  s λr t. (v. r = Result v  Q (Result v) t)  (v. r = Exn v  Q (Result v) t) 
  finally f  s  Q "
  by (simp add: runs_to_finally_iff)

lemma runs_to_partial_finally': "f  s ?⦃λr t. Q (unite r) t  finally f  s ?⦃ Q "
  by (simp add: finally_def runs_to_partial_map_value)

lemma runs_to_partial_finally_iff:
  "finally f  s ?⦃ Q  
    f  s ?⦃λr t. (v. r = Result v  Q (Result v) t)  (v. r = Exn v  Q (Result v) t)"
  by (auto simp: finally_def runs_to_partial_map_value_iff unite_def
           intro!: arg_cong[where f="runs_to_partial _ _"] split: xval_splits)

lemma runs_to_partial_finally[runs_to_vcg]:
  "f  s ?⦃λr t. (v. r = Result v  Q (Result v) t)  (v. r = Exn v  Q (Result v) t) 
  finally f  s ?⦃ Q "
  by (simp add: runs_to_partial_finally_iff)

lemma mono_finally: "f  f'  finally f  finally f'"
  unfolding finally_def
  by (rule mono_map_value)

lemma monotone_finally_le[partial_function_mono]:
  "monotone R (≤) (λf'. f f')  monotone R (≤) (λf'. finally (f f'))"
  unfolding finally_def by (rule monotone_map_value_le)

lemma monotone_finally_ge[partial_function_mono]:
  "monotone R (≥) (λf'. f f')  monotone R (≥) (λf'. finally (f f'))"
  unfolding finally_def by (rule monotone_map_value_ge)

subsection constcatch

lemma run_catch[run_spec_monad]:
  "run (f <catch> h) s =
    bind_post_state (run f s)
      (λ(r, t). case r of Exn e  run (h e) t | Result v  run (return v) t)"
  apply (simp add: catch_def run_bind_handle)
  apply (rule arg_cong[where f="bind_post_state _", OF ext])
  apply (auto simp add: Exn_def default_option_def split: exception_or_result_splits xval_splits)
  done

lemma always_progress_catch[always_progress_intros]:
  "always_progress f  (e. always_progress (h e))  always_progress (f <catch> h)"
  unfolding catch_def
  by (auto intro: always_progress_intros)

lemma runs_to_catch_iff[runs_to_iff]:
  "(f <catch> h)  s  Q  
    f  s  λr t. (v. r = Result v  Q (Result v) t)  (e. r = Exn e  h e  t  Q )"
  by (simp add: runs_to.rep_eq run_catch split_beta' ac_simps split: xval_splits prod.splits)

lemma runs_to_catch[runs_to_vcg]:
  "f  s  λr t. (v. r = Result v   Q (Result v) t) 
                 (e. r = Exn e  h e  t  Q ) 
      (f <catch> h)  s  Q "
  by (simp add: runs_to_catch_iff)

lemma runs_to_partial_catch[runs_to_vcg]:
  "f  s ?⦃ λr t. (v. r = Result v   Q (Result v) t) 
                 (e. r = Exn e  h e  t ?⦃ Q ) 
      (f <catch> h)  s ?⦃ Q "
  unfolding runs_to_partial.rep_eq run_catch
  by (rule holds_partial_bind_post_state)
     (simp add: split_beta' ac_simps split: xval_splits prod.splits)

lemma mono_catch: "f  f'  h  h'  catch f h  catch f' h'"
  unfolding catch_def
  apply (rule mono_bind_handle)
  by (simp_all add: le_fun_def)

lemma monotone_catch_le[partial_function_mono]:
  "monotone R (≤) (λf'. f f')  (e. monotone R (≤) (λf'. h f' e))
   monotone R (≤) (λf'. catch (f f') (h f'))"
  unfolding catch_def
  apply (rule monotone_bind_handle_le)
  by auto

lemma monotone_catch_ge[partial_function_mono]:
  "monotone R (≥) (λf'. f f')  (e. monotone R (≥) (λf'. h f' e))
   monotone R (≥) (λf'. catch (f f') (h f'))"
  unfolding catch_def
  apply (rule monotone_bind_handle_ge)
  by auto

lemma catch_liftE: "catch (liftE g) h = g"
  by (simp add: catch_def bind_handle_liftE)

lemma refines_catch:
  assumes f: "refines f f' s s' Q"
  assumes ll: "e e' t t'. Q (Exn e, t) (Exn e', t') 
    refines (h e) (h' e') t t' R"
  assumes lr: "e v' t t'. Q (Exn e, t) (Result v', t') 
    refines (h e) (return v') t t' R"
  assumes rl: "v e' t t'. Q (Result v, t) (Exn e', t') 
    refines (return v) (h' e') t t' R"
  assumes rr: "v v' t t'. Q (Result v, t) (Result v', t') 
    R (Result v, t) (Result v', t')"
  shows "refines (catch f h) (catch f' h') s s' R"
  unfolding catch_def
  apply (rule refines_bind_handle_bind_handle_exn[OF f])
  subgoal using ll by auto
  subgoal using lr by auto
  subgoal using rl by auto
  subgoal using rr by (auto simp add: refines_yield)
  done

lemma rel_spec_catch:
  assumes f: "rel_spec f f' s s' Q"
  assumes ll: "e e' t t'. Q (Exn e, t) (Exn e', t') 
    rel_spec (h e) (h' e') t t' R"
  assumes lr: "e v' t t'. Q (Exn e, t) (Result v', t') 
    rel_spec (h e) (return v') t t' R"
  assumes rl: "v e' t t'. Q (Result v, t) (Exn e', t') 
    rel_spec (return v) (h' e') t t' R"
  assumes rr: "v v' t t'. Q (Result v, t) (Result v', t') 
    R (Result v, t) (Result v', t')"
  shows "rel_spec (catch f h) (catch f' h') s s' R"
  using assms by (auto simp: rel_spec_iff_refines intro!: refines_catch)

subsection constcheck

lemma run_check[run_spec_monad]: "run (check e p) s =
  (if (x. p s x) then Success ((λx. (x, s)) ` (Result ` {x. p s x})) else
    pure_post_state (Exn e, s))"
  by (auto simp add: check_def run_bind)

lemma always_progress_check[always_progress_intros, simp]: "always_progress (check e p)"
  by (auto simp add: always_progress_def run_check)

lemma runs_to_check_iff[runs_to_iff]:
  "(check e p)  s Q  (if (x. p s x) then (x. p s x  Q (Result x) s) else Q (Exn e) s)"
  by (auto simp add: runs_to.rep_eq run_check)

lemma runs_to_check[runs_to_vcg]:
  "(x. p s x  (x. p s x  Q (Result x) s))  (x. ¬ p s x)  Q (Exn e) s 
    check e p  s Q"
  by (simp add: runs_to_check_iff)

lemma runs_to_partial_check[runs_to_vcg]:
  "(x. p s x  (x. p s x  Q (Result x) s))  (x. ¬ p s x)  Q (Exn e) s 
    check e p  s ?⦃Q"
  by (simp add: runs_to_partial.rep_eq run_check)

lemma refines_check_right_ok:
  "Q t a  refines f (g a) s t R  refines f (check q Q >>= g) s t R"
  by (rule refines_bind_right_single[of a t])
     (auto simp add: run_check)

lemma refines_check_right_fail:
  "(a. ¬ Q t a)  f  s  λr s'. R (r, s') (Exn q, t)   refines f (check q Q >>= g) s t R"
  apply (rule refines_bind_right_runs_to_partialI[OF runs_to_partial_of_runs_to])
  apply (auto simp: runs_to_check_iff Exn_def refines_yield_right_iff)
  done

lemma refines_throwError_check:
  "(a. ¬ P t a)  R (Exn r, s) (Exn q, t) 
    refines (throw r) (check q P >>= f) s t R"
  by (intro refines_check_right_fail) auto

lemma refines_condition_neg_check:
  "P s  (a. ¬ Q t a) 
    (P s  a. ¬ Q t a  R (Result r, s) (Exn q, t)) 
    (a. ¬ P s  Q t a  refines f (g a) s t R) 
    refines (condition P (return r) f) (check q Q >>= g) s t R"
 by (intro refines_condition_left)
    (auto intro: refines_check_right_ok refines_check_right_fail)

subsection constignoreE

named_theorems ignoreE_simps ‹Rewrite rules to push const‹ignoreE› inside.›

lemma ignoreE_eq: "ignoreE f = vmap_value (map_exception_or_result (λx. undefined) id) f"
  unfolding ignoreE_def catch_def
  by transfer
     (simp add: fun_eq_iff post_state_eq_iff eq_commute all_comm[where 'a='a]
           split: prod.splits exception_or_result_splits)

lemma run_ignoreE: "run (ignoreE f) s =
  bind_post_state (run f s)
    (λ(r, t). case r of Exn e   | Result v  pure_post_state (Result v, t))"
  unfolding ignoreE_def by (simp add: run_catch)

lemma runs_to_ignoreE_iff[runs_to_iff]:
  "(ignoreE f)  s  Q   f  s  λr t. (v. r = Result v   Q (Result v) t)"
  unfolding ignoreE_eq runs_to.rep_eq
  by transfer
     (simp add: split_beta' eq_commute prod_eq_iff split: prod.splits exception_or_result_splits)

lemma runs_to_ignoreE[runs_to_vcg]:
  "f  s  λr t. (v. r = Result v   Q (Result v) t)  (ignoreE f)  s  Q "
  by (simp add: runs_to_ignoreE_iff)

lemma liftE_le_iff_le_ignoreE: "liftE f  g  f  ignoreE g"
  unfolding liftE_def ignoreE_eq[abs_def]
  by transfer (simp add: le_fun_def vmap_post_state_le_iff_le_map_post_state id_def)

lemma runs_to_partial_ignoreE[runs_to_vcg]:
  "f  s ?⦃ λr t. (v. r = Result v   Q (Result v) t) 
      (ignoreE f)  s ?⦃ Q "
  unfolding ignoreE_def
  apply (runs_to_vcg)
  apply (fastforce simp add: runs_to_partial.rep_eq)
  done

lemma mono_ignoreE: "f  f'  ignoreE f  ignoreE f'"
  unfolding ignoreE_def
  apply (rule mono_catch)
  by simp_all

lemma monotone_ignoreE_le[partial_function_mono]:
  "monotone R (≤) (λf'. f f')
   monotone R (≤) (λf'. ignoreE (f f'))"
  unfolding ignoreE_def
  apply (rule monotone_catch_le)
  by simp_all

lemma monotone_ignoreE_ge[partial_function_mono]:
  "monotone R (≥) (λf'. f f')
   monotone R (≥) (λf'. ignoreE (f f'))"
  unfolding ignoreE_def
  apply (rule monotone_catch_ge)
  by simp_all

lemma ignoreE_liftE [simp]: "ignoreE (liftE f) = f"
  by (simp add: catch_liftE ignoreE_def)

lemma ignoreE_top[simp]: "ignoreE  = "
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_bot[simp]: "ignoreE bot = bot"
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_fail[simp]: "ignoreE fail = fail"
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_return[simp]: "ignoreE (return x) = return x"
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_throw[simp]: "ignoreE (throw x) = bot"
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_yield_Exception[simp]: "e  default  ignoreE (yield (Exception e)) = bot"
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_guard[simp]: "ignoreE (guard P) = guard P"
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_get_state[simp]: "ignoreE (get_state) = get_state"
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_set_state[simp]: "ignoreE (set_state s) = set_state s"
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_select[simp]: "ignoreE (select S) = select S"
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_unknown[simp]: "ignoreE (unknown) = unknown"
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_exec_concrete[simp]: "ignoreE (exec_concrete st f) = exec_concrete st (ignoreE f)"
  by (rule spec_monad_eqI) (auto simp add: runs_to_iff)

lemma ignoreE_assert[simp]: "ignoreE (assert P) = assert P"
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_assume[simp]: "ignoreE (assume P) = assume P"
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_gets[simp]: "ignoreE (gets f) = gets f"
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_assert_opt[simp]: "ignoreE (assert_opt v) = assert_opt v"
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_gets_the[simp]: "ignoreE (gets_the f) = gets_the f"
  by (rule spec_monad_eqI) (auto simp add: runs_to_iff)

lemma ignoreE_modify[simp]: "ignoreE (modify f) = modify f"
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_condition[ignoreE_simps]:
  "ignoreE (condition c f g) = condition c (ignoreE f) (ignoreE g)"
  by (rule spec_monad_eqI) (simp add: runs_to_iff)

lemma ignoreE_when[simp]: "ignoreE (when P f) = when P (ignoreE f)"
  by (simp add: when_def ignoreE_condition)

lemma ignoreE_bind[ignoreE_simps]:
  "ignoreE (bind f g) = bind (ignoreE f) (λv. (ignoreE (g v)))"
  by (simp add: spec_monad_eq_iff runs_to_ignoreE_iff runs_to_bind_iff)

lemma ignoreE_map_exn[simp]: "ignoreE (map_value (map_exn f) g) = ignoreE g"
  by (simp add: spec_monad_eq_iff runs_to_iff)

lemma ignoreE_whileLoop[ignoreE_simps]:
  "ignoreE (whileLoop C B I) = whileLoop C (λx. ignoreE (B x)) I"
proof -
  have "(λI. ignoreE (whileLoop C B I)) = whileLoop C (λx. ignoreE (B x))"
    unfolding whileLoop_def
    apply (rule gfp_fusion[OF _ mono_whileLoop_functional mono_whileLoop_functional,
      where g="λx a. liftE (x a)"])
    subgoal by (simp add: le_fun_def liftE_le_iff_le_ignoreE)
    subgoal by (simp add: ignoreE_bind ignoreE_condition)
    done
  then show ?thesis by (simp add: fun_eq_iff)
qed

subsection conston_exit'

lemmas bind_finally_def = bind_exception_or_result_def

lemma bind_exception_or_result_liftE_assoc:
  "bind_exception_or_result (bind (liftE f) g) h = bind f (λv. bind_exception_or_result (g v) h)"
  by (rule spec_monad_eqI) (auto simp add: runs_to_iff)

lemma bind_exception_or_result_bind_guard_assoc:
  "bind_exception_or_result (bind (guard P) g) h =
    bind (guard P) (λv. bind_exception_or_result (g v) h)"
  by (rule spec_monad_eqI) (auto simp add: runs_to_iff)

lemma on_exit_bind_exception_or_result_conv: 
  "on_exit f cleanup = bind_exception_or_result f (λx. do {state_select cleanup; yield x})"
  by (simp add: on_exit_def on_exit'_def)

lemma guard_on_exit_bind_exception_or_result_conv: 
  "guard_on_exit f P cleanup =
    bind_exception_or_result f (λx. do {guard P; state_select cleanup; yield x})"
  by (simp add: on_exit'_def bind_assoc)

lemma assume_result_and_state_check_only_state:
  "assume_result_and_state (λs. {((), s'). s' = s  P s}) = assuming P"
  by (simp add: spec_monad_eq_iff runs_to_iff)

lemma assume_on_exit_bind_exception_or_result_conv: 
  "assume_on_exit f P cleanup = bind_exception_or_result f
    (λx. do {assume_result_and_state (λs. {((), s'). s' = s  P s}); state_select cleanup; yield x})"
  by (simp add: on_exit'_def bind_assoc assume_result_and_state_check_only_state)

lemma monotone_on_exit'_le[partial_function_mono]:
  "monotone R (≤) (λf'. f f')  monotone R (≤) (λf'. g f') 
    monotone R (≤) (λf'. on_exit' (f f') (g f'))"
  unfolding on_exit'_def by (intro partial_function_mono)

lemma monotone_on_exit'_ge[partial_function_mono]:
  "monotone R (≥) (λf'. f f')  monotone R (≥) (λf'. g f') 
    monotone R (≥) (λf'. on_exit' (f f') (g f'))"
  unfolding on_exit'_def by (intro partial_function_mono)

lemma runs_to_on_exit'_iff[runs_to_iff]:
  "on_exit' f c  s  Q  
    f  s  λr t. c  t  λq t. Q (case q of Exception e  Exception e | Result _  r) t  "
  by (auto simp add: on_exit'_def runs_to_iff fun_eq_iff split: exception_or_result_split
           intro!: arg_cong[where f="runs_to _ _"])

lemma runs_to_on_exit'[runs_to_vcg]:
  "f  s  λr t. c  t  λq t. Q (case q of Exception e  Exception e | Result _  r) t    on_exit' f c  s  Q "
  by (simp add: runs_to_on_exit'_iff)

lemma runs_to_partial_on_exit'[runs_to_vcg]:
  "f  s ?⦃ λr t. c  t ?⦃ λq t. Q (case q of Exception e  Exception e | Result _  r) t    on_exit' f c  s ?⦃ Q "
  unfolding on_exit'_def
  apply (rule runs_to_partial_bind_exception_or_result)
  apply (rule runs_to_partial_weaken, assumption)
  apply (rule runs_to_partial_bind)
  apply (rule runs_to_partial_weaken, assumption)
  apply auto
  done

lemma refines_on_exit':
  assumes f: "refines f f' s s'
    (λ(r, t) (r', t'). refines c c' t t' (λ(q, t) (q', t'). R
      (case q of Exception e  Exception e | Result _  r, t) 
      (case q' of Exception e'  Exception e' | Result _  r', t')))"
  shows "refines (on_exit' f c) (on_exit' f' c') s s' R"
  unfolding on_exit'_def
  apply (rule refines_bind_exception_or_result[OF f[THEN refines_weaken]])
  apply (clarsimp intro!: refines_bind')
  apply (rule refines_weaken, assumption)
  apply auto
  done

lemma on_exit'_skip: "on_exit' f skip = f"
  by (simp add: spec_monad_eq_iff runs_to_iff)

subsection construn_bind

lemma run_run_bind: "run (run_bind f t g) s = bind_post_state (run f t) (λ(r, t). run (g r t) s)"
  by (transfer) simp

lemma runs_to_run_bind_iff[runs_to_iff]:
  "(run_bind f t g)  s Q  f  t λr t. (g r t)  s Q "
  by (simp add: runs_to.rep_eq run_run_bind split_beta')

lemma runs_to_run_bind[runs_to_vcg]:
  "f  t λr t. (g r t)  s Q   (run_bind f t g)  s Q"
  by (simp add: runs_to_run_bind_iff)

lemma runs_to_partial_run_bind[runs_to_vcg]:
  "f  t ?⦃λr t. (g r t)  s ?⦃Q   (run_bind f t g)  s ?⦃Q"
  by (auto simp: runs_to_partial.rep_eq run_run_bind split_beta'
           intro!: holds_partial_bind_post_state)

lemma mono_run_bind: "f  f'  g  g'  run_bind f t g  run_bind f' t g' "
  apply (rule le_spec_monad_runI)
  apply (simp add: run_run_bind)
  apply (rule mono_bind_post_state)
   apply (auto simp add: le_fun_def less_eq_spec_monad.rep_eq split: exception_or_result_splits)
  done

lemma monotone_run_bind_le[partial_function_mono]:
  "monotone R (≤) (λf'. f f')  (r t. monotone R (≤) (λf'. g f' r t))
   monotone R (≤) (λf'. run_bind (f f') t (g f'))"
  apply (simp add: monotone_def)
  using mono_run_bind
  by (metis le_fun_def)

lemma monotone_run_bind_ge[partial_function_mono]:
  "monotone R (≥) (λf'. f f')  (r t. monotone R (≥) (λf'. g f' r t))
   monotone R (≥) (λf'. run_bind (f f') t (g f'))"
  apply (simp add: monotone_def)
  using mono_run_bind
  by (metis le_fun_def)

lemma liftE_run_bind: "liftE (run_bind f t g) = run_bind f t (λr t. liftE (g r t))"
  apply (rule spec_monad_eqI)
  apply (auto simp add: runs_to_iff)
  done

lemma exec_concrete_run_bind: "exec_concrete st f =
  do {
   s  get_state;
   t  select {t. st t = s};
   run_bind f t (λr' t'. do {set_state (st t'); yield r' })
  }"
  apply (rule spec_monad_eqI)
  apply (auto simp add: runs_to_iff)
  done

lemma exec_abstract_run_bind: "exec_abstract st f =
  do {
   s  get_state;
   run_bind f (st s) (λr' t'. do {s'  select {s'. t' = st s'}; set_state s'; yield r' })
  }"
  apply (rule spec_monad_eqI)
  apply (auto simp add: runs_to_iff)
  done

lemma refines_run_bind:
  "refines f f' x y (λ(r,x') (r', y'). refines (g r x') (g' r' y') s t R) 
    refines (run_bind f x g) (run_bind f' y g') s t R"
  apply transfer
  apply (rule sim_bind_post_state, assumption)
  apply auto
  done


subsection ‹Iteration of monadic actions›

fun iter_spec_monad where
  "iter_spec_monad f 0       = skip" |
  "iter_spec_monad f (Suc n) = do {iter_spec_monad f n; f n }"

lemma iter_spec_monad_unfold:
  "0 < n  iter_spec_monad f n = do {iter_spec_monad f (n-1); f (n-1) }"
  by (metis Suc_pred' iter_spec_monad.simps(2))

lemma iter_spec_monad_cong:
  "(j. j < i  f j = g j)  iter_spec_monad f i = iter_spec_monad g i"
  by (induction i) auto

subsection forLoop›

definition forLoop:: "int  int  (int  ('a, 's) res_monad)  (int, 's) res_monad" where
  "forLoop z (m::int) B = whileLoop (λi s. i < m) (λi. do {B i; return (i + 1) }) z"

lemma runs_to_forLoop:
  assumes "z  m"
  assumes I: "r s. I r s  z  r  r < m  (B r)  s  λ_ t. I (r + 1) t "
  assumes Q: "s. I m s  Q m s"
  shows "I z s  (forLoop z m B)  s λRes r. Q r"
  unfolding forLoop_def
  using z  m Q I
  by (intro runs_to_whileLoop_res[where I="λq t. I q t  z  q  q  m"
        and P=Q
        and R="measure (λ(i, _). nat (m - i))"])
    (auto intro!: runs_to_bind)

lemma whileLoop_cong_inv:
  "run ((whileLoop C f z)::('a, 's) res_monad) s = run (whileLoop D g z) s"
  if I: "i s. I i s  C i s  f i  s ?⦃ λv' s'. i'. v' = Result i'  I i' s'"
    and I_eq: "i s. I i s  C i s  run (f i) s = run (g i) s"
    and C_eq: "i s. I i s  C i s  D i s"
    and "I z s"
  for s::'s and z::'a and R::"('a × 's) rel"
proof -
  have "rel_spec (whileLoop C f z) (whileLoop D g z) s s
    (λ(Res i, s) (Res i', s'). i = i'  s = s'  I i s)"
    apply (intro rel_spec_whileLoop_res)
    subgoal for i s j t using C_eq[of i s] by auto
    subgoal premises prems for i s j t
      using prems I[of i s]
      apply (clarsimp simp: rel_spec.rep_eq runs_to_partial.rep_eq I_eq
                      intro!: rel_post_state_refl')
      apply (auto intro: holds_partial_post_state_weaken)
      done
    subgoal using I z s by simp
    done
  from rel_spec_mono[OF _ this, of "(=)"]
  show ?thesis
    by (auto simp: le_fun_def rel_spec_eq)
qed

lemma forLoop_skip: "forLoop z m f = return z" if "z  m"
  using that unfolding forLoop_def
  by (subst whileLoop_unroll) simp

lemma forLoop_eq_whileLoop: "forLoop z m f = whileLoop C g z"
  if "i. z  i  i < m  g i = do { y  f i; return (i + 1) }"
    "i s. z  i  i  m  C i s  i < m"
    "s. z > m  ¬ C z s"
proof cases
  assume "z  m"
  then show ?thesis
    unfolding forLoop_def
    apply (intro spec_monad_ext ext whileLoop_cong_inv[where I="λi s. z  i  i  m"])
    apply (auto simp: that intro!: runs_to_partial_bind)
    done
next
  assume "¬ z  m"
  then show ?thesis
    apply (subst whileLoop_unroll)
    apply (simp add: forLoop_skip)
    apply (rule spec_monad_ext)
    apply (simp add: that)
    done
qed

lemma runs_to_partial_forLoopE: "forLoop z m f  s ?⦃ λx s'. v. x = Result v  v = m "
  if "z  m"
proof -
  have *: "forLoop z m f = whileLoop (λi _. i  m) (λi. do {f i; return (i + 1) }) z"
    using z  m
    by (auto intro!: forLoop_eq_whileLoop)
  show ?thesis unfolding *
    by (rule runs_to_partial_whileLoop_cond_false[THEN runs_to_partial_weaken]) simp
qed

subsection whileLoop_unroll_reachable›

context fixes C :: "'a  's  bool" and B :: "'a  ('e::default, 'a, 's) spec_monad"
begin

inductive whileLoop_unroll_reachable :: "'a  's  'a  's  bool" for a s where
  initial[intro, simp]: "whileLoop_unroll_reachable a s a s"
| step: "b t X c u.
    whileLoop_unroll_reachable a s b t  C b t 
    run (B b) t = Success X  (Result c, u)  X 
    whileLoop_unroll_reachable a s c u"

lemma whileLoop_unroll_reachable_trans:
  assumes a_b: "whileLoop_unroll_reachable a s b t" and b_c: "whileLoop_unroll_reachable b t c u"
  shows "whileLoop_unroll_reachable a s c u"
proof (use b_c in induction arbitrary:)
  case (step c u X d v)
  show ?case
    by (rule whileLoop_unroll_reachable.step[OF step(5,2,3,4)])
qed (simp add: a_b)

end

lemma run_whileLoop_unroll_reachable_cong:
  assumes eq:
    "b t. whileLoop_unroll_reachable C B a s b t  C b t  C' b t"
    "b t. whileLoop_unroll_reachable C B a s b t  C b t  run (B b) t = run (B' b) t"
  shows "run (whileLoop C B a) s  = run (whileLoop C' B' a) s"
proof -
  have "rel_spec (whileLoop C B a) (whileLoop C' B' a) s s
    (λ(x', s') (y, t). y = x'  t = s'
      (a'. x' = Result a'  whileLoop_unroll_reachable C B a s a' s'))"
    apply (rule rel_spec_whileLoop)
    subgoal by (simp add: eq)
    subgoal for a' s'
      using whileLoop_unroll_reachable.step[of C B a s a' s']
      by (cases "run (B' a') s'") (auto simp add: rel_spec_def eq intro!: rel_set_refl)
    subgoal by simp
    subgoal for a' s' b' t' by (cases a') (auto simp add: rel_exception_or_result.simps)
    done
  then have "rel_spec (whileLoop C B a) (whileLoop C' B' a) s s (=)"
    by (rule rel_spec_mono[rotated]) auto
  then show ?thesis
    by (simp add: rel_spec_eq)
qed

subsection on_exit›

lemma refines_rel_prod_on_exit:
  assumes f: "refines fc fa sc sa (rel_prod R S')"
  assumes cleanup: "sc sa tc. S' sc sa  (sc, tc)  cleanupc  ta. (sa, ta)  cleanupa  S tc ta"
  assumes emp: "sc sa. S' sc sa  tc. (sc, tc)  cleanupc  ta. (sa, ta)  cleanupa"
  shows "refines (on_exit fc cleanupc) (on_exit fa cleanupa) sc sa (rel_prod R S)"
  unfolding on_exit_bind_exception_or_result_conv
  apply (rule refines_bind_exception_or_result) 
  apply (rule refines_mono [OF _ f])
  apply (clarsimp) 
  apply (rule refines_bind')
  apply (rule refines_state_select)
  using cleanup emp
   apply auto
  done

lemma refines_runs_to_partial_rel_prod_on_exit:
  assumes f: "refines fc fa sc sa (rel_prod R S')"
  assumes runs_to: "fc  sc ?⦃λr t. P t"
  assumes cleanup: "sc sa tc. S' sc sa  (sc, tc)  cleanupc  P sc  ta. (sa, ta)  cleanupa  S tc ta"
  assumes emp: "sc sa. S' sc sa  P sc  tc. (sc, tc)  cleanupc  ta. (sa, ta)  cleanupa"
  shows "refines (on_exit fc cleanupc) (on_exit fa cleanupa) sc sa (rel_prod R S)"
proof -
  from refines_runs_to_partial_fuse [OF f runs_to]
  have "refines fc fa sc sa (λ(r, t) (r', t'). rel_prod R S' (r, t) (r', t')  P t)" .
  moreover have "(λ(r, t) (r', t'). rel_prod R S' (r, t) (r', t')  P t) = rel_prod R (λt t'. S' t t'  P t)"
    by (auto simp add: rel_prod_conv)
  ultimately have "refines fc fa sc sa (rel_prod R (λt t'. S' t t'  P t))" by simp
  then show ?thesis
    apply (rule refines_rel_prod_on_exit) 
    subgoal using cleanup by blast
    subgoal using emp by blast
    done
qed


lemma rel_spec_monad_mono:
  assumes Q: "rel_spec_monad R Q f g" and QQ': "x y. Q x y  Q' x y"
  shows "rel_spec_monad R Q' f g"
proof -
  have "rel_spec_monad R Q  rel_spec_monad R Q'"
    unfolding rel_spec_monad_def using QQ' 
    by (auto simp add: rel_post_state.simps rel_set_def)
      (metis rel_prod_sel)+
  with Q show ?thesis by auto
qed

lemma gets_return: "gets (λ_. x) = return x"
  by (rule spec_monad_eqI) (auto simp add: runs_to_iff)

lemma bind_handle_bind_exception_or_result_conv: 
  "bind_handle f g h = 
         bind_exception_or_result f 
          (λException e  h e | Result v  g v)"
  apply (rule spec_monad_eqI)
  by (auto simp add: runs_to_iff)
    (auto elim!: runs_to_weaken split: exception_or_result_splits)

lemma bind_handle_bind_exception_or_result_conv_exn: "bind_handle f g (λSome e  h e) = 
         bind_exception_or_result f 
          (λExn e  h e | Result v  g v)"
  by (simp add: bind_handle_bind_exception_or_result_conv case_xval_def)

lemma try_nested_bind_exception_or_result_conv: 
  shows "try (f >>= g) = 
    (bind_exception_or_result f 
      (λExn e  (case e of Inl l  throw l | Inr r  return r )
      | Result v  try (g v)))"
  apply (rule spec_monad_eqI)
  by (auto simp add: runs_to_iff )
    (auto elim!: runs_to_weaken simp add: runs_to_iff  unnest_exn_def 
      split: xval_splits sum.splits)

lemma try_nested_bind_handle_conv:
  shows "try (f >>= g) = 
    (bind_handle f (λv. try (g v))
      (λSome e  (case e of Inl l  throw l | Inr r  return r )))"
  by (simp add: bind_handle_bind_exception_or_result_conv_exn try_nested_bind_exception_or_result_conv)

definition no_fail:: "('s  bool)  ('e::default, 'a, 's) spec_monad  bool" where
 "no_fail P f  s. P s  run f s  "

definition no_throw:: "('s  bool)  ('e::default, 'a, 's) spec_monad  bool" where
 "no_throw P f  s. P s  f  s ?⦃ λr t. v. r = Result v"

definition no_return:: "('s  bool)  ('e::default, 'a, 's) spec_monad  bool" where
 "no_return P f  s. P s  f  s ?⦃ λr t. e. r = Exception e  e  default"

lemma no_return_exn_def: "no_return P f  (s. P s  f  s ?⦃ λr t. e. r = Exn e)"
  by (auto simp add: no_return_def Exn_def default_option_def elim!: runs_to_partial_weaken ) 

lemma no_throw_gets[simp]: "no_throw P (gets f)"
  by (auto simp add: no_throw_def runs_to_partial_def)

lemma no_throw_modify[simp]: "no_throw P (modify f)"
  by (auto simp add: no_throw_def runs_to_partial_def)

lemma no_throw_select[simp]: "no_throw P (select f)"
  by (auto simp add: no_throw_def runs_to_partial_def)

lemma always_progress_select_UNIV[simp]: "always_progress (select UNIV)"
  by (auto simp add: always_progress_def bot_post_state_def)

lemma rel_spec_monad_rel_xval_catch:
  assumes fh: "rel_spec_monad R (rel_xval E Q) f h"
  assumes gi: "rel_fun E (rel_spec_monad R (rel_xval E2 Q)) g i"
  shows "rel_spec_monad R (rel_xval E2 Q) (f <catch> g) (h <catch> i)"
  using assms unfolding rel_spec_monad_iff_rel_spec
  by (auto intro!: rel_spec_catch simp: rel_fun_def)

section ‹Setup for Tagging›

lemma runs_to_tag:
  "( tag  f  s  P )  (tag ¦ f)  s  P "
  unfolding TAG_def ASM_TAG_def by auto

lemma runs_to_tag_guard:
  fixes g :: "'a  bool"
    and s :: "'a"
  assumes "tag ¦ g s"
  assumes "P (Result ()) s"
  shows "(tag ¦ guard g)  s  P "
  using assms unfolding TAG_def by - (rule runs_to_weaken, rule runs_to_guard; simp)

bundle runs_to_vcg_tagging_setup
begin

unbundle basic_vcg_tagging_setup

lemmas [runs_to_vcg] = runs_to_tag runs_to_tag_guard

end

end