Theory HeapLift

(*
 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
 * Copyright (c) 2022 Apple Inc. All rights reserved.
 *
 * SPDX-License-Identifier: BSD-2-Clause
 *)


theory HeapLift
  imports
  In_Out_Parameters
  Split_Heap
  AbstractArrays
begin

section "Refinement Lemmas"

lemma ucast_ucast_id:
  "LENGTH('a)  LENGTH('b)  ucast (ucast (x::'a::len word)::'b::len word) = x"
  by (auto intro: ucast_up_ucast_id simp: is_up_def source_size_def target_size_def word_size)

lemma lense_ucast_signed:
  "lense (unsigned :: 'a::len word  'a signed word) (λv x. unsigned (v (unsigned x)))"
  by (rule lenseI_equiv) (simp_all add: ucast_ucast_id)

lemma pointer_lense_ucast_signed:
  fixes r :: "'h  'a::len8 word ptr  'a word"
  assumes "pointer_lense r w"
  shows "pointer_lense
    (λh p. UCAST('a  'a signed) (r h (PTR_COERCE('a signed word  'a word) p)))
    (λp m. w (PTR_COERCE('a signed word  'a word) p)
      (λw. UCAST('a signed  'a) (m (UCAST('a  'a signed) w))))"
proof -
  interpret pointer_lense r w by fact
  note scast_ucast_norm[simp del]
  note ucast_ucast_id[simp]
  show ?thesis
    apply unfold_locales
    apply (simp add: read_write_same)
    apply (simp add: write_same)
    apply (simp add: comp_def)
    apply (simp add: write_other_commute typ_uinfo_t_signed_word_word_conv
                flip: size_of_tag typ_uinfo_size)
    done
qed

lemma (in xmem_type) length_to_bytes:
  "length (to_bytes (v::'a) bs) = size_of TYPE('a)"
  by (simp add: to_bytes_def lense.access_result_size)

lemma (in xmem_type) heap_update_padding_eq:
  "length bs = size_of TYPE('a) 
    heap_update_padding p v bs h = heap_update p v (heap_update_list (ptr_val p) bs h)"
  using u.max_size
  by (simp add: heap_update_padding_def heap_update_def size_of_def
      heap_list_heap_update_list_id heap_update_list_overwrite)

lemma (in xmem_type) heap_update_padding_eq':
  "length bs = size_of TYPE('a) 
    heap_update_padding p v bs = heap_update p v  heap_update_list (ptr_val p) bs"
  by (simp add: fun_eq_iff heap_update_padding_eq)

lemma split_disj_asm: "P (x  y) = (¬ (x  ¬ P x  ¬ x  ¬ P y))"
  by (smt (verit, best))

lemma comp_commute_of_fold:
  assumes x: "x = fold f xs"
  assumes xs: "list_all (λx. f x o a = a o f x) xs"
  shows "x o a = a o x"
  unfolding x using xs by (induction xs) (simp_all add: fun_eq_iff)

definition padding_closed_under_all_fields where
  "padding_closed_under_all_fields t 
    (s f n bs bs'. field_lookup t f 0 = Some (s, n) 
      eq_upto_padding t bs bs'  eq_upto_padding s (take (size_td s) (drop n bs)) (take (size_td s) (drop n bs')))"

lemma padding_closed_under_all_fields_typ_uinfo_t:
  "padding_closed_under_all_fields (typ_uinfo_t TYPE('a::xmem_type))"
  unfolding padding_closed_under_all_fields_def
proof safe
  fix s f n bs bs' assume s_n: "field_lookup (typ_uinfo_t TYPE('a)) f 0 = Some (s, n)"
    and bs_bs': "eq_upto_padding (typ_uinfo_t TYPE('a)) bs bs'"
  then have len: "length bs = size_of TYPE('a)" "length bs' = size_of TYPE('a) "
    by (auto simp: eq_upto_padding_def size_of_def)

  from s_n[THEN field_lookup_uinfo_Some_rev] obtain k where
    k: "field_lookup (typ_info_t TYPE('a)) f 0 = Some (k, n)" and k_s: "export_uinfo k = s"
    by auto
  have [simp]: "size_td k = size_td s" by (simp flip: k_s)
  from xmem_type_field_lookup_eq_upto_padding_focus[OF k len bs_bs']
  show "eq_upto_padding s (take (size_td s) (drop n bs)) (take (size_td s) (drop n bs'))"
    unfolding k_s by simp
qed

lemma (in open_types) plift_heap_update_list_eq_upto_padding:
  assumes t: "mem_type_u t" and t': "padding_closed_under_all_fields t"
  assumes a: "ptr_valid_u t (hrs_htd h) a"
  assumes bs_bs': "eq_upto_padding t bs bs'"
  shows "plift (hrs_mem_update (heap_update_list a bs) h) =
    (plift (hrs_mem_update (heap_update_list a bs') h)::'a::xmem_type ptr  'a option)"
  apply (rule plift_eq_plift, simp_all add: h_val_def hrs_mem_update)
proof -
  from bs_bs' have [simp]: "length bs = size_td t" "length bs' = size_td t"
    by (simp_all add: eq_upto_padding_def)
  have [arith]: "size_td t < addr_card"
    using mem_type_u.max_size[OF t] by simp
  have a_bnd: "size_of TYPE('a)  addr_card"
    using max_size[where 'a='a] by arith
  let ?A = "typ_uinfo_t TYPE('a)"

  fix p :: "'a ptr" assume p: "ptr_valid (hrs_htd h) p"
  from ptr_valid_u_cases_weak[OF t a this[unfolded ptr_valid_def]]
  show "from_bytes (heap_list (heap_update_list a bs (hrs_mem h)) (size_of TYPE('a)) (ptr_val p)) =
    (from_bytes (heap_list (heap_update_list a bs' (hrs_mem h)) (size_of TYPE('a)) (ptr_val p))::'a)"
  proof (elim disjE exE conjE)
    assume "disjnt {a..+size_td t} {ptr_val p..+size_td (typ_uinfo_t TYPE('a))}"
    with bs_bs' show ?thesis
      unfolding heap_upd_list_def
      by (subst (1 2) heap_list_update_disjoint_same; simp add: size_of_def disjnt_def)
  next
    fix path assume path: "addressable_field t path ?A" and
      p_eq: "ptr_val p = a + word_of_nat (field_offset_untyped t path)"
    let ?n = "field_offset_untyped t path"
    have sz: "size_of TYPE('a) + ?n  size_td t"
      using field_lookup_offset_size'[OF addressable_fieldD_field_lookup'[OF path]]
      by (simp add: size_of_def)
    let ?s = "typ_uinfo_t TYPE('a)"
    from addressable_fieldD_field_lookup'[OF path] t' bs_bs' have *:
      "eq_upto_padding ?s (take (size_td ?s) (drop ?n bs)) (take (size_td ?s) (drop ?n bs'))"
      unfolding padding_closed_under_all_fields_def
      by (auto simp flip: typ_uinfo_size)

    show ?thesis unfolding p_eq
      using eq_upto_padding_from_bytes_eq[OF *] sz
      apply (subst (1 2) heap_list_update_list)
      apply (simp_all add: size_of_def)
      done
  next
    fix path assume path: "addressable_field ?A path t"
      and p_eq: "a = ptr_val p + word_of_nat (field_offset_untyped ?A path)"
    let ?n = "field_offset_untyped ?A path"
    have sz: "size_td t + ?n  size_of TYPE('a)"
      using field_lookup_offset_size'[OF addressable_fieldD_field_lookup'[OF path]]
      by (simp add: size_of_def)
    from field_lookup_uinfo_Some_rev[OF addressable_fieldD_field_lookup'[OF path]] obtain k
      where k: "field_lookup (typ_info_t TYPE('a)) path 0 = Some (k, ?n)"
        and eq_t: "export_uinfo k = t" by blast
    then have [simp]: "size_td k = size_td t"
      by (simp flip: eq_t)

    have *: "eq_upto_padding (typ_uinfo_t TYPE('a))
      (super_update_bs bs (heap_list (hrs_mem h) (size_of TYPE('a)) (ptr_val p)) ?n)
      (super_update_bs bs' (heap_list (hrs_mem h) (size_of TYPE('a)) (ptr_val p)) ?n)"
      by (subst xmem_type_field_lookup_eq_upto_padding_super_update_bs[OF k(1)])
         (simp_all add: eq_t bs_bs')

    note 1 = c_guard_no_wrap'[OF ptr_valid_c_guard, OF p]
    show ?thesis unfolding p_eq using sz
      apply (subst (1 2) heap_update_list_super_update_bs_heap_list[OF 1])
      apply (simp_all add: heap_list_heap_update_list_id[OF a_bnd])
      apply (intro eq_upto_padding_from_bytes_eq[OF *])
      done
  qed
qed

lemma (in open_types) read_dedicated_heap_heap_update_list_eq_upto_padding[simp]:
  assumes t: "mem_type_u t" and t': "padding_closed_under_all_fields t"
  assumes a: "ptr_valid_u t (hrs_htd h) a"
  assumes bs_bs': "eq_upto_padding t bs bs'"
  shows "read_dedicated_heap (hrs_mem_update (heap_update_list a bs) h) =
    (read_dedicated_heap (hrs_mem_update (heap_update_list a bs') h)::'a::xmem_type ptr  'a)  True"
  by (simp add: plift_heap_update_list_eq_upto_padding[OF assms] read_dedicated_heap_def fun_eq_iff)

definition "L2Tcorres st A C = corresXF st (λr _. r) (λr _. r) (λ_. True) A C"

lemma L2Tcorres_id:
  "L2Tcorres id C C"
  by (metis L2Tcorres_def corresXF_id)

lemma L2Tcorres_fail:
  "L2Tcorres st L2_fail X"
  apply (clarsimp simp: L2Tcorres_def L2_defs)
  apply (rule corresXF_fail)
  done

lemma admissible_nondet_ord_L2Tcorres [corres_admissible]:
  "ccpo.admissible Inf (≥) (λA. L2Tcorres st A C)"
  unfolding L2Tcorres_def
  apply (rule admissible_nondet_ord_corresXF)
  done

lemma L2Tcorres_top [corres_top]: "L2Tcorres st  C"
  by (auto simp add: L2Tcorres_def corresXF_def)

(* Abstraction predicates for inner expressions. *)
definition "abs_guard    st   A C  s. A (st s)  C s"
definition "abs_expr     st P A C  s. P (st s)  C s = A (st s)"
definition "abs_modifies st P A C  s. P (st s)  st (C s) = A (st s)"

(* Predicates to enable some transformations on the input expressions
   (namely, rewriting uses of field_lvalue) that are best done
   as a preprocessing stage (st = id).
   The corresTA rules should ensure that these are used to rewrite
   any inner expressions before handing off to the predicates above. *)
definition "struct_rewrite_guard      A C  s. A s  C s"
definition "struct_rewrite_expr     P A C  s. P s  C s = A s"
definition "struct_rewrite_modifies P A C  s. P s  C s = A s"


(* Standard heap abstraction rules. *)
named_theorems heap_abs
(* Rules that require first-order matching. *)
named_theorems heap_abs_fo

named_theorems derived_heap_defs and
 valid_array_defs and
 heap_upd_cong and
 valid_same_typ_descs

lemma deepen_heap_upd_cong: "f = f'  upd f s = upd f' s"
  by simp

lemma deepen_heap_map_cong: "f = f'  upd f p s = upd f' p s"
  by simp


(* fun_app2 is like fun_app, but it skips an abstraction.
 * We use this for terms like "λs a. Array.update a k (f s)".
 * fixme: ideally, the first order conversion code can skip abstractions. *)

lemma abs_expr_fun_app2 [heap_abs_fo]:
  " abs_expr st P f f';
     abs_expr st Q g g'  
   abs_expr st (λs. P s  Q s) (λs a. f s a (g s a)) (λs a. f' s a $ g' s a)"
  by (simp add: abs_expr_def)

lemma abs_expr_fun_app [heap_abs_fo]:
  " abs_expr st Y x x'; abs_expr st X f f'  
      abs_expr st (λs. X s  Y s) (λs. f s (x s)) (λs. f' s $ x' s)"
  apply (clarsimp simp: abs_expr_def)
  done

lemma abs_expr_Pair [heap_abs]: "
abs_expr st X f1 f1'  abs_expr st Y f2 f2' 
abs_expr st  (λs. X s  Y s)  (λs. (f1 s, f2 s)) (λs. (f1' s, f2' s))"
  apply (clarsimp simp: abs_expr_def)
  done

lemma abs_expr_constant [heap_abs]:
  "abs_expr st (λ_. True) (λs. a) (λs. a)"
  apply (clarsimp simp: abs_expr_def)
  done

lemma abs_guard_expr [heap_abs]:
  "abs_expr st P a' a  abs_guard st (λs. P s  a' s) a"
  by (simp add: abs_expr_def abs_guard_def)

lemma abs_guard_constant [heap_abs]:
  "abs_guard st (λ_. P) (λ_. P)"
  by (clarsimp simp: abs_guard_def)

lemma abs_guard_conj [heap_abs]:
  " abs_guard st G G'; abs_guard st H H' 
       abs_guard st (λs. G s  H s) (λs. G' s  H' s)"
  by (clarsimp simp: abs_guard_def)


lemma L2Tcorres_modify [heap_abs]:
    " struct_rewrite_modifies P b c; abs_guard st P' P;
       abs_modifies st Q a b  
     L2Tcorres st (L2_seq (L2_guard (λs. P' s  Q s)) (λ_. (L2_modify a))) (L2_modify c)"
  apply (auto intro!: refines_bind_guard_right refines_modify  
      simp: corresXF_refines_conv 
      L2Tcorres_def L2_defs abs_modifies_def abs_guard_def struct_rewrite_modifies_def struct_rewrite_guard_def)
  done

lemma L2Tcorres_gets [heap_abs]:
    " struct_rewrite_expr P b c; abs_guard st P' P;
       abs_expr st Q a b  
     L2Tcorres st (L2_seq (L2_guard (λs. P' s  Q s)) (λ_. L2_gets a n)) (L2_gets c n)"
  apply (auto intro!: refines_bind_guard_right refines_gets   
      simp: corresXF_refines_conv L2Tcorres_def L2_defs abs_expr_def abs_guard_def struct_rewrite_expr_def struct_rewrite_guard_def)
  done

lemma L2Tcorres_gets_const [heap_abs]:
    "L2Tcorres st (L2_gets (λ_. a) n) (L2_gets (λ_. a) n)"
  apply (simp add: corresXF_refines_conv refines_gets L2Tcorres_def L2_defs)
  done

lemma L2Tcorres_guard [heap_abs]:
    " struct_rewrite_guard b c; abs_guard st a b  
     L2Tcorres st (L2_guard a) (L2_guard c)"
  apply (simp add: corresXF_def L2Tcorres_def L2_defs abs_guard_def struct_rewrite_guard_def)
  done

lemma L2Tcorres_while [heap_abs]:
  assumes body_corres [simplified THIN_def,rule_format]:
    "PROP THIN (x. L2Tcorres st (B' x) (B x))"
  and cond_rewrite [simplified THIN_def,rule_format]:
    "PROP THIN (r. struct_rewrite_expr (G r) (C' r) (C r))"
  and guard_abs[simplified THIN_def,rule_format]:
    "PROP THIN (r. abs_guard st (G' r) (G r))"
  and guard_impl_cond[simplified THIN_def,rule_format]:
    "PROP THIN (r. abs_expr st (H r) (C'' r) (C' r))"
  shows "L2Tcorres st (L2_guarded_while (λi s. G' i s  H i s) C'' B' i n) (L2_while C B i n)"
proof -

  have cond_match: "r s. G' r (st s)  H r (st s)  C'' r (st s) = C r s"
    using cond_rewrite guard_abs guard_impl_cond
    by (clarsimp simp: abs_expr_def abs_guard_def struct_rewrite_expr_def)

  have "corresXF st (λr _. r) (λr _. r) (λ_. True)
           (do { _  guard (λs. G' i s  H i s);
                     whileLoop C''
                       (λi. do { r  B' i;
                                _  guard (λs. G' r s  H r s);
                                return r
                       }) i
           })
     (whileLoop C B i)"
    apply (rule corresXF_guard_imp)
     apply (rule corresXF_guarded_while [where P="λ_ _. True" and P'="λ_ _. True"])
         apply (clarsimp cong: corresXF_cong)
         apply (rule corresXF_guard_imp)
          apply (rule body_corres [unfolded L2Tcorres_def])
         apply simp
        apply (clarsimp simp: cond_match)
       apply clarsimp
       apply (simp add: runs_to_partial_def_old split: xval_splits)
      apply simp
     apply simp
    apply simp
    done

  thus ?thesis
    by (clarsimp simp: L2Tcorres_def L2_defs gets_return top_fun_def)
qed


named_theorems abs_spec

definition "abs_spec st P (A :: ('a × 'a) set) (C :: ('c × 'c) set)
            (s t. P (st s)  (((s, t)  C)  ((st s, st t)  A)))
                               (s. P (st s)  (x. (st s, x)  A)  (x. (s, x)  C))"

lemma L2Tcorres_spec [heap_abs]:
  " abs_spec st P A C 
      L2Tcorres st (L2_seq (L2_guard P) (λ_. (L2_spec A))) (L2_spec C)"
  unfolding corresXF_refines_conv L2Tcorres_def L2_defs
  apply (clarsimp simp add: abs_spec_def)
  apply (intro refines_bind_guard_right refines_bind_bind_exn_wp refines_state_select)
   apply (force intro!: refines_select simp add: abs_spec_def split: xval_splits)
  apply blast
  done


definition "abs_assume st P (A :: 'a  ('b × 'a) set) (C :: 'c  ('b × 'c) set)
   (s t r. P (st s)  (((r, t)  C s)  ((r, st t)  A (st s))))"

(* FIXME: replace refines_assume_result_and_state in spec monad *)
lemma refines_assume_result_and_state': 
  "refines (assume_result_and_state P) (assume_result_and_state Q) s t R"
  if "sim_set (λ(v, s) (w, t). R (Result v, s) (Result w, t))  (P s) (Q t)"
  using that
  by (force simp: refines_def_old  sim_set_def rel_set_def  case_prod_unfold)


lemma L2Tcorres_assume [heap_abs]:
  " abs_assume st P A C 
      L2Tcorres st (L2_seq (L2_guard P) (λ_. (L2_assume A))) (L2_assume C)"
  unfolding corresXF_refines_conv L2Tcorres_def L2_defs
  apply (clarsimp simp add: abs_assume_def) thm  refines_mono [OF _ refines_assume_result_and_state]
  apply (intro refines_bind_guard_right refines_bind_bind_exn_wp refines_assume_result_and_state' )
  apply (auto simp add: sim_set_def)
  done

lemma abs_spec_constant [heap_abs]:
  "abs_spec st (λ_. True) {(a, b). C} {(a, b). C}"
  apply (clarsimp simp: abs_spec_def)
  done

lemma L2Tcorres_condition [heap_abs]:
  "PROP THIN (Trueprop (L2Tcorres st L L'));
    PROP THIN (Trueprop (L2Tcorres st R R'));
    PROP THIN (Trueprop (struct_rewrite_expr P C' C));
    PROP THIN (Trueprop (abs_guard st P' P));
    PROP THIN (Trueprop (abs_expr st Q C'' C')) 
   L2Tcorres st (L2_seq (L2_guard (λs. P' s  Q s)) (λ_. L2_condition C'' L R)) (L2_condition C L' R')"
  unfolding THIN_def L2_defs L2Tcorres_def corresXF_refines_conv
  apply clarsimp
  apply (intro refines_bind_guard_right refines_condition)
  apply (auto simp add: abs_expr_def abs_guard_def struct_rewrite_expr_def struct_rewrite_guard_def)
  done

lemma L2Tcorres_seq [heap_abs]:
  "PROP THIN (Trueprop (L2Tcorres st L' L)); PROP THIN (r. L2Tcorres st (R' r) (R r)) 
       L2Tcorres st (L2_seq L' R') (L2_seq L R)"
  unfolding THIN_def L2_defs L2Tcorres_def corresXF_refines_conv
  apply clarsimp
  apply (intro refines_bind_bind_exn_wp)
  subgoal for t
    apply (erule_tac x=t in allE)
    apply (rule refines_weaken)
     apply assumption
    apply (auto split: xval_splits)
    done
  done


lemma L2Tcorres_guarded_simple [heap_abs]:
  assumes b_c: "struct_rewrite_guard b c"
  assumes a_b: "abs_guard st a b"
  assumes f_g: "s s'. c s  s' = st s  L2Tcorres st f g"
  shows "L2Tcorres st (L2_guarded a f) (L2_guarded c g)"
  unfolding L2_guarded_def L2_defs L2Tcorres_def corresXF_refines_conv
  using b_c a_b f_g
  by (fastforce simp add: refines_def_old L2Tcorres_def corresXF_refines_conv reaches_bind succeeds_bind 
      struct_rewrite_guard_def abs_guard_def abs_expr_def split: xval_splits)

lemma L2Tcorres_catch [heap_abs]:
    "PROP THIN (Trueprop (L2Tcorres st L L'));
      PROP THIN (r. L2Tcorres st (R r) (R' r))
       L2Tcorres st (L2_catch L R) (L2_catch L' R')"
  unfolding THIN_def
  apply (clarsimp simp: L2Tcorres_def L2_defs)
  apply (rule corresXF_guard_imp)
   apply (erule corresXF_except [where P'="λx y s. x = y" and Q="λ_. True"])
     apply (simp add: corresXF_refines_conv)
    apply (simp add: runs_to_partial_def_old split: xval_splits)
   apply simp
  apply simp
  done

lemma corresXF_return_same:
  "corresXF st (λr _. r) (λr _. r) (λ_. True) (return e) (return e)"
  by (clarsimp simp add: corresXF_def)

lemma corresXF_yield_same:
  "corresXF st (λr _. r) (λr _. r) (λ_. True) (yield e) (yield e)"
 by (auto simp add: corresXF_refines_conv intro!: refines_yield split: xval_splits)

lemma L2_try_catch: "L2_try L = L2_catch L (λe. yield (to_xval e))"
  unfolding L2_defs
  apply (rule spec_monad_eqI)
  apply (clarsimp simp add: runs_to_iff)
  apply (auto simp add: runs_to_def_old unnest_exn_def to_xval_def split: xval_splits sum.splits )
  done

lemma L2Tcorres_try [heap_abs]:
    " L2Tcorres st L L'  L2Tcorres st (L2_try L) (L2_try L')"
  apply (simp add: L2_try_catch)
  apply (erule L2Tcorres_catch [simplified THIN_def])
  apply (unfold L2Tcorres_def top_fun_def top_bool_def)
  apply (rule corresXF_yield_same)
  done

lemma L2Tcorres_unknown [heap_abs]:
  "L2Tcorres st (L2_unknown ns) (L2_unknown ns)"
  apply (clarsimp simp: L2_unknown_def)
  apply (clarsimp simp: L2Tcorres_def)
  apply (auto intro!: corresXF_select_select)
  done

lemma L2Tcorres_throw [heap_abs]:
  "L2Tcorres st (L2_throw x n) (L2_throw x n)"
  apply (clarsimp simp: L2Tcorres_def L2_defs)
  apply (rule corresXF_throw)
  apply simp
  done

lemma L2Tcorres_split [heap_abs]:
  " x y. L2Tcorres st (P x y) (P' x y)  
    L2Tcorres st (case a of (x, y)  P x y) (case a of (x, y)  P' x y)"
  apply (clarsimp simp: split_def)
  done

lemma L2Tcorres_seq_unused_result [heap_abs]:
  "PROP THIN (Trueprop (L2Tcorres st L L')); PROP THIN (Trueprop (L2Tcorres st R R')) 
   L2Tcorres st (L2_seq L (λ_. R)) (L2_seq L' (λ_. R'))"
  apply (rule L2Tcorres_seq, auto)
  done

lemma abs_expr_split [heap_abs]:
  " a b. abs_expr st (P a b) (A a b) (C a b) 
        abs_expr st (case r of (a, b)  P a b)
            (case r of (a, b)  A a b) (case r of (a, b)  C a b)"
  apply (auto simp: split_def)
  done

lemma abs_guard_split [heap_abs]:
  " a b. abs_guard st (A a b) (C a b) 
        abs_guard st (case r of (a, b)  A a b) (case r of (a, b)  C a b)"
  apply (auto simp: split_def)
  done

lemma L2Tcorres_abstract_fail [heap_abs]:
  "L2Tcorres st L2_fail L2_fail"
  apply (clarsimp simp: L2Tcorres_def L2_defs)
  apply (rule corresXF_fail)
  done

lemma abs_expr_id [heap_abs]:
  "abs_expr id (λ_. True) A A"
  apply (clarsimp simp: abs_expr_def)
  done

lemma abs_expr_lambda_null [heap_abs]:
  "abs_expr st P A C  abs_expr st P (λs r. A s) (λs r. C s)"
  apply (clarsimp simp: abs_expr_def)
  done

lemma abs_modify_id [heap_abs]:
  "abs_modifies id (λ_. True) A A"
  apply (clarsimp simp: abs_modifies_def)
  done

lemma corresXF_exec_concrete [intro?]:
  "corresXF id ret_xf ex_xf P A C  corresXF st ret_xf ex_xf P (exec_concrete st A) C"
  by (force simp add: corresXF_refines_conv refines_def_old reaches_exec_concrete succeeds_exec_concrete_iff split: xval_splits)

lemma L2Tcorres_exec_concrete [heap_abs]:
  "L2Tcorres id A C  L2Tcorres st (exec_concrete st (L2_call A emb ns)) (L2_call C emb ns)"
  apply (clarsimp simp: L2Tcorres_def L2_call_def map_exn_catch_conv)
  apply (rule corresXF_exec_concrete)
  apply (rule CorresXF.corresXF_except  [ where P' = "λx y s. x = y"])
     apply assumption
  subgoal 
    by (auto simp add: corresXF_refines_conv)
  subgoal
    by (auto simp add: runs_to_partial_def_old split: xval_splits)
  subgoal by simp
  done

lemma L2Tcorres_exec_concrete_simpl [heap_abs]:
  "L2Tcorres id A C  L2Tcorres st (exec_concrete st (L2_call_L1 arg_xf gs ret_xf A)) (L2_call_L1 arg_xf gs ret_xf C)"
  apply (clarsimp simp: L2Tcorres_def L2_call_L1_def)
  apply (rule corresXF_exec_concrete)
  apply (clarsimp simp add: corresXF_refines_conv)
  apply (rule refines_bind_bind_exn_wp)
  apply (clarsimp split: xval_splits)
  apply (rule refines_get_state)
  apply (clarsimp split: xval_splits)
  apply (rule refines_bind_bind_exn_wp)
  apply (clarsimp split: xval_splits)
  apply (rule refines_select)
  apply (clarsimp split: xval_splits)
  subgoal for x
    apply (rule exI[where x=x])
    apply (erule_tac x=x in allE)
    apply (clarsimp)
    apply (rule refines_run_bind)
    apply (clarsimp split: exception_or_result_splits)
    apply (erule refines_weaken)
    apply (clarsimp split: xval_splits)
    apply (rule refines_bind_bind_exn_wp)
    apply (clarsimp split: xval_splits)
    apply (rule refines_set_state)
    apply (clarsimp split: xval_splits)
    done
  done

lemma corresXF_exec_abstract [intro?]:
  "corresXF st ret_xf ex_xf P A C  corresXF id ret_xf ex_xf P (exec_abstract st A) C"
  by (force simp: corresXF_refines_conv refines_def_old reaches_exec_abstract split: xval_splits)

lemma L2Tcorres_exec_abstract [heap_abs]:
    "L2Tcorres st A C  L2Tcorres id (exec_abstract st (L2_call A emb ns)) (L2_call C emb ns)"
  apply (clarsimp simp: L2_call_def map_exn_catch_conv L2Tcorres_def)
  apply (rule corresXF_exec_abstract)
  apply (rule CorresXF.corresXF_except  [ where P' = "λx y s. x = y"])
     apply assumption
    subgoal by (auto simp add: corresXF_refines_conv)
    subgoal by (auto simp add: runs_to_partial_def_old split: xval_splits)
    subgoal by simp  
    done

lemma L2Tcorres_call [heap_abs]:
  "L2Tcorres st A C  L2Tcorres st (L2_call A emb ns) (L2_call C emb ns)"
  unfolding L2Tcorres_def L2_call_def map_exn_catch_conv
  apply (rule CorresXF.corresXF_except  [ where P' = "λx y s. x = y"])
    apply assumption
    subgoal by (auto simp add: corresXF_refines_conv)
    subgoal by (auto simp add: runs_to_partial_def_old split: xval_splits)
    subgoal by simp  
    done


named_theorems
valid_implies_c_guard and
read_commutes and
write_commutes  and
field_write_commutes and
write_valid_preservation and
lift_heap_update_padding_heap_update_conv

(*
 * Assert the given abstracted heap (accessed using "getter" and "setter") for type
 * "'a" is a valid abstraction w.r.t. the given state translation functions.
 *)



locale valid_implies_cguard =
  fixes st::"'s  't"
  fixes v::"'t  'a::c_type ptr  bool"
  assumes valid_implies_c_guard[valid_implies_c_guard]: "v (st s) p  c_guard p"

locale read_simulation =
  fixes st ::"'s  't"
  fixes v ::"'t  'a::c_type ptr  bool"
  fixes r :: "'t  'a ptr  'a"
  fixes t_hrs::"'s  heap_raw_state"
  assumes read_commutes[read_commutes]: "v (st s) p 
              r (st s) p = h_val (hrs_mem (t_hrs s)) p"

locale write_simulation =
  heap_raw_state t_hrs t_hrs_upd
  for
    t_hrs :: "('s  heap_raw_state)" and
    t_hrs_upd::"(heap_raw_state  heap_raw_state)  's  's" +
  fixes st ::"'s  't"
  fixes v ::"'t  'a::mem_type ptr  bool"
  fixes w :: "'a ptr  ('a  'a)   't  't"

  assumes write_padding_commutes[write_commutes]: "v (st s) p  length bs = size_of TYPE('a) 
           st (t_hrs_upd (hrs_mem_update (heap_update_padding p x bs)) s) =
                           w p (λ_. x)  (st s)"

begin
lemma write_commutes[write_commutes]:
  assumes valid: "v (st s) p"
  shows "st (t_hrs_upd (hrs_mem_update (heap_update p x)) s) =
                           w p (λ_. x) (st s)"
proof -
  have eq: "hrs_mem_update (heap_update p x) =
        hrs_mem_update (λh. heap_update_padding p x (heap_list h (size_of TYPE('a)) (ptr_val p)) h)"
    using heap_update_heap_update_padding_conv
    by metis

  show ?thesis
    apply (simp only: eq)
    apply (subst write_padding_commutes [symmetric,  where bs="heap_list (hrs_mem (t_hrs s)) (size_of TYPE('a)) (ptr_val p)"])
      apply (rule valid)
     apply clarsimp
    by (metis (no_types, lifting) heap.upd_cong)
qed

lemma lift_heap_update_padding_heap_update_conv[lift_heap_update_padding_heap_update_conv]:
  "v (st s) p  length bs = size_of TYPE('a) 
           st (t_hrs_upd (hrs_mem_update (heap_update_padding p x bs)) s) =
           st (t_hrs_upd (hrs_mem_update (heap_update p x)) s)"
  using write_padding_commutes write_commutes by auto

lemma write_commutes_atomic: "s p x. v (st s) p 
   st (t_hrs_upd (hrs_mem_update (heap_update p x)) s) =
                           w p (λ_. x) (st s)"
  using  write_commutes by blast

end


locale write_preserves_valid =
  fixes v ::"'t  'a::c_type ptr  bool"
  fixes w :: "'a ptr  ('b  'b)  't  't"
  assumes valid_preserved: "v (w p' f s) p = v s p"
begin
lemma valid_preserved_pointless[simp]: "v (w p' f s)  = v s"
  by (rule ext) (rule valid_preserved)
end


locale valid_only_typ_desc_dependent =
  fixes t_hrs :: "('s  heap_raw_state)"
  fixes st ::"'s  't"
  fixes v ::"'t  'a::c_type ptr  bool"
  assumes valid_same_typ_desc [valid_same_typ_descs]: "hrs_htd (t_hrs s) = hrs_htd (t_hrs t)  v (st s) p = v (st t) p"

locale heap_typing_simulation =
  open_types 𝒯 + t_hrs: heap_raw_state t_hrs t_hrs_upd + heap_typing_state heap_typing heap_typing_upd
  for
    𝒯 and
    t_hrs :: "('s  heap_raw_state)" and
    t_hrs_upd :: "(heap_raw_state  heap_raw_state)  ('s  's)" and
    heap_typing :: "'t  heap_typ_desc" and
    heap_typing_upd :: "(heap_typ_desc  heap_typ_desc)  't  't" +
  fixes st ::"'s  't"
  assumes heap_typing_commutes[simp]: "heap_typing (st s) = hrs_htd (t_hrs s)"
  assumes lift_heap_update_list_stack_byte_independent:
    "(i. i < length bs  root_ptr_valid (hrs_htd (t_hrs s)) ((p::stack_byte ptr) +p int i)) 
     st (t_hrs_upd (hrs_mem_update (heap_update_list (ptr_val p) bs)) s) = st s"
  assumes st_eq_upto_padding:
    "mem_type_u t  padding_closed_under_all_fields t 
      ptr_valid_u t (hrs_htd (t_hrs s)) a  eq_upto_padding t bs bs' 
      st (t_hrs_upd (hrs_mem_update (heap_update_list a bs)) s) =
      st (t_hrs_upd (hrs_mem_update (heap_update_list a bs')) s)"
begin

lemma heap_typing_upd_commutes: "heap_typing (heap_typing_upd f (st s)) = hrs_htd (t_hrs (t_hrs_upd (hrs_htd_update f) s))"
  apply (simp add: hrs_htd_update)
  done

lemma write_simulation_alt:
  assumes v: "s p. v (st s) p  ptr_valid (hrs_htd (t_hrs s)) p"
  assumes *: "s (p::'a::xmem_type ptr) x. v (st s) p 
    st (t_hrs_upd (hrs_mem_update (heap_update p x)) s) = w p (λ_. x)  (st s)"
  shows "write_simulation t_hrs t_hrs_upd st v w"
proof
  fix s p x and bs :: "byte list" assume p: "v (st s) p" and bs: "length bs = size_of TYPE('a)"

  have [simp]: "t_hrs_upd (hrs_mem_update (heap_update p x)) s =
    t_hrs_upd (hrs_mem_update (heap_update_list (ptr_val p)
      (to_bytes x (heap_list (hrs_mem (t_hrs s)) (size_of TYPE('a)) (ptr_val p))))) s"
    by (rule t_hrs.heap.upd_cong) (simp add: heap_update_def)

  show "st (t_hrs_upd (hrs_mem_update (heap_update_padding p x bs)) s) = w p (λ_. x) (st s)"
    apply (subst *[OF p, symmetric])
    apply (simp add: heap_update_padding_def[abs_def])
    apply (rule st_eq_upto_padding[where t="typ_uinfo_t TYPE('a)"])
    apply (rule typ_uinfo_t_mem_type)
    apply (rule padding_closed_under_all_fields_typ_uinfo_t)
    apply (subst ptr_valid_def[symmetric])
    apply (simp add: v p)
    unfolding to_bytes_def typ_uinfo_t_def
    apply (rule field_lookup_access_ti_eq_upto_padding[where f="[]" and 'b='a])
    apply (simp_all add: bs size_of_def)
    done
qed

end

locale typ_heap_simulation =
  heap_raw_state t_hrs t_hrs_update +
  read_simulation st v r t_hrs +
  write_simulation t_hrs t_hrs_update st v w  +
  write_preserves_valid v w +
  valid_implies_cguard st v +
  valid_only_typ_desc_dependent t_hrs st v +
  pointer_lense r w
  for
    st:: "'s  't" and
    r:: "'t  ('a::xmem_type) ptr  'a" and
    w:: "'a ptr  ('a  'a)  't  't" and
    v:: "'t  ('a::xmem_type) ptr  bool" and
    t_hrs :: "'s  heap_raw_state" and
    t_hrs_update:: "(heap_raw_state  heap_raw_state)  's  's"
begin

lemma write_valid_preservation [write_valid_preservation]:
  shows "v (st (t_hrs_update (hrs_mem_update (heap_update q x)) s)) p = v (st s) p"
  by (metis hrs_htd_mem_update get_upd valid_same_typ_desc)

lemma write_padding_valid_preservation [write_valid_preservation]:
  shows "v (st (t_hrs_update (hrs_mem_update (heap_update_padding q x bs)) s)) p = v (st s) p"
  by (metis hrs_htd_mem_update get_upd valid_same_typ_desc)

end



locale stack_simulation =
  heap_typing_simulation 𝒯 t_hrs t_hrs_update heap_typing heap_typing_upd st +
  typ_heap_typing r w heap_typing heap_typing_upd 𝒮
  for
    𝒯 and
    st:: "'s  't" and
    r:: "'t  ('a::xmem_type) ptr  'a" and
    w:: "'a ptr  ('a  'a)  't  't" and
    t_hrs :: "'s  heap_raw_state" and
    t_hrs_update:: "(heap_raw_state  heap_raw_state)  's  's" and
    heap_typing :: "'t  heap_typ_desc" and
    heap_typing_upd :: "(heap_typ_desc  heap_typ_desc)  't  't" and
    𝒮:: "addr set" +
assumes sim_stack_alloc:
  "p d vs bs s n.
    (p, d)  stack_allocs n 𝒮 TYPE('a) (hrs_htd (t_hrs s))  length vs = n  length bs = n * size_of TYPE ('a) 
      st (t_hrs_update (hrs_mem_update (fold (λi. heap_update_padding (p +p int i) (vs!i) (take (size_of TYPE('a)) (drop (i * size_of TYPE('a)) bs))) [0..<n])  hrs_htd_update (λ_. d)) s) =
      (fold (λi. w (p +p int i) (λ_. (vs ! i))) [0..<n]) (heap_typing_upd (λ_. d) (st s))"

assumes sim_stack_release: "p s n. (i. i < n  root_ptr_valid (hrs_htd (t_hrs s)) (p +p int i)) 
  length bs = n * size_of TYPE('a) 
 st (t_hrs_update (hrs_mem_update (heap_update_list (ptr_val p) bs)  hrs_htd_update (stack_releases n p)) s) =
         ((heap_typing_upd (stack_releases n p) (fold (λi. w (p +p int i) (λ_. c_type_class.zero)) [0..<n] (st s))))"

assumes stack_byte_zero: "p d i. (p, d)  stack_allocs n 𝒮 TYPE('a) (hrs_htd (t_hrs s))  i < n  r (st s) (p +p int i) = ZERO('a)"


lemma (in typ_heap_simulation) L2Tcorres_IO_modify_paddingE [heap_abs]:
  assumes "abs_expr st P a c"
  shows "L2Tcorres st (L2_seq (L2_guard (λt. v t p  P t)) (λ_. (L2_modify (λs. w p (λ_. a s) s)))) 
    (IO_modify_heap_paddingE (p::'a ptr) c)"
  using assms
  using length_to_bytes write_padding_commutes unfolding liftE_IO_modify_heap_padding
  by (auto simp add: abs_expr_def L2Tcorres_def corresXF_refines_conv L2_defs 
      IO_modify_heap_padding_def refines_def_old reaches_bind succeeds_bind split: xval_splits)



locale typ_heap_typing_stack_simulation =
  typ_heap_simulation st r w v t_hrs t_hrs_update +
  stack_simulation 𝒯 st r w t_hrs t_hrs_update heap_typing heap_typing_upd 𝒮
  for
    𝒯 and
    st:: "'s  't" and
    r:: "'t  ('a::xmem_type) ptr  'a" and
    w:: "'a ptr  ('a  'a)  't  't" and
    v:: "'t  ('a::xmem_type) ptr  bool" and
    t_hrs :: "'s  heap_raw_state" and
    t_hrs_update:: "(heap_raw_state  heap_raw_state)  's  's" and
    heap_typing :: "'t  heap_typ_desc" and
    heap_typing_upd :: "(heap_typ_desc  heap_typ_desc)  't  't" and
    𝒮:: "addr set"
begin

sublocale monolithic: stack_heap_raw_state t_hrs t_hrs_update 𝒮
  by (unfold_locales)

definition "rel_split_heap  λsc sa. sa = st sc"

lemma rel_split_heap_stack_free_eq:
  "rel_split_heap sc sa   stack_free (hrs_htd (t_hrs sc)) = stack_free (heap_typing sa)"
  by (simp only: rel_split_heap_def heap_typing_commutes)

definition rel_stack_free_eq where
  "rel_stack_free_eq sc sa   stack_free (hrs_htd (t_hrs sc)) = stack_free (heap_typing sa)"

lemma rel_prod_rel_split_heap_conv:
  "rel_prod (=) rel_split_heap = (λ(v, t) (r, s).
             s = st t  (case v of Exn x  r = Exn x | Result x  r = Result x)) "
  by (auto simp add: rel_split_heap_def rel_prod_conv fun_eq_iff split: prod.splits xval_splits)

lemma L2Tcorres_refines:
  "L2Tcorres st fa fc  refines fc fa s (st s) (rel_prod (=) rel_split_heap)"
  by (simp add: L2Tcorres_def corresXF_refines_conv rel_prod_rel_split_heap_conv)

lemma refines_L2Tcorres:
  assumes f: "s. refines fc fa s (st s) (rel_prod (=) rel_split_heap)"
  shows "L2Tcorres st fa fc"
  using f
  by (simp add: L2Tcorres_def corresXF_refines_conv rel_prod_rel_split_heap_conv)

lemma L2Tcorres_refines_conv:
"L2Tcorres st fa fc  (s. refines fc fa s (st s) (rel_prod (=) rel_split_heap))"
  by (auto simp add: L2Tcorres_refines refines_L2Tcorres)

lemma sim_guard_with_fresh_stack_ptr:
  fixes fc:: "'a ptr  ('b, 'c, 's) exn_monad"
  assumes init: "inita (st s) = initc s"
  assumes f: "s p::'a ptr. refines (fc p) (fa p) s (st s) (rel_prod (=) rel_split_heap)"
  shows "refines
           (monolithic.with_fresh_stack_ptr n initc fc)
           (guard_with_fresh_stack_ptr n inita fa) s (st s)
           (rel_prod (=) rel_split_heap)"
  unfolding monolithic.with_fresh_stack_ptr_def guard_with_fresh_stack_ptr_def
                   stack_ptr_acquire_def stack_ptr_release_def assume_stack_alloc_def
  apply (rule refines_bind_bind_exn [where Q= "(rel_prod (=) rel_split_heap)"])
  subgoal
    apply (rule refines_assume_result_and_state')
    using sim_stack_alloc stack_byte_zero 
    by (fastforce simp add: sim_set_def rel_split_heap_def init split: xval_splits)
  apply simp
  apply simp
  apply simp
  apply (rule refines_rel_prod_guard_on_exit [where S'="rel_split_heap"])
   apply (subst (asm) rel_split_heap_def )
   apply simp
   apply (rule f)
  subgoal by (auto simp add: rel_split_heap_def sim_stack_release)
  subgoal by (simp add: Ex_list_of_length)
  done

lemma sim_with_fresh_stack_ptr:
  fixes fc:: "'a ptr  ('b, 'c, 's) exn_monad"
  assumes init: "inita (st s) = initc s"
  assumes f: "s p::'a ptr. refines (fc p) (fa p) s (st s) (rel_prod (=) rel_split_heap)"
  assumes typing_unchanged: "s p::'a ptr. (fc p)  s ?⦃λr t. typing.unchanged_typing_on 𝒮 s t"
  shows "refines
           (monolithic.with_fresh_stack_ptr n initc fc)
           (with_fresh_stack_ptr n inita fa) s (st s)
           (rel_prod (=) rel_split_heap)"
  apply (simp add: monolithic.with_fresh_stack_ptr_def with_fresh_stack_ptr_def
                   stack_ptr_acquire_def stack_ptr_release_def assume_stack_alloc_def)
  apply (rule refines_bind_bind_exn [where Q= "λ(r,t) (r',t').
           (rel_prod (=) rel_split_heap) (r,t) (r',t') 
          (p. r = Result p  (i < n. ptr_span (p +p int i)  𝒮  root_ptr_valid (heap_typing t') (p +p int i)))"], simp_all)

  subgoal
    apply (rule refines_assume_result_and_state')
    using sim_stack_alloc stack_byte_zero stack_allocs_𝒮 
    apply (clarsimp simp add: sim_set_def init rel_split_heap_def, safe)
      apply blast+
    by (smt (verit) hrs_htd_update stack_allocs_cases)
 
  subgoal for p t t'
    apply (rule refines_runs_to_partial_rel_prod_on_exit [where S'="rel_split_heap" and P="typing.unchanged_typing_on 𝒮 t"])
       apply (subst (asm) rel_split_heap_def )
       apply simp
       apply (rule f)
      apply (rule typing_unchanged)
    subgoal for sc sa tc
      apply clarsimp
      apply (clarsimp simp add: rel_split_heap_def)
      apply (subst sim_stack_release)
      subgoal for bs i
        using typing.unchanged_typing_on_root_ptr_valid_preservation [where S=𝒮 and s=t and t=sc and p=" (p +p int i)"]
        by auto
      subgoal by auto
      subgoal by auto
      done
  subgoal by (simp add: Ex_list_of_length)
  done
  done

lemma sim_assume_with_fresh_stack_ptr:
  fixes fc:: "'a ptr  ('b, 'c, 's) exn_monad"
  assumes init: "inita (st s) = initc s"
  assumes f: "s p::'a ptr. refines (fc p) (fa p) s (st s) (rel_prod (=) rel_split_heap)"
  assumes typing_unchanged: "s p::'a ptr. (fc p)  s ?⦃λr t. typing.unchanged_typing_on 𝒮 s t"
  shows "refines
           (monolithic.with_fresh_stack_ptr n initc fc)
           (assume_with_fresh_stack_ptr n inita fa) s (st s)
           (rel_prod (=) rel_split_heap)"
  unfolding monolithic.with_fresh_stack_ptr_def assume_with_fresh_stack_ptr_def
                   stack_ptr_acquire_def stack_ptr_release_def assume_stack_alloc_def
  apply (rule refines_bind_bind_exn [where Q= "λ(r,t) (r',t').
           (rel_prod (=) rel_split_heap) (r,t) (r',t') 
          (p. r = Result p  (i < n. ptr_span (p +p int i)  𝒮  root_ptr_valid (heap_typing t') (p +p int i)))"])

  subgoal
    apply (rule refines_assume_result_and_state')
    using sim_stack_alloc stack_byte_zero stack_allocs_𝒮
    apply (clarsimp simp add: sim_set_def init rel_split_heap_def hrs_htd_update stack_allocs_root_ptr_valid_same)
     apply blast
    done
  apply simp
  apply simp
  apply simp
  subgoal for p q t t'
    apply (rule refines_runs_to_partial_rel_prod_assume_on_exit [where S'="rel_split_heap" and P="typing.unchanged_typing_on 𝒮 t"])
       apply (subst (asm) rel_split_heap_def )
       apply simp
       apply (rule f)
      apply (rule typing_unchanged)
    subgoal for sc sa tc
      apply clarsimp
      apply (clarsimp simp add: rel_split_heap_def)
      apply (subst sim_stack_release)
      subgoal for bs i
        using typing.unchanged_typing_on_root_ptr_valid_preservation [where S=𝒮 and s=t and t=sc and p=" (p +p int i)"]
        by auto
      subgoal by auto
      subgoal by auto
      done
    subgoal by (simp add: Ex_list_of_length)
    subgoal for sc sa
      apply clarsimp
      apply (clarsimp simp add: rel_split_heap_def)
      subgoal for i
        using typing.unchanged_typing_on_root_ptr_valid_preservation [where S=𝒮 and s=t and t=sc and p=" (p +p int i)"]
        by auto
      done
    done
  done



lemma L2Tcorres_guard_with_fresh_stack_ptr [heap_abs]:
  assumes rew: "struct_rewrite_expr P initc' initc"
  assumes grd: "abs_guard st P' P"
  assumes expr: "abs_expr st Q inita initc'"
  assumes f[simplified THIN_def, rule_format]: "PROP THIN (p::'a ptr. L2Tcorres st (fa p) (fc p))"
  shows "L2Tcorres st (L2_seq (L2_guard (λs. P' s  Q s))
           (λ_. (guard_with_fresh_stack_ptr n inita (L2_VARS fa nm))))
           (monolithic.with_fresh_stack_ptr n initc (L2_VARS fc nm))"
  apply (rule refines_L2Tcorres)
  apply (simp add: L2_seq_def L2_guard_def  L2_VARS_def )
  apply (rule refines_bind_guard_right)
  apply clarsimp
  apply (rule sim_guard_with_fresh_stack_ptr)
  subgoal for s
    using rew grd expr
    by (auto simp add: struct_rewrite_expr_def abs_guard_def abs_expr_def)
  subgoal for s s' p
    apply (rule L2Tcorres_refines)
    apply (rule f)
    done
  done

lemma L2Tcorres_with_fresh_stack_ptr:
  assumes typing_unchanged: "s p::'a ptr. (fc p)  s ?⦃λr t. typing.unchanged_typing_on 𝒮 s t"
  assumes rew: "struct_rewrite_expr P initc' initc"
  assumes grd: "abs_guard st P' P"
  assumes expr: "abs_expr st Q inita initc'"
  assumes f[simplified THIN_def, rule_format]: "PROP THIN (p::'a ptr. L2Tcorres st (fa p) (fc p))"
  shows "L2Tcorres st (L2_seq (L2_guard (λs. P' s  Q s))
           (λ_. (with_fresh_stack_ptr n inita (L2_VARS fa nm))))
           (monolithic.with_fresh_stack_ptr n initc (L2_VARS fc nm))"
  apply (rule refines_L2Tcorres)
  apply (simp add: L2_seq_def L2_guard_def  L2_VARS_def)
  apply (rule refines_bind_guard_right)
  apply clarsimp
  apply (rule sim_with_fresh_stack_ptr)
  subgoal for s
    using rew grd expr
    by (auto simp add: struct_rewrite_expr_def abs_guard_def abs_expr_def)
  subgoal for s s' p
    apply (rule L2Tcorres_refines)
    apply (rule f)
    done
  using typing_unchanged by blast

lemma L2Tcorres_assume_with_fresh_stack_ptr[heap_abs]:
  assumes typing_unchanged: "s p::'a ptr. (fc p)  s ?⦃λr t. typing.unchanged_typing_on 𝒮 s t"
  assumes rew: "struct_rewrite_expr P initc' initc"
  assumes grd: "abs_guard st P' P"
  assumes expr: "abs_expr st Q inita initc'"
  assumes f[simplified THIN_def, rule_format]: "PROP THIN (p::'a ptr. L2Tcorres st (fa p) (fc p))"
  shows "L2Tcorres st (L2_seq (L2_guard (λs. P' s  Q s))
           (λ_. (assume_with_fresh_stack_ptr n inita (L2_VARS fa nm))))
           (monolithic.with_fresh_stack_ptr n initc (L2_VARS fc nm))"
  apply (rule refines_L2Tcorres)
  apply (simp add: L2_seq_def L2_guard_def L2_VARS_def)
  apply (rule refines_bind_guard_right)
  apply clarsimp
  apply (rule sim_assume_with_fresh_stack_ptr)
  subgoal for s
    using rew grd expr
    by (auto simp add: struct_rewrite_expr_def abs_guard_def abs_expr_def)
  subgoal for s s' p
    apply (rule L2Tcorres_refines)
    apply (rule f)
    done
  using typing_unchanged by blast


lemma unchanged_typing_commutes: "typing.unchanged_typing_on S s t  unchanged_typing_on S (st s) (st t)"
  using heap_typing_commutes [of s] heap_typing_commutes [of t]
  by (auto simp add: unchanged_typing_on_def typing.unchanged_typing_on_def)
end

(* The following lemmas help to establish that reading from stack_byte typed locations
   results in ZERO('a) *)

named_theorems read_stack_byte_ZERO_base
  and read_stack_byte_ZERO_step
  and read_stack_byte_ZERO_step_subst

lemma (in open_types) ptr_span_with_stack_byte_type_implies_ptr_invalid:
  fixes p :: "('a :: {mem_type, stack_type}) ptr"
  assumes *: "a  ptr_span p. root_ptr_valid d (PTR (stack_byte) a)"
  shows "¬ ptr_valid_u (typ_uinfo_t TYPE('a)) d (ptr_val p)"
  by (metis assms disjoint_iff fold_ptr_valid' in_ptr_span_itself ptr_exhaust_eq
            ptr_valid_stack_byte_disjoint)

lemma (in open_types)
  ptr_span_with_stack_byte_type_implies_ZERO[read_stack_byte_ZERO_base]:
  fixes p :: "('a :: {mem_type, stack_type}) ptr"
  assumes "a  ptr_span p. root_ptr_valid (hrs_htd d) (PTR (stack_byte) a)"
  shows "the_default (ZERO('a)) (plift d p) = ZERO('a)"
  using ptr_span_with_stack_byte_type_implies_ptr_invalid[OF assms]
  by (simp add: fold_ptr_valid' plift_None)

lemma ptr_span_array_ptr_index_subset_ptr_span:
  fixes p :: "(('a :: {array_outer_max_size})['b :: array_max_count]) ptr"
  assumes "i < CARD('b)"
  shows "ptr_span (array_ptr_index p c i)  ptr_span p"
  using assms
  apply (simp add: array_ptr_index_def ptr_add_def)
  apply (rule intvl_sub_offset)
  apply (rule order_trans[of _ "i * size_of TYPE('a) + size_of TYPE('a)"])
  apply (simp add: unat_le_helper)
  apply (simp add: add.commute less_le_mult_nat)
  done

lemma read_stack_byte_ZERO_array_intro[read_stack_byte_ZERO_step]:
  fixes q :: "('a :: {array_outer_max_size}['b :: array_max_count]) ptr"
  assumes ptr_span_has_stack_byte_type:
    "aptr_span q. root_ptr_valid d (PTR(stack_byte) a)"
  assumes subtype_reads_ZERO:
    "p :: 'a ptr. aptr_span p. root_ptr_valid d (PTR(stack_byte) a)  r s p = ZERO('a)"
  shows "(ARRAY i. r s (array_ptr_index q c i)) = ZERO('a['b])"
  apply (rule array_ext)
  apply (simp add: array_index_zero)
  apply (rule subtype_reads_ZERO)
  using ptr_span_has_stack_byte_type ptr_span_array_ptr_index_subset_ptr_span by blast

lemma read_stack_byte_ZERO_array_2dim_intro[read_stack_byte_ZERO_step]:
  fixes q :: "('a :: {array_inner_max_size}['b :: array_max_count]['c :: array_max_count]) ptr"
  assumes ptr_span_has_stack_byte_type:
    "aptr_span q. root_ptr_valid d (PTR(stack_byte) a)"
  assumes subtype_reads_ZERO:
    "p :: 'a ptr. aptr_span p. root_ptr_valid d (PTR(stack_byte) a)  r s p = ZERO('a)"
  shows "(ARRAY i j. r s (array_ptr_index (array_ptr_index q c i) c j)) = ZERO('a['b]['c])"
  apply (rule array_ext)
  apply (simp add: array_index_zero)
  apply (rule array_ext)
  apply (simp add: array_index_zero)
  apply (rule subtype_reads_ZERO)
  by (metis (no_types, opaque_lifting) subset_iff ptr_span_has_stack_byte_type
            ptr_span_array_ptr_index_subset_ptr_span)

lemma read_stack_byte_ZERO_field_intro[read_stack_byte_ZERO_step]:
  fixes q :: "'a :: mem_type ptr"
  assumes ptr_span_has_stack_byte_type:
    "aptr_span q. root_ptr_valid d (PTR(stack_byte) a)"
  assumes subtype_reads_ZERO:
    "p :: 'b :: mem_type ptr. aptr_span p. root_ptr_valid d (PTR(stack_byte) a)  r s p = ZERO('b)"
  assumes subtype_lookup:
    "field_lookup (typ_uinfo_t TYPE('a)) f 0 = Some (typ_uinfo_t TYPE('b), n)"
  shows "r s (PTR('b) (&(qf))) = ZERO('b)"
proof -
  have "ptr_span (PTR('b) (&(qf)))  ptr_span q"
    using TypHeapSimple.field_tag_sub'[OF subtype_lookup]
    by (simp, metis size_of_fold)
  thus ?thesis
    using ptr_span_has_stack_byte_type subtype_lookup subtype_reads_ZERO by blast
qed


context open_types
begin

lemma ptr_span_with_stack_byte_type_implies_read_dedicated_heap_ZERO[simp]:
  "aptr_span p. root_ptr_valid (hrs_htd s) (PTR(stack_byte) a) 
    read_dedicated_heap s p = ZERO('a::{stack_type, xmem_type})"
  unfolding read_dedicated_heap_def ptr_span_with_stack_byte_type_implies_ZERO[of p] merge_addressable_fields.idem ..

lemma write_simulationI:
  fixes R :: "'s  'a::xmem_type ptr  'a"
  assumes fs: "map_of 𝒯 (typ_uinfo_t TYPE('a)) = Some fs"
  assumes "heap_typing_simulation 𝒯 t_hrs t_hrs_update heap_typing heap_typing_update l"
    and l_w: "list_all2 (λf w. t u n h (p::'a ptr) x.
        field_ti TYPE('a) f = Some t 
        field_lookup (typ_uinfo_t TYPE('a)) f 0 = Some (u, n) 
        ptr_valid_u u (hrs_htd (t_hrs h)) &(pf) 
        l (t_hrs_update (hrs_mem_update (heap_upd_list (size_td u) &(pf) (access_ti t x))) h)
          = w p x (l h)) fs ws"
    and l_u: "(p::'a ptr) (x::'a) (s::'b).
      ptr_valid (hrs_htd (t_hrs s)) p 
      l (t_hrs_update (write_dedicated_heap p x) s) = u (upd_fun p (λold. merge_addressable_fields old x)) (l s)"
  assumes V:
    "h p. V (l h) p  ptr_valid (hrs_htd (t_hrs h)) p"
  assumes W:
    "p f h. W p f h =
      fold (λw. w p (f (R h p))) ws (u (upd_fun p (λold. merge_addressable_fields old (f (R h p)))) h)"
  shows "write_simulation t_hrs t_hrs_update l V W"
proof -
  interpret hrs: heap_typing_simulation 𝒯 t_hrs t_hrs_update heap_typing heap_typing_update l
    by fact

  have valid:
    "list_all (λf. u n. field_lookup (typ_uinfo_t TYPE('a)) f 0 = Some (u, n) 
      ptr_valid_u u h &(pf)) fs"
    if *: "ptr_valid_u (typ_uinfo_t TYPE('a)) h (ptr_val p)" for h and p :: "'a ptr"
    using ptr_valid_u_step[OF fs _ _ *]
    by (auto simp: list_all_iff field_lvalue_def field_offset_def)

  have fold': "l (fold
        (λxa. t_hrs_update
                (hrs_mem_update
                  (heap_upd_list (size_td (the (field_ti TYPE('a) xa))) &(pxa)
                    (access_ti (the (field_ti TYPE('a) xa)) x))))
        fs s) =
      fold (λw. w p x) ws (l s)"
    if p: "ptr_valid_u (typ_uinfo_t TYPE('a)) (hrs_htd (t_hrs s)) (ptr_val p)"
    for p x s
    using l_w wf_𝒯[OF fs] p[THEN valid]
  proof (induction arbitrary: s)
    case (Cons f fs w ws)
    from Cons.prems obtain u n where f_u :"field_lookup (typ_uinfo_t TYPE('a)) f 0 = Some (u, n)"
      and [simp]: "list_all (λf. a b. field_lookup (typ_uinfo_t TYPE('a)) f 0 = Some (a, b)) fs"
      by auto
    from f_u[THEN field_lookup_uinfo_Some_rev] obtain k where
      "field_lookup (typ_info_t TYPE('a)) f 0 = Some (k, n)" and u_eq: "u = export_uinfo k"
      by auto
    then have [simp]: "field_ti TYPE('a) f = Some k" by (simp add: field_ti_def)
    have [simp]: "size_td k = size_td u"
      by (simp add: u_eq)
    have [simp]: "ptr_valid_u u (hrs_htd (t_hrs s)) &(pf)"
      using Cons.prems(2) f_u by auto
    show ?case
      using Cons.prems Cons.hyps by (simp add: Cons.IH f_u)
  qed simp

  have fold:
    "l ((fold (t_hrs_update 
          (λ(f, u). hrs_mem_update (heap_upd_list (size_td u) &(pf) (access_ti u x))))
        (addressable_fields TYPE('a)) 
      t_hrs_update (write_dedicated_heap p x)) s) =
    fold (λw. w p x) ws (u (upd_fun p (λold. merge_addressable_fields old x)) (l s))"
    if p: "ptr_valid_u (typ_uinfo_t TYPE('a)) (hrs_htd (t_hrs s)) (ptr_val p)"
    for p x s
    by (subst addressable_fields_def)
       (simp add: fs fold_map fold' p ptr_valid_def l_u cong: fold_cong)

  show ?thesis
    apply (rule hrs.write_simulation_alt)
    apply (simp add: V)
    apply (subst hrs_mem_update_heap_update')
    apply (subst W)
    apply (subst (asm) V)
    apply (subst (asm) ptr_valid_def)
    apply (subst hrs.t_hrs.upd_comp[symmetric])
    apply (subst hrs.t_hrs.upd_comp_fold)
    apply (subst fold)
    apply simp_all
    done
qed

end

locale stack_simulation_heap_typing =
  typ_heap_simulation st r w "λt p. open_types.ptr_valid 𝒯 (heap_typing t) p" t_hrs t_hrs_update +
  heap_typing_simulation 𝒯 t_hrs t_hrs_update heap_typing heap_typing_upd st +
  typ_heap_typing r w heap_typing heap_typing_upd 𝒮
  for
    st:: "'s  't" and
    r:: "'t  ('a::{xmem_type, stack_type}) ptr  'a" and
    w:: "'a ptr  ('a  'a)  't  't" and
    t_hrs :: "'s  heap_raw_state" and
    t_hrs_update:: "(heap_raw_state  heap_raw_state)  's  's" and
    heap_typing :: "'t  heap_typ_desc" and
    heap_typing_upd :: "(heap_typ_desc  heap_typ_desc)  't  't" and
    𝒮:: "addr set" and
    𝒯:: "(typ_uinfo * qualified_field_name list) list" +

assumes sim_stack_alloc_heap_typing:
  "p d s n.
    (p, d)  stack_allocs n 𝒮 TYPE('a) (hrs_htd (t_hrs s)) 
      st (t_hrs_update (hrs_mem_update (fold (λi. heap_update (p +p int i) c_type_class.zero) [0..<n])  hrs_htd_update (λ_. d)) s) =
      (heap_typing_upd (λ_. d) (st s))"

assumes sim_stack_release_heap_typing:
"(p::'a ptr) s n. (i. i < n  root_ptr_valid (hrs_htd (t_hrs s)) (p +p int i)) 
  st (t_hrs_update (hrs_htd_update (stack_releases n p)) s) =
    heap_typing_upd (stack_releases n p)
     (st (t_hrs_update (hrs_mem_update (fold (λi. heap_update (p +p int i) c_type_class.zero) [0..<n])) s))"

assumes sim_stack_stack_byte_zero[read_stack_byte_ZERO_step]:
  "p s. aptr_span p. root_ptr_valid (hrs_htd (t_hrs s)) (PTR(stack_byte) a)  r (st s) p = ZERO('a)" (* " *)

begin

lemma fold_heap_update_simulation:
  assumes valid: "i. i < n  ptr_valid (heap_typing (st s)) (p +p int i)"
  shows "st (t_hrs_update (hrs_mem_update (fold (λi. heap_update (p +p int i) (vs i)) [0..<n])) s) =
          fold (λi. w (p +p int i) (λ_. vs i)) [0..<n] (st s)"
  using valid
proof (induct n arbitrary: vs s)
  case 0
  then show ?case
    by (simp add: case_prod_unfold hrs_mem_update_def)
next
  case (Suc n)
  from Suc.prems obtain
    valid: "i. i < Suc n  ptr_valid (heap_typing (st s)) (p +p int i)" by blast

  from valid have valid': "i. i < n  ptr_valid (heap_typing (st s)) (p +p int i)" by auto
  note hyp = Suc.hyps [OF valid']
  show ?case
    apply (simp add: hyp [symmetric])
    apply (subst write_commutes [symmetric])
    using valid
    apply simp
    using TypHeapSimple.hrs_mem_update_comp hrs_mem_update_def
    apply simp
    done
qed

lemma fold_heap_update_padding_simulation:
  assumes valid: "i. i < n  ptr_valid (heap_typing (st s)) (p +p int i)"
  assumes lbs: "length bs = n * size_of TYPE('a)"
  shows "st (t_hrs_update (hrs_mem_update (fold (λi. heap_update_padding (p +p int i) (vs i) (take (size_of TYPE('a)) (drop (i * size_of TYPE('a)) bs))) [0..<n])) s) =
          fold (λi. w (p +p int i) (λ_. vs i)) [0..<n] (st s)"
  using valid lbs
proof (induct n arbitrary: bs vs s )
  case 0
  then show ?case
    by (simp add: case_prod_unfold hrs_mem_update_def)
next
  case (Suc n)
  from Suc.prems obtain
    valid: "i. i < Suc n  ptr_valid (heap_typing (st s)) (p +p int i)" and
    lbs': "length (take (n * (size_of TYPE('a))) bs) = n * size_of TYPE('a)"
    by auto

  from valid have valid': "i. i < n  ptr_valid (heap_typing (st s)) (p +p int i)" by auto
  note hyp = Suc.hyps [OF valid' lbs']
  have take_eq: "i. i < n  take (size_of TYPE('a)) (drop (i * size_of TYPE('a)) (take (n * size_of TYPE('a)) bs)) =
        take (size_of TYPE('a)) (drop (i * size_of TYPE('a)) bs)"
    by (metis Groups.mult_ac(2) mult_less_cancel1 not_less not_less_eq
        order_less_imp_le take_all take_drop_take times_nat.simps(2))

  have fold_eq: "h. fold
              (λi. heap_update_padding (p +p int i) (vs i)
                        (take (size_of TYPE('a)) (drop (i * size_of TYPE('a)) (take (n * size_of TYPE('a)) bs))))
              [0..<n] h =
            fold
              (λi. heap_update_padding (p +p int i) (vs i)
                        (take (size_of TYPE('a)) (drop (i * size_of TYPE('a)) bs)))
                 [0..<n] h"

    apply (rule fold_cong)
      apply (rule refl)
     apply (rule refl)
    using take_eq
    apply simp
    done


  show ?case
    apply (simp add: hyp [symmetric])
    apply (subst write_padding_commutes [symmetric, where bs = "take (size_of TYPE('a)) (drop (n * size_of TYPE('a)) bs)"])
    subgoal using valid
      by simp
    subgoal using Suc.prems by simp
    subgoal
    using TypHeapSimple.hrs_mem_update_comp hrs_mem_update_def
      by (simp add: fold_eq)
    done
qed

lemma sim_stack_alloc':
  assumes alloc: "(p, d)  stack_allocs n 𝒮 TYPE('a) (hrs_htd (t_hrs s))"
  assumes len: "length vs = n"
  assumes lbs: "length bs = n * size_of TYPE('a)"
  shows "st (t_hrs_update (hrs_mem_update (fold (λi. heap_update_padding (p +p int i) (vs!i) (take (size_of TYPE('a)) (drop (i * size_of TYPE('a)) bs))) [0..<n])  hrs_htd_update (λ_. d)) s) =
            (fold (λi. w (p +p int i) (λ_. (vs ! i))) [0..<n]) (heap_typing_upd (λ_. d) (st s))"
proof -
  {
    fix i
    assume i_bound: "i < n"
    have "ptr_valid (heap_typing (st (t_hrs_update
                 (hrs_mem_update (fold (λi. heap_update (p +p int i) c_type_class.zero) [0..<n]) 
                  hrs_htd_update (λ_. d))
                 s)))
          (p +p int i)"
    proof -
      from stack_allocs_cases [OF alloc] i_bound
      have ptr_valid: "ptr_valid d (p +p int i)"
        using root_ptr_valid_ptr_valid by auto
      thus ?thesis
        using heap_typing_upd_commutes by (simp, metis)
    qed
  } note valids = this

  from stack_allocs_cases [OF alloc] obtain
    bound: "unat (ptr_val p) + n * size_of TYPE('a)  addr_card" and
    not_null: "ptr_val p  0"
    by (metis Ptr_ptr_val c_guard_NULL_simp)

  show ?thesis
    apply (simp add: sim_stack_alloc_heap_typing [OF alloc, symmetric])
    apply (subst fold_heap_update_padding_simulation [OF valids lbs, symmetric])
      apply (simp)
     apply (simp add: len)
    apply (simp add: comp_def hrs_mem_update_comp')
    apply (subst fold_heap_update_padding_heap_update_collapse [OF bound not_null])
    using lbs
     apply (auto simp add: less_le_mult_nat nat_move_sub_le)
    done
qed

lemma sim_stack_release':
  fixes p :: "'a ptr"
  assumes roots: "i. i < n  root_ptr_valid (hrs_htd (t_hrs s)) (p +p int i)"
  shows "st (t_hrs_update (hrs_htd_update (stack_releases n p)) s) =
           ((heap_typing_upd (stack_releases n p) ((fold (λi. w (p +p int i) (λ_. c_type_class.zero)) [0..<n]) (st s))))"
proof -
  from roots root_ptr_valid_ptr_valid heap_typing_commutes
  have valids: "i . i < n  ptr_valid (heap_typing (st s)) (p +p int i)"
    by metis
  note commutes = fold_heap_update_simulation [OF valids, symmetric, of n, simplified]
  show ?thesis
    apply (simp add: commutes )
    apply (simp add: sim_stack_release_heap_typing [OF roots])
    done
qed


lemma sim_stack_release'':
  fixes p :: "'a ptr"
  assumes roots: "i. i < n  root_ptr_valid (hrs_htd (t_hrs s)) (p +p int i)"
  assumes lbs: "length bs = n * size_of TYPE('a)"
  shows "st (t_hrs_update (hrs_mem_update (heap_update_list (ptr_val p) bs) o hrs_htd_update (stack_releases n p)) s) =
           ((heap_typing_upd (stack_releases n p) ((fold (λi. w (p +p int i) (λ_. c_type_class.zero)) [0..<n]) (st s))))"
proof -
  {
    fix i
    assume bound_i: "i < length bs"
    have "root_ptr_valid (hrs_htd (t_hrs (t_hrs_update (hrs_htd_update (stack_releases n p)) s)))
            ((PTR_COERCE('a  stack_byte) p) +p int i)"
    proof -
      have span: "ptr_val (((PTR_COERCE('a  stack_byte) p) +p int i))  {ptr_val p..+n * size_of TYPE('a)}"
        using lbs bound_i intvlI by (auto simp add: ptr_add_def)
      from roots have "i<n. c_guard (p +p int i)"
        using root_ptr_valid_c_guard by blast
      from stack_releases_root_ptr_valid_footprint [OF span this]
      show ?thesis
        using typing.get_upd by force
    qed
  } note sb = this

  show ?thesis
    apply (simp add:  lift_heap_update_list_stack_byte_independent [OF sb, simplified])
    apply (simp add: sim_stack_release' [OF roots])
    done
qed


lemma stack_byte_zero':
  assumes "(p, d)  stack_allocs n 𝒮 TYPE('a) (hrs_htd (t_hrs s))"
  assumes "i < n"
  shows "r (st s) (p +p int i) = ZERO('a)"
  by (rule sim_stack_stack_byte_zero)
     (meson assms stack_allocs_cases stack_allocs_contained subsetD)

sublocale stack_simulation
  using sim_stack_alloc' sim_stack_release'' stack_byte_zero'
  by (unfold_locales) auto

sublocale typ_heap_typing_stack_simulation 𝒯 st r w "λt p. open_types.ptr_valid 𝒯 (heap_typing t) p" t_hrs t_hrs_update heap_typing heap_typing_upd 𝒮
  by (unfold_locales)

end





(*
 * Assert the given field ("field_getter", "field_setter") of the given structure
 * can be abstracted into the heap, and then accessed as a HOL object.
 *)

(*
 * This can deal with nested structures, but they must be packed_types.
 * fixme: generalise this framework to mem_types
 *)

definition
  valid_struct_field
    :: "string list
            (('p::xmem_type)  ('f::xmem_type))
            (('f  'f)  ('p  'p))
            ('s  heap_raw_state)
            ((heap_raw_state  heap_raw_state)  's  's)
            bool"
 where
  "valid_struct_field field_name field_getter field_setter t_hrs t_hrs_update 
     (lense field_getter field_setter
       field_ti TYPE('p) field_name =
          Some (adjust_ti (typ_info_t TYPE('f)) field_getter (field_setter  (λx _. x)))
       (p :: 'p ptr. c_guard p  c_guard (Ptr &(pfield_name) :: 'f ptr))
       lense t_hrs t_hrs_update)"

lemma typ_heap_simulation_get_hvalD:
  " typ_heap_simulation st r w v
        t_hrs t_hrs_update; v (st s) p  
      h_val (hrs_mem (t_hrs s)) p = r (st s) p"
  by (clarsimp simp: read_simulation.read_commutes [OF typ_heap_simulation.axioms(2)])

lemma valid_struct_fieldI [intro]:
  fixes field_getter :: "('a::xmem_type)  ('f::xmem_type)"
  shows "
     s f. f (field_getter s) = (field_getter s)  field_setter f s = s;
     s f f'. f (field_getter s) = f' (field_getter s)  field_setter f s = field_setter f' s;
     s f. field_getter (field_setter f s) = f (field_getter s);
     s f g. field_setter f (field_setter g s) = field_setter (f  g) s;
     field_ti TYPE('a) field_name =
         Some (adjust_ti (typ_info_t TYPE('f)) field_getter (field_setter  (λx _. x)));
     (p::'a ptr). c_guard p  c_guard (Ptr &(pfield_name) :: 'f ptr);
     s x. t_hrs (t_hrs_update x s) = x (t_hrs s);
     s f. f (t_hrs s) = t_hrs s  t_hrs_update f s = s;
     s f f'. f (t_hrs s) = f' (t_hrs s)  t_hrs_update f s = t_hrs_update f' s;
     s f g. t_hrs_update f (t_hrs_update g s) = t_hrs_update (λx. f (g x)) s
       
    valid_struct_field field_name field_getter field_setter t_hrs t_hrs_update"
  apply (unfold valid_struct_field_def lense_def o_def)
  apply (safe | assumption | rule ext)+
  done

lemma typ_heap_simulation_t_hrs_updateD:
  " typ_heap_simulation st r w v
         t_hrs t_hrs_update; v (st s) p  
           st (t_hrs_update (hrs_mem_update (heap_update p v')) s) =
                           w p (λx. v') (st s)"
  by (clarsimp simp: write_simulation.write_commutes [OF typ_heap_simulation.axioms(3)])

lemma heap_abs_expr_guard [heap_abs]:
  " typ_heap_simulation st getter setter vgetter t_hrs t_hrs_update;
     abs_expr st P x' x  
     abs_guard st (λs. P s  vgetter s (x' s)) (λs. (c_guard (x s :: ('a::xmem_type) ptr)))"
  apply (clarsimp simp: abs_expr_def abs_guard_def
                        simple_lift_def root_ptr_valid_def
                        valid_implies_cguard.valid_implies_c_guard [OF typ_heap_simulation.axioms(5)])
  done

lemma heap_abs_expr_h_val [heap_abs]:
  " typ_heap_simulation st r w v t_hrs t_hrs_update;
     abs_expr st P x' x  
      abs_expr st
       (λs. P s  v s (x' s))
         (λs. (r s (x' s)))
         (λs. (h_val (hrs_mem (t_hrs s))) (x s))"
  apply (clarsimp simp: abs_expr_def simple_lift_def)
  apply (metis typ_heap_simulation_get_hvalD)
  done

lemma heap_abs_modifies_heap_update__unused:
  " typ_heap_simulation st r w v t_hrs t_hrs_update;
     abs_expr st Pb b' b;
     abs_expr st Pc c' c  
      abs_modifies st (λs. Pb s  Pc s  v s (b' s))
        (λs. w (b' s) (λx. (c' s)) s)
        (λs. t_hrs_update (hrs_mem_update (heap_update (b s :: ('a::xmem_type) ptr) (c s))) s)"
  apply (clarsimp simp: typ_simple_heap_simps abs_expr_def abs_modifies_def)
  apply (metis typ_heap_simulation_t_hrs_updateD)
  done

(* See comment for heap_lift__wrap_h_val. *)
definition "heap_lift__h_val  h_val"

(* See the comment for struct_rewrite_modifies_field.
 * In this case we rely on nice unification for ?c.
 * The heap_abs_syntax generator also relies on this rule
 * and would need to be modified if the previous rule was used instead. *)
(*        (λs. setter (λx. x(b' s := c' (x (b' s)) s)) s) *)
lemma heap_abs_modifies_heap_update [heap_abs]:
  " typ_heap_simulation st r w v t_hrs t_hrs_update;
     abs_expr st Pb b' b;
     v. abs_expr st Pc (c' v) (c v)  
      abs_modifies st (λs. Pb s  Pc s  v s (b' s))
        (λs. w (b' s) (λ_. (c' (r s (b' s)) s)) s)
        (λs. t_hrs_update (hrs_mem_update
               (heap_update (b s :: ('a::xmem_type) ptr)
                            (c (heap_lift__h_val (hrs_mem (t_hrs s)) (b s)) s))) s)"
  apply (clarsimp simp: typ_simple_heap_simps abs_expr_def abs_modifies_def heap_lift__h_val_def)
  subgoal for s
    apply (rule subst[where t = "h_val (hrs_mem (t_hrs s)) (b' (st s))"
        and s = "r (st s) (b' (st s))"])
     apply (clarsimp simp: read_simulation.read_commutes [OF typ_heap_simulation.axioms(2)])
    apply (simp add: write_simulation.write_commutes [OF typ_heap_simulation.axioms(3)])
    done
  done


(*
 * struct_rewrite: remove uses of field_lvalue. (field_lvalue p a = &(p→a))
 * We do three transformations:
 *   c_guard (p→a)  ⟸  c_guard p
 *   h_val s (p→a)   =   p_C.a_C (h_val s p)
 *   heap_update (p→a) v s   =   heap_update p (p_C.a_C_update (λ_. v) (h_val s p)) s
 * However, an inner expression may nest h_vals arbitrarily.
 *
 * Any output of a struct_rewrite rule should be fully rewritten.
 * By doing this, each rule only needs to rewrite the parts of a term that it
 * introduces by itself.
 *)

(* struct_rewrite_guard rules *)

lemma struct_rewrite_guard_expr [heap_abs]:
  "struct_rewrite_expr P a' a  struct_rewrite_guard (λs. P s  a' s) a"
  by (simp add: struct_rewrite_expr_def struct_rewrite_guard_def)

lemma struct_rewrite_guard_constant [heap_abs]:
  "struct_rewrite_guard (λ_. P) (λ_. P)"
  by (simp add: struct_rewrite_guard_def)

lemma struct_rewrite_guard_conj [heap_abs]:
  " struct_rewrite_guard b' b; struct_rewrite_guard a' a  
       struct_rewrite_guard (λs. a' s  b' s) (λs. a s  b s)"
  by (clarsimp simp: struct_rewrite_guard_def)

lemma struct_rewrite_guard_split [heap_abs]:
  " a b. struct_rewrite_guard (A a b) (C a b) 
        struct_rewrite_guard (case r of (a, b)  A a b) (case r of (a, b)  C a b)"
  apply (auto simp: split_def)
  done

lemma struct_rewrite_guard_c_guard_field [heap_abs]:
  " valid_struct_field field_name (field_getter :: ('a :: xmem_type)  ('f :: xmem_type)) field_setter t_hrs t_hrs_update;
     struct_rewrite_expr P p' p;
     struct_rewrite_guard Q (λs. c_guard (p'