Theory Exc_Nres_Monad

section ‹Exception Monad for Refine-Monadic›
theory Exc_Nres_Monad
imports "Refine_Imperative_HOL.IICF"
begin

(*
  TODO:
    * Integrate with sepref --- currently, it's "integrated" by providing 
        some support to translate the program in to a plain nres monad before 
        sepref is invoked.
    * Move to Refine_Monadic.
*)
  
declare TrueI[refine_vcg]

type_synonym ('e,'a) enres = "('e + 'a) nres"

named_theorems enres_unfolds ‹Unfolding theorems from enres to nres›


definition [enres_unfolds]: "ERETURN x  RETURN (Inr x)"
definition ebind :: "('e,'a) enres  ('a  ('e,'b) enres)  ('e,'b) enres" 
  where [enres_unfolds]:
  "ebind m f  do {
    x  m;
    case x of Inl e  RETURN (Inl e) | Inr x  f x
  }"
definition [enres_unfolds]: "THROW e == RETURN (Inl e)"

definition [enres_unfolds]: "ESPEC Φ Ψ  SPEC (λInl e  Φ e | Inr r  Ψ r)"

definition [enres_unfolds]: "CATCH m h  do { rm; case r of Inl e  h e | Inr r  RETURN (Inr r) }"



abbreviation (do_notation) bind_doE where "bind_doE  ebind"

notation (output) bind_doE (infixl  54)
notation (ASCII output) bind_doE (infixl >>= 54)


nonterminal doE_binds and doE_bind
syntax
  "_doE_block" :: "doE_binds  'a" (doE {//(2  _)//} [12] 62)
  "_doE_bind"  :: "[pttrn, 'a]  doE_bind" ((2_ / _) 13)
  "_doE_let" :: "[pttrn, 'a]  doE_bind" ((2let _ =/ _) [1000, 13] 13)
  "_doE_then" :: "'a  doE_bind" (‹_› [14] 13)
  "_doE_final" :: "'a  doE_binds" (‹_›)
  "_doE_cons" :: "[doE_bind, doE_binds]  doE_binds" (‹_;//_› [13, 12] 12)
  "_thenM" :: "['a, 'b]  'c" (infixl  54)

syntax (ASCII)
  "_doE_bind" :: "[pttrn, 'a]  doE_bind" ((2_ <-/ _) 13)
  "_thenM" :: "['a, 'b]  'c" (infixr >> 54)

syntax_consts
  "_doE_block" "_doE_bind" "_doE_cons" "_thenM"  bind_doE and
  "_doE_let"  Let

translations
  "_doE_block (_doE_cons (_doE_then t) (_doE_final e))"
     "CONST bind_doE t (λ_. e)"
  "_doE_block (_doE_cons (_doE_bind p t) (_doE_final e))"
     "CONST bind_doE t (λp. e)"
  "_doE_block (_doE_cons (_doE_let p t) bs)"
     "let p = t in _doE_block bs"
  "_doE_block (_doE_cons b (_doE_cons c cs))"
     "_doE_block (_doE_cons b (_doE_final (_doE_block (_doE_cons c cs))))"
  "_doE_cons (_doE_let p t) (_doE_final s)"
     "_doE_final (let p = t in s)"
  "_doE_block (_doE_final e)"  "e"
  "(m  n)"  "(m  (λ_. n))"


definition [enres_unfolds]: "CHECK Φ e  if Φ then ERETURN () else THROW e"

definition [enres_unfolds]: "EASSUME Φ  if Φ then ERETURN () else SUCCEED"
definition [enres_unfolds]: "EASSERT Φ  if Φ then ERETURN () else FAIL"

lemma EASSUME_simps[simp]: 
  "EASSUME True = ERETURN ()"
  "EASSUME False = SUCCEED"
  unfolding EASSUME_def by auto

lemma EASSERT_simps[simp]: 
  "EASSERT True = ERETURN ()"
  "EASSERT False = FAIL"
  unfolding EASSERT_def by auto

lemma CHECK_simps[simp]: 
  "CHECK True e = ERETURN ()" 
  "CHECK False e = THROW e" 
  unfolding CHECK_def by auto

lemma pw_ESPEC[simp, refine_pw_simps]:
  "nofail (ESPEC Φ Ψ)"
  "inres (ESPEC Φ Ψ) (Inl e)  Φ e"
  "inres (ESPEC Φ Ψ) (Inr x)  Ψ x"
  unfolding enres_unfolds
  by auto

lemma pw_ERETURN[simp, refine_pw_simps]:
  "nofail (ERETURN x)"
  "¬inres (ERETURN x) (Inl e)"
  "inres (ERETURN x) (Inr y)  x=y"
  unfolding enres_unfolds
  by auto

lemma pw_ebind[refine_pw_simps]:
  "nofail (ebind m f)  nofail m  (x. inres m (Inr x)  nofail (f x))"
  "inres (ebind m f) (Inl e)  inres m (Inl e)  (x. inres m (Inr x)  inres (f x) (Inl e))"
  "inres (ebind m f) (Inr x)  nofail m  (y. inres m (Inr y)  inres (f y) (Inr x))"
  unfolding enres_unfolds
  apply (auto simp: refine_pw_simps split: sum.split)
  using sum.exhaust_sel apply blast
  using sum.exhaust_sel apply blast
  done

lemma pw_THROW[simp,refine_pw_simps]:
  "nofail (THROW e)"
  "inres (THROW e) (Inl f)  f=e"
  "¬inres (THROW e) (Inr x)"
  unfolding enres_unfolds
  by (auto simp: refine_pw_simps)

lemma pw_CHECK[simp, refine_pw_simps]:
  "nofail (CHECK Φ e)"
  "inres (CHECK Φ e) (Inl f)  ¬Φ  f=e"
  "inres (CHECK Φ e) (Inr u)  Φ"
  unfolding enres_unfolds
  by (auto simp: refine_pw_simps)
  
lemma pw_EASSUME[simp, refine_pw_simps]:
  "nofail (EASSUME Φ)"
  "¬inres (EASSUME Φ) (Inl e)"
  "inres (EASSUME Φ) (Inr u)  Φ"
  unfolding EASSUME_def  
  by (auto simp: refine_pw_simps)

lemma pw_EASSERT[simp, refine_pw_simps]:
  "nofail (EASSERT Φ)  Φ"
  "inres (EASSERT Φ) (Inr u)"
  "inres (EASSERT Φ) (Inl e)  ¬Φ"
  unfolding EASSERT_def  
  by (auto simp: refine_pw_simps)

lemma pw_CATCH[refine_pw_simps]:
  "nofail (CATCH m h)  (nofail m  (x. inres m (Inl x)  nofail (h x)))"
  "inres (CATCH m h) (Inl e)  (nofail m  (e'. inres m (Inl e')  inres (h e') (Inl e)))"
  "inres (CATCH m h) (Inr x)  inres m (Inr x)  (e. inres m (Inl e)  inres (h e) (Inr x))"
  unfolding CATCH_def
  apply (auto simp add: refine_pw_simps split: sum.splits)  
  using sum.exhaust_sel apply blast
  using sum.exhaust_sel apply blast
  done
  
lemma pw_ele_iff: "m  n  (nofail n  
    nofail m 
   (e. inres m (Inl e)  inres n (Inl e))
   (x. inres m (Inr x)  inres n (Inr x))
  )"
  apply (auto simp: pw_le_iff)
  by (metis sum.exhaust_sel)

lemma pw_eeq_iff: "m = n  
    (nofail m  nofail n) 
   (e. inres m (Inl e)  inres n (Inl e))
   (x. inres m (Inr x)  inres n (Inr x))"
  apply (auto simp: pw_eq_iff)
  by (metis sum.exhaust_sel)+
  


lemma enres_monad_laws[simp]:
  "ebind (ERETURN x) f = f x"
  "ebind m (ERETURN) = m"
  "ebind (ebind m f) g = ebind m (λx. ebind (f x) g)"
  by (auto simp: pw_eeq_iff refine_pw_simps)
  
lemma enres_additional_laws[simp]:
  "ebind (THROW e) f = THROW e"
  
  "CATCH (THROW e) h = h e"
  "CATCH (ERETURN x) h = ERETURN x"
  "CATCH m THROW = m"
  
  apply (auto simp: pw_eeq_iff refine_pw_simps)
  done  

lemmas ESPEC_trans = order_trans[where z="ESPEC Error_Postcond Normal_Postcond" for Error_Postcond Normal_Postcond, zero_var_indexes]

lemma ESPEC_cons: 
  assumes "m  ESPEC E Q"
  assumes "e. E e  E' e"
  assumes "x. Q x  Q' x"
  shows "m  ESPEC E' Q'"  
  using assms by (auto simp: pw_ele_iff)
  
  
lemma ebind_rule_iff: "doE { xm; f x }  ESPEC Φ Ψ  m  ESPEC Φ (λx. f x  ESPEC Φ Ψ)"
  by (auto simp: pw_ele_iff refine_pw_simps)

lemmas ebind_rule[refine_vcg] = ebind_rule_iff[THEN iffD2]

lemma ERETURN_rule_iff[simp]: "ERETURN x  ESPEC Φ Ψ  Ψ x"
  by (auto simp: pw_ele_iff refine_pw_simps)
lemmas ERETURN_rule[refine_vcg] = ERETURN_rule_iff[THEN iffD2]

lemma ESPEC_rule_iff: "ESPEC Φ Ψ  ESPEC Φ' Ψ'  (e. Φ e  Φ' e)  (x. Ψ x  Ψ' x)"
  by (auto simp: pw_ele_iff refine_pw_simps)
lemmas ESPEC_rule[refine_vcg] = ESPEC_rule_iff[THEN iffD2]

lemma THROW_rule_iff: "THROW e  ESPEC Φ Ψ  Φ e"
  by (auto simp: pw_ele_iff refine_pw_simps)
lemmas THROW_rule[refine_vcg] = THROW_rule_iff[THEN iffD2]

lemma CATCH_rule_iff: "CATCH m h  ESPEC Φ Ψ  m  ESPEC (λe. h e  ESPEC Φ Ψ) Ψ"
  by (auto simp: pw_ele_iff refine_pw_simps)
lemmas CATCH_rule[refine_vcg] = CATCH_rule_iff[THEN iffD2]



lemma CHECK_rule_iff: "CHECK c e  ESPEC Φ Ψ  (c  Ψ ())  (¬c  Φ e)"
  by (auto simp: pw_ele_iff refine_pw_simps)

lemma CHECK_rule[refine_vcg]:
  assumes "c  Ψ ()"
  assumes "¬c  Φ e"
  shows "CHECK c e  ESPEC Φ Ψ"
  using assms by (simp add: CHECK_rule_iff)



lemma EASSUME_rule[refine_vcg]: "Φ  Ψ ()  EASSUME Φ  ESPEC E Ψ"
  by (cases Φ) auto

lemma EASSERT_rule[refine_vcg]: "Φ; Φ  Ψ ()  EASSERT Φ  ESPEC E Ψ" by auto

lemma eprod_rule[refine_vcg]: 
  "a b. p=(a,b)  S a b  ESPEC Φ Ψ  (case p of (a,b)  S a b)  ESPEC Φ Ψ"
  by (auto split: prod.split)

(* TODO: Add a simplifier setup that normalizes nested case-expressions to
  the vcg! *)
lemma eprod2_rule[refine_vcg]:
  assumes "a b c d. ab=(a,b); cd=(c,d)  f a b c d  ESPEC Φ Ψ"
  shows "(λ(a,b) (c,d). f a b c d) ab cd  ESPEC Φ Ψ"
  using assms
  by (auto split: prod.split)

lemma eif_rule[refine_vcg]: 
  " b  S1  ESPEC Φ Ψ; ¬b  S2  ESPEC Φ Ψ 
   (if b then S1 else S2)  ESPEC Φ Ψ"
  by (auto)

lemma eoption_rule[refine_vcg]: 
  " v=None  S1  ESPEC Φ Ψ; x. v=Some x  f2 x  ESPEC Φ Ψ 
   case_option S1 f2 v  ESPEC Φ Ψ"
  by (auto split: option.split)

lemma eLet_rule[refine_vcg]: "f v  ESPEC Φ Ψ  (let x=v in f x)  ESPEC Φ Ψ" by simp

lemma eLet_rule':
  assumes "x. x=v  f x  ESPEC Φ Ψ"
  shows "Let v (λx. f x)  ESPEC Φ Ψ"
  using assms by simp

definition [enres_unfolds]: "EWHILEIT I c f s  WHILEIT 
  (λInl _  True | Inr s  I s) 
  (λInl _  False | Inr s  c s)
  (λs. ASSERT (¬isl s)  (let s = projr s in f s))
  (Inr s)"

definition [enres_unfolds]: "EWHILET  EWHILEIT (λ_. True)"

lemma EWHILEIT_rule[refine_vcg]:
  assumes WF: "wf R"
    and I0: "I s0"
    and IS: "s. I s; b s; (s,s0)R*  f s  ESPEC E (λs'. I s'  (s', s)  R)"
    and IMP: "s. I s; ¬ b s; (s,s0)R*  Φ s"
  shows "EWHILEIT I b f s0  ESPEC E Φ"
  unfolding EWHILEIT_def ESPEC_def
  apply (rule order_trans[OF WHILEIT_weaken[where I="λInl e  E e | Inr s  I s  (s,s0)R*"]])
  apply (auto split: sum.splits) []
  apply (rule WHILEIT_rule[where R="inv_image (less_than <*lex*> R) (λInl e  (0,undefined) | Inr s  (1,s))"])
  subgoal using WF by auto
  subgoal using I0 by auto
  subgoal 
    apply (clarsimp split: sum.splits simp: ESPEC_def)
    apply (rule order_trans[OF IS])
    apply (auto simp: ESPEC_def)
    done
  subgoal using IMP by (auto split: sum.splits)
  done
  
lemma EWHILET_rule:
  assumes WF: "wf R"
    and I0: "I s0"
    and IS: "s. I s; b s; (s,s0)R*  f s  ESPEC E (λs'. I s'  (s', s)  R)"
    and IMP: "s. I s; ¬ b s; (s,s0)R*  Φ s"
  shows "EWHILET b f s0  ESPEC E Φ"
  unfolding EWHILET_def EWHILEIT_def ESPEC_def
  apply (rule order_trans[OF WHILEIT_weaken[where I="λInl e  E e | Inr s  I s  (s,s0)R*"]])
  apply (auto split: sum.splits) []
  apply (rule WHILEIT_rule[where R="inv_image (less_than <*lex*> R) (λInl e  (0,undefined) | Inr s  (1,s))"])
  subgoal using WF by auto
  subgoal using I0 by auto
  subgoal 
    apply (clarsimp split: sum.splits simp: ESPEC_def)
    apply (rule order_trans[OF IS])
    apply (auto simp: ESPEC_def)
    done
  subgoal using IMP by (auto split: sum.splits)
  done
  
lemma EWHILEIT_weaken:
  assumes "x. I x  I' x"
  shows "EWHILEIT I' b f x  EWHILEIT I b f x"
  unfolding enres_unfolds
  apply (rule WHILEIT_weaken)
  using assms by (auto split: sum.split)
    
text ‹Explicitly specify a different invariant. ›    
lemma EWHILEIT_expinv_rule:
  assumes WF: "wf R"
    and I0: "I s0"
    and IS: "s. I s; b s; (s,s0)R*  f s  ESPEC E (λs'. I s'  (s', s)  R)"
    and IMP: "s. I s; ¬ b s; (s,s0)R*  Φ s"
    and INVIMP: "s. I s  I' s"
  shows "EWHILEIT I' b f s0  ESPEC E Φ"
  apply (rule order_trans[OF EWHILEIT_weaken])
  using INVIMP apply assumption
  apply (rule EWHILEIT_rule; fact+)
  done
  

definition [enres_unfolds]: "enfoldli l c f s  
  nfoldli l (λInl eFalse | Inr x  c x) (λx s. do {ASSERT (¬isl s); let s=projr s; f x s}) (Inr s)"

lemma enfoldli_simps[simp]:
  "enfoldli [] c f s = ERETURN s"
  "enfoldli (x#ls) c f s = 
    (if c s then doE { sf x s; enfoldli ls c f s} else ERETURN s)"
  unfolding enres_unfolds
  by (auto split: sum.split intro!: arg_cong[where f = "Refine_Basic.bind _"] ext)

lemma enfoldli_rule:
  assumes I0: "I [] l0 σ0"
  assumes IS: "x l1 l2 σ.  l0=l1@x#l2; I l1 (x#l2) σ; c σ   f x σ  ESPEC E (I (l1@[x]) l2)"
  assumes FNC: "l1 l2 σ.  l0=l1@l2; I l1 l2 σ; ¬c σ   P σ"
  assumes FC: "σ.  I l0 [] σ; c σ   P σ"
  shows "enfoldli l0 c f σ0  ESPEC E P"
  unfolding enfoldli_def ESPEC_def
  apply (rule nfoldli_rule[where I="λl1 l2. λInl e  E e | Inr σ  I l1 l2 σ"])
  subgoal by (auto simp: I0)
  subgoal 
    apply (simp split: sum.splits)
    apply (erule (2) order_trans[OF IS])
    apply (auto simp: ESPEC_def)
    done
  subgoal using FNC by (auto split: sum.split)
  subgoal using FC by (auto split: sum.split)
  done


    
    
    
subsection ‹Data Refinement›

lemma sum_rel_conv:
  "(Inl l, s')  L,Rsum_rel  (l'. s'=Inl l'  (l,l')L)"
  "(Inr r, s')  L,Rsum_rel  (r'. s'=Inr r'  (r,r')R)"
  "(s, Inl l')  L,Rsum_rel  (l. s=Inl l  (l,l')L)"
  "(s, Inr r')  L,Rsum_rel  (r. s=Inr r  (r,r')R)"
  "(l. s  Inl l)  (r. s=Inr r)"
  "(r. s  Inr r)  (l. s=Inl l)"
  apply -
  subgoal by (cases s'; auto)
  subgoal by (cases s'; auto)
  subgoal by (cases s; auto)
  subgoal by (cases s; auto)
  subgoal by (cases s; auto)
  subgoal by (cases s; auto)
  done  

definition econc_fun (E) where [enres_unfolds]: "econc_fun E R  (E,Rsum_rel)"

lemma RELATES_pat_erefine[refine_dref_pattern]: "RELATES R; mi E E R m   mi E E R m" .

lemma pw_econc_iff[refine_pw_simps]:
  "inres (E E R m) (Inl ei)  (nofail m  (e. inres m (Inl e)  (ei,e)E))"
  "inres (E E R m) (Inr xi)  (nofail m  (x. inres m (Inr x)  (xi,x)R))"
  "nofail (E E R m)  nofail m"
  by (auto simp: refine_pw_simps econc_fun_def sum_rel_conv)

lemma econc_fun_id[simp]: "E Id Id = (λx. x)"
  by (auto simp: pw_eeq_iff refine_pw_simps intro!: ext)

lemma econc_fun_ESPEC: "E E R (ESPEC Φ Ψ) = ESPEC (λei. e. (ei,e)E  Φ e) (λri. r. (ri,r)R  Ψ r)"
  by (auto simp: pw_eeq_iff refine_pw_simps)

lemma econc_fun_ERETURN: "E E R (ERETURN x) = ESPEC (λ_. False) (λxi. (xi,x)R)"
  by (auto simp: pw_eeq_iff refine_pw_simps)

lemma econc_fun_univ_id[simp]: "E UNIV Id (ESPEC Φ Ψ) = ESPEC (λ_. Ex Φ) Ψ"
  by (auto simp: pw_eeq_iff refine_pw_simps)

lemma erefine_same_sup_Id[simp]: " IdE; IdR   m E E R m" by (auto simp: pw_ele_iff refine_pw_simps)

lemma econc_mono3: "mm'  E E R m  E E R m'"
  by (auto simp: pw_ele_iff refine_pw_simps)

(* Order of these two is important! *)    
lemma econc_x_trans[trans]: 
  "x  E E R y  y  z  x  E E R z"
  by (force simp: pw_ele_iff refine_pw_simps)
lemma econc_econc_trans[trans]: 
  "x E E1 R1 y  y  E E2 R2 z  x  E (E1 O E2) (R1 O R2) z"    
  by (force simp: pw_ele_iff refine_pw_simps)
  
    
    
lemma ERETURN_refine[refine]: 
  assumes "(xi,x)R"
  shows "ERETURN xi  EE R (ERETURN x)"
  using assms
  by (auto simp: pw_ele_iff refine_pw_simps)

lemma EASSERT_bind_refine_right:
  assumes "Φ  mi E E R m"
  shows "mi E E R (doE {EASSERT Φ; m})"
  using assms
  by (simp add: pw_ele_iff refine_pw_simps)
  
lemma EASSERT_bind_refine_left:
  assumes "Φ"
  assumes "mi E E R m"
  shows "(doE {EASSERT Φ; mi}) E E R m"
  using assms
  by simp

lemma EASSUME_bind_refine_right:
  assumes "Φ"
  assumes "mi E E R m"
  shows "mi E E R (doE {EASSUME Φ; m})"
  using assms
  by (simp)

lemma EASSUME_bind_refine_left:
  assumes "Φ  mi E E R m"
  shows "(doE {EASSUME Φ; mi}) E E R m"
  using assms
  by (simp add: pw_ele_iff refine_pw_simps)

lemma ebind_refine:
  assumes "mi E E R' m"
  assumes "xi x. (xi,x)R'  fi xi E E R (f x)"
  shows "doE { xi  mi; fi xi }  E E R (doE { x  m; f x })"
  using assms
  by (simp add: pw_ele_iff refine_pw_simps) blast

text ‹Order of this lemmas matters!›
lemmas [refine] = 
  ebind_refine
  EASSERT_bind_refine_left EASSUME_bind_refine_right
  EASSUME_bind_refine_left EASSERT_bind_refine_right

thm refine(1-10)

lemma ebind_refine':
  assumes "mi E E R' m"
  assumes "xi x. (xi,x)R'; inres mi (Inr xi); inres m (Inr x); nofail mi; nofail m  fi xi E E R (f x)"
  shows "doE { xi  mi; fi xi }  E E R (doE { x  m; f x })"
  using assms
  by (simp add: pw_ele_iff refine_pw_simps) blast

lemma THROW_refine[refine]: "(ei,e)E  THROW ei E E R (THROW e)"
  by (auto simp: pw_ele_iff refine_pw_simps)

lemma CATCH_refine':
  assumes "mi  E E' R m"
  assumes "ei e.  (ei,e)E'; inres mi (Inl ei); inres m (Inl e); nofail mi; nofail m   hi ei E E R (h e)"
  shows "CATCH mi hi  E E R (CATCH m h)"  
  using assms
  by (simp add: pw_ele_iff refine_pw_simps) blast
  
lemma CATCH_refine[refine]:
  assumes "mi  E E' R m"
  assumes "ei e.  (ei,e)E'   hi ei E E R (h e)"
  shows "CATCH mi hi  E E R (CATCH m h)"  
  using assms CATCH_refine' by metis

lemma CHECK_refine[refine]: 
  assumes "Φi  Φ"
  assumes "¬Φ  (msgi,msg)E"
  shows "CHECK Φi msgi E E Id (CHECK Φ msg)"
  using assms by (auto simp: pw_ele_iff refine_pw_simps)

text ‹This must be declared after @{thm CHECK_refine}!›
lemma CHECK_bind_refine[refine]: 
  assumes "Φi  Φ"
  assumes "¬Φ  (msgi,msg)E"
  assumes "Φ  mi E E R m"
  shows "doE {CHECK Φi msgi;mi} E E R (doE {CHECK Φ msg; m})"
  using assms by (auto simp: pw_ele_iff refine_pw_simps)
    
lemma Let_unfold_refine[refine]:
  assumes "f x  E E R (f' x')"
  shows "Let x f  E E R (Let x' f')"
  using assms by auto

lemma Let_refine:
  assumes "(m,m')R'"
  assumes "x x'. (x,x')R'  f x  E E R (f' x')"
  shows "Let m (λx. f x) E E R (Let m' (λx'. f' x'))"
  using assms by auto

lemma eif_refine[refine]:
  assumes "(b,b')bool_rel"
  assumes "b;b'  S1  E E R S1'"
  assumes "¬b;¬b'  S2  E E R S2'"
  shows "(if b then S1 else S2)  E E R (if b' then S1' else S2')"
  using assms by auto




(* TODO: Also add enfoldli_invar_refine *)
lemma enfoldli_refine[refine]:
  assumes "(li, l)  Slist_rel"
    and "(si, s)  R"
    and CR: "(ci, c)  R  bool_rel"
    and FR: "xi x si s.  (xi,x)S; (si,s)R; c s   fi xi si  E E R (f x s)"
  shows "enfoldli li ci fi si  E E R (enfoldli l c f s)"
  unfolding enres_unfolds
  apply (rule nfoldli_refine)
  apply (rule assms(1))
  apply (simp add: assms(2))
  subgoal using CR[param_fo] by (auto split: sum.split simp: sum_rel_conv)
  subgoal 
    apply refine_rcg
    applyS (auto split: sum.splits simp: sum_rel_conv)
    apply (rule FR[unfolded enres_unfolds]) 
    by (auto split: sum.splits simp: sum_rel_conv)
  done

lemma EWHILET_refine[refine]:
  assumes R0: "(x,x')R"
  assumes COND_REF: "x x'.  (x,x')R   b x = b' x'"
  assumes STEP_REF: 
    "x x'.  (x,x')R; b x; b' x'   f x  E E R (f' x')"
  shows "EWHILET b f x E E R (EWHILET b' f' x')"
  unfolding enres_unfolds
  apply refine_rcg
  using assms
  by (auto split: sum.splits simp: econc_fun_def)

thm WHILEIT_refine

lemma EWHILEIT_refine[refine]:
  assumes R0: "I' x'  (x,x')R"
  assumes I_REF: "x x'.  (x,x')R; I' x'   I x"  
  assumes COND_REF: "x x'.  (x,x')R; I x; I' x'   b x = b' x'"
  assumes STEP_REF: 
    "x x'.  (x,x')R; b x; b' x'; I x; I' x'   f x  E E R (f' x')"
  shows "EWHILEIT I b f x E E R (EWHILEIT I' b' f' x')"
  unfolding enres_unfolds
  apply refine_rcg
  using assms
  by (auto split: sum.splits simp: econc_fun_def)



subsubsection ‹Refine2- heuristics›

lemma remove_eLet_refine:
  assumes "M  E E R (f x)"
  shows "M  E E R (Let x f)" using assms by auto

lemma intro_eLet_refine:
  assumes "f x  E E R M'"
  shows "Let x f  E E R M'" using assms by auto

lemma ebind2let_refine[refine2]:
  assumes "ERETURN x  E E R' M'"
  assumes "x'. (x,x')R'  f x  E E R (f' x')"
  shows "Let x f  E E R (ebind M' (λx'. f' x'))"
  using assms 
  apply (simp add: pw_ele_iff refine_pw_simps) 
  apply fast
  done

lemma ebind_Let_refine2[refine2]: " 
    m' E E R' (ERETURN x);
    x'. inres m' (Inr x'); (x',x)R'  f' x'  E E R (f x) 
    ebind m' (λx'. f' x')  E E R (Let x (λx. f x))"
  apply (simp add: pw_ele_iff refine_pw_simps)
  apply blast
  done

lemma ebind2letRETURN_refine[refine2]:
  assumes "ERETURN x  E E R' M'"
  assumes "x'. (x,x')R'  ERETURN (f x)  E E R (f' x')"
  shows "ERETURN (Let x f)  E E R (ebind M' (λx'. f' x'))"
  using assms
  apply (simp add: pw_ele_iff refine_pw_simps)
  apply fast
  done

lemma ERETURN_as_SPEC_refine[refine2]:
  assumes "RELATES R"
  assumes "M  ESPEC (λ_. False) (λc. (c,a)R)"
  shows "M  E E R (ERETURN a)"
  using assms
  by (simp add: pw_ele_iff refine_pw_simps)

lemma if_ERETURN_refine[refine2]:
  assumes "b  b'"
  assumes "b;b'  ERETURN S1  E E R S1'"
  assumes "¬b;¬b'  ERETURN S2  E E R S2'"
  shows "ERETURN (if b then S1 else S2)  E E R (if b' then S1' else S2')"
  (* this is nice to have for small functions, hence keep it in refine2 *)
  using assms
  by (simp add: pw_le_iff refine_pw_simps)



text ‹Breaking down enres-monad ›
definition enres_lift :: "'a nres  (_,'a) enres" where 
  "enres_lift m  do { x  m; RETURN (Inr x) }"

lemma enres_lift_rule[refine_vcg]: "mSPEC Φ  enres_lift m  ESPEC E Φ"
  by (auto simp: pw_ele_iff pw_le_iff refine_pw_simps enres_lift_def)
  
named_theorems_rev enres_breakdown  
lemma [enres_breakdown]:
  "ERETURN x = enres_lift (RETURN x)"  
  "EASSERT Φ = enres_lift (ASSERT Φ)"
  "doE { x  enres_lift m; ef x } = do { x  m; ef x }"
  (*"NO_MATCH (enres_lift m) em ⟹ doE { x ← em; ef x } = do { ex ← em; case ex of Inl e ⇒ RETURN (Inl e) | Inr x ⇒ ef x }"*)
  unfolding enres_unfolds enres_lift_def
  apply (auto split: sum.splits simp: pw_eq_iff refine_pw_simps)
  done
  
lemma [enres_breakdown]: 
  "do { x  m; enres_lift (f x) } = enres_lift (do { x  m; f x })"
  "do { let x = v; enres_lift (f x) } = enres_lift (do { let x=v; f x })"
  unfolding enres_unfolds enres_lift_def
  apply (auto split: sum.splits simp: pw_eq_iff refine_pw_simps)
  done

lemma [enres_breakdown]:
  "CATCH (enres_lift m) h = enres_lift m"  
  unfolding enres_unfolds enres_lift_def
  apply (auto split: sum.splits simp: pw_eq_iff refine_pw_simps)
  done
  
lemma enres_lift_fail[simp]:  "enres_lift FAIL = FAIL"
  unfolding enres_lift_def by auto


(* TODO: Also do breakdown-thm for RECT. It's exactly the same approach! *)
lemma [enres_breakdown]: "EWHILEIT I c (λs. enres_lift (f s)) s = enres_lift (WHILEIT I c f s)"
  (is "?lhs = ?rhs")
proof (rule antisym)
  show "?lhs  ?rhs"
    unfolding enres_unfolds WHILEIT_def WHILET_def
    apply (rule RECT_transfer_rel'[where P="λc a. c = Inr a"])
    apply (simp add: while.WHILEI_body_trimono)
    apply (simp add: while.WHILEI_body_trimono)
    apply simp
    apply simp
    by (auto simp: WHILEI_body_def enres_lift_def pw_le_iff refine_pw_simps)

  show "?rhs  ?lhs"
    unfolding enres_unfolds WHILEIT_def WHILET_def
    apply (rule RECT_transfer_rel'[where P="λa c. c = Inr a"])
    apply (simp add: while.WHILEI_body_trimono)
    apply (simp add: while.WHILEI_body_trimono)
    apply simp
    apply simp
    by (auto simp: WHILEI_body_def enres_lift_def pw_le_iff refine_pw_simps)
qed    



lemma [enres_breakdown]: "EWHILET c (λs. enres_lift (f s)) s = enres_lift (WHILET c f s)"
  unfolding EWHILET_def WHILET_def enres_breakdown ..
  
lemma [enres_breakdown]: "enfoldli l c (λx s. enres_lift (f x s)) s = enres_lift (nfoldli l c f s)"
  apply (induction l arbitrary: s)
  by (auto simp: enres_breakdown)  
    
    
lemma [enres_breakdown]: 
  "(λ(a,b). enres_lift (f a b)) = (λx. enres_lift (case x of (a,b)  f a b))" by auto
  
lemmas [enres_breakdown] = nres_monad_laws nres_bind_let_law

lemma [enres_breakdown]:
  "doE { CHECK Φ e; m } = (if Φ then m else THROW e)"
  by (auto simp: enres_unfolds)

lemma [enres_breakdown]: "(if b then enres_lift m else enres_lift n) = enres_lift (if b then m else n)"
  by simp

lemma option_case_enbd[enres_breakdown]:
  "case_option (enres_lift fn) (λv. enres_lift (fs v)) = (λx. enres_lift (case_option fn fs x))"
  by (auto split: option.split)
    
    
    
named_theorems enres_inline
    
method opt_enres_unfold = ((unfold enres_inline)?; (unfold enres_monad_laws)?; (unfold enres_breakdown)?; (rule refl)?; (unfold enres_unfolds enres_lift_def nres_monad_laws)?; (rule refl)?)

  
subsection ‹More Combinators›  
  subsubsection ‹CHECK-Monadic›
    
  definition [enres_unfolds]: "CHECK_monadic c e  doE { b  c; CHECK b e }"
  
  lemma CHECK_monadic_rule_iff:
    "(CHECK_monadic c e  ESPEC E P)  (c  ESPEC E (λr. (r  P ())  (¬r  E e)))"
    by (auto simp: pw_ele_iff CHECK_monadic_def refine_pw_simps)
  
  lemma CHECK_monadic_pw[refine_pw_simps]:
    "nofail (CHECK_monadic c e)  nofail c"
    "inres (CHECK_monadic c e) (Inl ee)  (inres c (Inl ee)  inres c (Inr False)  ee=e)"
    "inres (CHECK_monadic c e) (Inr x)  (inres c (Inr True))"
    unfolding CHECK_monadic_def
    by (auto simp: refine_pw_simps)
  
  lemma CHECK_monadic_rule[refine_vcg]:
    assumes "c  ESPEC E (λr. (r  P ())  (¬r  E e))"
    shows "CHECK_monadic c e  ESPEC E P"
    using assms by (simp add: CHECK_monadic_rule_iff)  
  
  lemma CHECK_monadic_refine[refine]:
    assumes "ci  E ER bool_rel c"
    assumes "(ei,e)ER"  
    shows "CHECK_monadic ci ei E ER unit_rel (CHECK_monadic c e)"  
    using assms  
    by (auto simp: pw_ele_iff refine_pw_simps)
  
  lemma CHECK_monadic_CHECK_refine[refine]:
    assumes "ci  ESPEC (λe'. (e',e)ER  ¬c) (λr. r  c)"
    assumes "(ei,e)ER"
    shows "CHECK_monadic ci ei E ER unit_rel (CHECK c e)"
    using assms  
    by (auto simp: pw_ele_iff refine_pw_simps)
  
  lemma CHECK_monadic_endb[enres_breakdown]: "CHECK_monadic (enres_lift c) e = 
    do {b  c; CHECK b e}"
    by (auto simp: enres_unfolds enres_lift_def)
  
  
  
  
  

end