Theory NREST

theory NREST
  imports "HOL-Library.Extended_Nat" "Refine_Monadic.RefineG_Domain" Refine_Monadic.Refine_Misc  
  "HOL-Library.Monad_Syntax" NREST_Auxiliaries
begin

section "NREST"

datatype 'a nrest = FAILi | REST "'a  enat option"

instantiation nrest :: (type) complete_lattice
begin

fun less_eq_nrest where
  "_  FAILi  True" |
  "(REST a)  (REST b)  a  b" |
  "FAILi  (REST _)  False"

fun less_nrest where
  "FAILi < _  False" |
  "(REST _) < FAILi  True" |
  "(REST a) < (REST b)  a < b"

fun sup_nrest where
  "sup _ FAILi = FAILi" |
  "sup FAILi _ = FAILi" |
  "sup (REST a) (REST b) = REST (λx. max (a x) (b x))"

fun inf_nrest where 
  "inf x FAILi = x" |
  "inf FAILi x = x" |
  "inf (REST a) (REST b) = REST (λx. min (a x) (b x))"

lemma "min (None) (Some (1::enat)) = None" by simp
lemma "max (None) (Some (1::enat)) = Some 1" by eval

definition "Sup X  if FAILiX then FAILi else REST (Sup {f . REST f  X})"
definition "Inf X  if f. REST fX then REST (Inf {f . REST f  X}) else FAILi"

definition "bot  REST (Map.empty)"
definition "top  FAILi"

instance
proof(intro_classes, goal_cases)
  case (1 x y)
  then show ?case
    by (cases x; cases y) auto
next
  case (2 x)
  then show ?case
    by (cases x) auto
next
  case (3 x y z)
  then show ?case 
    by (cases x; cases y; cases z) auto
next
  case (4 x y)
  then show ?case 
    by (cases x; cases y) auto
next
  case (5 x y)
  then show ?case 
    by (cases x; cases y) (auto intro: le_funI)
next
  case (6 x y)
  then show ?case 
    by (cases x; cases y) (auto intro: le_funI)
next
  case (7 x y z)
  then show ?case
    by (cases x; cases y; cases z) (auto intro!: le_funI dest: le_funD)
next
  case (8 x y)
  then show ?case 
    by (cases x; cases y) (auto intro: le_funI)
next
  case (9 y x)
  then show ?case 
    by (cases x; cases y) (auto intro: le_funI)
next
  case (10 y x z)
  then show ?case
    by (cases x; cases y; cases z) (auto intro!: le_funI dest: le_funD)
next
  case (11 x A)
  then show ?case
    by (cases x) (auto simp add: Inf_nrest_def Inf_lower)
next
  case (12 A z)
  then show ?case
    by (cases z) (fastforce simp add: Inf_nrest_def le_Inf_iff)+
next
  case (13 x A)
  then show ?case
    by (cases x) (auto simp add: Sup_nrest_def Sup_upper)
next
  case (14 A z)
  then show ?case 
    by (cases z) (fastforce simp add: Sup_nrest_def Sup_le_iff)+
next
  case 15
  then show ?case
    by (auto simp add: Inf_nrest_def top_nrest_def)
next
  case 16
  then show ?case
    by (auto simp add: Sup_nrest_def bot_nrest_def bot_option_def)
qed
end


definition "RETURNT x  REST (λe. if e=x then Some 0 else None)"
abbreviation "FAILT  top::'a nrest"
abbreviation "SUCCEEDT  bot::'a nrest"
abbreviation SPECT where "SPECT  REST"

lemma RETURNT_alt: "RETURNT x = REST [x0]"
  unfolding RETURNT_def by auto

lemma nrest_inequalities[simp]: 
  "FAILT  REST X"
  "FAILT  SUCCEEDT" 
  "FAILT  RETURNT x"
  "SUCCEEDT  FAILT"
  "SUCCEEDT  RETURNT x"
  "REST X  FAILT"
  "RETURNT x  FAILT"
  "RETURNT x  SUCCEEDT"
  unfolding top_nrest_def bot_nrest_def RETURNT_def  
  by simp_all (metis option.distinct(1))+

lemma nrest_more_simps[simp]:
  "SUCCEEDT = REST X  X=Map.empty" 
  "REST X = SUCCEEDT  X=Map.empty" 
  "REST X = RETURNT x  X=[x0]" 
  "REST X = REST Y  X=Y"
  "RETURNT x = REST X  X=[x0]"
  "RETURNT x = RETURNT y  x=y" 
  unfolding top_nrest_def bot_nrest_def RETURNT_def
  by (auto simp add: fun_eq_iff)

lemma nres_simp_internals: 
  "REST Map.empty = SUCCEEDT"
   "FAILi = FAILT" 
  unfolding top_nrest_def bot_nrest_def by simp_all

lemma nres_order_simps[simp]:
  "¬ FAILT  REST M" 
  "REST M  REST M'  (MM')"
  by (auto simp: nres_simp_internals[symmetric])   

lemma nres_top_unique[simp]:" FAILT  S'  S' = FAILT"
  by (rule top_unique) 

lemma FAILT_cases[simp]: "(case FAILT of FAILi  P | REST x  Q x) = P"
  by (auto simp: nres_simp_internals[symmetric])  

lemma nrest_Sup_FAILT: 
  "Sup X = FAILT  FAILT  X"
  "FAILT = Sup X  FAILT  X"
  by (auto simp: nres_simp_internals Sup_nrest_def)

lemma nrest_Sup_SPECT_D: "Sup X = SPECT m  m x = Sup {f x | f. REST f  X}"
  unfolding Sup_nrest_def by(auto split: if_splits intro!: arg_cong[where f=Sup])

declare nres_simp_internals(2)[simp]

lemma nrest_noREST_FAILT[simp]: "(x2. m  REST x2)  m=FAILT"
  by (cases m) auto

lemma no_FAILTE:  
  assumes "g xa  FAILT" 
  obtains X where "g xa = REST X" 
  using assms by (cases "g xa") auto

lemma case_prod_refine:
  fixes P Q :: "'a  'b  'c nrest"
  assumes "a b. P a b  Q a b"
  shows "(case x of (a,b)  P a b)  (case x of (a,b)  Q a b)"
  using assms by (simp add: split_def)

lemma case_option_refine: (* obsolete ? *)
  fixes P Q :: "'a  'b  'c nrest"
  assumes
    "PN  QN"
    "a. PS a  QS a"
  shows "(case x of None  PN | Some a  PS a )  (case x of None  QN | Some a  QS a )"
  using assms by (auto split: option.splits)

lemma SPECT_Map_empty[simp]: "SPECT Map.empty  a" 
  by (cases a) (auto simp: le_fun_def)

lemma FAILT_SUP: "(FAILT  X)  Sup X = FAILT "
  by (simp add: nrest_Sup_FAILT)

section "pointwise reasoning"

named_theorems refine_pw_simps 
ML structure refine_pw_simps = Named_Thms
    ( val name = @{binding refine_pw_simps}
      val description = "Refinement Framework: " ^
        "Simplifier rules for pointwise reasoning" )    
  
definition nofailT :: "'a nrest  bool" where "nofailT S  SFAILT"

definition le_or_fail :: "'a nrest  'a nrest  bool" (infix n 50) where
  "m n m'  nofailT m  m  m'"

lemma nofailT_simps[simp]:
  "nofailT FAILT  False"
  "nofailT (REST X)  True"
  "nofailT (RETURNT x)  True"
  "nofailT SUCCEEDT  True"
  unfolding nofailT_def
  by (simp_all add: RETURNT_def)

definition inresT :: "'a nrest  'a  nat  bool" where 
  "inresT S x t  (case S of FAILi  True | REST X  (t'. X x = Some t'   enat tt'))"

lemma inresT_alt: "inresT S x t  REST ([xenat t])  S"
  unfolding inresT_def by (cases S) (auto dest!: le_funD[where x=x] simp: le_funI less_eq_option_def split: option.splits)


lemma inresT_mono: "inresT S x t  t'  t  inresT S x t'"
  unfolding inresT_def by (cases S) (auto simp add: order_subst2)

lemma inresT_RETURNT[simp]: "inresT (RETURNT x) y t  t = 0  y = x"
  by(auto simp: inresT_def RETURNT_def enat_0_iff split: nrest.splits)

lemma inresT_FAILT[simp]: "inresT FAILT r t"
  by(simp add: inresT_def)

lemma fail_inresT[refine_pw_simps]: "¬ nofailT M  inresT M x t"
  unfolding nofailT_def by simp

lemma pw_inresT_Sup[refine_pw_simps]: "inresT (Sup X) r t  (MX. t't.  inresT M r t')"
proof
  assume a: "inresT (Sup X) r t"
  show "(MX. t't.  inresT M r t')"
  proof(cases "Sup X")
    case FAILi
    then show ?thesis 
      by (force simp: nrest_Sup_FAILT)
  next
    case (REST Y)
    with a obtain t' where t': "Y r = Some t'" "enat tt'"
      by (auto simp add: inresT_def)
    from REST have Y: "Y = ( {f. SPECT f  X})"
      by (auto simp add: Sup_nrest_def split: if_splits)
    with t' REST have "( {f r | f . SPECT f  X}) = Some t'"
      using nrest_Sup_SPECT_D by fastforce

    with t' Y obtain f t'' where f_t'': "SPECT f  X" "f r = Some t''" "t' = Sup {t' | f t'. SPECT fX  f r = Some t'}"
      by (auto simp add: SUP_eq_Some_iff Sup_apply[where A=" {f. SPECT f  X}"]) 
    from f_t''(3) t'(2) have "enat t  Sup {t' | f t'. SPECT fX  f r = Some t'}" 
      by blast
    with f_t'' obtain f' t''' where "SPECT f'  X" "f' r = Some t'''" "enat t  t'''"
      by (smt (verit) Sup_enat_less empty_iff mem_Collect_eq option.sel)
    with REST show ?thesis 
      by (intro bexI[of _ "SPECT f'"] exI[of _ "t"]) (auto simp add: inresT_def)
  qed
next
  assume a: "(MX. t't.  inresT M r t')"
  from this obtain M t' where M_t': "MX" "t't" "inresT M r t'"
    by blast

  show "inresT (Sup X) r t"
  proof (cases "Sup X")
    case FAILi
    then show ?thesis 
      by (auto simp: nrest_Sup_FAILT top_Sup)
  next
    case (REST Y)
    with M_t' have "Y =  {f. SPECT f  X}" 
      by (auto simp add: Sup_nrest_def split: if_splits)
    from REST have nf: "FAILT  X"
      using FAILT_SUP by fastforce
    with M_t' REST obtain f t'' where "SPECT f  X" "f r = Some t''" "enat t'  t''" 
      by (auto simp add: inresT_def split: nrest.splits)
    with REST  M_t'(2) obtain a where "Y r = Some a" "enat t  a"
      by (metis Sup_upper enat_ord_simps(1) le_fun_def le_some_optE nres_order_simps(2) order_trans)
    then show ?thesis
      by (auto simp: REST inresT_def)
  qed
qed
lemma inresT_REST[simp]:
  "inresT (REST X) x t  (t't. X x = Some t')" 
  unfolding inresT_def 
  by auto

lemma pw_Sup_nofail[refine_pw_simps]: "nofailT (Sup X)  (xX. nofailT x)"
  by (cases "Sup X") (auto simp add: Sup_nrest_def nofailT_def split: if_splits)

lemma inres_simps[simp]:
  "inresT FAILT = (λ_ _. True)" 
  "inresT SUCCEEDT = (λ_ _ . False)"
  unfolding inresT_def [abs_def]
  by (auto split: nrest.splits simp add: RETURNT_def) 

lemma pw_le_iff: 
  "S  S'  (nofailT S' (nofailT S  (x t. inresT S x t  inresT S' x t)))"
proof (cases S; cases S')
  fix R R'
  assume a[simp]: "S=SPECT R" "S'=SPECT R'"
  show "S  S'  (nofailT S' (nofailT S  (x t. inresT S x t  inresT S' x t)))" (is "?l  ?r")
  proof (rule iffI; safe)
    assume b: "nofailT S" "x t. inresT S x t  inresT S' x t" 
    hence b_2': "inresT S x t  inresT S' x t" for x t
      by blast
    from b(1) have "nofailT S'"
      by auto
    with a b_2' have nf: "R x  None  R' x  None" for x
      by simp (metis enat_ile nle_le not_Some_eq)
    have "R x  R' x" for x
    proof (cases "R x"; cases "R' x")
      fix r r'
      assume c[simp]: "R x = Some r" "R' x = Some r'"
      show ?thesis
      proof(rule ccontr)
        assume "¬ R x  R' x"
        from this obtain vt where "inresT S x vt" "¬inresT S' x vt"
          by simp (metis Suc_ile_eq enat.exhaust linorder_le_less_linear order_less_irrefl)
        with b_2' show False
          by blast
      qed
    qed (use nf in auto simp add: inresT_def nofailT_def) 
    then show "S  S'" 
      by (auto intro!: le_funI)
  qed(fastforce simp add: less_eq_option_def le_fun_def split: option.splits)+
qed auto

lemma RETURN_le_RETURN_iff[simp]: "RETURNT x  RETURNT y  x=y"
  by (auto simp add: pw_le_iff)

lemma "SS'  inresT S x t  inresT S' x t"
  unfolding inresT_alt by auto

lemma pw_eq_iff:
  "S=S'  (nofailT S = nofailT S'  (x t. inresT S x t  inresT S' x t))"
  by (auto intro: antisym simp add: pw_le_iff)

lemma pw_flat_ge_iff: "flat_ge S S'  
  (nofailT S)  nofailT S'  (x t. inresT S x t  inresT S' x t)"
  by (metis flat_ord_def nofailT_def pw_eq_iff)

lemma pw_eqI: 
  assumes "nofailT S = nofailT S'" 
  assumes "x t. inresT S x t  inresT S' x t" 
  shows "S=S'"
  using assms by (simp add: pw_eq_iff)
 
definition "consume M t  case M of 
          FAILi  FAILT |
          REST X  REST (map_option ((+) t) o (X))"

definition "SPEC P t = REST (λv. if P v then Some (t v) else None)"

lemma consume_mono: 
  assumes "tt'" "M  M'" 
  shows "consume M t  consume M' t'"
proof(cases M; cases M')
  fix m m' 
  assume a[simp]: "M=REST m" "M'=REST m'"
  from assms(2) have p: "m x  m' x" for x
    by (auto dest: le_funD)
  from assms(1) have "map_option ((+) t) (m x)  map_option ((+) t') (m' x)" for x
    using p[of x] by (cases "m' x"; cases "m x") (auto intro: add_mono)
  then show ?thesis
    by (auto intro: le_funI simp add: consume_def)
qed (use assms(2) in auto simp add: consume_def)

lemma nofailT_SPEC[refine_pw_simps]: "nofailT (SPEC a b)"
  unfolding SPEC_def by auto

lemma inresT_SPEC[refine_pw_simps]: "inresT (SPEC a b) = (λx t. a x   b x  t)"
  unfolding SPEC_def inresT_REST by (auto split: if_splits)

section ‹Monad Operators›

definition bindT :: "'b nrest  ('b  'a nrest)  'a nrest" where
  "bindT M f  case M of 
  FAILi  FAILT |
  REST X  Sup { (case f x of FAILi  FAILT 
                | REST m2  REST (map_option ((+) t1) o (m2) ))
                    |x t1. X x = Some t1}"

lemma bindT_alt: "bindT M f = (case M of 
  FAILi  FAILT | 
  REST X  Sup { consume (f x) t1 |x t1. X x = Some t1})"
  unfolding bindT_def consume_def by simp

lemma "bindT (REST X) f = 
  (x  dom X. consume (f x) (the (X x)))"
proof -
  have *: "f X. { f x t |x t. X x = Some t}
      = (λx. f x (the (X x))) ` (dom X)"
    by force
  show ?thesis by (auto simp: bindT_alt *)
qed

adhoc_overloading
  Monad_Syntax.bind NREST.bindT

lemma bindT_FAIL[simp]: "bindT FAILT g = FAILT"
  by (auto simp: bindT_def)       

lemma "bindT SUCCEEDT f = SUCCEEDT"
  unfolding bindT_def by (auto split: nrest.split simp add: bot_nrest_def) 

lemma "m r = Some   inresT (REST m) r t"
  by auto 

lemma pw_inresT_bindT_aux: "inresT (bindT m f) r t 
     (nofailT m  (r' t' t''. inresT m r' t'  inresT (f r') r t''  t  t' + t''))" (is "?l  ?r")
proof(intro iffI impI)
  assume ?l and "nofailT m"
  show "r' t' t''. inresT m r' t'  inresT (f r') r t''  t  t' + t''"
  proof(cases m)
    case [simp]: (REST X)
    with ?l obtain M x t1 t' where 
      parts: "X x = Some t1"
      "(x2. f x = SPECT x2 
         M = SPECT (map_option ((+) t1)  x2))"
       "t't" "inresT M r t'"
      by (auto simp add: pw_inresT_Sup bindT_def simp flip: nrest_noREST_FAILT split: nrest.splits)

    show ?thesis
    proof(cases "f x")
      case FAILi
      show ?thesis
        by (rule exI[where x=x], cases t1) (auto intro: le_add2 simp add: FAILi parts(1))
    next
      case [simp]: (REST re)
      with parts(2) have M: "M = SPECT (map_option ((+) t1)  re)"
        by blast
      with parts(4) obtain rer where rer[simp]: "re r = Some rer"
        by auto 
      show ?thesis
      proof(rule exI[where x=x])
        from M parts(1,3,4) show "t' t''. inresT m x t'  inresT (f x) r t''  t  t' + t''"
          by (cases "t1";cases rer) (fastforce intro: le_add2)+ 
      qed
    qed
  qed (use nofailT m in auto)
next
  assume "?r"
  show "?l"
  proof(cases m)
    case [simp]: (REST X)
    then show ?thesis
    proof(cases "nofailT m")
      case True
      with ?r obtain t' t'' t''' r' where
        parts: "enat t'  t'''"
        "X r' = Some t'''"
        "inresT (f r') r t''"
        "t  t' + t''"
        by (auto simp: bindT_def split: nrest.splits)
      then show ?thesis
      proof(cases "f r'")
        case FAILi
        with parts(2) show ?thesis 
          by (fastforce simp add: pw_inresT_Sup bindT_def split: nrest.splits)
      next
        case [simp]: (REST x)
        from parts(3) obtain ta where ta: "x r = Some ta" "enat t''ta"
          by auto
        with parts True obtain tf where tf: "tft" "enat tf  t''' + ta"
          using add_mono by fastforce
        with ta parts True show ?thesis
          by (force simp add: pw_inresT_Sup bindT_def split: nrest.splits
              intro!: exI[where x="REST (map_option ((+) t''')  x)"]) (* Witness here *)
      qed
    qed auto
  qed auto
qed

lemma pw_inresT_bindT[refine_pw_simps]: "inresT (bindT m f) r t 
     (nofailT m  (r' t' t''. inresT m r' t'  inresT (f r') r t''  t = t' + t''))"
  by (simp add: pw_inresT_bindT_aux) (metis add_le_imp_le_left inresT_mono le_iff_add nat_le_linear)

lemma pw_bindT_nofailT[refine_pw_simps]: "nofailT (bindT M f)  (nofailT M  (x t. inresT M x t  nofailT (f x)))" (is "?l  ?r")
proof
  assume ?l
  then show ?r
    by (force elim: no_FAILTE simp: bindT_def refine_pw_simps split: nrest.splits )  
next
  assume ?r
  hence a: "nofailT M" "x t . inresT M x t  nofailT (f x)"
    by auto
  show ?l
  proof(cases M)
    case FAILi
    then show ?thesis
      using ?r by (simp add: nofailT_def)
  next
    case [simp]: (REST m)
    with a have "f x  FAILi" if "m x  None" for x
      using that i0_lb by (auto simp add: nofailT_def zero_enat_def)
    then show ?thesis
      by (force simp add: bindT_def nofailT_def nrest_Sup_FAILT split: nrest.splits)
  qed
qed

lemma nat_plus_0_is_id[simp]: "((+) (0::enat)) = id" by auto

declare map_option.id[simp] 

section ‹Monad Rules›

lemma nres_bind_left_identity[simp]: "bindT (RETURNT x) f = f x"
  unfolding bindT_def RETURNT_def 
  by(auto split: nrest.split ) 

lemma nres_bind_right_identity[simp]: "bindT M RETURNT = M" 
  by(auto intro!: pw_eqI simp: refine_pw_simps) 

lemma nres_bind_assoc[simp]: "bindT (bindT M (λx. f x)) g = bindT M (%x. bindT (f x) g)"
proof (rule pw_eqI)
  fix x t
  show "inresT (M  f  g) x t = inresT (M  (λx. f x  g)) x t"
    by (simp add: refine_pw_simps) (use inresT_mono in fastforce)
qed (fastforce simp add: refine_pw_simps)

section ‹Monotonicity lemmas›

lemma bindT_mono: 
  "m  m'  (x. (t. inresT m x t)  nofailT m'   f x  f' x)
  bindT m f  bindT  m' f'"
  by(fastforce simp: pw_le_iff refine_pw_simps) 

lemma bindT_mono'[refine_mono]: 
  "m  m'  (x.   f x  f' x)
  bindT m f  bindT  m' f'"
  by (erule bindT_mono)

lemma bindT_flat_mono[refine_mono]:  
  " flat_ge M M'; x. flat_ge (f x) (f' x)   flat_ge (bindT M f) (bindT M' f')" 
  by (fastforce simp: refine_pw_simps pw_flat_ge_iff)

subsection ‹ Derived Program Constructs ›

subsubsection ‹Assertion› 

definition "iASSERT ret Φ  if Φ then ret () else top"

definition ASSERT where "ASSERT  iASSERT RETURNT"

lemma ASSERT_True[simp]: "ASSERT True = RETURNT ()" 
  by (auto simp: ASSERT_def iASSERT_def)
lemma ASSERT_False[simp]: "ASSERT False = FAILT" 
  by (auto simp: ASSERT_def iASSERT_def) 

lemma bind_ASSERT_eq_if: "do { ASSERT Φ; m } = (if Φ then m else FAILT)"
  unfolding ASSERT_def iASSERT_def by simp

lemma pw_ASSERT[refine_pw_simps]:
  "nofailT (ASSERT Φ)  Φ"
  "inresT (ASSERT Φ) x 0"
  by (cases Φ, simp_all)+

lemma ASSERT_refine: "(Q  P)  ASSERT P  ASSERT Q"
  by(auto simp: pw_le_iff refine_pw_simps)

lemma ASSERT_leI: "Φ  (Φ  M  M')  ASSERT Φ  (λ_. M)  M'"
  by(auto simp: pw_le_iff refine_pw_simps)

lemma le_ASSERTI: "(Φ  M  M')  M  ASSERT Φ  (λ_. M')"
  by(auto simp: pw_le_iff refine_pw_simps)

lemma inresT_ASSERT: "inresT (ASSERT Q  (λ_. M)) x ta = (Q  inresT M x ta)"
  unfolding ASSERT_def iASSERT_def by auto

lemma nofailT_ASSERT_bind: "nofailT (ASSERT P  (λ_. M))  (P  nofailT M)"
  by(auto simp: pw_bindT_nofailT pw_ASSERT)

subsection ‹SELECT›

definition emb' where "Q T. emb' Q (T::'a  enat) = (λx. if Q x then Some (T x) else None)"

abbreviation "emb Q t  emb' Q (λ_. t)" 

lemma emb_eq_Some_conv: "T. emb' Q T x = Some t'  (t'=T x  Q x)"
  by (auto simp: emb'_def)

lemma emb_le_Some_conv: "T. Some t'  emb' Q T x  ( t'T x  Q x)"
  by (auto simp: emb'_def)

lemma SPEC_REST_emb'_conv: "SPEC P t = REST (emb' P t)"
  unfolding SPEC_def emb'_def by auto

lemma SPECT_ub: "TT'  SPECT (emb' M' T)  SPECT (emb' M' T')"
  unfolding emb'_def by (auto simp: pw_le_iff le_funD order_trans refine_pw_simps)

text ‹Select some value with given property, or None› if there is none.›  
definition SELECT :: "('a  bool)  enat  'a option nrest"
  where "SELECT P tf  if x. P x then REST (emb (λr. case r of Some p  P p | None  False) tf)
               else REST [None  tf]"
                    
lemma inresT_SELECT_Some: "inresT (SELECT Q tt) (Some x) t'  (Q x   (t'  tt))"
  by(auto simp: inresT_def SELECT_def emb'_def) 

lemma inresT_SELECT_None: "inresT (SELECT Q tt) None t'  (¬(x. Q x)  (t'  tt))"
  by(auto simp: inresT_def SELECT_def emb'_def) 

lemma inresT_SELECT[refine_pw_simps]:
 "inresT (SELECT Q tt) x t'  ((case x of None  ¬(x. Q x) | Some x  Q x)   (t'  tt))"
  by(auto simp: inresT_def SELECT_def emb'_def) 


lemma nofailT_SELECT[refine_pw_simps]: "nofailT (SELECT Q tt)"
  by(auto simp: nofailT_def SELECT_def)

lemma s1: "SELECT P T  (SELECT P T')  T  T'"
  by (cases "x. P x"; cases T; cases T')(auto simp: pw_le_iff refine_pw_simps not_le split: option.splits)
     
lemma s2: "SELECT P T  (SELECT P' T)  (
    (Ex P'  Ex P)   (x. P x  P' x)) "
  by (cases T) (auto simp: pw_le_iff refine_pw_simps split: option.splits)
 
lemma SELECT_refine:
  assumes "x'. P' x'  x. P x"
  assumes "x. P x    P' x"
  assumes "T  T'"
  shows "SELECT P T  (SELECT P' T')"
proof -
  have "SELECT P T  SELECT P T'"
    using s1 assms(3) by auto
  also have "  SELECT P' T'"
    unfolding s2 using assms(1,2) by auto  
  finally show ?thesis .
qed


section ‹RECT›

definition "mono2 B  flatf_mono_ge B   mono B"


lemma trimonoD_flatf_ge: "mono2 B  flatf_mono_ge B"
  unfolding mono2_def by auto

lemma trimonoD_mono: "mono2 B  mono B"
  unfolding mono2_def by auto

definition "RECT B x = 
  (if mono2 B then (gfp B x) else (top::'a::complete_lattice))"

lemma RECT_flat_gfp_def: "RECT B x = 
  (if mono2 B then (flatf_gfp B x) else (top::'a::complete_lattice))"
  unfolding RECT_def
  by (auto simp: gfp_eq_flatf_gfp[OF trimonoD_flatf_ge trimonoD_mono])

lemma RECT_unfold: "mono2 B  RECT B = B (RECT B)"
  unfolding RECT_def [abs_def]  
  by (auto dest: trimonoD_mono simp: gfp_unfold[ symmetric])


definition whileT :: "('a  bool)  ('a  'a nrest)  'a  'a nrest" where
  "whileT b c = RECT (λwhileT s. (if b s then bindT (c s) whileT else RETURNT s))"

definition  whileIET :: "('a  bool)  ('a  nat)  ('a  bool)  ('a  'a nrest)  'a  'a nrest" where
  "E c. whileIET I E b c = whileT b c"

definition whileTI :: "('a  enat option)  ( ('a×'a) set)  ('a  bool)  ('a  'a nrest)  'a  'a nrest" where
  "whileTI I R b c s = whileT b c s"

lemma trimonoI[refine_mono]: 
  "flatf_mono_ge B; mono B  mono2 B"
  unfolding mono2_def by auto

(* Naming *)
lemma mono_fun_transform[refine_mono]: "(f g x. (x. f x  g x)  B f x  B g x)  mono B"
  by (intro monoI le_funI) (simp add: le_funD)

lemma whileT_unfold: "whileT b c = (λs. (if b s then bindT (c s) (whileT b c) else RETURNT s))"
  unfolding whileT_def by (rule RECT_unfold) (refine_mono)   


lemma RECT_mono[refine_mono]:
  assumes [simp]: "mono2 B'"
  assumes LE: "F x. (B' F x)  (B F x) "
  shows " (RECT B' x)  (RECT B x)"
  unfolding RECT_def by simp (meson LE gfp_mono le_fun_def)  

lemma whileT_mono: 
  assumes "x. b x  c x  c' x"
  shows " (whileT b c x)  (whileT b c' x)"
unfolding whileT_def proof (rule RECT_mono)
  show "(if b x then c x  F else RETURNT x)  (if b x then c' x  F else RETURNT x)" for F x
    using assms by (auto intro: bindT_mono)
qed refine_mono

lemma wf_fp_induct:
  assumes fp: "x. f x = B (f) x"
  assumes wf: "wf R"
  assumes "x D. y. (y,x)R  P y (D y)  P x (B D x)"
  shows "P x (f x)"
  using wf
  apply induction
  apply (subst fp)
  apply fact  
  done

lemma RECT_wf_induct_aux:
  assumes wf: "wf R"
  assumes mono: "mono2 B"
  assumes "(x D. (y. (y, x)  R  P y (D y))  P x (B D x))"
  shows "P x (RECT B x)"
  using wf_fp_induct[where f="RECT B" and B=B] RECT_unfold assms 
     by metis 

theorem RECT_wf_induct[consumes 1]:
  assumes "RECT B x = r"
  assumes "wf R"
    and "mono2 B"
    and "x D r. (y r. (y, x)  R  D y = r  P y r)  B D x = r  P x r"
  shows "P x r"
 (* using RECT_wf_induct_aux[where P = "λx fx. ∀r. fx=r ⟶ P x fx"] assms by metis *)
  using RECT_wf_induct_aux[where P = "λx fx.  P x fx"] assms by metis



definition "monadic_WHILEIT I b f s  do {
  RECT (λD s. do {
    ASSERT (I s);
    bv  b s;
    if bv then do {
      s  f s;
      D s
    } else do {RETURNT s}
  }) s
}"

section ‹Generalized Weakest Precondition›

subsection "mm"

definition mm :: "( 'a  enat option)  ( 'a  enat option)  ( 'a  enat option)" where
  "mm R m = (λx. (case m x of None  Some 
                                | Some mt 
                                  (case R x of None  None | Some rt  (if rt < mt then None else Some (rt - mt)))))"

lemma mm_mono: "Q1 x  Q2 x  mm Q1 M x  mm Q2 M x"
  unfolding mm_def by (cases "M x") (auto split: option.splits elim!: le_some_optE intro!: helper)


lemma mm_antimono: "M1 x  M2 x  mm Q M1 x  mm Q M2 x"
  unfolding mm_def by (auto split: option.splits intro: helper2)

lemma mm_continous: "mm (λx. Inf {u. y. u = f y x}) m x = Inf {u. y. u = mm (f y) m x}" 
proof(rule antisym)
  show "mm (λx.  {u. y. u = f y x}) m x   {u. y. u = mm (f y) m x}"
  proof (rule Inf_greatest, drule CollectD)
    fix z
    assume z: "y. z = mm (f y) m x"
    from this obtain y where y: "z = mm (f y) m x"
      by blast

    show "mm (λx.  {u. y. u = f y x}) m x  z"
    proof (cases "Inf {u. y. u = f y x}")
      case None
      with z show ?thesis 
        by (cases "m x") (auto simp add: mm_def None)
    next
      case Some_Inf[simp]: (Some l)
      then have i: "y. f y x  Some l"
        by (metis (mono_tags, lifting) Inf_lower mem_Collect_eq) 
      then have I: "y. mm (λx. Inf {u. y. u = f y x}) m x  mm (f y) m x"
        by (intro mm_mono) simp
      show ?thesis
      proof(cases "m x")
        case None
        with y show ?thesis 
          by (auto simp add: mm_def)
      next
        case [simp]: (Some a)
        with y I show ?thesis
          by (auto simp add: mm_def)
      qed
    qed
  qed
  show " {u. y. u = mm (f y) m x}  mm (λx.  {u. y. u = f y x}) m x"
  proof(rule Inf_lower, rule CollectI)
    have "y. Inf {u. y. u = f y x} = f y x"
    proof (cases "Option.these {u. y. u = f y x} = {}")
      case True
      then show ?thesis 
        by (simp add: Inf_option_def Inf_enat_def these_empty_eq) blast
    next
      case False
      then show ?thesis 
        by (auto simp add: Inf_option_def Inf_enat_def in_these_eq intro: LeastI)
    qed
    then obtain y where z: "Inf {u. y. u = f y x} = f y x"
      by blast
    show "y. mm (λx. Inf {u. y. u = f y x}) m x = mm (f y) m x"
      by (rule exI[where x=y]) (unfold mm_def z, rule refl)
  qed
qed

definition mm2 :: "(  enat option)  (   enat option)  (   enat option)" where
  "mm2 r m = (case m  of None  Some 
                                | Some mt 
                                  (case r of None  None | Some rt  (if rt < mt then None else Some (rt - mt))))"


lemma mm_alt: "mm R m x = mm2 (R x) (m x)" unfolding mm_def mm2_def ..

lemma mm2_None[simp]: "mm2 q None = Some " unfolding mm2_def by auto

lemma mm2_Some0[simp]: "mm2 q (Some 0) = q" unfolding mm2_def by (auto split: option.splits)

lemma mm2_antimono: "x  y  mm2 q x  mm2 q y"
  unfolding mm2_def by (auto split: option.splits intro: helper2)

lemma mm2_contiuous2:
  assumes "xX. t  mm2 q x" shows "t  mm2 q (Sup X)"
proof (cases q; cases "Sup X")
  fix q' assume q[simp]: "q = Some q'"
  fix S assume SupX[]: "Sup X = Some S"

  have "t  (if q' < S then None else Some (q' - S))"
  proof(cases "q' < S")
    case True
    with SupX obtain x where "x  Option.these X" "q' < x"
      using less_Sup_iff by (auto simp add: Sup_option_def split: if_splits)
    hence "Some x  X" "q < Some x"
      by (auto simp add: in_these_eq)
    with True assms show ?thesis
      by (auto simp add: mm2_def)
  next
    case False
    with assms SupX show ?thesis
      by (cases q'; cases S) (force simp: mm2_def dest: Sup_finite_enat)+
  qed
  then show ?thesis 
    by (auto simp add: mm2_def SupX)
qed (use assms in auto simp add: mm2_def Sup_option_def split: option.splits if_splits)
 
lemma fl: "(a::enat) - b =   a = "
  by(cases b; cases a) auto

lemma mm_inf1: "mm R m x = Some   m x = None  R x = Some "
  by (auto simp: mm_def split: option.splits if_splits intro: fl)

lemma mm_inf2: "m x = None  mm R m x = Some " 
  by(auto simp: mm_def split: option.splits if_splits)  

lemma mm_inf3: " R x = Some   mm R m x = Some " 
  by(auto simp: mm_def split: option.splits if_splits)  

lemma mm_inf: "mm R m x = Some   m x = None  R x = Some "
  using mm_inf1 mm_inf2 mm_inf3  by metis

lemma InfQ_E: "Inf Q = Some t  None  Q"
  unfolding Inf_option_def by auto

lemma InfQ_iff: "(t'enat t. Inf Q = Some t')  None  Q  Inf (Option.these Q)  t"
  unfolding Inf_option_def 
  by auto

lemma mm2_fst_None[simp]: "mm2 None q = (case q of None  Some  | _  None)"
  by (cases q) (auto simp: mm2_def)


lemma mm2_auxXX1: "Some t  mm2 (Q x) (Some t')  Some t'  mm2 (Q x) (Some t)"
  by (auto dest: fl simp: less_eq_enat_def mm2_def split: enat.splits option.splits if_splits)
 
subsection "mii"

definition mii :: "('a  enat option)  'a nrest  'a  enat option" where 
  "mii Qf M x =  (case M of FAILi  None 
                                             | REST Mf  (mm Qf Mf) x)"


lemma mii_alt: "mii Qf M x = (case M of FAILi  None 
                                             | REST Mf  (mm2 (Qf x) (Mf x)) )"
  unfolding mii_def mm_alt ..

lemma mii_continuous: "mii (λx. Inf {f y x|y. True}) m x = Inf {mii (%x. f y x) m x|y. True}"
  unfolding mii_def by (cases m) (auto simp add: mm_continous)
 
lemma mii_continuous2: "(mii Q (Sup {F x t1 |x t1. P x t1}) x  t) = (y t1. P y t1  mii Q (F y t1) x  t)"
proof(intro iffI allI impI)
  fix y t1
  assume a: "t  mii Q ( {F x t1 |x t1. P x t1}) x" "P y t1"
  then show "t  mii Q (F y t1) x"
  proof(cases "F y t1")
    case REST_F[simp]: (REST Ff)
    then show ?thesis
    proof(cases "( {F x t1 |x t1. P x t1})")
      case FAILi
      with a show ?thesis
        using a by (simp add: mii_alt less_eq_option_None_is_None')
    next
      case [simp]: (REST Sf)
      note Sf_x = nrest_Sup_SPECT_D[OF this, where x=x]
      from a(1) have "t  mm2 (Q x) (Sf x)" 
        by (auto simp add: mii_alt)
      also from a(2) have "mm2 (Q x) (Sf x)  mm2 (Q x) (Ff x)"
        by (intro mm2_antimono) (force intro: Sup_upper simp add: Sf_x)
      finally show ?thesis
        by (simp add: mii_alt)
    qed
  qed (auto simp: mii_alt  Sup_nrest_def split: if_splits) 
next
  assume "y t1. P y t1  t  mii Q (F y t1) x"
  then show "t  mii Q ( {F x t1 |x t1. P x t1}) x"
    by (auto simp: mii_alt Sup_nrest_def split: nrest.splits intro: mm2_contiuous2)
qed

lemma mii_inf: "mii Qf M x = Some   (Mf. M = SPECT Mf  (Mf x = None  Qf x = Some ))"
  by (auto simp: mii_def mm_inf split: nrest.split )

lemma miiFailt: "mii Q FAILT x = None" 
  unfolding mii_def by auto

subsection "lst - latest starting time"

definition lst :: "'a nrest  ('a  enat option)  enat option" 
  where "lst M Qf =  Inf { mii Qf M x | x. True}"

lemma T_alt_def: "lst M Qf = Inf ( (mii Qf M) ` UNIV )"
  unfolding lst_def by (simp add: full_SetCompr_eq) 


lemma T_pw: "lst M Q  t  (x. mii Q M x  t)"
  by (auto simp add: T_alt_def mii_alt le_Inf_iff)

lemma T_specifies_I: 
  assumes "lst m Q  Some 0" shows "m  SPECT Q"
proof (cases m)
  case [simp]: (REST q)
  from assms have "q x  Q x" for x
    by (cases "q x"; cases "Q x") 
      (force simp add: T_alt_def mii_alt mm2_def le_Inf_iff split: option.splits if_splits nrest.splits)+
  then show ?thesis
    by (auto intro: le_funI)
qed (use assms in auto simp add: T_alt_def mii_alt)

lemma T_specifies_rev: 
  assumes "m  SPECT Q" shows "lst m Q  Some 0" 
proof (cases m)
  case [simp]: (REST q)
  with assms have le: "q x  Q x " for x
    by (auto dest: le_funD)
  show ?thesis
  proof(subst T_pw, rule allI)
    fix x
    from le[of x] show "Some 0  mii Q m x"
      by (cases "q x"; cases "Q x") (auto simp add: mii_alt mm2_def)
  qed
qed (use assms in auto simp add: T_alt_def mii_alt)

lemma T_specifies: "lst m Q  Some 0 = (m  SPECT Q)"
  using T_specifies_I T_specifies_rev by metis

lemma pointwise_lesseq:
  fixes x :: "'a::order"
  shows "(t. x  t  x'  t)  x  x'"
  by simp

subsection "pointwise reasoning about lst via nres3"


definition nres3 where "nres3 Q M x t  mii Q M x  t"


lemma pw_T_le:
  assumes "t. (x. nres3 Q M x t)  (x. nres3 Q' M' x t)"
  shows "lst  M Q  lst  M' Q'"
  apply(rule pointwise_lesseq)
  using assms unfolding T_pw nres3_def by metis 

lemma assumes "t. (x. nres3 Q M x t) = (x. nres3 Q' M' x t)" 
  shows pw_T_eq_iff: "lst M Q  = lst M' Q'"
  apply (rule antisym)
   apply(rule pw_T_le) using assms apply metis
  apply(rule pw_T_le) using assms apply metis
  done 

lemma assumes "t. (x. nres3 Q M x t)  (x. nres3 Q' M' x t)"
      "t. (x. nres3 Q' M' x t)  (x. nres3 Q M x t)"
  shows pw_T_eqI: "lst M Q = lst M' Q'"
  apply (rule antisym)
   apply(rule pw_T_le) 
   apply fact
  apply(rule pw_T_le) 
  apply fact 
  done 

lemma lem: "t1. M y = Some t1  t  mii Q (SPECT (map_option ((+) t1)  x2)) x  f y = SPECT x2  t  mii (λy. mii Q (f y) x) (SPECT M) y"
proof(cases "M y"; cases t)
  fix m' t'
  assume a: "t1. M y = Some t1  t  mii Q (SPECT (map_option ((+) t1)  x2)) x" "f y = SPECT x2"
  assume c[simp]: "M y = Some m'" "t = Some t'"

  from a c show "t  mii (λy. mii Q (f y) x) (SPECT M) y"
  proof (cases "x2 x"; cases "Q x")
    fix x2' Q' 
    assume c'[simp]: "x2 x = Some x2'" "Q x = Some Q'"
    from a c show ?thesis 
      by (cases m'; cases x2'; cases Q') (auto split: option.splits if_splits simp add: add.commute mii_def mm_def)
  qed (auto split: option.splits if_splits simp add: add.commute mii_def mm_def)
qed(auto simp add: mii_def mm_def)

(* TODO Move *)
lemma diff_diff_add_enat:"a - (b+c) = a - b - (c::enat)" 
  by (cases a; cases b; cases c) auto


lemma lem2: "t  mii (λy. mii Q (f y) x) (SPECT M) y  M y = Some t1  f y = SPECT fF  t  mii Q (SPECT (map_option ((+) t1)  fF)) x"
  by (cases "fF x"; cases "Q x"; cases t)  (auto simp add: mii_def mm_def enat_plus_minus_aux2
      add.commute linorder_not_less diff_diff_add_enat split: if_splits)

lemma fixes m :: "'b nrest"
  shows mii_bindT: "(t  mii Q (bindT m f) x)  (y. t  mii (λy. mii Q (f y) x) m y)"
proof -
  { fix M
    assume mM: "m = SPECT M"
    let ?P = "%x t1. M x = Some t1"
    let ?F = "%x t1. case f x of FAILi  FAILT | REST m2  SPECT (map_option ((+) t1)  m2)"
    let ?Sup = "(Sup {?F x t1 |x t1. ?P x t1})" 

    { fix y 
      have 1: "mii (λy. mii Q (f y) x) (SPECT M) y = None" if "f y = FAILT" "M y  None"
        using that by (auto simp add: mii_def mm_def less_eq_option_None_is_None' )
      have "(t1. ?P y t1  mii Q (?F y t1) x  t)
              = (t  mii (λy. mii Q (f y) x) m y)"
        by (cases "f y") (auto intro: lem lem2 meta_le_everything_if_top simp add: 1 mM miiFailt mii_inf 
            top_enat_def top_option_def less_eq_option_None_is_None')
    } note h=this


    from mM have "mii Q (bindT m f) x = mii Q ?Sup x" by (auto simp: bindT_def)
    then have "(t  mii Q (bindT m f) x) = (t  mii Q ?Sup x)" by simp
    also have " = (y t1. ?P y t1  mii Q (?F y t1) x  t)" by (rule mii_continuous2)  
    also have " = (y. t  mii (λy. mii Q (f y) x) m y)" by (simp only: h)
    finally have ?thesis .
  } note bl=this

  show ?thesis
  proof(cases m)
    case FAILi
    then show ?thesis 
      by (simp add: mii_def)
  next
    case (REST x2)
    then show ?thesis
      by (rule bl)
  qed
qed


lemma nres3_bindT: "(x. nres3 Q (bindT m f) x t) = (y. nres3 (λy. lst (f y) Q) m y t)"
proof -
  have t: "f and t::enat option. (y. t  f y)  (tInf {f y|y. True})"  
    using le_Inf_iff by fastforce   

  have "(x. nres3 Q (bindT m f) x t) = (x.  t  mii Q (bindT m f) x)" unfolding nres3_def by auto
  also have " = (x. (y. t  mii (λy. mii Q (f y) x) m y))" by(simp only: mii_bindT)
  also have " = (y. (x. t  mii (λy. mii Q (f y) x) m y))" by blast
  also have " = (y.  t  mii (λy. Inf {mii Q (f y) x|x. True}) m y)"
    using t by (fastforce simp only: mii_continuous)
  also have " = (y.  t  mii (λy. lst (f y) Q) m y)" unfolding lst_def by auto
  also have " = (y. nres3 (λy. lst (f y) Q) m y t)" unfolding nres3_def by auto
  finally show ?thesis .
  have "(y.  t  mii (λy. lst (f y) Q) m y) = (t  Inf{ mii (λy. lst (f y) Q) m y|y. True})" using t by metis
qed


subsection "rules for lst"

lemma T_bindT: "lst (bindT M f) Q  = lst M (λy. lst (f y) Q)"
  by (rule pw_T_eq_iff, rule nres3_bindT) 

lemma T_REST: "lst (REST [xt]) Q = mm2 (Q x) (Some t)"
proof- 
  have *: "Inf {uu. xa. (xa = x  uu= v)  (xa  x  uu = Some )} = v"  (is "Inf ?S = v") for v :: "enat option"
  proof -
    have "?S  { {v}  {Some }, {v}  }" by auto 
    then show ?thesis 
      by safe (simp_all add: top_enat_def[symmetric] top_option_def[symmetric] ) 
  qed
  then show ?thesis
    unfolding lst_def mii_alt by auto
qed


lemma T_RETURNT: "lst (RETURNT x) Q = Q x"
  unfolding RETURNT_alt by (rule trans, rule T_REST) simp
             
lemma T_SELECT: 
  assumes  
    "x. ¬ P x  Some tt  lst (SPECT [None  tf]) Q"
  and p: "(x. P x  Some tt  lst (SPECT [Some x  tf]) Q)"
  shows "Some tt  lst (SELECT P tf) Q"
proof(cases "x. P x")
  case True
  from p[unfolded T_pw mii_alt] have
    p': "y x. P y  Some tt  mm2 (Q x) ([Some y  tf] x)"
    by auto
  hence p'': "y x. P y  x=Some y  Some tt  mm2 (Q x) (Some tf)"
    by (metis fun_upd_same)
  with True show ?thesis
    by (auto simp: SELECT_def emb'_def T_pw mii_alt split: if_splits option.splits)
next
  case False
  with assms show ?thesis 
    by (auto simp: SELECT_def)
qed 



section ‹consequence rules›
   
lemma aux1: "Some t  mm2 Q (Some t')  Some (t+t')  Q"
proof (cases t; cases t'; cases Q) 
  fix n n' t''
  assume a[simp]: "t = enat n" "t' = enat n'" "Q = Some t''"
  show ?thesis
    by (cases t'') (auto simp: mm2_def)
qed (auto simp: mm2_def elim: less_enatE split: option.splits)

lemma aux1a: "(x t''. Q' x = Some t''  (Q x)  Some (t + t''))
      = (x. mm2 (Q x) (Q' x)  Some t)"
proof(intro iffI allI impI)
  fix x
  assume a: "x t''. Q' x = Some t''  Some (t + t'')  Q x"
  then show "Some t  mm2 (Q x) (Q' x)"
    by (cases "Q' x") (simp_all add: aux1)
next
  fix x t''
  assume a: "x. Some t  mm2 (Q x) (Q' x)" "Q' x = Some t''"
  then show "Some (t + t'')  Q x"
    using aux1 by metis
qed

lemma T_conseq4:
  assumes 
    "lst f Q'  Some t'"
    "x t'' M. Q' x = Some t''  (Q x)  Some ((t - t') + t'')" 
  shows "lst f Q   Some t"
proof -
  {
    fix x
    from assms(1)[unfolded T_pw] have i: "Some t'  mii Q' f x" by auto
    from assms(2) have ii: "t''. Q' x = Some t''  (Q x)  Some ((t - t') + t'')" by auto
    from i ii have "Some t  mii Q f x"
    proof (cases f)
      case [simp]: (REST Mf)
      then show ?thesis
      proof(cases "Mf x")
        case [simp]: (Some a)
        have arith: "t' + a  b  t - t' + b  b'  t + a  b'" for b b'
          by (cases t; cases t'; cases a; cases b; cases b') auto
        with i ii show ?thesis
          by (cases "Q' x"; cases "Q x") (auto simp add: mii_alt aux1)
      qed (auto simp add: mii_alt)
    qed (auto simp add: mii_alt)
  } 
  thus ?thesis
    unfolding T_pw ..
qed

lemma T_conseq6:
  assumes 
    "lst f Q'  Some t"
    "x t'' M. f = SPECT M  M x  None  Q' x = Some t''  (Q x)  Some ( t'')" 
  shows "lst f Q  Some t"
proof -
  {
    fix x
    from assms(1)[unfolded T_pw] have i: "Some t  mii Q' f x" by auto
    from assms(2) have ii: "t'' M.  f = SPECT M  M x  None  Q' x = Some t''  (Q x)  Some ( t'')"
      by auto
    from i ii have "Some t  mii Q f x"
    proof (cases f)
      case [simp]: (REST Mf)
      then show ?thesis
      proof(cases "Mf x")
        case [simp]: (Some a)
        with i ii show ?thesis
          by (cases "Q' x"; cases "Q x") (fastforce simp add: mii_alt aux1)+
      qed (auto simp add: mii_alt)
    qed (auto simp add: mii_alt)
  } 
  thus ?thesis
    unfolding T_pw ..
qed



lemma T_conseq6':
  assumes 
    "lst f Q'  Some t"
    "x t'' M. f = SPECT M  M x  None    (Q x)  Q' x" 
  shows "lst f Q  Some t"
  by (rule T_conseq6) (auto intro: assms(1) dest: assms(2))

lemma T_conseq5:
  assumes 
    "lst f Q'  Some t'"
    "x t'' M. f = SPECT M  M x  None  Q' x = Some t''  (Q x)  Some ((t - t') + t'')" 
  shows "lst f Q  Some t"
proof -
  {
    fix x
    from assms(1)[unfolded T_pw] have i: "Some t'  mii Q' f x" by auto
    from assms(2) have ii: "t'' M.  f = SPECT M  M x  None  Q' x = Some t''  (Q x)  Some ((t - t') + t'')" by auto
    from i ii have "Some t  mii Q f x"
    
    proof (cases f)
      case [simp]: (REST Mf)
      then show ?thesis
      proof(cases "Mf x")
        case [simp]: (Some a)
        have arith: "t' + a  b  t - t' + b  b'  t + a  b'" for b b' 
          by (cases a; cases b; cases b'; cases t; cases t') auto
        with i ii show ?thesis
          by (cases "Q' x"; cases "Q x") (fastforce simp add: mii_alt aux1)+
      qed (auto simp add: mii_alt)
    qed (auto simp add: mii_alt)
  } 
  thus ?thesis
    unfolding T_pw ..
qed

lemma T_conseq3: 
  assumes 
    "lst f Q'  Some t'"
    "x. mm2 (Q x) (Q' x)  Some (t - t')" 
  shows "lst f Q  Some t"
  using assms T_conseq4 aux1a by metis
             
section "Experimental Hoare reasoning"

named_theorems vcg_rules

method vcg uses rls = ((rule rls vcg_rules[THEN T_conseq4] | clarsimp simp: emb_eq_Some_conv emb_le_Some_conv T_bindT T_RETURNT)+)

experiment
begin

definition P where "P f g = bindT f (λx. bindT g (λy. RETURNT (x+(y::nat))))"

lemma assumes
  f_spec[vcg_rules]: "lst f ( emb' (λx. x > 2) (enat o ((*) 2)) )  Some 0"
and 
  g_spec[vcg_rules]: "lst g ( emb' (λx. x > 2) (enat) )  Some 0"
shows "lst (P f g) ( emb' (λx. x > 5) (enat o (*) 3) )  Some 0"
proof -
  have ?thesis
    unfolding P_def by vcg

  have ?thesis
    unfolding P_def
    apply(simp add: T_bindT )
    apply(simp add: T_RETURNT)
    apply(rule T_conseq4[OF f_spec])
    apply(clarsimp simp: emb_eq_Some_conv)
    apply(rule T_conseq4[OF g_spec])
    apply (clarsimp simp: emb_eq_Some_conv emb_le_Some_conv)
    done
  thus ?thesis .
qed
end


section ‹VCG›

named_theorems vcg_simp_rules
lemmas [vcg_simp_rules] = T_RETURNT

lemma TbindT_I: "Some t   lst M (λy. lst (f y) Q)   Some t  lst (M  f) Q"
  by(simp add: T_bindT)

method vcg' uses rls = ((rule rls TbindT_I vcg_rules[THEN T_conseq6] | clarsimp split: if_splits simp:  vcg_simp_rules)+)

lemma mm2_refl: "A <   mm2 (Some A) (Some A) = Some 0"
  unfolding mm2_def by auto
 
definition mm3 where
  "mm3 t A = (case A of None  None | Some t'  if t't then Some (enat (t-t')) else None)"

lemma mm3_same[simp]: "mm3 t0 (Some t0) = Some 0"  by (auto simp: mm3_def zero_enat_def)

lemma mm3_Some_conv: "(mm3 t0 A = Some t) = (t'. A = Some t'  t0  t'  t=t0-t')"
  unfolding mm3_def by(auto split: option.splits)

lemma mm3_None[simp]: "mm3 t0 None = None" unfolding mm3_def by auto

lemma T_FAILT[simp]: "lst FAILT Q = None"
  unfolding lst_def mii_alt by simp

definition "progress m  s' M. m = SPECT M  M s'  None  M s' > Some 0"
lemma progressD: "progress m  m=SPECT M  M s'  None  M s' > Some 0"
  by (auto simp: progress_def)

lemma progress_FAILT[simp]: "progress FAILT" by(auto simp: progress_def)

subsection ‹Progress rules›

named_theorems progress_rules

lemma progress_SELECT_iff: "progress (SELECT P t)  t > 0"
  unfolding progress_def SELECT_def emb'_def by (auto split: option.splits)

lemmas [progress_rules] = progress_SELECT_iff[THEN iffD2]

lemma progress_REST_iff: "progress (REST [x  t])  t>0"
  by (auto simp: progress_def)

lemmas [progress_rules] = progress_REST_iff[THEN iffD2]

lemma progress_ASSERT_bind[progress_rules]: "Φ  progress (f ())   progress (ASSERT Φf)"
  by (cases Φ) (auto simp: progress_def)

lemma progress_SPECT_emb[progress_rules]: "t > 0  progress (SPECT (emb P t))" by(auto simp: progress_def emb'_def)

lemma Sup_Some: "Sup (S::enat option set) = Some e  xS. (i. x = Some i)"
  unfolding Sup_option_def by (auto split: if_splits)

lemma progress_bind[progress_rules]: assumes "progress m  (x. progress (f x))"
  shows "progress (mf)"
proof (cases m)
  case FAILi
  then show ?thesis by (auto simp: progress_def)
next
  case (REST x2)   
  then show ?thesis unfolding bindT_def progress_def
  proof (safe, goal_cases)
    case (1 s' M y)
    let ?P = "λfa. x. f x  FAILT 
             (t1. x2a. f x = SPECT x2a  fa = map_option ((+) t1)  x2a  x2 x = Some t1)"
    from 1 have A: "Sup {fa s' |fa. ?P fa} = Some y"
      by (auto dest: nrest_Sup_SPECT_D[where x=s'] split: nrest.splits)
    from Sup_Some[OF this] obtain fa i where P: "?P fa" and 3: "fa s' = Some i"   by blast 
    then obtain x t1 x2a  where  a3: "f x = SPECT x2a"
      and "x2a. f x = SPECT x2a  fa = map_option ((+) t1)  x2a" and a2: "x2 x = Some t1"  
      by fastforce 
    then have a1: " fa = map_option ((+) t1)  x2a" by auto
    have "progress m  t1 > 0" 
      by (drule progressD) (use 1(1) a2 a1 a3 in auto)  
    moreover
    have "progress (f x)  x2a s' > Some 0"  
      using a1 1(1) a2 3 by (auto dest!: progressD[OF _ a3])   
    ultimately
    have " t1 > 0  x2a s' > Some 0" using assms by auto

    then have "Some 0 < fa s'" using   a1  3 by auto
    also have "  Sup {fa s'|fa. ?P fa}" 
      by (rule Sup_upper) (blast intro: P)
    also have " = M s'" using A 1(3) by simp
    finally show ?case .
  qed 
qed

lemma mm2SomeleSome_conv: "mm2 (Qf) (Some t)  Some 0  Qf  Some t"
  unfolding mm2_def  by (auto split: option.split)                              

section "rules for whileT"

lemma
  assumes "whileT b c s = r"
  assumes IS[vcg_rules]: "s t'. I s = Some t'  b s 
               lst (c s) (λs'. if (s',s)R then I s' else None)  Some t'"
    (*  "T (λx. T I (c x)) (SPECT (λx. if b x then I x else None)) ≥ Some 0" *) 
  assumes "I s = Some t"
  assumes wf: "wf R"
  shows whileT_rule'': "lst r (λx. if b x then None else I x)  Some t"
  using assms(1,3)
  unfolding whileT_def
proof (induction arbitrary: t rule: RECT_wf_induct[where R="R"])
  case 1  
  show ?case by fact
next
  case 2
  then show ?case by refine_mono
next
  case step: (3 x D r t') 
  note IH[vcg_rules] = step.IH[OF _ refl] 
  note step.hyps[symmetric, simp]   

  from step.prems show ?case 
    by simp vcg'      
qed

lemma
  fixes I :: "'a  nat option"
  assumes "whileT b c s0 = r"
  assumes progress: "s. progress (c s)" 
  assumes IS[vcg_rules]: "s t t'. I s = Some t   b s   
           lst (c s) (λs'. mm3 t (I s') )  Some 0"
    (*  "T (λx. T I (c x)) (SPECT (λx. if b x then I x else None)) ≥ Some 0" *) 
  assumes [simp]: "I s0 = Some t0" 
    (*  assumes wf: "wf R" *)                         
  shows whileT_rule''': "lst r (λx. if b x then None else mm3 t0 (I x))  Some 0"  
  apply(rule T_conseq4)
   apply(rule whileT_rule''[where I="λs. mm3 t0 (I s)"
        and R="measure (the_enat o the o I)", OF assms(1)])
  subgoal for s t'
    apply(cases "I s"; simp)
    subgoal for ti
      using IS[of s ti]  
      apply (cases "c s"; simp) 
      subgoal for M
        ― ‹TODO›
        using progress[of s, THEN progressD, of M]
        apply(auto simp: T_pw mm3_Some_conv mii_alt mm2_def mm3_def split: option.splits if_splits)
            apply fastforce 
           apply (metis enat_ord_simps(1) le_diff_iff le_less_trans option.distinct(1)) 
          apply (metis diff_is_0_eq' leI less_option_Some option.simps(3) zero_enat_def) 
         apply (smt Nat.add_diff_assoc enat_ile enat_ord_code(1) idiff_enat_enat leI 
                le_add_diff_inverse2 nat_le_iff_add option.simps(3)) 
        using dual_order.trans by blast 
      done
    done
  by auto

lemma whileIET_rule[THEN T_conseq6, vcg_rules]:
  fixes E
  assumes 
    "(s t t'.
    (if I s then Some (E s) else None) = Some t 
    b s  Some 0  lst (C s) (λs'. mm3 t (if I s' then Some (E s') else None)))" 
  "s. progress (C s)"
  "I s0" 
shows "Some 0  lst (whileIET I E b C s0) (λx. if b x then None else mm3 (E s0) (if I x then Some (E x) else None))"
  unfolding whileIET_def  
  apply(rule whileT_rule'''[OF refl, where I="(λe. if I e
                then Some (E e) else None)"])
  using assms by auto 

 

lemma transf:
  assumes "I s   b s  Some 0  lst (C s) (λs'. mm3 (E s) (if I s' then Some (E s') else None))" 
  shows "
    (if I s then Some (E s) else None) = Some t 
    b s  Some 0  lst (C s) (λs'. mm3 t (if I s' then Some (E s') else None))"
  using assms by (cases "I s") auto

lemma whileIET_rule':
  fixes E
  assumes 
    "(s t t'. I s   b s  Some 0  lst (C s) (λs'. mm3 (E s) (if I s' then Some (E s') else None)))" 
  "s. progress (C s)"
  "I s0" 
shows "Some 0  lst (whileIET I E b C s0) (λx. if b x then None else mm3 (E s0) (if I x then Some (E x) else None))" 
  by (rule whileIET_rule, rule transf[where b=b]) (use assms in auto)

section "some Monadic Refinement Automation"


ML structure Refine = struct

  structure vcg = Named_Thms
    ( val name = @{binding refine_vcg}
      val description = "Refinement Framework: " ^ 
        "Verification condition generation rules (intro)" )

  structure vcg_cons = Named_Thms
    ( val name = @{binding refine_vcg_cons}
      val description = "Refinement Framework: " ^
        "Consequence rules tried by VCG" )

  structure refine0 = Named_Thms
    ( val name = @{binding refine0}
      val description = "Refinement Framework: " ^
        "Refinement rules applied first (intro)" )

  structure refine = Named_Thms
    ( val name = @{binding refine}
      val description = "Refinement Framework: Refinement rules (intro)" )

  structure refine2 = Named_Thms
    ( val name = @{binding refine2}
      val description = "Refinement Framework: " ^
        "Refinement rules 2nd stage (intro)" )

  (* If set to true, the product splitter of refine_rcg is disabled. *)
  val no_prod_split = 
    Attrib.setup_config_bool @{binding refine_no_prod_split} (K false);

  fun rcg_tac add_thms ctxt = 
    let 
      val cons_thms = vcg_cons.get ctxt
      val ref_thms = (refine0.get ctxt 
        @ add_thms @ refine.get ctxt @ refine2.get ctxt);
      val prod_ss = (Splitter.add_split @{thm prod.split} 
        (put_simpset HOL_basic_ss ctxt));
      val prod_simp_tac = 
        if Config.get ctxt no_prod_split then 
          K no_tac
        else
          (simp_tac prod_ss THEN' 
            REPEAT_ALL_NEW (resolve_tac ctxt @{thms impI allI}));
    in
      REPEAT_ALL_NEW_FWD (DETERM o FIRST' [
        resolve_tac ctxt ref_thms,
        resolve_tac ctxt cons_thms THEN' resolve_tac ctxt ref_thms,
        prod_simp_tac
      ])
    end;

  fun post_tac ctxt = REPEAT_ALL_NEW_FWD (FIRST' [
    eq_assume_tac,
    (*match_tac ctxt thms,*)
    SOLVED' (Tagged_Solver.solve_tac ctxt)]) 
         

end;
setup Refine.vcg.setup
setup Refine.vcg_cons.setup
setup Refine.refine0.setup
setup Refine.refine.setup
setup Refine.refine2.setup
(*setup {* Refine.refine_post.setup *}*)

method_setup refine_rcg = 
  Attrib.thms >> (fn add_thms => fn ctxt => SIMPLE_METHOD' (
    Refine.rcg_tac add_thms ctxt THEN_ALL_NEW_FWD (TRY o Refine.post_tac ctxt)
  )) 
  "Refinement framework: Generate refinement conditions"     

method_setup refine_vcg = 
  Attrib.thms >> (fn add_thms => fn ctxt => SIMPLE_METHOD' (
    Refine.rcg_tac (add_thms @ Refine.vcg.get ctxt) ctxt THEN_ALL_NEW_FWD (TRY o Refine.post_tac ctxt)
  )) 
  "Refinement framework: Generate refinement and verification conditions"

end