Theory DP_CRelVS

theory DP_CRelVS
imports State_Monad_Ext Pure_Monad
subsection ‹Parametricity of the State Monad›

theory DP_CRelVS
  imports "./State_Monad_Ext" "../Pure_Monad"
begin

definition lift_p :: "('s ⇒ bool) ⇒ ('s, 'a) state ⇒ bool" where
  "lift_p P f =
    (∀ heap. P heap ⟶ (case State_Monad.run_state f heap of (_, heap) ⇒ P heap))"

context
  fixes P f heap
  assumes lift: "lift_p P f" and P: "P heap"
begin

lemma run_state_cases:
  "case State_Monad.run_state f heap of (_, heap) ⇒ P heap"
  using lift P unfolding lift_p_def by auto

lemma lift_p_P:
  "P heap'" if "State_Monad.run_state f heap = (v, heap')"
  using that run_state_cases by auto

end

locale state_mem_defs =
  fixes lookup :: "'param ⇒ ('mem, 'result option) state"
    and update :: "'param ⇒ 'result ⇒ ('mem, unit) state"
begin

definition checkmem :: "'param ⇒ ('mem, 'result) state ⇒ ('mem, 'result) state" where
  "checkmem param calc ≡ do {
    x ← lookup param;
    case x of
      Some x ⇒ State_Monad.return x
    | None ⇒ do {
        x ← calc;
        update param x;
        State_Monad.return x
      }
  }"

abbreviation checkmem_eq ::
  "('param ⇒ ('mem, 'result) state) ⇒ 'param ⇒ ('mem, 'result) state ⇒ bool"
  ("_$ _ =CHECKMEM= _" [1000,51] 51) where
  "(dpT$ param =CHECKMEM= calc) ≡ (dpT param = checkmem param calc)"
term 0 (**)

definition map_of where
  "map_of heap k = fst (run_state (lookup k) heap)"

definition checkmem' :: "'param ⇒ (unit ⇒ ('mem, 'result) state) ⇒ ('mem, 'result) state" where
  "checkmem' param calc ≡ do {
    x ← lookup param;
    case x of
      Some x ⇒ State_Monad.return x
    | None ⇒ do {
        x ← calc ();
        update param x;
        State_Monad.return x
      }
  }"

lemma checkmem_checkmem':
  "checkmem' param (λ_. calc) = checkmem param calc"
  unfolding checkmem'_def checkmem_def ..

lemma checkmem_eq_alt:
  "checkmem_eq dp param calc = (dp param = checkmem' param (λ _. calc))"
  unfolding checkmem_checkmem' ..

end (* Mem Defs *)


locale mem_correct = state_mem_defs +
  fixes P
  assumes lookup_inv: "lift_p P (lookup k)" and update_inv: "lift_p P (update k v)"
  assumes
    lookup_correct: "P m ⟹ map_of (snd (State_Monad.run_state (lookup k) m)) ⊆m (map_of m)"
      and
    update_correct: "P m ⟹ map_of (snd (State_Monad.run_state (update k v) m)) ⊆m (map_of m)(k ↦ v)"
  (* assumes correct: "lookup (update m k v) ⊆m (lookup m)(k ↦ v)" *)

locale dp_consistency =
  mem_correct lookup update P
  for lookup :: "'param ⇒ ('mem, 'result option) state" and update and P +
  fixes dp :: "'param ⇒ 'result"
begin

context
  includes lifting_syntax state_monad_syntax
begin

definition cmem :: "'mem ⇒ bool" where
  "cmem M ≡ ∀param∈dom (map_of M). map_of M param = Some (dp param)"

definition crel_vs :: "('a ⇒ 'b ⇒ bool) ⇒ 'a ⇒ ('mem, 'b) state ⇒ bool" where
  "crel_vs R v s ≡ ∀M. cmem M ∧ P M ⟶ (case State_Monad.run_state s M of (v', M') ⇒ R v v' ∧ cmem M' ∧ P M')"
  
abbreviation rel_fun_lifted :: "('a ⇒ 'c ⇒ bool) ⇒ ('b ⇒ 'd ⇒ bool) ⇒ ('a ⇒ 'b) ⇒ ('c ==_⟹ 'd) ⇒ bool" (infixr "===>T" 55) where
  "rel_fun_lifted R R' ≡ R ===> crel_vs R'"
term 0 (**)

definition consistentDP :: "('param == 'mem ⟹ 'result) ⇒ bool" where
  "consistentDP ≡ ((=) ===> crel_vs (=)) dp"
term 0 (**)
  
  (* cmem *)
private lemma cmem_intro:
  assumes "⋀param v M'. State_Monad.run_state (lookup param) M = (Some v, M') ⟹ v = dp param"
  shows "cmem M"
  unfolding cmem_def map_of_def
  apply safe
  subgoal for param y
    by (cases "State_Monad.run_state (lookup param) M") (auto intro: assms)
  done

lemma cmem_elim:
  assumes "cmem M" "State_Monad.run_state (lookup param) M = (Some v, M')"
  obtains "dp param = v"
  using assms unfolding cmem_def dom_def map_of_def by auto (metis fst_conv option.inject)
term 0 (**)
  
  (* crel_vs *)
lemma crel_vs_intro:
  assumes "⋀M v' M'. ⟦cmem M; P M; State_Monad.run_state vT M = (v', M')⟧ ⟹ R v v' ∧ cmem M' ∧ P M'"
  shows "crel_vs R v vT"
  using assms unfolding crel_vs_def by blast
term 0 (**)
  
lemma crel_vs_elim:
  assumes "crel_vs R v vT" "cmem M" "P M"
  obtains v' M' where "State_Monad.run_state vT M = (v', M')" "R v v'" "cmem M'" "P M'"
  using assms unfolding crel_vs_def by blast
term 0 (**)
  
  (* consistentDP *)
lemma consistentDP_intro:
  assumes "⋀param. Transfer.Rel (crel_vs (=)) (dp param) (dpT param)"
  shows "consistentDP dpT"
  using assms unfolding consistentDP_def Rel_def by blast
  
lemma crel_vs_return:
  "⟦Transfer.Rel R x y⟧ ⟹ Transfer.Rel (crel_vs R) (Wrap x) (State_Monad.return y)"
  unfolding State_Monad.return_def Wrap_def Rel_def by (fastforce intro: crel_vs_intro)
term 0 (**)
  
lemma crel_vs_return_ext:
  "⟦Transfer.Rel R x y⟧ ⟹ Transfer.Rel (crel_vs R) x (State_Monad.return y)"
  by (fact crel_vs_return[unfolded Wrap_def])
term 0 (**)

  (* Low level operators *)
private lemma cmem_upd:
  "cmem M'" if "cmem M" "P M" "State_Monad.run_state (update param (dp param)) M = (v, M')"
  using update_correct[of M param "dp param"] that unfolding cmem_def map_le_def by simp force

private lemma P_upd:
  "P M'" if "P M" "State_Monad.run_state (update param (dp param)) M = (v, M')"
  by (meson lift_p_P that update_inv)

private lemma crel_vs_get:
  "⟦⋀M. cmem M ⟹ crel_vs R v (sf M)⟧ ⟹ crel_vs R v (State_Monad.get ⤜ sf)"
  unfolding State_Monad.get_def State_Monad.bind_def by (fastforce intro: crel_vs_intro elim: crel_vs_elim split: prod.split)
term 0 (**)
  
private lemma crel_vs_set:
  "⟦crel_vs R v sf; cmem M; P M⟧ ⟹ crel_vs R v (State_Monad.set M ⪢ sf)"
  unfolding State_Monad.set_def State_Monad.bind_def by (fastforce intro: crel_vs_intro elim: crel_vs_elim split: prod.split)
term 0 (**)
  
private lemma crel_vs_bind_eq:
  "⟦crel_vs (=) v s; crel_vs R (f v) (sf v)⟧ ⟹ crel_vs R (f v) (s ⤜ sf)"
  unfolding State_Monad.bind_def rel_fun_def by (fastforce intro: crel_vs_intro elim: crel_vs_elim split: prod.split)
term 0 (**)

lemma bind_transfer[transfer_rule]:
  "(crel_vs R0 ===> (R0 ===>T R1) ===> crel_vs R1) (λv f. f v) (⤜)"
  unfolding State_Monad.bind_def rel_fun_def by (fastforce intro: crel_vs_intro elim: crel_vs_elim split: prod.split)

private lemma cmem_lookup:
  "cmem M'" if "cmem M" "P M" "State_Monad.run_state (lookup param) M = (v, M')"
  using lookup_correct[of M param] that unfolding cmem_def map_le_def by force

private lemma P_lookup:
  "P M'" if "P M" "State_Monad.run_state (lookup param) M = (v, M')"
  by (meson lift_p_P that lookup_inv)

lemma crel_vs_lookup:
  "crel_vs (λ v v'. case v' of None ⇒ True | Some v' ⇒ v = v' ∧ v = dp param) (dp param) (lookup param)"
  by (auto elim: cmem_elim intro: cmem_lookup crel_vs_intro P_lookup split: option.split)

lemma crel_vs_update:
  "crel_vs (=) () (update param (dp param))"
  by (auto intro: cmem_upd crel_vs_intro P_upd)

private lemma crel_vs_checkmem:
  "⟦is_equality R; Transfer.Rel (crel_vs R) (dp param) s⟧
  ⟹ Transfer.Rel (crel_vs R) (dp param) (checkmem param s)"
  unfolding checkmem_def Rel_def is_equality_def
  by (rule bind_transfer[unfolded rel_fun_def, rule_format, OF crel_vs_lookup])
     (auto 4 3 intro: crel_vs_lookup crel_vs_update crel_vs_return[unfolded Rel_def Wrap_def] crel_vs_bind_eq
               split: option.split_asm
     )

lemma crel_vs_checkmem_tupled:
  assumes "v = dp param"
  shows "⟦is_equality R; Transfer.Rel (crel_vs R) v s⟧
        ⟹ Transfer.Rel (crel_vs R) v (checkmem param s)"
  unfolding assms by (fact crel_vs_checkmem)

  (** Transfer rules **)
  (* Basics *)
lemma return_transfer[transfer_rule]:
  "(R ===>T R) Wrap State_Monad.return"
  unfolding rel_fun_def by (metis crel_vs_return Rel_def)

lemma fun_app_lifted_transfer[transfer_rule]:
  "(crel_vs (R0 ===>T R1) ===> crel_vs R0 ===> crel_vs R1) App (.)"
  unfolding App_def fun_app_lifted_def by transfer_prover
    
lemma crel_vs_fun_app:
  "⟦Transfer.Rel (crel_vs R0) x xT; Transfer.Rel (crel_vs (R0 ===>T R1)) f fT⟧ ⟹ Transfer.Rel (crel_vs R1) (App f x) (fT . xT)"
  unfolding Rel_def using fun_app_lifted_transfer[THEN rel_funD, THEN rel_funD] .

  (* HOL *)
lemma ifT_transfer[transfer_rule]:
  "(crel_vs (=) ===> crel_vs R ===> crel_vs R ===> crel_vs R) If State_Monad_Ext.ifT"
  unfolding State_Monad_Ext.ifT_def by transfer_prover
end (* Lifting Syntax *)

end (* Consistency *)
end (* Theory *)