Theory Monad_Memo_DP.DP_CRelVS

subsection ‹Parametricity of the State Monad›

theory DP_CRelVS
  imports "./State_Monad_Ext" "../Pure_Monad"
begin

definition lift_p :: "('s  bool)  ('s, 'a) state  bool" where
  "lift_p P f =
    ( heap. P heap  (case State_Monad.run_state f heap of (_, heap)  P heap))"

context
  fixes P f heap
  assumes lift: "lift_p P f" and P: "P heap"
begin

lemma run_state_cases:
  "case State_Monad.run_state f heap of (_, heap)  P heap"
  using lift P unfolding lift_p_def by auto

lemma lift_p_P:
  "P heap'" if "State_Monad.run_state f heap = (v, heap')"
  using that run_state_cases by auto

end

locale state_mem_defs =
  fixes lookup :: "'param  ('mem, 'result option) state"
    and update :: "'param  'result  ('mem, unit) state"
begin

definition checkmem :: "'param  ('mem, 'result) state  ('mem, 'result) state" where
  "checkmem param calc  do {
    x  lookup param;
    case x of
      Some x  State_Monad.return x
    | None  do {
        x  calc;
        update param x;
        State_Monad.return x
      }
  }"

abbreviation checkmem_eq ::
  "('param  ('mem, 'result) state)  'param  ('mem, 'result) state  bool"
  (‹_$ _ =CHECKMEM= _› [1000,51] 51) where
  "(dpT$ param =CHECKMEM= calc)  (dpT param = checkmem param calc)"
term 0 (**)

definition map_of where
  "map_of heap k = fst (run_state (lookup k) heap)"

definition checkmem' :: "'param  (unit  ('mem, 'result) state)  ('mem, 'result) state" where
  "checkmem' param calc  do {
    x  lookup param;
    case x of
      Some x  State_Monad.return x
    | None  do {
        x  calc ();
        update param x;
        State_Monad.return x
      }
  }"

lemma checkmem_checkmem':
  "checkmem' param (λ_. calc) = checkmem param calc"
  unfolding checkmem'_def checkmem_def ..

lemma checkmem_eq_alt:
  "checkmem_eq dp param calc = (dp param = checkmem' param (λ _. calc))"
  unfolding checkmem_checkmem' ..

end (* Mem Defs *)


locale mem_correct = state_mem_defs +
  fixes P
  assumes lookup_inv: "lift_p P (lookup k)" and update_inv: "lift_p P (update k v)"
  assumes
    lookup_correct: "P m  map_of (snd (State_Monad.run_state (lookup k) m)) m (map_of m)"
      and
    update_correct: "P m  map_of (snd (State_Monad.run_state (update k v) m)) m (map_of m)(k  v)"
  (* assumes correct: "lookup (update m k v) ⊆m (lookup m)(k ↦ v)" *)

locale dp_consistency =
  mem_correct lookup update P
  for lookup :: "'param  ('mem, 'result option) state" and update and P +
  fixes dp :: "'param  'result"
begin

context
  includes lifting_syntax and state_monad_syntax
begin

definition cmem :: "'mem  bool" where
  "cmem M  paramdom (map_of M). map_of M param = Some (dp param)"

definition crel_vs :: "('a  'b  bool)  'a  ('mem, 'b) state  bool" where
  "crel_vs R v s  M. cmem M  P M  (case State_Monad.run_state s M of (v', M')  R v v'  cmem M'  P M')"
  
abbreviation rel_fun_lifted :: "('a  'c  bool)  ('b  'd  bool)  ('a  'b)  ('c ==_ 'd)  bool" (infixr ===>T 55) where
  "rel_fun_lifted R R'  R ===> crel_vs R'"
term 0 (**)

definition consistentDP :: "('param == 'mem  'result)  bool" where
  "consistentDP  ((=) ===> crel_vs (=)) dp"
term 0 (**)
  
  (* cmem *)
private lemma cmem_intro:
  assumes "param v M'. State_Monad.run_state (lookup param) M = (Some v, M')  v = dp param"
  shows "cmem M"
  unfolding cmem_def map_of_def
  apply safe
  subgoal for param y
    by (cases "State_Monad.run_state (lookup param) M") (auto intro: assms)
  done

lemma cmem_elim:
  assumes "cmem M" "State_Monad.run_state (lookup param) M = (Some v, M')"
  obtains "dp param = v"
  using assms unfolding cmem_def dom_def map_of_def by auto (metis fst_conv option.inject)
term 0 (**)
  
  (* crel_vs *)
lemma crel_vs_intro:
  assumes "M v' M'. cmem M; P M; State_Monad.run_state vT M = (v', M')  R v v'  cmem M'  P M'"
  shows "crel_vs R v vT"
  using assms unfolding crel_vs_def by blast
term 0 (**)
  
lemma crel_vs_elim:
  assumes "crel_vs R v vT" "cmem M" "P M"
  obtains v' M' where "State_Monad.run_state vT M = (v', M')" "R v v'" "cmem M'" "P M'"
  using assms unfolding crel_vs_def by blast
term 0 (**)
  
  (* consistentDP *)
lemma consistentDP_intro:
  assumes "param. Transfer.Rel (crel_vs (=)) (dp param) (dpT param)"
  shows "consistentDP dpT"
  using assms unfolding consistentDP_def Rel_def by blast
  
lemma crel_vs_return:
  "Transfer.Rel R x y  Transfer.Rel (crel_vs R) (Wrap x) (State_Monad.return y)"
  unfolding State_Monad.return_def Wrap_def Rel_def by (fastforce intro: crel_vs_intro)
term 0 (**)
  
lemma crel_vs_return_ext:
  "Transfer.Rel R x y  Transfer.Rel (crel_vs R) x (State_Monad.return y)"
  by (fact crel_vs_return[unfolded Wrap_def])
term 0 (**)

  (* Low level operators *)
private lemma cmem_upd:
  "cmem M'" if "cmem M" "P M" "State_Monad.run_state (update param (dp param)) M = (v, M')"
  using update_correct[of M param "dp param"] that unfolding cmem_def map_le_def by simp force

private lemma P_upd:
  "P M'" if "P M" "State_Monad.run_state (update param (dp param)) M = (v, M')"
  by (meson lift_p_P that update_inv)

private lemma crel_vs_get:
  "M. cmem M  crel_vs R v (sf M)  crel_vs R v (State_Monad.get  sf)"
  unfolding State_Monad.get_def State_Monad.bind_def by (fastforce intro: crel_vs_intro elim: crel_vs_elim split: prod.split)
term 0 (**)
  
private lemma crel_vs_set:
  "crel_vs R v sf; cmem M; P M  crel_vs R v (State_Monad.set M  sf)"
  unfolding State_Monad.set_def State_Monad.bind_def by (fastforce intro: crel_vs_intro elim: crel_vs_elim split: prod.split)
term 0 (**)
  
private lemma crel_vs_bind_eq:
  "crel_vs (=) v s; crel_vs R (f v) (sf v)  crel_vs R (f v) (s  sf)"
  unfolding State_Monad.bind_def rel_fun_def by (fastforce intro: crel_vs_intro elim: crel_vs_elim split: prod.split)
term 0 (**)

lemma bind_transfer[transfer_rule]:
  "(crel_vs R0 ===> (R0 ===>T R1) ===> crel_vs R1) (λv f. f v) (⤜)"
  unfolding State_Monad.bind_def rel_fun_def by (fastforce intro: crel_vs_intro elim: crel_vs_elim split: prod.split)

private lemma cmem_lookup:
  "cmem M'" if "cmem M" "P M" "State_Monad.run_state (lookup param) M = (v, M')"
  using lookup_correct[of M param] that unfolding cmem_def map_le_def by force

private lemma P_lookup:
  "P M'" if "P M" "State_Monad.run_state (lookup param) M = (v, M')"
  by (meson lift_p_P that lookup_inv)

lemma crel_vs_lookup:
  "crel_vs (λ v v'. case v' of None  True | Some v'  v = v'  v = dp param) (dp param) (lookup param)"
  by (auto elim: cmem_elim intro: cmem_lookup crel_vs_intro P_lookup split: option.split)

lemma crel_vs_update:
  "crel_vs (=) () (update param (dp param))"
  by (auto intro: cmem_upd crel_vs_intro P_upd)

private lemma crel_vs_checkmem:
  "is_equality R; Transfer.Rel (crel_vs R) (dp param) s
   Transfer.Rel (crel_vs R) (dp param) (checkmem param s)"
  unfolding checkmem_def Rel_def is_equality_def
  by (rule bind_transfer[unfolded rel_fun_def, rule_format, OF crel_vs_lookup])
     (auto 4 3 intro: crel_vs_lookup crel_vs_update crel_vs_return[unfolded Rel_def Wrap_def] crel_vs_bind_eq
               split: option.split_asm
     )

lemma crel_vs_checkmem_tupled:
  assumes "v = dp param"
  shows "is_equality R; Transfer.Rel (crel_vs R) v s
         Transfer.Rel (crel_vs R) v (checkmem param s)"
  unfolding assms by (fact crel_vs_checkmem)

  (** Transfer rules **)
  (* Basics *)
lemma return_transfer[transfer_rule]:
  "(R ===>T R) Wrap State_Monad.return"
  unfolding rel_fun_def by (metis crel_vs_return Rel_def)

lemma fun_app_lifted_transfer[transfer_rule]:
  "(crel_vs (R0 ===>T R1) ===> crel_vs R0 ===> crel_vs R1) App (.)"
  unfolding App_def fun_app_lifted_def by transfer_prover
    
lemma crel_vs_fun_app:
  "Transfer.Rel (crel_vs R0) x xT; Transfer.Rel (crel_vs (R0 ===>T R1)) f fT  Transfer.Rel (crel_vs R1) (App f x) (fT . xT)"
  unfolding Rel_def using fun_app_lifted_transfer[THEN rel_funD, THEN rel_funD] .

  (* HOL *)
lemma ifT_transfer[transfer_rule]:
  "(crel_vs (=) ===> crel_vs R ===> crel_vs R ===> crel_vs R) If State_Monad_Ext.ifT"
  unfolding State_Monad_Ext.ifT_def by transfer_prover
end (* Lifting Syntax *)

end (* Consistency *)
end (* Theory *)