Theory MapExtra

(*
 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
 *
 * SPDX-License-Identifier: BSD-2-Clause
 *)

section ‹More properties of maps plus map disjuction.›

theory MapExtra
imports Main
begin

text ‹
  BEWARE: we are not interested in using the @{term "dom x  dom y = {}"}
  rules from Map for our separation logic proofs. As such, we overwrite the
  Map rules where that form of disjointness is in the assumption conflicts
  with a name we want to use with @{text "⊥"}.›

text ‹
  A note on naming:
  Anything not involving heap disjuction can potentially be incorporated
  directly into Map.thy, thus uses @{text "m"}.
  Anything involving heap disjunction is not really mergeable with Map, is
  destined for use in separation logic, and hence uses @{text "h"}

text ‹---------------------------------------›
text ‹Things that should go into Option Type›
text ‹---------------------------------------›

text ‹Misc option lemmas›

lemma None_not_eq: "(None  x) = (y. x = Some y)" by (cases x) auto

lemma None_com: "(None = x) = (x = None)" by fast

lemma Some_com: "(Some y = x) = (x = Some y)" by fast

text ‹---------------------------------------›
text ‹Things that should go into Map.thy›
text ‹---------------------------------------›

text ‹Map intersection: set of all keys for which the maps agree.›

definition
  map_inter :: "('a  'b)  ('a  'b)  'a set" (infixl m 70) where
  "m1 m m2  {x  dom m1. m1 x = m2 x}"

text ‹Map restriction via domain subtraction›

definition
  sub_restrict_map :: "('a  'b) => 'a set => ('a  'b)" (infixl `-  110)
  where
  "m `- S  (λx. if x  S then None else m x)"

subsection ‹Properties of maps not related to restriction›

lemma empty_forall_equiv: "(m = Map.empty) = (x. m x = None)"
  by (rule fun_eq_iff)

lemma map_le_empty2 [simp]:
  "(m m Map.empty) = (m = Map.empty)"
  by (auto simp: map_le_def)

lemma dom_iff:
  "(y. m x = Some y) = (x  dom m)"
  by auto

lemma non_dom_eval:
  "x  dom m  m x = None"
  by auto

lemma non_dom_eval_eq:
  "x  dom m = (m x = None)"
  by auto

lemma map_add_same_left_eq:
  "m1 = m1'  (m0 ++ m1 = m0 ++ m1')"
  by simp

lemma map_add_left_cancelI [intro!]:
  "m1 = m1'  m0 ++ m1 = m0 ++ m1'"
  by simp

lemma dom_empty_is_empty:
  "(dom m = {}) = (m = Map.empty)"
proof (rule iffI)
  assume a: "dom m = {}"
  { assume "m  Map.empty"
    hence "dom m  {}"
      by - (subst (asm) empty_forall_equiv, simp add: dom_def)
    hence False using a by blast
  }
  thus "m = Map.empty" by blast
next
  assume a: "m = Map.empty"
  thus "dom m = {}" by simp
qed

lemma map_add_dom_eq:
  "dom m = dom m'  m ++ m' = m'"
  by (rule ext) (auto simp: map_add_def split: option.splits)

lemma map_add_right_dom_eq:
  " m0 ++ m1 = m0' ++ m1'; dom m1 = dom m1'   m1 = m1'"
  unfolding map_add_def
  apply (rule ext)
  apply (rule ccontr)
  subgoal for x
    apply (drule fun_cong [where x=x], clarsimp split: option.splits)
    apply (drule sym, drule sym, force+)
    done
  done

lemma map_le_same_dom_eq:
  " m0 m m1 ; dom m0 = dom m1   m0 = m1"
  by (simp add: map_le_antisym map_le_def)

subsection ‹Properties of map restriction›

lemma restrict_map_cancel:
  "(m |` S = m |` T) = (dom m  S = dom m  T)"
  by (fastforce intro: set_eqI dest: fun_cong
               simp: restrict_map_def None_not_eq
               split: if_split_asm)

lemma map_add_restricted_self [simp]:
  "m ++ m |` S = m"
  by (auto simp: restrict_map_def map_add_def split: option.splits)

lemma map_add_restrict_dom_right [simp]:
  "(m ++ m') |` dom m' = m'"
  by (rule ext, auto simp: restrict_map_def map_add_def split: option.splits)

lemma restrict_map_UNIV [simp]:
  "m |` UNIV = m"
  by (simp add: restrict_map_def)

lemma restrict_map_dom:
  "S = dom m  m |` S = m"
  by (fastforce simp: restrict_map_def None_not_eq)

lemma restrict_map_subdom:
  "dom m  S  m |` S = m"
  by (fastforce simp: restrict_map_def None_com)

lemma map_add_restrict:
  "(m0 ++ m1) |` S = ((m0 |` S) ++ (m1 |` S))"
  by (force simp: map_add_def restrict_map_def)

lemma map_le_restrict:
  "m m m'  m = m' |` dom m"
  by (force simp: map_le_def restrict_map_def None_com)

lemma restrict_map_le:
  "m |` S m m"
  by (auto simp: map_le_def)

lemma restrict_map_remerge:
  " S  T = {}   m |` S ++ m |` T = m |` (S  T)"
  by (rule ext, clarsimp simp: restrict_map_def map_add_def
                         split: option.splits)

lemma restrict_map_empty:
  "dom m  S = {}  m |` S = Map.empty"
  by (fastforce simp: restrict_map_def)

lemma map_add_restrict_comp_right [simp]:
  "(m |` S ++ m |` (UNIV - S)) = m"
  by (force simp: map_add_def restrict_map_def split: option.splits)

lemma map_add_restrict_comp_right_dom [simp]:
  "(m |` S ++ m |` (dom m - S)) = m"
  by (fastforce simp: map_add_def restrict_map_def split: option.splits)

lemma map_add_restrict_comp_left [simp]:
  "(m |` (UNIV - S) ++ m |` S) = m"
  by (subst map_add_comm, auto)

lemma restrict_self_UNIV:
  "m |` (dom m - S) = m |` (UNIV - S)"
  by (fastforce simp: restrict_map_def)

lemma map_add_restrict_nonmember_right:
  "x  dom m'  (m ++ m') |` {x} = m |` {x}"
  by (rule ext, auto simp: restrict_map_def map_add_def split: option.splits)

lemma map_add_restrict_nonmember_left:
  "x  dom m  (m ++ m') |` {x} = m' |` {x}"
  by (rule ext, auto simp: restrict_map_def map_add_def split: option.splits)

lemma map_add_restrict_right:
  "x  dom m'  (m ++ m') |` x = m' |` x"
  by (rule ext, auto simp: restrict_map_def map_add_def split: option.splits)

lemma restrict_map_compose:
  " S  T = dom m ; S  T = {}   m |` S ++ m |` T = m"
  by (fastforce simp: map_add_def restrict_map_def)

lemma map_le_dom_subset_restrict:
  " m' m m; dom m'  S   m' m (m |` S)"
  by (force simp: restrict_map_def map_le_def)

lemma map_le_dom_restrict_sub_add:
  "m' m m  m |` (dom m - dom m') ++ m' = m"
  by (metis map_add_restrict_comp_right_dom map_le_iff_map_add_commute map_le_restrict)

lemma subset_map_restrict_sub_add:
  "T  S  m |` (S - T) ++ m |` T = m |` S"
  by (fastforce simp: restrict_map_def map_add_def split: option.splits)

lemma restrict_map_sub_union:
  "m |` (dom m - (S  T)) = (m |` (dom m - T)) |` (dom m - S)"
  by (auto simp: restrict_map_def)

lemma prod_restrict_map_add:
  " S  T = U; S  T = {}   m |` (X × S) ++ m |` (X × T) = m |` (X × U)"
  by (auto simp: map_add_def restrict_map_def split: option.splits)


text ‹---------------------------------------›
text ‹Things that should NOT go into Map.thy›
text ‹---------------------------------------›

section ‹Definitions›

text ‹Map disjuction›

definition
  map_disj :: "('a  'b)  ('a  'b)  bool" (infix  51) where
  "h0  h1  dom h0  dom h1 = {}"

declare None_not_eq [simp]

text ‹Heap monotonicity and the frame property›

definition
  heap_mono :: "(('a  'b)  'c option)  bool" where
  "heap_mono f  h h' v. h  h'  f h = Some v  f (h ++ h') = Some v"

lemma heap_monoE:
  " heap_mono f ; f h = Some v ; h  h'   f (h ++ h') = Some v"
  unfolding heap_mono_def by blast

lemma heap_mono_simp:
  " heap_mono f ; f h = Some v ; h  h'   f (h ++ h') = f h"
  by (frule (2) heap_monoE, simp)

definition
  heap_frame :: "(('a  'b)  'c option)  bool" where
  "heap_frame f  h h' v. h  h'  f (h ++ h') = Some v
                            (f h = Some v  f h = None)"

lemma heap_frameE:
  " heap_frame f ; f (h ++ h') = Some v ; h  h' 
    f h = Some v  f h = None"
  unfolding heap_frame_def by fastforce


section ‹Properties of @{term "sub_restrict_map"}

lemma restrict_map_sub_disj: "h |` S  h `- S"
  by (fastforce simp: sub_restrict_map_def restrict_map_def map_disj_def
               split: option.splits if_split_asm)

lemma restrict_map_sub_add: "h |` S ++ h `- S = h"
  by (fastforce simp: sub_restrict_map_def restrict_map_def map_add_def
               split: option.splits if_split)


section ‹Properties of map disjunction›

lemma map_disj_empty_right [simp]:
  "h  Map.empty"
  by (simp add: map_disj_def)

lemma map_disj_empty_left [simp]:
  "Map.empty  h"
  by (simp add: map_disj_def)

lemma map_disj_com:
  "h0  h1 = h1  h0"
  by (simp add: map_disj_def, fast)

lemma map_disjD:
  "h0  h1  dom h0  dom h1 = {}"
  by (simp add: map_disj_def)

lemma map_disjI:
  "dom h0  dom h1 = {}  h0  h1"
  by (simp add: map_disj_def)


subsection ‹Map associativity-commutativity based on map disjuction›

lemma map_add_com:
  "h0  h1  h0 ++ h1 = h1 ++ h0"
  by (drule map_disjD, rule map_add_comm, force)

lemma map_add_left_commute:
  "h0  h1  h0 ++ (h1 ++ h2) = h1 ++ (h0 ++ h2)"
  by (simp add: map_add_com map_disj_com)

lemma map_add_disj:
  "h0  (h1 ++ h2) = (h0  h1  h0  h2)"
  by (simp add: map_disj_def, fast)

lemma map_add_disj':
  "(h1 ++ h2)  h0 = (h1  h0  h2  h0)"
  by (simp add: map_disj_def, fast)

text ‹
  We redefine @{term "map_add"} associativity to bind to the right, which
  seems to be the more common case.
  Note that when a theory includes Map again, @{text "map_add_assoc"} will
  return to the simpset and will cause infinite loops if its symmetric
  counterpart is added (e.g. via @{text "map_ac_simps"})
›

declare map_add_assoc [simp del]

text ‹
  Since the associativity-commutativity of @{term "map_add"} relies on
  map disjunction, we include some basic rules into the ac set.
›

lemmas map_ac_simps =
  map_add_assoc[symmetric] map_add_com map_disj_com
  map_add_left_commute map_add_disj map_add_disj'


subsection ‹Basic properties›

lemma map_disj_None_right:
  " h0  h1 ; x  dom h0   h1 x = None"
  by (auto simp: map_disj_def dom_def)

lemma map_disj_None_left:
  " h0  h1 ; x  dom h1   h0 x = None"
  by (auto simp: map_disj_def dom_def)

lemma map_disj_None_left':
  " h0 x = Some y ; h1  h0   h1 x = None "
  by (auto simp: map_disj_def)

lemma map_disj_None_right':
  " h1 x = Some y ; h1  h0   h0 x = None "
  by (auto simp: map_disj_def)

lemma map_disj_common:
  " h0  h1 ; h0 p = Some v ; h1 p = Some v'   False"
  by (frule (1) map_disj_None_left', simp)


subsection ‹Map disjunction and addition›

lemma map_add_eval_left:
  " x  dom h ; h  h'   (h ++ h') x = h x"
  by (auto dest!: map_disj_None_right simp: map_add_def cong: option.case_cong)

lemma map_add_eval_right:
  " x  dom h' ; h  h'   (h ++ h') x = h' x"
  by (auto elim!: map_disjD simp: map_add_comm map_add_eval_left map_disj_com)

lemma map_add_eval_left':
  " x  dom h' ; h  h'   (h ++ h') x = h x"
  by (clarsimp simp: map_disj_def map_add_def split: option.splits)

lemma map_add_eval_right':
  " x  dom h ; h  h'   (h ++ h') x = h' x"
  by (clarsimp simp: map_disj_def map_add_def split: option.splits)

lemma map_add_left_dom_eq:
  assumes eq: "h0 ++ h1 = h0' ++ h1'"
  assumes etc: "h0  h1" "h0'  h1'" "dom h0 = dom h0'"
  shows "h0 = h0'"
proof -
  from eq have "h1 ++ h0 = h1' ++ h0'" using etc by (simp add: map_ac_simps)
  thus ?thesis using etc
    by (fastforce elim!: map_add_right_dom_eq simp: map_ac_simps)
qed

lemma map_add_left_eq:
  assumes eq: "h0 ++ h = h1 ++ h"
  assumes disj: "h0  h" "h1  h"
  shows "h0 = h1"
proof (rule ext)
  fix x
  from eq have eq': "(h0 ++ h) x = (h1 ++ h) x" by auto
  { assume "x  dom h"
    hence "h0 x = h1 x" using disj by (simp add: map_disj_None_left)
  } moreover {
    assume "x  dom h"
    hence "h0 x = h1 x" using disj eq' by (simp add: map_add_eval_left')
  }
  ultimately show "h0 x = h1 x" by cases
qed


lemma map_add_right_eq:
  "h ++ h0 = h ++ h1; h0  h; h1  h  h0 = h1"
  by (rule map_add_left_eq [where h=h], auto simp: map_ac_simps)

lemma map_disj_add_eq_dom_right_eq:
  assumes merge: "h0 ++ h1 = h0' ++ h1'" and d: "dom h0 = dom h0'" and
      ab_disj: "h0  h1" and cd_disj: "h0'  h1'"
  shows "h1 = h1'"
proof (rule ext)
  fix x
  from merge have merge_x: "(h0 ++ h1) x = (h0' ++ h1') x" by simp
  with d ab_disj cd_disj show  "h1 x = h1' x"
    apply (cases "h1 x")
     apply (cases "h1' x")
      apply force
     apply (fastforce simp: map_disj_def)
    apply (cases "h1' x")
     apply clarsimp
     apply (simp add: Some_com)
    by (force simp: map_disj_def)+
qed

lemma map_disj_add_eq_dom_left_eq:
  assumes add: "h0 ++ h1 = h0' ++ h1'" and
          dom: "dom h1 = dom h1'" and
          disj: "h0  h1" "h0'  h1'"
  shows "h0 = h0'"
proof -
  have "h1 ++ h0 = h1' ++ h0'" using add disj by (simp add: map_ac_simps)
  thus ?thesis using dom disj
    by - (rule map_disj_add_eq_dom_right_eq, auto simp: map_disj_com)
qed

lemma map_add_left_cancel:
  assumes disj: "h0  h1" "h0  h1'"
  shows "(h0 ++ h1 = h0 ++ h1') = (h1 = h1')"
proof (rule iffI, rule ext)
  fix x
  assume "(h0 ++ h1) = (h0 ++ h1')"
  hence "(h0 ++ h1) x = (h0 ++ h1') x" by auto
  hence "h1 x = h1' x" using disj
    by - (cases "x  dom h0",
          simp_all add: map_disj_None_right map_add_eval_right')
  thus "h1 x = h1' x" by auto
qed auto

lemma map_add_lr_disj:
  " h0 ++ h1 = h0' ++ h1'; h1  h1'    dom h1  dom h0'"
  apply (clarsimp simp: map_disj_def map_add_def)
  subgoal for x y
    apply (drule fun_cong [where x=x])
    apply (auto split: option.splits)
    done
  done


subsection ‹Map disjunction and updates›

lemma map_disj_update_left [simp]:
  "p  dom h1  h0  h1(p  v) = h0  h1"
  by (clarsimp simp add: map_disj_def, blast)

lemma map_disj_update_right [simp]:
  "p  dom h1  h1(p  v)  h0 = h1  h0"
  by (simp add: map_disj_com)

lemma map_add_update_left:
  " h0  h1 ; p  dom h0   (h0 ++ h1)(p  v) = (h0(p  v) ++ h1)"
  by (drule (1) map_disj_None_right)
     (auto simp: map_add_def cong: option.case_cong)

lemma map_add_update_right:
  " h0  h1 ; p  dom h1    (h0 ++ h1)(p  v) = (h0 ++ h1 (p  v))"
  by (drule (1) map_disj_None_left)
     (auto simp: map_add_def cong: option.case_cong)

lemma map_add3_update:
  " h0  h1 ; h1   h2 ; h0  h2 ; p  dom h0 
   (h0 ++ h1 ++ h2)(p  v) = h0(p  v) ++ h1 ++ h2"
  by (auto simp: map_add_update_left[symmetric] map_ac_simps)


subsection ‹Map disjunction and @{term "map_le"}

lemma map_le_override [simp]:
  " h  h'   h m h ++ h'"
  by (auto simp: map_le_def map_add_def map_disj_def split: option.splits)

lemma map_leI_left:
  " h = h0 ++ h1 ; h0  h1   h0 m h" by auto

lemma map_leI_right:
  " h = h0 ++ h1 ; h0  h1   h1 m h" by auto

lemma map_disj_map_le:
  " h0' m h0; h0  h1   h0'  h1"
  by (force simp: map_disj_def map_le_def)

lemma map_le_on_disj_left:
  " h' m h ; h0  h1 ; h' = h0 ++ h1   h0 m h"
  unfolding map_le_def
  apply (rule ballI)
  subgoal for a
    apply (erule ballE [where x=a], auto simp: map_add_eval_left)+
    done
  done

lemma map_le_on_disj_right:
  " h' m h ; h0  h1 ; h' = h1 ++ h0   h0 m h"
  by (auto simp: map_le_on_disj_left map_ac_simps)

lemma map_le_add_cancel:
  " h0  h1 ; h0' m h0   h0' ++ h1 m h0 ++ h1"
  by (auto simp: map_le_def map_add_def map_disj_def split: option.splits)

lemma map_le_override_bothD:
  assumes subm: "h0' ++ h1 m h0 ++ h1"
  assumes disj': "h0'  h1"
  assumes disj: "h0  h1"
  shows "h0' m h0"
unfolding map_le_def
proof (rule ballI)
  fix a
  assume a: "a  dom h0'"
  hence sumeq: "(h0' ++ h1) a = (h0 ++ h1) a"
    using subm unfolding map_le_def by auto
  from a have "a  dom h1" using disj' by (auto dest!: map_disj_None_right)
  thus "h0' a = h0 a" using a sumeq disj disj'
    by (simp add: map_add_eval_left map_add_eval_left')
qed

lemma map_le_conv:
  "(h0' m h0  h0'  h0) = (h1. h0 = h0' ++ h1  h0'  h1  h0'  h0)"
  unfolding map_le_def map_disj_def map_add_def
  using exI[where x="λx. if x  dom h0' then h0 x else None"]
  by (rule iffI, clarsimp)
     (fastforce intro: set_eqI split: option.splits if_split_asm)+

lemma map_le_conv2:
  "h0' m h0 = (h1. h0 = h0' ++ h1  h0'  h1)"
  by (cases "h0'=h0", insert map_le_conv, auto intro: exI[where x=Map.empty])


subsection ‹Map disjunction and restriction›

lemma map_disj_comp [simp]:
  "h0  h1 |` (UNIV - dom h0)"
  by (force simp: map_disj_def)

lemma restrict_map_disj:
  "S  T = {}  h |` S  h |` T"
  by (auto simp: map_disj_def restrict_map_def dom_def)

lemma map_disj_restrict_dom [simp]:
  "h0  h1 |` (dom h1 - dom h0)"
  by (force simp: map_disj_def)

lemma restrict_map_disj_dom_empty:
  "h  h'  h |` dom h' = Map.empty"
  by (fastforce simp: map_disj_def restrict_map_def)

lemma restrict_map_univ_disj_eq:
  "h  h'  h |` (UNIV - dom h') = h"
  by (rule ext, auto simp: map_disj_def restrict_map_def)

lemma restrict_map_disj_dom:
  "h0  h1  h |` dom h0  h |` dom h1"
  by (auto simp: map_disj_def restrict_map_def dom_def)

lemma map_add_restrict_dom_left:
  "h  h'  (h ++ h') |` dom h = h"
  by (rule ext, auto simp: restrict_map_def map_add_def dom_def map_disj_def
                     split: option.splits)

lemma restrict_map_disj_left:
  "h0  h1  h0 |` S  h1"
  by (auto simp: map_disj_def)

lemma restrict_map_disj_right:
  "h0  h1  h0  h1 |` S"
  by (auto simp: map_disj_def)

lemmas restrict_map_disj_both = restrict_map_disj_right restrict_map_disj_left

lemma map_dom_disj_restrict_right:
  "h0  h1  (h0 ++ h0') |` dom h1 = h0' |` dom h1"
  by (simp add: map_add_restrict restrict_map_empty map_disj_def)

lemma restrict_map_on_disj:
  "h0'  h1  h0 |` dom h0'  h1"
  unfolding map_disj_def by auto

lemma restrict_map_on_disj':
  "h0  h1  h0  h1 |` S"
  by (auto simp: map_disj_def map_add_def)

lemma map_le_sub_dom:
  " h0 ++ h1 m h ; h0  h1   h0 m h |` (dom h - dom h1)"
  by (rule map_le_override_bothD, subst map_le_dom_restrict_sub_add)
     (auto elim: map_add_le_mapE simp: map_ac_simps)

lemma map_submap_break:
  " h m h'   h' = (h' |` (UNIV - dom h)) ++ h"
  by (fastforce split: option.splits
                simp: map_le_restrict restrict_map_def map_le_def map_add_def dom_def)

lemma map_add_disj_restrict_both:
  " h0  h1; S  S' = {}; T  T' = {} 
    (h0 |` S) ++ (h1 |` T)  (h0 |` S') ++ (h1 |` T')"
  by (auto simp: map_ac_simps intro!: restrict_map_disj_both restrict_map_disj)

end