Theory Heap_Error_Monad

(***********************************************************************************
 * Copyright (c) 2016-2018 The University of Sheffield, UK
 *
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * * Redistributions of source code must retain the above copyright notice, this
 *   list of conditions and the following disclaimer.
 *
 * * Redistributions in binary form must reproduce the above copyright notice,
 *   this list of conditions and the following disclaimer in the documentation
 *   and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 * 
 * SPDX-License-Identifier: BSD-2-Clause
 ***********************************************************************************)

section‹The Heap Error Monad›
text ‹In this theory, we define a heap and error monad for modeling exceptions. 
This allows us to define composite methods similar to stateful programming in Haskell, 
but also to stay close to the official DOM specification.›
theory 
  Heap_Error_Monad
  imports
    Hiding_Type_Variables
    "HOL-Library.Monad_Syntax"
begin

subsection ‹The Program Data Type›

datatype ('heap, 'e, 'result) prog = Prog (the_prog: "'heap  'e + 'result × 'heap")
register_default_tvars "('heap, 'e, 'result) prog" (print, parse)

subsection ‹Basic Functions›

definition 
  bind :: "(_, 'result) prog  ('result  (_, 'result2) prog)  (_, 'result2) prog"
  where
    "bind f g = Prog (λh. (case (the_prog f) h of Inr (x, h')  (the_prog (g x)) h' 
                                                | Inl exception  Inl exception))"

adhoc_overloading Monad_Syntax.bind bind

definition 
  execute :: "'heap  ('heap, 'e, 'result) prog  ('e + 'result × 'heap)" 
  ("((_)/  (_))" [51, 52] 55)
  where
    "execute h p = (the_prog p) h"

definition 
  returns_result :: "'heap  ('heap, 'e, 'result) prog  'result  bool" 
  ("((_)/  (_)/ r (_))" [60, 35, 61] 65)
  where
    "returns_result h p r  (case h  p of Inr (r', _)  r = r' | Inl _  False)"

fun select_result ("|(_)|r")
  where
    "select_result (Inr (r, _)) = r"
  | "select_result (Inl _) = undefined"

lemma returns_result_eq [elim]: "h  f r y  h  f r y'  y = y'"
  by(auto simp add: returns_result_def split: sum.splits)

definition 
  returns_heap :: "'heap  ('heap, 'e, 'result) prog  'heap  bool" 
  ("((_)/  (_)/ h (_))" [60, 35, 61] 65)
  where
    "returns_heap h p h'  (case h  p of Inr (_ , h'')  h' = h'' | Inl _  False)"

fun select_heap ("|(_)|h")
  where
    "select_heap (Inr ( _, h)) = h"
  | "select_heap (Inl _) = undefined"

lemma returns_heap_eq [elim]: "h  f h h'  h  f h h''  h' = h''"
  by(auto simp add: returns_heap_def split: sum.splits)

definition 
  returns_result_heap :: "'heap  ('heap, 'e, 'result) prog  'result  'heap  bool" 
  ("((_)/  (_)/ r (_) h (_))" [60, 35, 61, 62] 65)
  where
    "returns_result_heap h p r h'  h  p r r  h  p h h'"

lemma return_result_heap_code [code]:
  "returns_result_heap h p r h'  (case h  p of Inr (r', h'')  r = r'  h' = h'' | Inl _  False)"
  by(auto simp add: returns_result_heap_def returns_result_def returns_heap_def split: sum.splits)

fun select_result_heap ("|(_)|rh")
  where
    "select_result_heap (Inr (r, h)) = (r, h)"
  | "select_result_heap (Inl _) = undefined"

definition 
  returns_error :: "'heap  ('heap, 'e, 'result) prog  'e  bool" 
  ("((_)/  (_)/ e (_))" [60, 35, 61] 65)
  where
    "returns_error h p e = (case h  p of Inr _  False | Inl e'  e = e')"

definition is_OK :: "'heap  ('heap, 'e, 'result) prog  bool" ("((_)/  ok (_))" [75, 75])
  where
    "is_OK h p = (case h  p of Inr _  True | Inl _  False)"

lemma is_OK_returns_result_I [intro]: "h  f r y  h  ok f"
  by(auto simp add: is_OK_def returns_result_def split: sum.splits)

lemma is_OK_returns_result_E [elim]:
  assumes "h  ok f"
  obtains x where "h  f r x"
  using assms by(auto simp add: is_OK_def returns_result_def split: sum.splits)

lemma is_OK_returns_heap_I [intro]: "h  f h h'  h  ok f"
  by(auto simp add: is_OK_def returns_heap_def split: sum.splits)

lemma is_OK_returns_heap_E [elim]:
  assumes "h  ok f"
  obtains h' where "h  f h h'"
  using assms by(auto simp add: is_OK_def returns_heap_def split: sum.splits)

lemma select_result_I:
  assumes "h  ok f"
    and "x. h  f r x  P x"
  shows "P |h  f|r"
  using assms
  by(auto simp add: is_OK_def returns_result_def split: sum.splits)

lemma select_result_I2 [simp]:
  assumes "h  f r x"
  shows "|h  f|r = x"
  using assms
  by(auto simp add: is_OK_def returns_result_def split: sum.splits)

lemma returns_result_select_result [simp]:
  assumes "h  ok f"
  shows "h  f r |h  f|r"
  using assms
  by (simp add: select_result_I)

lemma select_result_E:
  assumes "P |h  f|r" and "h  ok f" 
  obtains x where "h  f r x" and "P x"
  using assms
  by(auto simp add: is_OK_def returns_result_def split: sum.splits)

lemma select_result_eq: "(x .h  f r x = h'  f r x)  |h  f|r = |h'  f|r"
  by (metis (no_types, lifting) is_OK_def old.sum.simps(6) select_result.elims 
      select_result_I select_result_I2)

definition error :: "'e  ('heap, 'e, 'result) prog"
  where
    "error exception = Prog (λh. Inl exception)"

lemma error_bind [iff]: "(error e  g) = error e"
  unfolding error_def bind_def by auto

lemma error_returns_result [simp]: "¬ (h  error e r y)"
  unfolding returns_result_def error_def execute_def by auto

lemma error_returns_heap [simp]: "¬ (h  error e h h')"
  unfolding returns_heap_def error_def execute_def by auto

lemma error_returns_error [simp]: "h  error e e e"
  unfolding returns_error_def error_def execute_def by auto

definition return :: "'result  ('heap, 'e, 'result) prog"
  where
    "return result = Prog (λh. Inr (result, h))"

lemma return_ok [simp]: "h  ok (return x)"
  by(simp add: return_def is_OK_def execute_def)

lemma return_bind [iff]: "(return x  g) = g x"
  unfolding return_def bind_def by auto

lemma return_id [simp]: "f  return = f"
  by (induct f) (auto simp add: return_def bind_def split: sum.splits prod.splits)

lemma return_returns_result [iff]: "(h  return x r y) = (x = y)"
  unfolding returns_result_def return_def execute_def by auto

lemma return_returns_heap [iff]: "(h  return x h h') = (h = h')"
  unfolding returns_heap_def return_def execute_def by auto

lemma return_returns_error [iff]: "¬ h  return x e e"
  unfolding returns_error_def execute_def return_def by auto

definition noop :: "('heap, 'e, unit) prog"
  where
    "noop = return ()"

lemma noop_returns_heap [simp]: "h  noop h h'  h = h'"
  by(simp add: noop_def)

definition get_heap :: "('heap, 'e, 'heap) prog"
  where
    "get_heap = Prog (λh. h  return h)"

lemma get_heap_ok [simp]: "h  ok (get_heap)"
  by (simp add: get_heap_def execute_def is_OK_def return_def)

lemma get_heap_returns_result [simp]: "(h  get_heap  (λh'. f h') r x) = (h  f h r x)"
  by(simp add: get_heap_def returns_result_def bind_def return_def execute_def)

lemma get_heap_returns_heap [simp]: "(h  get_heap  (λh'. f h') h h'') = (h  f h h h'')"
  by(simp add: get_heap_def returns_heap_def bind_def return_def execute_def)

lemma get_heap_is_OK [simp]: "(h  ok (get_heap  (λh'. f h'))) = (h  ok (f h))"
  by(auto simp add: get_heap_def is_OK_def bind_def return_def execute_def)

lemma get_heap_E [elim]: "(h  get_heap r x)  x = h"
  by(simp add: get_heap_def returns_result_def return_def execute_def)

definition return_heap :: "'heap  ('heap, 'e, unit) prog"
  where
    "return_heap h = Prog (λ_. h  return ())"

lemma return_heap_E [iff]: "(h  return_heap h' h h'') = (h'' = h')"
  by(simp add: return_heap_def returns_heap_def return_def execute_def)

lemma return_heap_returns_result [simp]: "h  return_heap h' r ()"
  by(simp add: return_heap_def execute_def returns_result_def return_def)


subsection ‹Pure Heaps›

definition pure :: "('heap, 'e, 'result) prog  'heap  bool"
  where "pure f h  h  ok f  h  f h h"

lemma return_pure [simp]: "pure (return x) h"
  by(simp add: pure_def return_def is_OK_def returns_heap_def execute_def)

lemma error_pure [simp]: "pure (error e) h"
  by(simp add: pure_def error_def is_OK_def returns_heap_def execute_def)

lemma noop_pure [simp]: "pure (noop) h"
  by (simp add: noop_def)

lemma get_pure [simp]: "pure get_heap h"
  by(simp add: pure_def get_heap_def is_OK_def returns_heap_def return_def execute_def)

lemma pure_returns_heap_eq:
  "h  f h h'  pure f h  h = h'"
  by (meson pure_def is_OK_returns_heap_I returns_heap_eq)

lemma pure_eq_iff:     
  "(h' x. h  f r x  h  f h h'  h = h')  pure f h"
  by(auto simp add: pure_def)

subsection ‹Bind›

lemma bind_assoc [simp]:
  "((bind f g)  h) = (f  (λx. (g x  h)))"
  by(auto simp add: bind_def split: sum.splits)

lemma bind_returns_result_E:
  assumes "h  f  g r y"
  obtains x h' where "h  f r x" and "h  f h h'" and "h'  g x r y"
  using assms by(auto simp add: bind_def returns_result_def returns_heap_def execute_def 
      split: sum.splits)

lemma bind_returns_result_E2:
  assumes "h  f  g r y" and "pure f h"
  obtains x where "h  f r x" and "h  g x r y"
  using assms pure_returns_heap_eq bind_returns_result_E by metis

lemma bind_returns_result_E3:
  assumes "h  f  g r y" and "h  f r x" and "pure f h"
  shows "h  g x r y"
  using assms returns_result_eq bind_returns_result_E2 by metis

lemma bind_returns_result_E4:
  assumes "h  f  g r y" and "h  f r x" 
  obtains h' where "h  f h h'" and "h'  g x r y"
  using assms returns_result_eq bind_returns_result_E by metis

lemma bind_returns_heap_E:
  assumes "h  f  g h h''"
  obtains x h' where "h  f r x" and "h  f h h'" and "h'  g x h h''"
  using assms by(auto simp add: bind_def returns_result_def returns_heap_def execute_def 
      split: sum.splits)

lemma bind_returns_heap_E2 [elim]:
  assumes "h  f  g h h'" and "pure f h"
  obtains x where "h  f r x" and "h  g x h h'"
  using assms pure_returns_heap_eq by (fastforce elim: bind_returns_heap_E)

lemma bind_returns_heap_E3 [elim]:
  assumes "h  f  g h h'" and "h  f r x" and "pure f h" 
  shows "h  g x h h'"
  using assms pure_returns_heap_eq returns_result_eq by (fastforce elim: bind_returns_heap_E)

lemma bind_returns_heap_E4:
  assumes "h  f  g h h''" and "h  f h h'"
  obtains x where "h  f r x" and "h'  g x h h''"
  using assms
  by (metis bind_returns_heap_E returns_heap_eq)

lemma bind_returns_error_I [intro]:
  assumes "h  f e e"
  shows "h  f  g e e"
  using assms
  by(auto simp add: returns_error_def bind_def execute_def split: sum.splits)

lemma bind_returns_error_I3:
  assumes "h  f r x" and "h  f h h'" and "h'  g x e e"
  shows "h  f  g e e"
  using assms
  by(auto simp add: returns_error_def bind_def execute_def returns_heap_def returns_result_def 
      split: sum.splits)

lemma bind_returns_error_I2 [intro]:
  assumes "pure f h" and "h  f r x" and "h  g x e e"
  shows "h  f  g e e"
  using assms
  by (meson bind_returns_error_I3 is_OK_returns_result_I pure_def)

lemma bind_is_OK_E [elim]:
  assumes "h  ok (f  g)"
  obtains x h' where "h  f r x" and "h  f h h'" and "h'  ok (g x)"
  using assms 
  by(auto simp add: bind_def returns_result_def returns_heap_def is_OK_def execute_def 
      split: sum.splits)

lemma bind_is_OK_E2:
  assumes "h  ok (f  g)" and "h  f r x"
  obtains h' where "h  f h h'" and "h'  ok (g x)"
  using assms 
  by(auto simp add: bind_def returns_result_def returns_heap_def is_OK_def execute_def 
      split: sum.splits)

lemma bind_returns_result_I [intro]:
  assumes "h  f r x" and "h  f h h'" and "h'  g x r y"
  shows "h  f  g r y"
  using assms 
  by(auto simp add: bind_def returns_result_def returns_heap_def execute_def 
      split: sum.splits)

lemma bind_pure_returns_result_I [intro]:
  assumes "pure f h" and "h  f r x" and "h  g x r y"
  shows "h  f  g r y"
  using assms
  by (meson bind_returns_result_I pure_def is_OK_returns_result_I)

lemma bind_pure_returns_result_I2 [intro]:
  assumes "pure f h" and "h  ok f" and "x. h  f r x  h  g x r y"
  shows "h  f  g r y"
  using assms by auto

lemma bind_returns_heap_I [intro]:
  assumes "h  f r x" and "h  f h h'" and "h'  g x h h''"
  shows "h  f  g h h''"
  using assms 
  by(auto simp add: bind_def returns_result_def returns_heap_def execute_def 
      split: sum.splits)

lemma bind_returns_heap_I2 [intro]:
  assumes "h  f h h'" and "x. h  f r x  h'  g x h h''"
  shows "h  f  g h h''"
  using assms
  by (meson bind_returns_heap_I is_OK_returns_heap_I is_OK_returns_result_E)

lemma bind_is_OK_I [intro]:
  assumes "h  f r x" and "h  f h h'" and "h'  ok (g x)"
  shows "h  ok (f  g)"
  by (meson assms(1) assms(2) assms(3) bind_returns_heap_I is_OK_returns_heap_E 
      is_OK_returns_heap_I)

lemma bind_is_OK_I2 [intro]:
  assumes "h  ok f" and "x h'. h  f r x  h  f h h'  h'  ok (g x)"
  shows "h  ok (f  g)"
  using assms by blast  

lemma bind_is_OK_pure_I [intro]:
  assumes "pure f h" and "h  ok f" and "x. h  f r x  h  ok (g x)"
  shows "h  ok (f  g)"
  using assms by blast

lemma bind_pure_I:
  assumes "pure f h" and "x. h  f r x  pure (g x) h"
  shows "pure (f  g) h"
  using assms
  by (metis bind_returns_heap_E2 pure_def pure_returns_heap_eq is_OK_returns_heap_E)

lemma pure_pure:
  assumes "h  ok f" and "pure f h"
  shows "h  f h h"
  using assms returns_heap_eq 
  unfolding pure_def
  by auto

lemma bind_returns_error_eq: 
  assumes "h  f e e"
    and "h  g e e"
  shows "h  f = h  g"
  using assms 
  by(auto simp add: returns_error_def split: sum.splits)

subsection ‹Map›

fun map_M :: "('x  ('heap, 'e, 'result) prog)  'x list  ('heap, 'e, 'result list) prog"
  where
    "map_M f [] = return []"
  | "map_M f (x#xs) = do {
      y  f x;
      ys  map_M f xs;
      return (y # ys)
    }"

lemma map_M_ok_I [intro]: 
  "(x. x  set xs  h  ok (f x))  (x. x  set xs  pure (f x) h)  h  ok (map_M f xs)"
  apply(induct xs)
  by (simp_all add: bind_is_OK_I2 bind_is_OK_pure_I)

lemma map_M_pure_I : "h. (x. x  set xs  pure (f x) h)  pure (map_M f xs) h"
  apply(induct xs)
   apply(simp)
  by(auto intro!: bind_pure_I)

lemma map_M_pure_E :
  assumes "h  map_M g xs r ys" and "x  set xs" and "x h. x  set xs  pure (g x) h"
  obtains y where "h  g x r y" and "y  set ys"
  apply(insert assms, induct xs arbitrary: ys)
   apply(simp)
  apply(auto elim!: bind_returns_result_E)[1]
  by (metis (full_types) pure_returns_heap_eq)

lemma map_M_pure_E2:
  assumes "h  map_M g xs r ys" and "y  set ys" and "x h. x  set xs  pure (g x) h"
  obtains x where "h  g x r y" and "x  set xs"
  apply(insert assms, induct xs arbitrary: ys)
   apply(simp)
  apply(auto elim!: bind_returns_result_E)[1]
  by (metis (full_types) pure_returns_heap_eq)


subsection ‹Forall›

fun forall_M :: "('y  ('heap, 'e, 'result) prog)  'y list  ('heap, 'e, unit) prog"
  where
    "forall_M P [] = return ()"
  | "forall_M P (x # xs) = do {
      P x;
      forall_M P xs
    }"
    

lemma pure_forall_M_I: "(x. x  set xs  pure (P x) h)  pure (forall_M P xs) h"
  apply(induct xs)
  by(auto intro!: bind_pure_I)
   

subsection ‹Fold›

fun fold_M :: "('result  'y  ('heap, 'e, 'result) prog)  'result  'y list
   ('heap, 'e, 'result) prog"
  where 
    "fold_M f d [] = return d" |
    "fold_M f d (x # xs) = do { y  f d x; fold_M f y xs }"

lemma fold_M_pure_I : "(d x. pure (f d x) h)  (d. pure (fold_M f d xs) h)"
  apply(induct xs)
  by(auto intro: bind_pure_I)

subsection ‹Filter›

fun filter_M :: "('x  ('heap, 'e, bool) prog)  'x list  ('heap, 'e, 'x list) prog"
  where
    "filter_M P [] = return []"
  | "filter_M P (x#xs) = do {
      p  P x;
      ys  filter_M P xs;
      return (if p then x # ys else ys)
    }"

lemma filter_M_pure_I [intro]: "(x. x  set xs  pure (P x) h)  pure (filter_M P xs)h"
  apply(induct xs) 
  by(auto intro!: bind_pure_I)

lemma filter_M_is_OK_I [intro]:
  "(x. x  set xs  h  ok (P x))  (x. x  set xs  pure (P x) h)  h  ok (filter_M P xs)"
  apply(induct xs)
   apply(simp)
  by(auto intro!: bind_is_OK_pure_I)

lemma filter_M_not_more_elements:
  assumes "h  filter_M P xs r ys" and "x. x  set xs  pure (P x) h" and "x  set ys"
  shows "x  set xs"
  apply(insert assms, induct xs arbitrary: ys)
  by(auto elim!: bind_returns_result_E2 split: if_splits intro!: set_ConsD)

lemma filter_M_in_result_if_ok:
  assumes "h  filter_M P xs r ys" and "h x. x  set xs  pure (P x) h" and "x  set xs" and
    "h  P x r True"
  shows "x  set ys"
  apply(insert assms, induct xs arbitrary: ys)
   apply(simp)
  apply(auto elim!: bind_returns_result_E2)[1]
  by (metis returns_result_eq)

lemma filter_M_holds_for_result:
  assumes "h  filter_M P xs r ys" and "x  set ys" and "x h. x  set xs  pure (P x) h"
  shows "h  P x r True"
  apply(insert assms, induct xs arbitrary: ys)
  by(auto elim!: bind_returns_result_E2 split: if_splits intro!: set_ConsD)

lemma filter_M_empty_I:
  assumes "x. pure (P x) h"
    and "x  set xs. h  P x r False"
  shows "h  filter_M P xs r []"
  using assms
  apply(induct xs)
  by(auto intro!: bind_pure_returns_result_I)

lemma filter_M_subset_2: "h  filter_M P xs r ys  h'  filter_M P xs r ys' 
                           (x. pure (P x) h)  (x. pure (P x) h') 
                           (b. x  set xs. h  P x r True  h'  P x r b  b) 
                           set ys  set ys'"
proof -
  assume 1: "h  filter_M P xs r ys" and 2: "h'  filter_M P xs r ys'" 
    and 3: "(x. pure (P x) h)" and "(x. pure (P x) h')" 
    and 4: "b. xset xs. h  P x r True  h'  P x r b  b"
  have h1: "x  set xs. h'  ok (P x)"
    using 2 3 (x. pure (P x) h')
    apply(induct xs arbitrary: ys')
    by(auto elim!: bind_returns_result_E2)
  then have 5: "xset xs. h  P x r True  h'  P x r True"
    using 4
    apply(auto)[1]
    by (metis is_OK_returns_result_E)
  show ?thesis
    using 1 2 3 5 (x. pure (P x) h')
    apply(induct xs arbitrary: ys ys')
     apply(auto)[1]
    apply(auto elim!: bind_returns_result_E2 split: if_splits)[1]
          apply auto[1]
         apply auto[1]
        apply(metis returns_result_eq)
       apply auto[1]
      apply auto[1]
     apply auto[1]
    by(auto)
qed

lemma filter_M_subset: "h  filter_M P xs r ys  set ys  set xs"
  apply(induct xs arbitrary: h ys)
   apply(auto)[1]
  apply(auto elim!: bind_returns_result_E split: if_splits)[1]
   apply blast
  by blast

lemma filter_M_distinct: "h  filter_M P xs r ys  distinct xs  distinct ys"
  apply(induct xs arbitrary: h ys)
   apply(auto)[1]
  using filter_M_subset
  apply(auto elim!: bind_returns_result_E)[1]
  by fastforce

lemma filter_M_filter: "h  filter_M P xs r ys  (x. x  set xs  pure (P x) h) 
                        (x  set xs. h  ok P x)  ys = filter (λx. |h  P x|r) xs"
  apply(induct xs arbitrary: ys)
  by(auto elim!: bind_returns_result_E2)

lemma filter_M_filter2: "(x. x  set xs  pure (P x) h   h  ok P x) 
                        filter (λx. |h  P x|r) xs = ys  h  filter_M P xs r ys"
  apply(induct xs arbitrary: ys)
  by(auto elim!: bind_returns_result_E2 intro!: bind_pure_returns_result_I)

lemma filter_ex1: "∃!x  set xs. P x  P x  x  set xs  distinct xs 
                   filter P xs = [x]"
  apply(auto)[1]
  apply(induct xs)
   apply(auto)[1]
  apply(auto)[1]
  using filter_empty_conv by fastforce

lemma filter_M_ex1:
  assumes "h  filter_M P xs r ys"
    and "x  set xs"
    and "∃!x  set xs. h  P x r True"
    and "x. x  set xs  pure (P x) h"
    and "distinct xs"
    and "h  P x r True"
  shows "ys = [x]"
proof -
  have *: "∃!x  set xs. |h  P x|r"
    apply(insert assms(1) assms(3) assms(4))
    apply(drule filter_M_filter) 
     apply(simp)
    apply(auto simp add: select_result_I2)[1]
    by (metis (full_types) is_OK_returns_result_E select_result_I2)
  then show ?thesis
    apply(insert assms(1) assms(4))
    apply(drule filter_M_filter)
     apply(auto)[1] 
    by (metis * assms(2) assms(5) assms(6) distinct_filter 
        distinct_length_2_or_more filter_empty_conv filter_set list.exhaust 
        list.set_intros(1) list.set_intros(2) member_filter select_result_I2)
qed

lemma filter_M_eq:
  assumes "x. pure (P x) h" and "x. pure (P x) h'"
    and "b x. x  set xs  h  P x r b = h'  P x r b"
  shows "h  filter_M P xs r ys  h'  filter_M P xs r ys"
  using assms
  apply (induct xs arbitrary: ys)
  by(auto elim!: bind_returns_result_E2 intro!: bind_pure_returns_result_I 
      dest: returns_result_eq)


subsection ‹Map Filter›

definition map_filter_M :: "('x  ('heap, 'e, 'y option) prog)  'x list
   ('heap, 'e, 'y list) prog"
  where
    "map_filter_M f xs = do {
      ys_opts  map_M f xs;
      ys_no_opts  filter_M (λx. return (x  None)) ys_opts;
      map_M (λx. return (the x)) ys_no_opts
    }"

lemma map_filter_M_pure: "(x h. x  set xs  pure (f x) h)  pure (map_filter_M f xs) h"
  by(auto simp add: map_filter_M_def map_M_pure_I intro!: bind_pure_I)

lemma map_filter_M_pure_E:
  assumes "h  (map_filter_M::('x  ('heap, 'e, 'y option) prog)  'x list
   ('heap, 'e, 'y list) prog) f xs r ys" and "y  set ys" and "x h. x  set xs  pure (f x) h"
  obtains x where "h  f x r Some y" and "x  set xs"
proof -
  obtain ys_opts ys_no_opts where
    ys_opts: "h  map_M f xs r ys_opts" and
    ys_no_opts: "h  filter_M (λx. (return (x  None)::('heap, 'e, bool) prog)) ys_opts r ys_no_opts" and
    ys: "h  map_M (λx. (return (the x)::('heap, 'e, 'y) prog)) ys_no_opts r ys"
    using assms
    by(auto simp add: map_filter_M_def map_M_pure_I elim!: bind_returns_result_E2)
  have "y  set ys_no_opts. y  None"
    using ys_no_opts filter_M_holds_for_result
    by fastforce
  then have "Some y  set ys_no_opts"
    using map_M_pure_E2 ys y  set ys
    by (metis (no_types, lifting) option.collapse return_pure return_returns_result)
  then have "Some y  set ys_opts"
    using filter_M_subset ys_no_opts by fastforce
  then show "(x. h  f x r Some y  x  set xs  thesis)  thesis"
    by (metis assms(3) map_M_pure_E2 ys_opts)
qed


subsection ‹Iterate›

fun iterate_M :: "('heap, 'e, 'result) prog list  ('heap, 'e, 'result) prog"
  where
    "iterate_M [] = return undefined"
  | "iterate_M (x # xs) = x  (λ_. iterate_M xs)"


lemma iterate_M_concat:
  assumes "h  iterate_M xs h h'"
    and "h'  iterate_M ys h h''"
  shows "h  iterate_M (xs @ ys) h h''"
  using assms
  apply(induct "xs" arbitrary: h h'')
   apply(simp)
  apply(auto)[1]
  by (meson bind_returns_heap_E bind_returns_heap_I)

subsection‹Miscellaneous Rules›

lemma execute_bind_simp:
  assumes "h  f r x" and "h  f h h'"
  shows "h  f  g = h'  g x"
  using assms 
  by(auto simp add: returns_result_def returns_heap_def bind_def execute_def  
      split: sum.splits)

lemma bind_cong [fundef_cong]:
  fixes f1 f2 :: "('heap, 'e, 'result) prog"
    and g1 g2 :: "'result  ('heap, 'e, 'result2) prog"
  assumes "h  f1 = h  f2"
    and "y h'. h  f1 r y  h  f1 h h'  h'  g1 y = h'  g2 y"
  shows "h  (f1  g1) = h  (f2  g2)"
  apply(insert assms, cases "h  f1") 
  by(auto simp add: bind_def returns_result_def returns_heap_def execute_def 
      split: sum.splits)

lemma bind_cong_2:
  assumes "pure f h" and "pure f h'"
    and "x. h  f r x = h'  f r x"
    and "x. h  f r x  h  g x r y = h'  g x r y'"
  shows "h  f  g r y = h'  f  g r y'"
  using assms
  by(auto intro!: bind_pure_returns_result_I elim!: bind_returns_result_E2)

lemma bind_case_cong [fundef_cong]:
  assumes "x = x'" and "a. x = Some a  f a h = f' a h"
  shows "(case x of Some a  f a | None  g) h = (case x' of Some a  f' a | None  g) h"
  by (insert assms, simp add: option.case_eq_if)


subsection ‹Reasoning About Reads and Writes›

definition preserved :: "('heap, 'e, 'result) prog  'heap  'heap  bool"
  where
    "preserved f h h'  (x. h  f r x  h'  f r x)"

lemma preserved_code [code]:
  "preserved f h h' = (((h  ok f)  (h'  ok f)  |h  f|r = |h'  f|r)  ((¬h  ok f)  (¬h'  ok f)))"
  apply(auto simp add: preserved_def)[1]
   apply (meson is_OK_returns_result_E is_OK_returns_result_I)+
  done

lemma reflp_preserved_f [simp]: "reflp (preserved f)"
  by(auto simp add: preserved_def reflp_def)
lemma transp_preserved_f [simp]: "transp (preserved f)"
  by(auto simp add: preserved_def transp_def)


definition 
  all_args :: "('a  ('heap, 'e, 'result) prog)  ('heap, 'e, 'result) prog set"
  where
    "all_args f = (arg. {f arg})"


definition  
  reads :: "('heap  'heap  bool) set  ('heap, 'e, 'result) prog  'heap 
             'heap  bool"
  where
    "reads S getter h h'  (P  S. reflp P  transp P)  ((P  S. P h h') 
                               preserved getter h h')"

lemma reads_singleton [simp]: "reads {preserved f} f h h'"
  by(auto simp add: reads_def)

lemma reads_bind_pure:
  assumes "pure f h" and "pure f h'"
    and "reads S f h h'"
    and "x. h  f r x  reads S (g x) h h'"
  shows "reads S (f  g) h h'"
  using assms
  by(auto simp add: reads_def pure_pure preserved_def 
      intro!: bind_pure_returns_result_I is_OK_returns_result_I 
      dest: pure_returns_heap_eq 
      elim!: bind_returns_result_E)

lemma reads_insert_writes_set_left:
  "P  S. reflp P  transp P  reads {getter} f h h'  reads (insert getter S) f h h'"
  unfolding reads_def by simp

lemma reads_insert_writes_set_right:
  "reflp getter  transp getter  reads S f h h'  reads (insert getter S) f h h'"
  unfolding reads_def by blast

lemma reads_subset:
  "reads S f h h'  P  S' - S. reflp P  transp P  S  S'  reads S' f h h'"
  by(auto simp add: reads_def)

lemma return_reads [simp]: "reads {} (return x) h h'"
  by(simp add: reads_def preserved_def)

lemma error_reads [simp]: "reads {} (error e) h h'"
  by(simp add: reads_def preserved_def)

lemma noop_reads [simp]: "reads {} noop h h'"
  by(simp add: reads_def noop_def preserved_def)

lemma filter_M_reads:
  assumes "x. x  set xs  pure (P x) h" and "x. x  set xs  pure (P x) h'"
    and "x. x  set xs  reads S (P x) h h'"
    and "P  S. reflp P  transp P"
  shows "reads S (filter_M P xs) h h'"
  using assms
  apply(induct xs)
  by(auto intro: reads_subset[OF return_reads] intro!: reads_bind_pure)

definition writes :: 
  "('heap, 'e, 'result) prog set  ('heap, 'e, 'result2) prog  'heap  'heap  bool"
  where                                                                                
    "writes S setter h h' 
      (h  setter h h'  (progs. set progs  S  h  iterate_M progs h h'))"

lemma writes_singleton [simp]: "writes (all_args f) (f a) h h'"
  apply(auto simp add: writes_def all_args_def)[1]
  apply(rule exI[where x="[f a]"])
  by(auto)

lemma writes_singleton2 [simp]: "writes {f} f h h'"
  apply(auto simp add: writes_def all_args_def)[1]
  apply(rule exI[where x="[f]"])
  by(auto)

lemma writes_union_left_I:
  assumes "writes S f h h'"
  shows "writes (S  S') f h h'"
  using assms
  by(auto simp add: writes_def)

lemma writes_union_right_I:
  assumes "writes S' f h h'"
  shows "writes (S  S') f h h'"
  using assms
  by(auto simp add: writes_def)

lemma writes_union_minus_split:
  assumes "writes (S - S2) f h h'"
    and "writes (S' - S2) f h h'"
  shows "writes ((S  S') - S2) f h h'"
  using assms
  by(auto simp add: writes_def)

lemma writes_subset: "writes S f h h'  S  S'  writes S' f h h'"
  by(auto simp add: writes_def)

lemma writes_error [simp]: "writes S (error e) h h'"
  by(simp add: writes_def)

lemma writes_not_ok [simp]: "¬h  ok f  writes S f h h'"
  by(auto simp add: writes_def)

lemma writes_pure [simp]:
  assumes "pure f h"
  shows "writes S f h h'"
  using assms
  apply(auto simp add: writes_def)[1]
  by (metis bot.extremum iterate_M.simps(1) list.set(1) pure_returns_heap_eq return_returns_heap)

lemma writes_bind:
  assumes "h2. writes S f h h2" 
  assumes "x h2. h  f r x  h  f h h2  writes S (g x) h2 h'"
  shows "writes S (f  g) h h'"
  using assms
  apply(auto simp add: writes_def elim!: bind_returns_heap_E)[1]
  by (metis iterate_M_concat le_supI set_append)

lemma writes_bind_pure:
  assumes "pure f h"
  assumes "x. h  f r x  writes S (g x) h h'"
  shows "writes S (f  g) h h'"
  using assms
  by(auto simp add: writes_def elim!: bind_returns_heap_E2)

lemma writes_small_big:
  assumes "writes SW setter h h'"
  assumes "h  setter h h'"
  assumes "h h' w. w  SW   h  w h h'  P h h'"
  assumes "reflp P"
  assumes "transp P"
  shows "P h h'"
proof -
  obtain progs where "set progs  SW" and iterate: "h  iterate_M progs h h'"
    by (meson assms(1) assms(2) writes_def)
  then have "h h'. prog  set progs. h  prog h h'  P h h'"
    using assms(3) by auto
  with iterate assms(4) assms(5) have "h  iterate_M progs h h'  P h h'"
  proof(induct progs arbitrary: h)
    case Nil
    then show ?case
      using reflpE by force
  next
    case (Cons a progs)
    then show ?case
      apply(auto elim!: bind_returns_heap_E)[1]
      by (metis (full_types) transpD)
  qed
  then show ?thesis
    using assms(1) iterate by blast
qed

lemma reads_writes_preserved:
  assumes "reads SR getter h h'"
  assumes "writes SW setter h h'"
  assumes "h  setter h h'"
  assumes "h h'. w  SW. h  w h h'  (r  SR. r h h')"
  shows "h  getter r x  h'  getter r x"
proof -
  obtain progs where "set progs  SW" and iterate: "h  iterate_M progs h h'"
    by (meson assms(2) assms(3) writes_def)
  then have "h h'. prog  set progs. h  prog h h'  (r  SR. r h h')"
    using assms(4) by blast
  with iterate have "r  SR. r h h'"
    using writes_small_big assms(1) unfolding reads_def
    by (metis assms(2) assms(3) assms(4))
  then show ?thesis
    using assms(1)
    by (simp add: preserved_def reads_def)
qed

lemma reads_writes_separate_forwards:
  assumes "reads SR getter h h'"
  assumes "writes SW setter h h'"
  assumes "h  setter h h'"
  assumes "h  getter r x"
  assumes "h h'. w  SW. h  w h h'  (r  SR. r h h')"
  shows "h'  getter r x"
  using reads_writes_preserved[OF assms(1) assms(2) assms(3) assms(5)] assms(4)
  by(auto simp add: preserved_def)

lemma reads_writes_separate_backwards:
  assumes "reads SR getter h h'"
  assumes "writes SW setter h h'"
  assumes "h  setter h h'"
  assumes "h'  getter r x"
  assumes "h h'. w  SW. h  w h h'  (r  SR. r h h')"
  shows "h  getter r x"
  using reads_writes_preserved[OF assms(1) assms(2) assms(3) assms(5)] assms(4)
  by(auto simp add: preserved_def)

end