Theory Pair_Memory

subsection ‹Pair Memory›

theory Pair_Memory
  imports "../state_monad/Memory"
begin

(* XXX Move *)
lemma map_add_mono:
  "(m1 ++ m2) m (m1' ++ m2')" if "m1 m m1'" "m2 m m2'" "dom m1  dom m2' = {}"
  using that unfolding map_le_def map_add_def dom_def by (auto split: option.splits)

lemma map_add_upd2:
  "f(x  y) ++ g = (f ++ g)(x  y)" if "dom f  dom g = {}" "x  dom g"
  apply (subst map_add_comm)
   defer
   apply simp
   apply (subst map_add_comm)
  using that
  by auto

locale pair_mem_defs =
  fixes lookup1 lookup2 :: "'a  ('mem, 'v option) state"
    and update1 update2 :: "'a  'v  ('mem, unit) state"
    and move12 :: "'k1  ('mem, unit) state"
    and get_k1 get_k2 :: "('mem, 'k1) state"
    and P :: "'mem  bool"
  fixes key1 :: "'k  'k1" and key2 :: "'k  'a"
begin

text ‹We assume that look-ups happen on the older row, so it is biased towards the second entry.›
definition
  "lookup_pair k = do {
     let k' = key1 k;
     k2  get_k2;
     if k' = k2
     then lookup2 (key2 k)
     else do {
       k1  get_k1;
       if k' = k1
       then lookup1 (key2 k)
       else State_Monad.return None
     }
   }
   "

text ‹We assume that updates happen on the newer row, so it is biased towards the first entry.›
definition
  "update_pair k v = do {
    let k' = key1 k;
    k1  get_k1;
    if k' = k1
    then update1 (key2 k) v
    else do {
      k2  get_k2;
      if k' = k2
      then update2 (key2 k) v
      else (move12 k'  update1 (key2 k) v)
    }
  }
  "

sublocale pair: state_mem_defs lookup_pair update_pair .

sublocale mem1: state_mem_defs lookup1 update1 .

sublocale mem2: state_mem_defs lookup2 update2 .

definition
  "inv_pair heap 
    let
      k1 = fst (State_Monad.run_state get_k1 heap);
      k2 = fst (State_Monad.run_state get_k2 heap)
    in
    ( k  dom (mem1.map_of heap).  k'. key1 k' = k1  key2 k' = k) 
    ( k  dom (mem2.map_of heap).  k'. key1 k' = k2  key2 k' = k) 
    k1  k2  P heap
  "

definition
  "map_of1 m k = (if key1 k = fst (State_Monad.run_state get_k1 m) then mem1.map_of m (key2 k) else None)"

definition
  "map_of2 m k = (if key1 k = fst (State_Monad.run_state get_k2 m) then mem2.map_of m (key2 k) else None)"

end (* Pair Mem Defs *)

locale pair_mem = pair_mem_defs +
  assumes get_state:
    "State_Monad.run_state get_k1 m = (k, m')  m' = m"
    "State_Monad.run_state get_k2 m = (k, m')  m' = m"
  assumes move12_correct:
    "P m  State_Monad.run_state (move12 k1) m = (x, m')  mem1.map_of m' m Map.empty"
    "P m  State_Monad.run_state (move12 k1) m = (x, m')  mem2.map_of m' m mem1.map_of m"
  assumes move12_keys:
    "State_Monad.run_state (move12 k1) m = (x, m')  fst (State_Monad.run_state get_k1 m') = k1"
    "State_Monad.run_state (move12 k1) m = (x, m')  fst (State_Monad.run_state get_k2 m') = fst (State_Monad.run_state get_k1 m)"
  assumes move12_inv:
    "lift_p P (move12 k1)"
  assumes lookup_inv:
    "lift_p P (lookup1 k')" "lift_p P (lookup2 k')"
  assumes update_inv:
    "lift_p P (update1 k' v)" "lift_p P (update2 k' v)"
  assumes lookup_keys:
    "P m  State_Monad.run_state (lookup1 k') m = (v', m') 
     fst (State_Monad.run_state get_k1 m') = fst (State_Monad.run_state get_k1 m)"
    "P m  State_Monad.run_state (lookup1 k') m = (v', m') 
     fst (State_Monad.run_state get_k2 m') = fst (State_Monad.run_state get_k2 m)"
    "P m  State_Monad.run_state (lookup2 k') m = (v', m') 
     fst (State_Monad.run_state get_k1 m') = fst (State_Monad.run_state get_k1 m)"
    "P m  State_Monad.run_state (lookup2 k') m = (v', m') 
     fst (State_Monad.run_state get_k2 m') = fst (State_Monad.run_state get_k2 m)"
  assumes update_keys:
    "P m  State_Monad.run_state (update1 k' v) m = (x, m') 
     fst (State_Monad.run_state get_k1 m') = fst (State_Monad.run_state get_k1 m)"
    "P m  State_Monad.run_state (update1 k' v) m = (x, m') 
     fst (State_Monad.run_state get_k2 m') = fst (State_Monad.run_state get_k2 m)"
    "P m  State_Monad.run_state (update2 k' v) m = (x, m') 
     fst (State_Monad.run_state get_k1 m') = fst (State_Monad.run_state get_k1 m)"
    "P m  State_Monad.run_state (update2 k' v) m = (x, m') 
     fst (State_Monad.run_state get_k2 m') = fst (State_Monad.run_state get_k2 m)"
  assumes
    lookup_correct:
      "P m  mem1.map_of (snd (State_Monad.run_state (lookup1 k') m)) m (mem1.map_of m)"
      "P m  mem2.map_of (snd (State_Monad.run_state (lookup1 k') m)) m (mem2.map_of m)"
      "P m  mem1.map_of (snd (State_Monad.run_state (lookup2 k') m)) m (mem1.map_of m)"
      "P m  mem2.map_of (snd (State_Monad.run_state (lookup2 k') m)) m (mem2.map_of m)"
  assumes
    update_correct:
      "P m  mem1.map_of (snd (State_Monad.run_state (update1 k' v) m)) m (mem1.map_of m)(k'  v)"
      "P m  mem2.map_of (snd (State_Monad.run_state (update2 k' v) m)) m (mem2.map_of m)(k'  v)"
      "P m  mem2.map_of (snd (State_Monad.run_state (update1 k' v) m)) m mem2.map_of m"
      "P m  mem1.map_of (snd (State_Monad.run_state (update2 k' v) m)) m mem1.map_of m"
begin

lemma map_of_le_pair:
  "pair.map_of m m map_of1 m ++ map_of2 m"
  if "inv_pair m"
  using that
  unfolding pair.map_of_def map_of1_def map_of2_def
  unfolding lookup_pair_def inv_pair_def map_of_def map_le_def dom_def map_add_def
  unfolding State_Monad.bind_def
  by (auto 4 4
        simp: mem2.map_of_def mem1.map_of_def Let_def
        dest: get_state split: prod.split_asm if_split_asm
     )

lemma pair_le_map_of:
  "map_of1 m ++ map_of2 m m pair.map_of m"
  if "inv_pair m"
  using that
  unfolding pair.map_of_def map_of1_def map_of2_def
  unfolding lookup_pair_def inv_pair_def map_of_def map_le_def dom_def map_add_def
  unfolding State_Monad.bind_def
  by (auto
        simp: mem2.map_of_def mem1.map_of_def State_Monad.run_state_return Let_def
        dest: get_state split: prod.splits if_split_asm option.split
     )

lemma map_of_eq_pair:
  "map_of1 m ++ map_of2 m = pair.map_of m"
  if "inv_pair m"
  using that
  unfolding pair.map_of_def map_of1_def map_of2_def
  unfolding lookup_pair_def inv_pair_def map_of_def map_le_def dom_def map_add_def
  unfolding State_Monad.bind_def
  by (auto 4 4
        simp: mem2.map_of_def mem1.map_of_def State_Monad.run_state_return Let_def
        dest: get_state split: prod.splits option.split
     )

lemma inv_pair_neq[simp]:
  False if "inv_pair m" "fst (State_Monad.run_state get_k1 m) = fst (State_Monad.run_state get_k2 m)"
  using that unfolding inv_pair_def by auto

lemma inv_pair_P_D:
  "P m" if "inv_pair m"
  using that unfolding inv_pair_def by (auto simp: Let_def)

lemma inv_pair_domD[intro]:
  "dom (map_of1 m)  dom (map_of2 m) = {}" if "inv_pair m"
  using that unfolding inv_pair_def map_of1_def map_of2_def by (auto split: if_split_asm)

lemma move12_correct1:
  "map_of1 heap' m Map.empty" if "State_Monad.run_state (move12 k1) heap = (x, heap')" "P heap"
  using move12_correct[OF that(2,1)] unfolding map_of1_def by (auto simp: move12_keys map_le_def)

lemma move12_correct2:
  "map_of2 heap' m map_of1 heap" if "State_Monad.run_state (move12 k1) heap = (x, heap')" "P heap"
  using move12_correct(2)[OF that(2,1)] that unfolding map_of1_def map_of2_def
  by (auto simp: move12_keys map_le_def)

lemma dom_empty[simp]:
  "dom (map_of1 heap') = {}" if "State_Monad.run_state (move12 k1) heap = (x, heap')" "P heap"
  using move12_correct1[OF that] by (auto dest: map_le_implies_dom_le)

lemma inv_pair_lookup1:
  "inv_pair m'" if "State_Monad.run_state (lookup1 k) m = (v, m')" "inv_pair m"
  using that lookup_inv[of k] inv_pair_P_D[OF inv_pair m] unfolding inv_pair_def
  by (auto 4 4
        simp: Let_def lookup_keys
        dest: lift_p_P lookup_correct[of _ k, THEN map_le_implies_dom_le]
     )

lemma inv_pair_lookup2:
  "inv_pair m'" if "State_Monad.run_state (lookup2 k) m = (v, m')" "inv_pair m"
  using that lookup_inv[of k] inv_pair_P_D[OF inv_pair m] unfolding inv_pair_def
  by (auto 4 4
        simp: Let_def lookup_keys
        dest: lift_p_P lookup_correct[of _ k, THEN map_le_implies_dom_le]
     )

lemma inv_pair_update1:
  "inv_pair m'"
  if "State_Monad.run_state (update1 (key2 k) v) m = (v', m')" "inv_pair m" "fst (State_Monad.run_state get_k1 m) = key1 k"
  using that update_inv[of "key2 k" v] inv_pair_P_D[OF inv_pair m] unfolding inv_pair_def
  apply (auto
        simp: Let_def update_keys
        dest: lift_p_P update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]
     )
   apply (frule update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]; auto 13 2; fail)
  apply (frule update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]; auto 13 2; fail)
  done

lemma inv_pair_update2:
  "inv_pair m'"
  if "State_Monad.run_state (update2 (key2 k) v) m = (v', m')" "inv_pair m" "fst (State_Monad.run_state get_k2 m) = key1 k"
  using that update_inv[of "key2 k" v] inv_pair_P_D[OF inv_pair m] unfolding inv_pair_def
  apply (auto
        simp: Let_def update_keys
        dest: lift_p_P update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]
     )
   apply (frule update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]; auto 13 2; fail)
  apply (frule update_correct[of _ "key2 k" v, THEN map_le_implies_dom_le]; auto 13 2; fail)
  done

lemma inv_pair_move12:
  "inv_pair m'"
  if "State_Monad.run_state (move12 k) m = (v', m')" "inv_pair m" "fst (State_Monad.run_state get_k1 m)  k"
  using that move12_inv[of "k"] inv_pair_P_D[OF inv_pair m] unfolding inv_pair_def
  apply (auto
        simp: Let_def move12_keys
        dest: lift_p_P move12_correct[of _ "k", THEN map_le_implies_dom_le]
     )
  apply (blast dest: move12_correct[of _ "k", THEN map_le_implies_dom_le])
  done

lemma mem_correct_pair:
  "mem_correct lookup_pair update_pair inv_pair"
  if injective: " k k'. key1 k = key1 k'  key2 k = key2 k'  k = k'"
proof (standard, goal_cases)
  case (1 k) ― ‹Lookup invariant›
  show ?case
    unfolding lookup_pair_def Let_def
    by (auto 4 4
        intro!: lift_pI
        dest: get_state inv_pair_lookup1 inv_pair_lookup2
        simp: State_Monad.bind_def State_Monad.run_state_return
        split: if_split_asm prod.split_asm
        )
next
  case (2 k v) ― ‹Update invariant›
  show ?case
    unfolding update_pair_def Let_def
    apply (auto 4 4
        intro!: lift_pI intro: inv_pair_update1 inv_pair_update2
        dest: get_state
        simp: State_Monad.bind_def get_state State_Monad.run_state_return
        split: if_split_asm prod.split_asm
        )+
    apply (elim inv_pair_update1 inv_pair_move12)
      apply (((subst get_state, assumption)+)?, auto intro: move12_keys dest: get_state; fail)+
    done
next
  case (3 m k)
  {
    let ?m = "snd (State_Monad.run_state (lookup2 (key2 k)) m)"
    have "map_of1 ?m m map_of1 m"
      by (smt 3 domIff inv_pair_P_D local.lookup_keys lookup_correct map_le_def map_of1_def surjective_pairing)
    moreover have "map_of2 ?m m map_of2 m"
      by (smt 3 domIff inv_pair_P_D local.lookup_keys lookup_correct map_le_def map_of2_def surjective_pairing)
    moreover have "dom (map_of1 ?m)  dom (map_of2 m) = {}"
      using 3 map_of1 ?m m map_of1 m inv_pair_domD map_le_implies_dom_le by fastforce
    moreover have "inv_pair ?m"
      using 3 inv_pair_lookup2 surjective_pairing by metis
    ultimately have "pair.map_of ?m m pair.map_of m"
      apply (subst map_of_eq_pair[symmetric])
       defer
       apply (subst map_of_eq_pair[symmetric])
      by (auto intro: 3 map_add_mono)
  }
  moreover
  {
    let ?m = "snd (State_Monad.run_state (lookup1 (key2 k)) m)"
    have "map_of1 ?m m map_of1 m"
      by (smt 3 domIff inv_pair_P_D local.lookup_keys lookup_correct map_le_def map_of1_def surjective_pairing)
    moreover have "map_of2 ?m m map_of2 m"
      by (smt 3 domIff inv_pair_P_D local.lookup_keys lookup_correct map_le_def map_of2_def surjective_pairing)
    moreover have "dom (map_of1 ?m)  dom (map_of2 m) = {}"
      using 3 map_of1 ?m m map_of1 m inv_pair_domD map_le_implies_dom_le by fastforce
    moreover have "inv_pair ?m"
      using 3 inv_pair_lookup1 surjective_pairing by metis
    ultimately have "pair.map_of ?m m pair.map_of m"
      apply (subst map_of_eq_pair[symmetric])
       defer
       apply (subst map_of_eq_pair[symmetric])
      by (auto intro: 3 map_add_mono)
  }
  ultimately show ?case
    by (auto
        split:if_split prod.split
        simp: Let_def lookup_pair_def State_Monad.bind_def State_Monad.run_state_return dest: get_state intro: map_le_refl
        )
next
  case prems: (4 m k v)
  let ?m1 = "snd (State_Monad.run_state (update1 (key2 k) v) m)"
  let ?m2 = "snd (State_Monad.run_state (update2 (key2 k) v) m)"
  from prems have disjoint: "dom (map_of1 m)  dom (map_of2 m) = {}"
    by (simp add: inv_pair_domD)
  show ?case
    apply (auto
        intro: map_le_refl dest: get_state
        split: prod.split
        simp: Let_def update_pair_def State_Monad.bind_def State_Monad.run_state_return
        )
  proof goal_cases
    case (1 m')
    then have "m' = m"
      by (rule get_state)
    from 1 prems have "map_of1 ?m1 m (map_of1 m)(k  v)"
      by (smt inv_pair_P_D map_le_def map_of1_def surjective_pairing domIff
          fst_conv fun_upd_apply injective update_correct update_keys
          )
    moreover from prems have "map_of2 ?m1 m map_of2 m"
      by (smt domIff inv_pair_P_D update_correct update_keys map_le_def map_of2_def surjective_pairing)
    moreover from prems have "dom (map_of1 ?m1)  dom (map_of2 m) = {}"
      by (smt inv_pair_P_D[OF inv_pair m] domIff Int_emptyI eq_snd_iff inv_pair_neq 
          map_of1_def map_of2_def update_keys(1)
          )
    moreover from 1 prems have "k  dom (map_of2 m)"
      using inv_pair_neq map_of2_def by fastforce
    moreover from 1 prems have "inv_pair ?m1"
      using inv_pair_update1 fst_conv surjective_pairing by metis
    ultimately show "pair.map_of (snd (State_Monad.run_state (update1 (key2 k) v) m')) m (pair.map_of m)(k  v)"
      unfolding m' = m using disjoint
      apply (subst map_of_eq_pair[symmetric])
       defer
       apply (subst map_of_eq_pair[symmetric], rule prems)
       apply (subst map_add_upd2[symmetric])
      by (auto intro: map_add_mono)
  next
    case (2 k1 m' m'')
    then have "m' = m" "m'' = m"
      by (auto dest: get_state)
    from 2 prems have "map_of2 ?m2 m (map_of2 m)(k  v)"
      unfolding m' = m m'' = m
      by (smt inv_pair_P_D map_le_def map_of2_def surjective_pairing domIff
          fst_conv fun_upd_apply injective update_correct update_keys
          )
    moreover from prems have "map_of1 ?m2 m map_of1 m"
      by (smt domIff inv_pair_P_D update_correct update_keys map_le_def map_of1_def surjective_pairing)
    moreover from 2 have "dom (map_of1 ?m2)  dom ((map_of2 m)(k  v)) = {}"
      unfolding m' = m
      by (smt domIff map_of1 ?m2 m map_of1 m disjoint_iff_not_equal fst_conv fun_upd_apply
          map_le_def map_of1_def map_of2_def
          )
    moreover from 2 prems have "inv_pair ?m2"
      unfolding m' = m
      using inv_pair_update2 fst_conv surjective_pairing by metis
    ultimately show "pair.map_of (snd (State_Monad.run_state (update2 (key2 k) v) m'')) m (pair.map_of m)(k  v)"
      unfolding m' = m m'' = m
      apply (subst map_of_eq_pair[symmetric])
       defer
       apply (subst map_of_eq_pair[symmetric], rule prems)
       apply (subst map_add_upd[symmetric])
      by (rule map_add_mono)
  next
    case (3 k1 m1 k2 m2 m3)
    then have "m1 = m" "m2 = m"
      by (auto dest: get_state)
    let ?m3 = "snd (State_Monad.run_state (update1 (key2 k) v) m3)"
    from 3 prems have "map_of1 ?m3 m (map_of2 m)(k  v)"
      unfolding m2 = m
      by (smt inv_pair_P_D map_le_def map_of1_def surjective_pairing domIff
          fst_conv fun_upd_apply injective
          inv_pair_move12 move12_correct move12_keys update_correct update_keys
          )
    moreover have "map_of2 ?m3 m map_of1 m"
    proof -
      from prems 3 have "P m" "P m3"
        unfolding m1 = m m2 = m
        using inv_pair_P_D[OF prems] by (auto elim: lift_p_P[OF move12_inv])
      from 3(3)[unfolded m2 = m] have "mem2.map_of ?m3 m mem1.map_of m"
        by - (erule map_le_trans[OF update_correct(3)[OF P m3] move12_correct(2)[OF P m]])
      with 3 prems show ?thesis
        unfolding m1 = m m2 = m map_le_def map_of2_def
        apply auto
        apply (frule move12_keys(2), simp)
        by (metis
            domI inv_pair_def map_of1_def surjective_pairing
            inv_pair_move12 move12_keys(2) update_keys(2)
            )
    qed
    moreover from prems 3 have "dom (map_of1 ?m3)  dom (map_of1 m) = {}"
      unfolding m1 = m m2 = m
      by (smt inv_pair_P_D disjoint_iff_not_equal map_of1_def surjective_pairing domIff
          fst_conv inv_pair_move12 move12_keys update_keys
          )
    moreover from 3 have "k  dom (map_of1 m)"
      by (simp add: domIff map_of1_def)
    moreover from 3 prems have "inv_pair ?m3"
      unfolding m2 = m
      by (metis inv_pair_move12 inv_pair_update1 move12_keys(1) fst_conv surjective_pairing)
    ultimately show ?case
      unfolding m1 = m m2 = m using disjoint
      apply (subst map_of_eq_pair[symmetric])
       defer
       apply (subst map_of_eq_pair[symmetric])
        apply (rule prems)
       apply (subst (2) map_add_comm)
        defer
        apply (subst map_add_upd2[symmetric])
          apply (auto intro: map_add_mono)
      done
  qed
qed

lemma emptyI:
  assumes "inv_pair m" "mem1.map_of m m Map.empty" "mem2.map_of m m Map.empty"
  shows "pair.map_of m m Map.empty"
  using assms by (auto simp: map_of1_def map_of2_def map_le_def map_of_eq_pair[symmetric])

end (* Pair Mem *)


datatype ('k, 'v) pair_storage = Pair_Storage 'k 'k 'v 'v

context mem_correct_empty
begin

context
  fixes key :: "'a  'k"
begin

text ‹We assume that look-ups happen on the older row, so it is biased towards the second entry.›
definition
  "lookup_pair k =
    State (λ mem.
    (
      case mem of Pair_Storage k1 k2 m1 m2  let k' = key k in
        if k' = k2 then case State_Monad.run_state (lookup k) m2 of (v, m)  (v, Pair_Storage k1 k2 m1 m)
        else if k' = k1 then case State_Monad.run_state (lookup k) m1 of (v, m)  (v, Pair_Storage k1 k2 m m2)
        else (None, mem)
    )
    )
  "

text ‹We assume that updates happen on the newer row, so it is biased towards the first entry.›
definition
  "update_pair k v =
    State (λ mem.
    (
      case mem of Pair_Storage k1 k2 m1 m2  let k' = key k in
        if k' = k1 then case State_Monad.run_state (update k v) m1 of (_, m)  ((), Pair_Storage k1 k2 m m2)
        else if k' = k2 then case State_Monad.run_state (update k v) m2 of (_, m)  ((),Pair_Storage k1 k2 m1 m)
        else case State_Monad.run_state (update k v) empty of (_, m)  ((), Pair_Storage k' k1 m m1)
    )
    )
  "

interpretation pair: state_mem_defs lookup_pair update_pair .

definition
  "inv_pair p = (case p of Pair_Storage k1 k2 m1 m2 
    key ` dom (map_of m1)  {k1}  key ` dom (map_of m2)  {k2}  k1  k2  P m1  P m2
  )"

lemma map_of_le_pair:
  "pair.map_of (Pair_Storage k1 k2 m1 m2) m (map_of m1 ++ map_of m2)"
  if "inv_pair (Pair_Storage k1 k2 m1 m2)"
  using that
  unfolding pair.map_of_def
  unfolding lookup_pair_def inv_pair_def map_of_def map_le_def dom_def map_add_def
  apply auto
  apply (auto 4 6 split: prod.split_asm if_split_asm option.split simp: Let_def)
  done

lemma pair_le_map_of:
  "map_of m1 ++ map_of m2 m pair.map_of (Pair_Storage k1 k2 m1 m2)"
  if "inv_pair (Pair_Storage k1 k2 m1 m2)"
  using that
  unfolding pair.map_of_def
  unfolding lookup_pair_def inv_pair_def map_of_def map_le_def dom_def map_add_def
  by (auto 4 5 split: prod.split_asm if_split_asm option.split simp: Let_def)

lemma map_of_eq_pair:
  "map_of m1 ++ map_of m2 = pair.map_of (Pair_Storage k1 k2 m1 m2)"
  if "inv_pair (Pair_Storage k1 k2 m1 m2)"
  using that
  unfolding pair.map_of_def
  unfolding lookup_pair_def inv_pair_def map_of_def map_le_def dom_def map_add_def
  by (auto 4 7 split: prod.split_asm if_split_asm option.split simp: Let_def)

lemma inv_pair_neq[simp, dest]:
  False if "inv_pair (Pair_Storage k k x y)"
  using that unfolding inv_pair_def by auto

lemma inv_pair_P_D1:
  "P m1" if "inv_pair (Pair_Storage k1 k2 m1 m2)"
  using that unfolding inv_pair_def by auto

lemma inv_pair_P_D2:
  "P m2" if "inv_pair (Pair_Storage k1 k2 m1 m2)"
  using that unfolding inv_pair_def by auto

lemma inv_pair_domD[intro]:
  "dom (map_of m1)  dom (map_of m2) = {}" if "inv_pair (Pair_Storage k1 k2 m1 m2)"
  using that unfolding inv_pair_def by fastforce

lemma mem_correct_pair:
  "mem_correct lookup_pair update_pair inv_pair"
proof (standard, goal_cases)
  case (1 k) ― ‹Lookup invariant›
  with lookup_inv[of k] show ?case
    unfolding lookup_pair_def Let_def
    by (auto intro!: lift_pI split: pair_storage.split_asm if_split_asm prod.split_asm)
       (auto dest: lift_p_P simp: inv_pair_def,
         (force dest!: lookup_correct[of _ k] map_le_implies_dom_le)+
       )
next
  case (2 k v) ― ‹Update invariant›
  with update_inv[of k v] update_correct[OF P_empty, of k v] P_empty show ?case
    unfolding update_pair_def Let_def
    by (auto intro!: lift_pI split: pair_storage.split_asm if_split_asm prod.split_asm)
       (auto dest: lift_p_P simp: inv_pair_def,
         (force dest: lift_p_P dest!: update_correct[of _ k v] map_le_implies_dom_le)+
       )
next
  case (3 m k)
  {
    fix m1 v1 m1' m2 v2 m2' k1 k2
    assume assms:
      "State_Monad.run_state (lookup k) m1 = (v1, m1')" "State_Monad.run_state (lookup k) m2 = (v2, m2')"
      "inv_pair (Pair_Storage k1 k2 m1 m2)"
    from assms have "P m1" "P m2"
      by (auto intro: inv_pair_P_D1 inv_pair_P_D2)
    have [intro]: "map_of m1' m map_of m1" "map_of m2' m map_of m2"
      using lookup_correct[OF P m1, of k] lookup_correct[OF P m2, of k] assms by auto
    from inv_pair_domD[OF assms(3)] have 1: "dom (map_of m1')  dom (map_of m2) = {}"
      by (metis (no_types) map_of m1' m map_of m1 disjoint_iff_not_equal domIff map_le_def)
    have inv1: "inv_pair (Pair_Storage (key k) k2 m1' m2)" if "k2  key k" "k1 = key k"
      using that P m1 P m2 unfolding inv_pair_def
      apply clarsimp
      apply safe
      subgoal for x' y
        using assms(1,3) lookup_correct[OF P m1, of k, THEN map_le_implies_dom_le]
        unfolding inv_pair_def by auto
      subgoal for x' y
        using assms(3) unfolding inv_pair_def by fastforce
      using lookup_inv[of k] assms unfolding lift_p_def by force
    have inv2: "inv_pair (Pair_Storage k1 (key k) m1 m2')" if "k2 = key k" "k1  key k"
      using that P m1 P m2 unfolding inv_pair_def
      apply clarsimp
      apply safe
      subgoal for x' y
        using assms(3) unfolding inv_pair_def by fastforce
      subgoal for x x' y
        using assms(2,3) lookup_correct[OF P m2, of k, THEN map_le_implies_dom_le]
        unfolding inv_pair_def by fastforce
      using lookup_inv[of k] assms unfolding lift_p_def by force
    have A:
      "pair.map_of (Pair_Storage (key k) k2 m1' m2) m pair.map_of (Pair_Storage (key k) k2 m1 m2)"
      if "k2  key k" "k1 = key k"
      using inv1 assms(3) 1
      by (auto intro: map_add_mono map_le_refl simp: that map_of_eq_pair[symmetric])
    have B:
      "pair.map_of (Pair_Storage k1 (key k) m1 m2') m pair.map_of (Pair_Storage k1 (key k) m1 m2)"
      if "k2 = key k" "k1  key k"
      using inv2 assms(3) that
      by (auto intro: map_add_mono map_le_refl simp: map_of_eq_pair[symmetric] dest: inv_pair_domD)
    note A B
  }
  with inv_pair m show ?case
    by (auto split: pair_storage.split if_split prod.split simp: Let_def lookup_pair_def)
next
  case (4 m k v)
  {
    fix m1 v1 m1' m2 v2 m2' m3 k1 k2
    assume assms:
      "State_Monad.run_state (update k v) m1 = ((), m1')" "State_Monad.run_state (update k v) m2 = ((), m2')"
      "State_Monad.run_state (update k v) empty = ((), m3)"
      "inv_pair (Pair_Storage k1 k2 m1 m2)"
    from assms have "P m1" "P m2"
      by (auto intro: inv_pair_P_D1 inv_pair_P_D2)
    from assms(3) P_empty update_inv[of k v] have "P m3"
      unfolding lift_p_def by auto
    have [intro]: "map_of m1' m (map_of m1)(k  v)" "map_of m2' m (map_of m2)(k  v)"
      using update_correct[OF P m1, of k v] update_correct[OF P m2, of k v] assms by auto
    have "map_of m3 m (map_of empty)(k  v)"
      using assms(3) update_correct[OF P_empty, of k v] by auto
    also have " m (map_of m2)(k  v)"
      using empty_correct by (auto elim: map_le_trans intro!: map_le_upd)
    finally have "map_of m3 m (map_of m2)(k  v)" .
    have 1: "dom (map_of m1)  dom ((map_of m2)(k  v)) = {}" if "k1  key k"
      using assms(4) that by (force simp: inv_pair_def)
    have 2: "dom (map_of m3)  dom (map_of m1) = {}" if "k1  key k"
      using local.map_of m3 m (map_of empty)(k  v) assms(4) that
      by (fastforce dest!: map_le_implies_dom_le simp: inv_pair_def)
    have inv: "inv_pair (Pair_Storage (key k) k1 m3 m1)" if "k2  key k" "k1  key k"
      using that P m1 P m2 P m3 unfolding inv_pair_def
      apply clarsimp
      apply safe
      subgoal for x x' y
        using assms(3) update_correct[OF P_empty, of k v, THEN map_le_implies_dom_le]
          empty_correct
        by (auto dest: map_le_implies_dom_le)
      subgoal for x x' y
        using assms(4) unfolding inv_pair_def by fastforce
      done
    have A:
      "pair.map_of (Pair_Storage (key k) k1 m3 m1) m (pair.map_of (Pair_Storage k1 k2 m1 m2))(k  v)"
      if "k2  key k" "k1  key k"
      using inv assms(4) map_of m3 m (map_of m2)(k  v) 1
      apply (simp add: that map_of_eq_pair[symmetric])
      apply (subst map_add_upd[symmetric], subst Map.map_add_comm, rule 2, rule that)
      by (rule map_add_mono; auto)
    have inv1: "inv_pair (Pair_Storage (key k) k2 m1' m2)" if "k2  key k" "k1 = key k"
      using that P m1 P m2 unfolding inv_pair_def
      apply clarsimp
      apply safe
      subgoal for x' y
        using assms(1,4) update_correct[OF P m1, of k v, THEN map_le_implies_dom_le]
        unfolding inv_pair_def by auto
      subgoal for x' y
        using assms(4) unfolding inv_pair_def by fastforce
      using update_inv[of k v] assms unfolding lift_p_def by force
    have inv2: "inv_pair (Pair_Storage k1 (key k) m1 m2')" if "k2 = key k" "k1  key k"
      using that P m1 P m2 unfolding inv_pair_def
      apply clarsimp
      apply safe
      subgoal for x' y
        using assms(4) unfolding inv_pair_def by fastforce
      subgoal for x x' y
        using assms(2,4) update_correct[OF P m2, of k v, THEN map_le_implies_dom_le]
        unfolding inv_pair_def by fastforce
      using update_inv[of k v] assms unfolding lift_p_def by force
    have C:
      "pair.map_of (Pair_Storage (key k) k2 m1' m2) m
       (pair.map_of (Pair_Storage (key k) k2 m1 m2))(k  v)"
      if "k2  key k" "k1 = key k"
      using inv1[OF that] assms(4) inv_pair m
      by (simp add: that map_of_eq_pair[symmetric])
         (subst map_add_upd2[symmetric]; force simp: inv_pair_def intro: map_add_mono map_le_refl)
    have B:
      "pair.map_of (Pair_Storage k1 (key k) m1 m2') m
       (pair.map_of (Pair_Storage k1 (key k) m1 m2))(k  v)"
      if "k2 = key k" "k1  key k"
      using inv2[OF that] assms(4)
      by (simp add: that map_of_eq_pair[symmetric])
         (subst map_add_upd[symmetric]; rule map_add_mono; force simp: inv_pair_def)
    note A B C
  }
  with inv_pair m show ?case
    by (auto split: pair_storage.split if_split prod.split simp: Let_def update_pair_def)
qed

end (* Key function *)

end (* Lookup & Update w/ Empty *)

end (* Theory *)