Theory Refine_Imperative_HOL.Sepref_HOL_Bindings

section ‹HOL Setup›
theory Sepref_HOL_Bindings
imports Sepref_Tool
begin

subsection ‹Assertion Annotation›
text ‹Annotate an assertion to a term. The term must then be refined with this assertion.›
(* TODO: Version for monadic expressions.*)
definition ASSN_ANNOT :: "('a  'ai  assn)  'a  'a" where [simp]: "ASSN_ANNOT A x  x"
context fixes A :: "'a  'ai  assn" begin
  sepref_register "PR_CONST (ASSN_ANNOT A)"
  lemma [def_pat_rules]: "ASSN_ANNOT$A  UNPROTECT (ASSN_ANNOT A)" by simp
  lemma [sepref_fr_rules]: "(return o (λx. x), RETURN o PR_CONST (ASSN_ANNOT A))  AdaA"
    by sepref_to_hoare sep_auto
end  

lemma annotate_assn: "x  ASSN_ANNOT A x" by simp

subsection ‹Shortcuts›
abbreviation (input) "nat_assn  (id_assn::nat  _)"
abbreviation (input) "int_assn  (id_assn::int  _)"
abbreviation (input) "bool_assn  (id_assn::bool  _)"

subsection ‹Identity Relations›
definition "IS_ID R  R=Id"
definition "IS_BELOW_ID R  RId"

lemma [safe_constraint_rules]: 
  "IS_ID Id"
  "IS_ID R1  IS_ID R2  IS_ID (R1  R2)"
  "IS_ID R  IS_ID (Roption_rel)"
  "IS_ID R  IS_ID (Rlist_rel)"
  "IS_ID R1  IS_ID R2  IS_ID (R1 ×r R2)"
  "IS_ID R1  IS_ID R2  IS_ID (R1,R2sum_rel)"
  by (auto simp: IS_ID_def)

lemma [safe_constraint_rules]: 
  "IS_BELOW_ID Id"
  "IS_BELOW_ID R  IS_BELOW_ID (Roption_rel)"
  "IS_BELOW_ID R1  IS_BELOW_ID R2  IS_BELOW_ID (R1 ×r R2)"
  "IS_BELOW_ID R1  IS_BELOW_ID R2  IS_BELOW_ID (R1,R2sum_rel)"
  by (auto simp: IS_ID_def IS_BELOW_ID_def option_rel_def sum_rel_def list_rel_def)

lemma IS_BELOW_ID_fun_rel_aux: "R1Id  IS_BELOW_ID R2  IS_BELOW_ID (R1  R2)"
  by (auto simp: IS_BELOW_ID_def dest: fun_relD)

corollary IS_BELOW_ID_fun_rel[safe_constraint_rules]: 
  "IS_ID R1  IS_BELOW_ID R2  IS_BELOW_ID (R1  R2)"
  using IS_BELOW_ID_fun_rel_aux[of Id R2]
  by (auto simp: IS_ID_def)


lemma IS_BELOW_ID_list_rel[safe_constraint_rules]: 
  "IS_BELOW_ID R  IS_BELOW_ID (Rlist_rel)"
  unfolding IS_BELOW_ID_def
proof safe
  fix l l'
  assume A: "RId" 
  assume "(l,l')Rlist_rel"
  thus "l=l'"
    apply induction
    using A by auto
qed

lemma IS_ID_imp_BELOW_ID[constraint_rules]: 
  "IS_ID R  IS_BELOW_ID R"
  by (auto simp: IS_ID_def IS_BELOW_ID_def )



subsection ‹Inverse Relation›

lemma inv_fun_rel_eq[simp]: "(AB)¯ = A¯B¯"
  by (auto dest: fun_relD)

lemma inv_option_rel_eq[simp]: "(Koption_rel)¯ = K¯option_rel"
  by (auto simp: option_rel_def)

lemma inv_prod_rel_eq[simp]: "(P ×r Q)¯ = P¯ ×r Q¯"
  by (auto)

lemma inv_sum_rel_eq[simp]: "(P,Qsum_rel)¯ = P¯,Q¯sum_rel"
  by (auto simp: sum_rel_def)

lemma inv_list_rel_eq[simp]: "(Rlist_rel)¯ = R¯list_rel"
  unfolding list_rel_def
  apply safe
  apply (subst list.rel_flip[symmetric])
  apply (simp add: conversep_iff[abs_def])
  apply (subst list.rel_flip[symmetric])
  apply (simp add: conversep_iff[abs_def])
  done

lemmas [constraint_simps] =
  Relation.converse_Id
  inv_fun_rel_eq
  inv_option_rel_eq
  inv_prod_rel_eq
  inv_sum_rel_eq
  inv_list_rel_eq


subsection ‹Single Valued and Total Relations›

(* TODO: Link to other such theories: Transfer, Autoref *)
definition "IS_LEFT_UNIQUE R  single_valued (R¯)"
definition "IS_LEFT_TOTAL R  Domain R = UNIV"
definition "IS_RIGHT_TOTAL R  Range R = UNIV"
abbreviation (input) "IS_RIGHT_UNIQUE  single_valued"

lemmas IS_RIGHT_UNIQUED = single_valuedD
lemma IS_LEFT_UNIQUED: "IS_LEFT_UNIQUE r; (y, x)  r; (z, x)  r  y = z"
  by (auto simp: IS_LEFT_UNIQUE_def dest: single_valuedD)

lemma prop2p:
  "IS_LEFT_UNIQUE R = left_unique (rel2p R)"
  "IS_RIGHT_UNIQUE R = right_unique (rel2p R)"
  "right_unique (rel2p (R¯)) = left_unique (rel2p R)"
  "IS_LEFT_TOTAL R = left_total (rel2p R)"
  "IS_RIGHT_TOTAL R = right_total (rel2p R)"
  by (auto 
    simp: IS_LEFT_UNIQUE_def left_unique_def single_valued_def
    simp: right_unique_def
    simp: IS_LEFT_TOTAL_def left_total_def
    simp: IS_RIGHT_TOTAL_def right_total_def
    simp: rel2p_def
    )

lemma p2prop:
  "left_unique P = IS_LEFT_UNIQUE (p2rel P)"
  "right_unique P = IS_RIGHT_UNIQUE (p2rel P)"
  "left_total P = IS_LEFT_TOTAL (p2rel P)"
  "right_total P = IS_RIGHT_TOTAL (p2rel P)"
  "bi_unique P  left_unique P  right_unique P"
  by (auto 
    simp: IS_LEFT_UNIQUE_def left_unique_def single_valued_def
    simp: right_unique_def bi_unique_alt_def
    simp: IS_LEFT_TOTAL_def left_total_def
    simp: IS_RIGHT_TOTAL_def right_total_def
    simp: p2rel_def
    )

lemmas [safe_constraint_rules] = 
  single_valued_Id  
  prod_rel_sv 
  list_rel_sv 
  option_rel_sv 
  sum_rel_sv

lemma [safe_constraint_rules]:
  "IS_LEFT_UNIQUE Id"
  "IS_LEFT_UNIQUE R1  IS_LEFT_UNIQUE R2  IS_LEFT_UNIQUE (R1×rR2)"
  "IS_LEFT_UNIQUE R1  IS_LEFT_UNIQUE R2  IS_LEFT_UNIQUE (R1,R2sum_rel)"
  "IS_LEFT_UNIQUE R  IS_LEFT_UNIQUE (Roption_rel)"
  "IS_LEFT_UNIQUE R  IS_LEFT_UNIQUE (Rlist_rel)"
  by (auto simp: IS_LEFT_UNIQUE_def prod_rel_sv sum_rel_sv option_rel_sv list_rel_sv)

lemma IS_LEFT_TOTAL_alt: "IS_LEFT_TOTAL R  (x. y. (x,y)R)"
  by (auto simp: IS_LEFT_TOTAL_def)

lemma IS_RIGHT_TOTAL_alt: "IS_RIGHT_TOTAL R  (x. y. (y,x)R)"
  by (auto simp: IS_RIGHT_TOTAL_def)

lemma [safe_constraint_rules]:
  "IS_LEFT_TOTAL Id"
  "IS_LEFT_TOTAL R1  IS_LEFT_TOTAL R2  IS_LEFT_TOTAL (R1×rR2)"
  "IS_LEFT_TOTAL R1  IS_LEFT_TOTAL R2  IS_LEFT_TOTAL (R1,R2sum_rel)"
  "IS_LEFT_TOTAL R  IS_LEFT_TOTAL (Roption_rel)"
  apply (auto simp: IS_LEFT_TOTAL_alt sum_rel_def option_rel_def list_rel_def)
  apply (rename_tac x; case_tac x; auto)
  apply (rename_tac x; case_tac x; auto)
  done

lemma [safe_constraint_rules]: "IS_LEFT_TOTAL R  IS_LEFT_TOTAL (Rlist_rel)"
  unfolding IS_LEFT_TOTAL_alt
proof safe
  assume A: "x.y. (x,y)R"
  fix l
  show "l'. (l,l')Rlist_rel"
    apply (induction l)
    using A
    by (auto simp: list_rel_split_right_iff)
qed

lemma [safe_constraint_rules]:
  "IS_RIGHT_TOTAL Id"
  "IS_RIGHT_TOTAL R1  IS_RIGHT_TOTAL R2  IS_RIGHT_TOTAL (R1×rR2)"
  "IS_RIGHT_TOTAL R1  IS_RIGHT_TOTAL R2  IS_RIGHT_TOTAL (R1,R2sum_rel)"
  "IS_RIGHT_TOTAL R  IS_RIGHT_TOTAL (Roption_rel)"
  apply (auto simp: IS_RIGHT_TOTAL_alt sum_rel_def option_rel_def) []
  apply (auto simp: IS_RIGHT_TOTAL_alt sum_rel_def option_rel_def) []
  apply (auto simp: IS_RIGHT_TOTAL_alt sum_rel_def option_rel_def) []
  apply (rename_tac x; case_tac x; auto)
  apply (clarsimp simp: IS_RIGHT_TOTAL_alt option_rel_def)
  apply (rename_tac x; case_tac x; auto)
  done

lemma [safe_constraint_rules]: "IS_RIGHT_TOTAL R  IS_RIGHT_TOTAL (Rlist_rel)"
  unfolding IS_RIGHT_TOTAL_alt
proof safe
  assume A: "x.y. (y,x)R"
  fix l
  show "l'. (l',l)Rlist_rel"
    apply (induction l)
    using A
    by (auto simp: list_rel_split_left_iff)
qed
  
lemma [constraint_simps]:
  "IS_LEFT_TOTAL (R¯)  IS_RIGHT_TOTAL R "
  "IS_RIGHT_TOTAL (R¯)  IS_LEFT_TOTAL R  "
  "IS_LEFT_UNIQUE (R¯)  IS_RIGHT_UNIQUE R"
  "IS_RIGHT_UNIQUE (R¯)  IS_LEFT_UNIQUE R "
  by (auto simp: IS_RIGHT_TOTAL_alt IS_LEFT_TOTAL_alt IS_LEFT_UNIQUE_def)

lemma [safe_constraint_rules]:
  "IS_RIGHT_UNIQUE A  IS_RIGHT_TOTAL B  IS_RIGHT_TOTAL (AB)"
  "IS_RIGHT_TOTAL A  IS_RIGHT_UNIQUE B  IS_RIGHT_UNIQUE (AB)"
  "IS_LEFT_UNIQUE A  IS_LEFT_TOTAL B  IS_LEFT_TOTAL (AB)"
  "IS_LEFT_TOTAL A  IS_LEFT_UNIQUE B  IS_LEFT_UNIQUE (AB)"
  apply (simp_all add: prop2p rel2p)
  (*apply transfer_step TODO: Isabelle 2016 *)
  apply (blast intro!: transfer_raw)+
  done

lemma [constraint_rules]: 
  "IS_BELOW_ID R  IS_RIGHT_UNIQUE R"
  "IS_BELOW_ID R  IS_LEFT_UNIQUE R"
  "IS_ID R  IS_RIGHT_TOTAL R"
  "IS_ID R  IS_LEFT_TOTAL R"
  by (auto simp: IS_BELOW_ID_def IS_ID_def IS_LEFT_UNIQUE_def IS_RIGHT_TOTAL_def IS_LEFT_TOTAL_def
    intro: single_valuedI)

thm constraint_rules

subsubsection ‹Additional Parametricity Lemmas›
(* TODO: Move. Problem: Depend on IS_LEFT_UNIQUE, which has to be moved to!*)

lemma param_distinct[param]: "IS_LEFT_UNIQUE A; IS_RIGHT_UNIQUE A  (distinct, distinct)  Alist_rel  bool_rel"  
  apply (fold rel2p_def)
  apply (simp add: rel2p)
  apply (rule distinct_transfer)
  apply (simp add: p2prop)
  done

lemma param_Image[param]: 
  assumes "IS_LEFT_UNIQUE A" "IS_RIGHT_UNIQUE A"
  shows "((``), (``))  A×rBset_rel  Aset_rel  Bset_rel"
  apply (clarsimp simp: set_rel_def; intro conjI)  
  apply (fastforce dest: IS_RIGHT_UNIQUED[OF assms(2)])
  apply (fastforce dest: IS_LEFT_UNIQUED[OF assms(1)])
  done

lemma pres_eq_iff_svb: "((=),(=))KKbool_rel  (single_valued K  single_valued (K¯))"
  apply (safe intro!: single_valuedI)
  apply (metis (full_types) IdD fun_relD1)
  apply (metis (full_types) IdD fun_relD1)
  by (auto dest: single_valuedD)

definition "IS_PRES_EQ R  ((=), (=))RRbool_rel"
lemma [constraint_rules]: "single_valued R; single_valued (R¯)  IS_PRES_EQ R"
  by (simp add: pres_eq_iff_svb IS_PRES_EQ_def)


subsection ‹Bounded Assertions›
definition "b_rel R P  R  UNIV×Collect P"
definition "b_assn A P  λx y. A x y * (P x)"

lemma b_assn_pure_conv[constraint_simps]: "b_assn (pure R) P = pure (b_rel R P)"
  by (auto intro!: ext simp: b_rel_def b_assn_def pure_def)
lemmas [sepref_import_rewrite, sepref_frame_normrel_eqs, fcomp_norm_unfold] 
  = b_assn_pure_conv[symmetric]

lemma b_rel_nesting[simp]: 
  "b_rel (b_rel R P1) P2 = b_rel R (λx. P1 x  P2 x)"
  by (auto simp: b_rel_def)
lemma b_rel_triv[simp]: 
  "b_rel R (λ_. True) = R"
  by (auto simp: b_rel_def)
lemma b_assn_nesting[simp]: 
  "b_assn (b_assn A P1) P2 = b_assn A (λx. P1 x  P2 x)"
  by (auto simp: b_assn_def pure_def intro!: ext)
lemma b_assn_triv[simp]: 
  "b_assn A (λ_. True) = A"
  by (auto simp: b_assn_def pure_def intro!: ext)

lemmas [simp,constraint_simps,sepref_import_rewrite, sepref_frame_normrel_eqs, fcomp_norm_unfold]
  = b_rel_nesting b_assn_nesting

lemma b_rel_simp[simp]: "(x,y)b_rel R P  (x,y)R  P y"
  by (auto simp: b_rel_def)

lemma b_assn_simp[simp]: "b_assn A P x y = A x y * (P x)"
  by (auto simp: b_assn_def)

lemma b_rel_Range[simp]: "Range (b_rel R P) = Range R  Collect P" by auto
lemma b_assn_rdom[simp]: "rdomp (b_assn R P) x  rdomp R x  P x"
  by (auto simp: rdomp_def)


lemma b_rel_below_id[constraint_rules]: 
  "IS_BELOW_ID R  IS_BELOW_ID (b_rel R P)"
  by (auto simp: IS_BELOW_ID_def)

lemma b_rel_left_unique[constraint_rules]: 
  "IS_LEFT_UNIQUE R  IS_LEFT_UNIQUE (b_rel R P)"
  by (auto simp: IS_LEFT_UNIQUE_def single_valued_def)
  
lemma b_rel_right_unique[constraint_rules]: 
  "IS_RIGHT_UNIQUE R  IS_RIGHT_UNIQUE (b_rel R P)"
  by (auto simp: single_valued_def)

― ‹Registered as safe rule, although may loose information in the 
    odd case that purity depends condition.›
lemma b_assn_is_pure[safe_constraint_rules]:
  "is_pure A  is_pure (b_assn A P)"
  by (auto simp: is_pure_conv b_assn_pure_conv)

― ‹Most general form›
lemma b_assn_subtyping_match[sepref_frame_match_rules]:
  assumes "hn_ctxt (b_assn A P) x y t hn_ctxt A' x y"
  assumes "vassn_tag (hn_ctxt A x y); vassn_tag (hn_ctxt A' x y); P x  P' x"
  shows "hn_ctxt (b_assn A P) x y t hn_ctxt (b_assn A' P') x y"
  using assms
  unfolding hn_ctxt_def b_assn_def entailst_def entails_def
  by (fastforce simp: vassn_tag_def mod_star_conv)
  
― ‹Simplified forms:›
lemma b_assn_subtyping_match_eqA[sepref_frame_match_rules]:
  assumes "vassn_tag (hn_ctxt A x y); P x  P' x"
  shows "hn_ctxt (b_assn A P) x y t hn_ctxt (b_assn A P') x y"
  apply (rule b_assn_subtyping_match)
  subgoal 
    unfolding hn_ctxt_def b_assn_def entailst_def entails_def
    by (fastforce simp: vassn_tag_def mod_star_conv)
  subgoal
    using assms .
  done  

lemma b_assn_subtyping_match_tR[sepref_frame_match_rules]:
  assumes "P x  hn_ctxt A x y t hn_ctxt A' x y"
  shows "hn_ctxt (b_assn A P) x y t hn_ctxt A' x y"
  using assms
  unfolding hn_ctxt_def b_assn_def entailst_def entails_def
  by (fastforce simp: vassn_tag_def mod_star_conv)

lemma b_assn_subtyping_match_tL[sepref_frame_match_rules]:
  assumes "hn_ctxt A x y t hn_ctxt A' x y"
  assumes "vassn_tag (hn_ctxt A x y)  P' x"
  shows "hn_ctxt A x y t hn_ctxt (b_assn A' P') x y"
  using assms
  unfolding hn_ctxt_def b_assn_def entailst_def entails_def
  by (fastforce simp: vassn_tag_def mod_star_conv)


lemma b_assn_subtyping_match_eqA_tR[sepref_frame_match_rules]: 
  "hn_ctxt (b_assn A P) x y t hn_ctxt A x y"
  unfolding hn_ctxt_def b_assn_def
  by (sep_auto intro!: enttI)

lemma b_assn_subtyping_match_eqA_tL[sepref_frame_match_rules]:
  assumes "vassn_tag (hn_ctxt A x y)  P' x"
  shows "hn_ctxt A x y t hn_ctxt (b_assn A P') x y"
  using assms
  unfolding hn_ctxt_def b_assn_def entailst_def entails_def
  by (fastforce simp: vassn_tag_def mod_star_conv)

― ‹General form›
lemma b_rel_subtyping_merge[sepref_frame_merge_rules]:
  assumes "hn_ctxt A x y A hn_ctxt A' x y t hn_ctxt Am x y"
  shows "hn_ctxt (b_assn A P) x y A hn_ctxt (b_assn A' P') x y t hn_ctxt (b_assn Am (λx. P x  P' x)) x y"
  using assms
  unfolding hn_ctxt_def b_assn_def entailst_def entails_def
  by (fastforce simp: vassn_tag_def)
  
― ‹Simplified forms›
lemma b_rel_subtyping_merge_eqA[sepref_frame_merge_rules]:
  shows "hn_ctxt (b_assn A P) x y A hn_ctxt (b_assn A P') x y t hn_ctxt (b_assn A (λx. P x  P' x)) x y"
  apply (rule b_rel_subtyping_merge)
  by simp

lemma b_rel_subtyping_merge_tL[sepref_frame_merge_rules]:
  assumes "hn_ctxt A x y A hn_ctxt A' x y t hn_ctxt Am x y"
  shows "hn_ctxt A x y A hn_ctxt (b_assn A' P') x y t hn_ctxt Am x y"
  using b_rel_subtyping_merge[of A x y A' Am "λ_. True" P', simplified] assms .

lemma b_rel_subtyping_merge_tR[sepref_frame_merge_rules]:
  assumes "hn_ctxt A x y A hn_ctxt A' x y t hn_ctxt Am x y"
  shows "hn_ctxt (b_assn A P) x y A hn_ctxt A' x y t hn_ctxt Am x y"
  using b_rel_subtyping_merge[of A x y A' Am P "λ_. True", simplified] assms .

lemma b_rel_subtyping_merge_eqA_tL[sepref_frame_merge_rules]:
  shows "hn_ctxt A x y A hn_ctxt (b_assn A P') x y t hn_ctxt A x y"
  using b_rel_subtyping_merge_eqA[of A "λ_. True" x y P', simplified] .

lemma b_rel_subtyping_merge_eqA_tR[sepref_frame_merge_rules]:
  shows "hn_ctxt (b_assn A P) x y A hn_ctxt A x y t hn_ctxt A x y"
  using b_rel_subtyping_merge_eqA[of A P x y "λ_. True", simplified] .

(* TODO: Combinatorial explosion :( *)
lemma b_assn_invalid_merge1: "hn_invalid (b_assn A P) x y A hn_invalid (b_assn A P') x y
  t hn_invalid (b_assn A (λx. P x  P' x)) x y"
  by (sep_auto simp: hn_ctxt_def invalid_assn_def entailst_def)

lemma b_assn_invalid_merge2: "hn_invalid (b_assn A P) x y A hn_invalid A x y
  t hn_invalid A x y"
  by (sep_auto simp: hn_ctxt_def invalid_assn_def entailst_def)
lemma b_assn_invalid_merge3: "hn_invalid A x y A hn_invalid (b_assn A P) x y
  t hn_invalid A x y"
  by (sep_auto simp: hn_ctxt_def invalid_assn_def entailst_def)

lemma b_assn_invalid_merge4: "hn_invalid (b_assn A P) x y A hn_ctxt (b_assn A P') x y
  t hn_invalid (b_assn A (λx. P x  P' x)) x y"
  by (sep_auto simp: hn_ctxt_def invalid_assn_def entailst_def)
lemma b_assn_invalid_merge5: "hn_ctxt (b_assn A P') x y A hn_invalid (b_assn A P) x y
  t hn_invalid (b_assn A (λx. P x  P' x)) x y"
  by (sep_auto simp: hn_ctxt_def invalid_assn_def entailst_def)

lemma b_assn_invalid_merge6: "hn_invalid (b_assn A P) x y A hn_ctxt A x y
  t hn_invalid A x y"
  by (sep_auto simp: hn_ctxt_def invalid_assn_def entailst_def)
lemma b_assn_invalid_merge7: "hn_ctxt A x y A hn_invalid (b_assn A P) x y
  t hn_invalid A x y"
  by (sep_auto simp: hn_ctxt_def invalid_assn_def entailst_def)

lemma b_assn_invalid_merge8: "hn_ctxt (b_assn A P) x y A hn_invalid A x y
  t hn_invalid A x y"
  by (sep_auto simp: hn_ctxt_def invalid_assn_def entailst_def)
lemma b_assn_invalid_merge9: "hn_invalid A x y A hn_ctxt (b_assn A P) x y
  t hn_invalid A x y"
  by (sep_auto simp: hn_ctxt_def invalid_assn_def entailst_def)

lemmas b_assn_invalid_merge[sepref_frame_merge_rules] = 
  b_assn_invalid_merge1
  b_assn_invalid_merge2
  b_assn_invalid_merge3
  b_assn_invalid_merge4
  b_assn_invalid_merge5
  b_assn_invalid_merge6
  b_assn_invalid_merge7
  b_assn_invalid_merge8
  b_assn_invalid_merge9




(*
lemma list_rel_b_id: "∀x∈set l. B x ⟹ (l,l)∈⟨b_rel B⟩list_rel"
  by (induction l) auto
*)


abbreviation nbn_rel :: "nat  (nat × nat) set" 
  ― ‹Natural numbers with upper bound.›
  where "nbn_rel n  b_rel nat_rel (λx::nat. x<n)"  

abbreviation nbn_assn :: "nat  nat  nat  assn" 
  ― ‹Natural numbers with upper bound.›
  where "nbn_assn n  b_assn nat_assn (λx::nat. x<n)"  

(*
subsection ‹Bounded Identity Relations›
definition "b_rel B ≡ {(x,x) | x. B x}"

lemma b_rel_simp[simp]: "(x,y)∈b_rel B ⟷ x=y ∧ B y"
  by (auto simp: b_rel_def)

lemma b_rel_Range[simp]: "Range (b_rel B) = Collect B" by auto

lemma b_rel_below_id[safe_constraint_rules]: "IS_BELOW_ID (b_rel B)"
  by (auto simp: IS_BELOW_ID_def)

lemma list_rel_b_id: "∀x∈set l. B x ⟹ (l,l)∈⟨b_rel B⟩list_rel"
  by (induction l) auto

lemma b_rel_subtyping_match[sepref_frame_match_rules]:
  "P x ⟹ hn_val Id x y ⟹t hn_val (b_rel P) x y"
  "⟦P1 x ⟹ P2 x⟧ ⟹ hn_val (b_rel P1) x y ⟹t hn_val (b_rel P2) x y"
  "hn_val (b_rel P) x y ⟹t hn_val Id x y"
  by (auto simp: hn_ctxt_def pure_def intro: enttI)

lemma b_rel_subtyping_merge[sepref_frame_merge_rules]:
  "hn_val Id x y ∨A hn_val (b_rel P) x y ⟹t hn_val Id x y"
  "hn_val (b_rel P) x y ∨A hn_val Id x y ⟹t hn_val Id x y"
  "hn_val (b_rel P1) x y ∨A hn_val (b_rel P2) x y ⟹t hn_val (b_rel (λx. P1 x ∨ P2 x)) x y"
  by (auto simp: hn_ctxt_def pure_def intro: enttI)


abbreviation nbn_rel :: "nat ⇒ (nat × nat) set" 
  -- ‹Natural numbers with upper bound.›
  where "nbn_rel n ≡ b_rel (λx::nat. x<n)"  


*)


subsection ‹Tool Setup›
lemmas [sepref_relprops] = 
  sepref_relpropI[of IS_LEFT_UNIQUE]
  sepref_relpropI[of IS_RIGHT_UNIQUE]
  sepref_relpropI[of IS_LEFT_TOTAL]
  sepref_relpropI[of IS_RIGHT_TOTAL]
  sepref_relpropI[of is_pure]
  sepref_relpropI[of "IS_PURE Φ" for Φ]
  sepref_relpropI[of IS_ID]
  sepref_relpropI[of IS_BELOW_ID]
 


lemma [sepref_relprops_simps]:
  "CONSTRAINT (IS_PURE IS_ID) A  CONSTRAINT (IS_PURE IS_BELOW_ID) A"
  "CONSTRAINT (IS_PURE IS_ID) A  CONSTRAINT (IS_PURE IS_LEFT_TOTAL) A"
  "CONSTRAINT (IS_PURE IS_ID) A  CONSTRAINT (IS_PURE IS_RIGHT_TOTAL) A"
  "CONSTRAINT (IS_PURE IS_BELOW_ID) A  CONSTRAINT (IS_PURE IS_LEFT_UNIQUE) A"
  "CONSTRAINT (IS_PURE IS_BELOW_ID) A  CONSTRAINT (IS_PURE IS_RIGHT_UNIQUE) A"
  by (auto 
    simp: IS_ID_def IS_BELOW_ID_def IS_PURE_def IS_LEFT_UNIQUE_def
    simp: IS_LEFT_TOTAL_def IS_RIGHT_TOTAL_def
    simp: single_valued_below_Id)

declare True_implies_equals[sepref_relprops_simps]

lemma [sepref_relprops_transform]: "single_valued (R¯) = IS_LEFT_UNIQUE R"
  by (auto simp: IS_LEFT_UNIQUE_def)


subsection ‹HOL Combinators›
lemma hn_if[sepref_comb_rules]:
  assumes P: "Γ t Γ1 * hn_val bool_rel a a'"
  assumes RT: "a  hn_refine (Γ1 * hn_val bool_rel a a') b' Γ2b R b"
  assumes RE: "¬a  hn_refine (Γ1 * hn_val bool_rel a a') c' Γ2c R c"
  assumes IMP: "TERM If  Γ2b A Γ2c t Γ'"
  shows "hn_refine Γ (if a' then b' else c') Γ' R (If$a$b$c)"
  using P RT RE IMP[OF TERMI]
  unfolding APP_def PROTECT2_def 
  by (rule hnr_If)

lemmas [sepref_opt_simps] = if_True if_False

lemma hn_let[sepref_comb_rules]:
  assumes P: "Γ t Γ1 * hn_ctxt R v v'"
  assumes R: "x x'. x=v  hn_refine (Γ1 * hn_ctxt R x x') (f' x') 
    (Γ' x x') R2 (f x)"
  assumes F: "x x'. Γ' x x' t Γ2 * hn_ctxt R' x x'"
  shows 
    "hn_refine Γ (Let v' f') (Γ2 * hn_ctxt R' v v') R2 (Let$v$(λ2x. f x))"
  apply (rule hn_refine_cons[OF P _ F entt_refl])
  apply (simp)
  apply (rule R)
  by simp

subsection ‹Basic HOL types›

lemma hnr_default[sepref_import_param]: "(default,default)Id" by simp

lemma unit_hnr[sepref_import_param]: "((),())unit_rel" by auto
    
lemmas [sepref_import_param] = 
  param_bool
  param_nat1
  param_int

lemmas [id_rules] = 
  itypeI[Pure.of 0 "TYPE (nat)"]
  itypeI[Pure.of 0 "TYPE (int)"]
  itypeI[Pure.of 1 "TYPE (nat)"]
  itypeI[Pure.of 1 "TYPE (int)"]
  itypeI[Pure.of numeral "TYPE (num  nat)"]
  itypeI[Pure.of numeral "TYPE (num  int)"]
  itype_self[of num.One]
  itype_self[of num.Bit0]
  itype_self[of num.Bit1]

lemma param_min_nat[param,sepref_import_param]: "(min,min)nat_rel  nat_rel  nat_rel" by auto
lemma param_max_nat[param,sepref_import_param]: "(max,max)nat_rel  nat_rel  nat_rel" by auto

lemma param_min_int[param,sepref_import_param]: "(min,min)int_rel  int_rel  int_rel" by auto
lemma param_max_int[param,sepref_import_param]: "(max,max)int_rel  int_rel  int_rel" by auto

lemma uminus_hnr[sepref_import_param]: "(uminus,uminus)int_rel  int_rel" by auto
    
lemma nat_param[param,sepref_import_param]: "(nat,nat)  int_rel  nat_rel" by auto
lemma int_param[param,sepref_import_param]: "(int,int)  nat_rel  int_rel" by auto
      
      
      
subsection "Product"


lemmas [sepref_import_rewrite, sepref_frame_normrel_eqs, fcomp_norm_unfold] = prod_assn_pure_conv[symmetric]

lemma prod_assn_precise[constraint_rules]: 
  "precise P1  precise P2  precise (prod_assn P1 P2)"
  apply rule
  apply (clarsimp simp: prod_assn_def star_assoc)
  apply safe
  apply (erule (1) prec_frame) apply frame_inference+
  apply (erule (1) prec_frame) apply frame_inference+
  done

lemma  
  "precise P1  precise P2  precise (prod_assn P1 P2)" ― ‹Original proof›
  apply rule
  apply (clarsimp simp: prod_assn_def)
proof (rule conjI)
  fix F F' h as a b a' b' ap bp
  assume P1: "precise P1" and P2: "precise P2"
  assume F: "(h, as)  P1 a ap * P2 b bp * F A P1 a' ap * P2 b' bp * F'"

  from F have "(h, as)  P1 a ap * (P2 b bp * F) A P1 a' ap * (P2 b' bp * F')"
    by (simp only: mult.assoc)
  with preciseD[OF P1] show "a=a'" .
  from F have "(h, as)  P2 b bp * (P1 a ap * F) A P2 b' bp * (P1 a' ap * F')"
    by (simp only: mult.assoc[where 'a=assn] mult.commute[where 'a=assn] mult.left_commute[where 'a=assn])
  with preciseD[OF P2] show "b=b'" .
qed

(* TODO Add corresponding rules for other types and add to datatype snippet *)
lemma intf_of_prod_assn[intf_of_assn]:
  assumes "intf_of_assn A TYPE('a)" "intf_of_assn B TYPE('b)"
  shows "intf_of_assn (prod_assn A B) TYPE('a * 'b)"
by simp

lemma pure_prod[constraint_rules]: 
  assumes P1: "is_pure P1" and P2: "is_pure P2"
  shows "is_pure (prod_assn P1 P2)"
proof -
  from P1 obtain P1' where P1': "x x'. P1 x x' = (P1' x x')"
    using is_pureE by blast
  from P2 obtain P2' where P2': "x x'. P2 x x' = (P2' x x')"
    using is_pureE by blast

  show ?thesis proof
    fix x x'
    show "prod_assn P1 P2 x x' =
          (case (x, x') of ((a1, a2), c1, c2)  P1' a1 c1  P2' a2 c2)"
      unfolding prod_assn_def
      apply (simp add: P1' P2' split: prod.split)
      done
  qed
qed

lemma prod_frame_match[sepref_frame_match_rules]:
  assumes "hn_ctxt A (fst x) (fst y) t hn_ctxt A' (fst x) (fst y)"
  assumes "hn_ctxt B (snd x) (snd y) t hn_ctxt B' (snd x) (snd y)"
  shows "hn_ctxt (prod_assn A B) x y t hn_ctxt (prod_assn A' B') x y"
  apply (cases x; cases y; simp)
  apply (simp add: hn_ctxt_def)
  apply (rule entt_star_mono)
  using assms apply (auto simp: hn_ctxt_def)
  done

lemma prod_frame_merge[sepref_frame_merge_rules]:   
  assumes "hn_ctxt A (fst x) (fst y) A hn_ctxt A' (fst x) (fst y) t hn_ctxt Am (fst x) (fst y)"
  assumes "hn_ctxt B (snd x) (snd y) A hn_ctxt B' (snd x) (snd y) t hn_ctxt Bm (snd x) (snd y)"
  shows "hn_ctxt (prod_assn A B) x y A hn_ctxt (prod_assn A' B') x y t hn_ctxt (prod_assn Am Bm) x y"
  by (blast intro: entt_disjE prod_frame_match 
    entt_disjD1[OF assms(1)] entt_disjD2[OF assms(1)]
    entt_disjD1[OF assms(2)] entt_disjD2[OF assms(2)])
  
lemma entt_invalid_prod: "hn_invalid (prod_assn A B) p p' t hn_ctxt (prod_assn (invalid_assn A) (invalid_assn B)) p p'"
    apply (simp add: hn_ctxt_def invalid_assn_def[abs_def])
    apply (rule enttI)
    apply clarsimp
    apply (cases p; cases p'; auto simp: mod_star_conv pure_def) 
    done

lemmas invalid_prod_merge[sepref_frame_merge_rules] = gen_merge_cons[OF entt_invalid_prod]

lemma prod_assn_ctxt: "prod_assn A1 A2 x y = z  hn_ctxt (prod_assn A1 A2) x y = z"
  by (simp add: hn_ctxt_def)

lemma hn_case_prod'[sepref_prep_comb_rule,sepref_comb_rules]:
  assumes FR: "Γthn_ctxt (prod_assn P1 P2) p' p * Γ1"
  assumes Pair: "a1 a2 a1' a2'. p'=(a1',a2') 
     hn_refine (hn_ctxt P1 a1' a1 * hn_ctxt P2 a2' a2 * Γ1 * hn_invalid (prod_assn P1 P2) p' p) (f a1 a2) 
          (hn_ctxt P1' a1' a1 * hn_ctxt P2' a2' a2 * hn_ctxt XX1 p' p * Γ1') R (f' a1' a2')"
  shows "hn_refine Γ (case_prod f p) (hn_ctxt (prod_assn P1' P2') p' p * Γ1')
    R (case_prod$(λ2a b. f' a b)$p')" (is "?G Γ")
    apply1 (rule hn_refine_cons_pre[OF FR])
    apply1 extract_hnr_invalids
    apply1 (cases p; cases p'; simp add: prod_assn_pair_conv[THEN prod_assn_ctxt])
    apply (rule hn_refine_cons[OF _ Pair _ entt_refl])
    applyS (simp add: hn_ctxt_def)
    applyS simp
    applyS (simp add: hn_ctxt_def entt_fr_refl entt_fr_drop)
    done

lemma hn_case_prod_old:
  assumes P: "ΓtΓ1 * hn_ctxt (prod_assn P1 P2) p' p"
  assumes R: "a1 a2 a1' a2'. p'=(a1',a2') 
     hn_refine (Γ1 * hn_ctxt P1 a1' a1 * hn_ctxt P2 a2' a2 * hn_invalid (prod_assn P1 P2) p' p) (f a1 a2) 
          (Γh a1 a1' a2 a2') R (f' a1' a2')"
  assumes M: "a1 a1' a2 a2'. Γh a1 a1' a2 a2' 
    t Γ' * hn_ctxt P1' a1' a1 * hn_ctxt P2' a2' a2 * hn_ctxt Pxx p' p"
  shows "hn_refine Γ (case_prod f p) (Γ' * hn_ctxt (prod_assn P1' P2') p' p)
    R (case_prod$(λ2a b. f' a b)$p')"
  apply1 (cases p; cases p'; simp)  
  apply1 (rule hn_refine_cons_pre[OF P])
  apply (rule hn_refine_preI)
  apply (simp add: hn_ctxt_def assn_aci)
  apply (rule hn_refine_cons[OF _ R])
  apply1 (rule enttI)
  applyS (sep_auto simp add: hn_ctxt_def invalid_assn_def mod_star_conv)

  applyS simp
  apply1 (rule entt_trans[OF M])
  applyS (sep_auto intro!: enttI simp: hn_ctxt_def)

  applyS simp
  done

lemma hn_Pair[sepref_fr_rules]: "hn_refine 
  (hn_ctxt P1 x1 x1' * hn_ctxt P2 x2 x2')
  (return (x1',x2'))
  (hn_invalid P1 x1 x1' * hn_invalid P2 x2 x2')
  (prod_assn P1 P2)
  (RETURN$(Pair$x1$x2))"
  unfolding hn_refine_def
  apply (sep_auto simp: hn_ctxt_def prod_assn_def)
  apply (rule ent_frame_fwd[OF invalidate_clone'[of P1]], frame_inference)
  apply (rule ent_frame_fwd[OF invalidate_clone'[of P2]], frame_inference)
  apply sep_auto
  done

lemma fst_hnr[sepref_fr_rules]: "(return o fst,RETURN o fst)  (prod_assn A B)d a A"
  by sepref_to_hoare sep_auto
lemma snd_hnr[sepref_fr_rules]: "(return o snd,RETURN o snd)  (prod_assn A B)d a B"
  by sepref_to_hoare sep_auto


lemmas [constraint_simps] = prod_assn_pure_conv
lemmas [sepref_import_param] = param_prod_swap

lemma rdomp_prodD[dest!]: "rdomp (prod_assn A B) (a,b)  rdomp A a  rdomp B b"
  unfolding rdomp_def prod_assn_def
  by (sep_auto simp: mod_star_conv)


subsection "Option"
fun option_assn :: "('a  'c  assn)  'a option  'c option  assn" where
  "option_assn P None None = emp"
| "option_assn P (Some a) (Some c) = P a c"
| "option_assn _ _ _ = false"

lemma option_assn_simps[simp]:
  "option_assn P None v' = (v'=None)"
  "option_assn P v None = (v=None)"
  apply (cases v', simp_all)
  apply (cases v, simp_all)
  done

lemma option_assn_alt_def: "option_assn R a b = 
  (case (a,b) of (Some x, Some y)  R x y
  | (None,None)  emp
  | _  false)"
  by (auto split: option.split)


lemma option_assn_pure_conv[constraint_simps]: "option_assn (pure R) = pure (Roption_rel)"
  apply (intro ext)      
  apply (rename_tac a c)
  apply (case_tac "(pure R,a,c)" rule: option_assn.cases)  
  by (auto simp: pure_def)
                                                
lemmas [sepref_import_rewrite, sepref_frame_normrel_eqs, fcomp_norm_unfold] = option_assn_pure_conv[symmetric]

lemma hr_comp_option_conv[simp, fcomp_norm_unfold]: "
  hr_comp (option_assn R) (R'option_rel) 
  = option_assn (hr_comp R R')"
  unfolding hr_comp_def[abs_def]
  apply (intro ext ent_iffI)
  apply solve_entails
  apply (case_tac "(R,b,c)" rule: option_assn.cases)
  apply clarsimp_all
  
  apply (sep_auto simp: option_assn_alt_def split: option.splits)
  apply (clarsimp simp: option_assn_alt_def split: option.splits; safe)
  apply (sep_auto split: option.splits)
  apply (intro ent_ex_preI) 
  apply (rule ent_ex_postI)
  apply (sep_auto split: option.splits)
  done
      

lemma option_assn_precise[safe_constraint_rules]: 
  assumes "precise P"  
  shows "precise (option_assn P)"
proof
  fix a a' p h F F'
  assume A: "h  option_assn P a p * F A option_assn P a' p * F'"
  thus "a=a'" proof (cases "(P,a,p)" rule: option_assn.cases)
    case (2 _ av pv) hence [simp]: "a=Some av" "p=Some pv" by simp_all

    from A obtain av' where [simp]: "a'=Some av'" by (cases a', simp_all)

    from A have "h  P av pv * F A P av' pv * F'" by simp
    with precise P have "av=av'" by (rule preciseD)
    thus ?thesis by simp
  qed simp_all
qed

lemma pure_option[safe_constraint_rules]: 
  assumes P: "is_pure P"
  shows "is_pure (option_assn P)"
proof -
  from P obtain P' where P': "x x'. P x x' = (P' x x')"
    using is_pureE by blast

  show ?thesis proof
    fix x x'
    show "option_assn P x x' =
          (case (x, x') of 
             (None,None)  True | (Some v, Some v')  P' v v' | _  False
           )"
      apply (simp add: P' split: prod.split option.split)
      done
  qed
qed

lemma hn_ctxt_option: "option_assn A x y = z  hn_ctxt (option_assn A) x y = z"
  by (simp add: hn_ctxt_def)

lemma hn_case_option[sepref_prep_comb_rule, sepref_comb_rules]:
  fixes p p' P
  defines [simp]: "INVE  hn_invalid (option_assn P) p p'"
  assumes FR: "Γ t hn_ctxt (option_assn P) p p' * F"
  assumes Rn: "p=None  hn_refine (hn_ctxt (option_assn P) p p' * F) f1' (hn_ctxt XX1 p p' * Γ1') R f1"
  assumes Rs: "x x'.  p=Some x; p'=Some x'   
    hn_refine (hn_ctxt P x x' * INVE * F) (f2' x') (hn_ctxt P' x x' * hn_ctxt XX2 p p' * Γ2') R (f2 x)"
  assumes MERGE1: "Γ1' A Γ2' t Γ'"  
  shows "hn_refine Γ (case_option f1' f2' p') (hn_ctxt (option_assn P') p p' * Γ') R (case_option$f1$(λ2x. f2 x)$p)"
    apply (rule hn_refine_cons_pre[OF FR])
    apply1 extract_hnr_invalids
    apply (cases p; cases p'; simp add: option_assn.simps[THEN hn_ctxt_option])
    subgoal 
      apply (rule hn_refine_cons[OF _ Rn _ entt_refl]; assumption?)
      applyS (simp add: hn_ctxt_def)

      apply (subst mult.commute, rule entt_fr_drop)
      apply (rule entt_trans[OF _ MERGE1])
      apply (simp add: ent_disjI1' ent_disjI2')
    done  

    subgoal
      apply (rule hn_refine_cons[OF _ Rs _ entt_refl]; assumption?)
      applyS (simp add: hn_ctxt_def)
      apply (rule entt_star_mono)
      apply1 (rule entt_fr_drop)
      applyS (simp add: hn_ctxt_def)
      apply1 (rule entt_trans[OF _ MERGE1])
      applyS (simp add: hn_ctxt_def)
    done
    done

lemma hn_None[sepref_fr_rules]:
  "hn_refine emp (return None) emp (option_assn P) (RETURN$None)"
  by rule sep_auto

lemma hn_Some[sepref_fr_rules]: "hn_refine 
  (hn_ctxt P v v')
  (return (Some v'))
  (hn_invalid P v v')
  (option_assn P)
  (RETURN$(Some$v))"
  by rule (sep_auto simp: hn_ctxt_def invalidate_clone')

definition "imp_option_eq eq a b  case (a,b) of 
  (None,None)  return True
| (Some a, Some b)  eq a b
| _  return False"

(* TODO: This is some kind of generic algorithm! Use GEN_ALGO here, and 
  let GEN_ALGO re-use the registered operator rules *)
lemma option_assn_eq[sepref_comb_rules]:
  fixes a b :: "'a option"
  assumes F1: "Γ t hn_ctxt (option_assn P) a a' * hn_ctxt (option_assn P) b b' * Γ1"
  assumes EQ: "va va' vb vb'. hn_refine 
    (hn_ctxt P va va' * hn_ctxt P vb vb' * Γ1)
    (eq' va' vb') 
    (Γ' va va' vb vb') 
    bool_assn
    (RETURN$((=) $va$vb))"
  assumes F2: 
    "va va' vb vb'. 
      Γ' va va' vb vb' t hn_ctxt P va va' * hn_ctxt P vb vb' * Γ1"
  shows "hn_refine 
    Γ 
    (imp_option_eq eq' a' b') 
    (hn_ctxt (option_assn P) a a' * hn_ctxt (option_assn P) b b' * Γ1)
    bool_assn 
    (RETURN$((=) $a$b))"
  apply (rule hn_refine_cons_pre[OF F1])
  unfolding imp_option_eq_def
  apply rule
  apply (simp split: option.split add: hn_ctxt_def, intro impI conjI)

  apply (sep_auto split: option.split simp: hn_ctxt_def pure_def)
  apply (cases a, (sep_auto split: option.split simp: hn_ctxt_def pure_def)+)[]
  apply (cases a, (sep_auto split: option.split simp: hn_ctxt_def pure_def)+)[]
  apply (cases b, (sep_auto split: option.split simp: hn_ctxt_def pure_def)+)[]
  apply (rule cons_post_rule)
  apply (rule hn_refineD[OF EQ[unfolded hn_ctxt_def]])
  apply simp
  apply (rule ent_frame_fwd[OF F2[THEN enttD,unfolded hn_ctxt_def]])
  apply (fr_rot 2)
  apply (fr_rot_rhs 1)
  apply (rule fr_refl)
  apply (rule ent_refl)
  apply (sep_auto simp: pure_def)
  done

lemma [pat_rules]: 
  "(=) $a$None  is_None$a"
  "(=) $None$a  is_None$a"
  apply (rule eq_reflection, simp split: option.split)+
  done

lemma hn_is_None[sepref_fr_rules]: "hn_refine 
  (hn_ctxt (option_assn P) a a')
  (return (is_None a'))
  (hn_ctxt (option_assn P) a a')
  bool_assn
  (RETURN$(is_None$a))"
  apply rule
  apply (sep_auto split: option.split simp: hn_ctxt_def pure_def)
  done

lemma (in -) sepref_the_complete[sepref_fr_rules]:
  assumes "xNone"
  shows "hn_refine 
    (hn_ctxt (option_assn R) x xi) 
    (return (the xi)) 
    (hn_invalid (option_assn R) x xi)
    (R)
    (RETURN$(the$x))"
    using assms
    apply (cases x)
    apply simp
    apply (cases xi)
    apply (simp add: hn_ctxt_def)
    apply rule
    apply (sep_auto simp: hn_ctxt_def invalidate_clone' vassn_tagI invalid_assn_const)
    done

(* As the sepref_the_complete rule does not work for us 
  --- the assertion ensuring the side-condition gets decoupled from its variable by a copy-operation ---
  we use the following rule that only works for the identity relation *)
lemma (in -) sepref_the_id:
  assumes "CONSTRAINT (IS_PURE IS_ID) R"
  shows "hn_refine 
    (hn_ctxt (option_assn R) x xi) 
    (return (the xi)) 
    (hn_ctxt (option_assn R) x xi)
    (R)
    (RETURN$(the$x))"
    using assms 
    apply (clarsimp simp: IS_PURE_def IS_ID_def hn_ctxt_def is_pure_conv)
    apply (cases x)
    apply simp
    apply (cases xi)
    apply (simp add: hn_ctxt_def invalid_assn_def)
    apply rule apply (sep_auto simp: pure_def)
    apply rule apply (sep_auto)
    apply (simp add: option_assn_pure_conv)
    apply rule apply (sep_auto simp: pure_def invalid_assn_def)
    done


subsection "Lists"

fun list_assn :: "('a  'c  assn)  'a list  'c list  assn" where
  "list_assn P [] [] = emp"
| "list_assn P (a#as) (c#cs) = P a c * list_assn P as cs"
| "list_assn _ _ _ = false"

lemma list_assn_aux_simps[simp]:
  "list_assn P [] l' = ((l'=[]))"
  "list_assn P l [] = ((l=[]))"
  unfolding hn_ctxt_def
  apply (cases l')
  apply simp
  apply simp
  apply (cases l)
  apply simp
  apply simp
  done

lemma list_assn_aux_append[simp]:
  "length l1=length l1'  
    list_assn P (l1@l2) (l1'@l2') 
    = list_assn P l1 l1' * list_assn P l2 l2'"
  apply (induct rule: list_induct2)
  apply simp
  apply (simp add: star_assoc)
  done

lemma list_assn_aux_ineq_len: "length l  length li  list_assn A l li = false"
proof (induction l arbitrary: li)
  case (Cons x l li) thus ?case by (cases li; auto)
qed simp

lemma list_assn_aux_append2[simp]:
  assumes "length l2=length l2'"  
  shows "list_assn P (l1@l2) (l1'@l2') 
    = list_assn P l1 l1' * list_assn P l2 l2'"
  apply (cases "length l1 = length l1'")
  apply (erule list_assn_aux_append)
  apply (simp add: list_assn_aux_ineq_len assms)
  done

lemma list_assn_pure_conv[constraint_simps]: "list_assn (pure R) = pure (Rlist_rel)"
proof (intro ext)
  fix l li
  show "list_assn (pure R) l li = pure (Rlist_rel) l li"
    apply (induction "pure R" l li rule: list_assn.induct)
    by (auto simp: pure_def)
qed

lemmas [sepref_import_rewrite, sepref_frame_normrel_eqs, fcomp_norm_unfold] = list_assn_pure_conv[symmetric]


lemma list_assn_simps[simp]:
  "hn_ctxt (list_assn P) [] l' = ((l'=[]))"
  "hn_ctxt (list_assn P) l [] = ((l=[]))"
  "hn_ctxt (list_assn P) [] [] = emp"
  "hn_ctxt (list_assn P) (a#as) (c#cs) = hn_ctxt P a c * hn_ctxt (list_assn P) as cs"
  "hn_ctxt (list_assn P) (a#as) [] = false"
  "hn_ctxt (list_assn P) [] (c#cs) = false"
  unfolding hn_ctxt_def
  apply (cases l')
  apply simp
  apply simp
  apply (cases l)
  apply simp
  apply simp
  apply simp_all
  done

lemma list_assn_precise[constraint_rules]: "precise P  precise (list_assn P)"
proof
  fix l1 l2 l h F1 F2
  assume P: "precise P"
  assume "hlist_assn P l1 l * F1 A list_assn P l2 l * F2"
  thus "l1=l2"
  proof (induct l arbitrary: l1 l2 F1 F2)
    case Nil thus ?case by simp
  next
    case (Cons a ls)
    from Cons obtain a1 ls1 where [simp]: "l1=a1#ls1"
      by (cases l1, simp)
    from Cons obtain a2 ls2 where [simp]: "l2=a2#ls2"
      by (cases l2, simp)
    
    from Cons.prems have M:
      "h  P a1 a * list_assn P ls1 ls * F1 
        A P a2 a * list_assn P ls2 ls * F2" by simp
    have "a1=a2"
      apply (rule preciseD[OF P, where a=a1 and a'=a2 and p=a
        and F= "list_assn P ls1 ls * F1" 
        and F'="list_assn P ls2 ls * F2"
        ])
      using M
      by (simp add: star_assoc)
    
    moreover have "ls1=ls2"
      apply (rule Cons.hyps[where ?F1.0="P a1 a * F1" and ?F2.0="P a2 a * F2"])
      using M
      by (simp only: star_aci)
    ultimately show ?case by simp
  qed
qed
lemma list_assn_pure[constraint_rules]: 
  assumes P: "is_pure P" 
  shows "is_pure (list_assn P)"
proof -
  from P obtain P' where P_eq: "x x'. P x x' = (P' x x')" 
    by (rule is_pureE) blast

  {
    fix l l'
    have "list_assn P l l' = (list_all2 P' l l')"
      by (induct PP l l' rule: list_assn.induct)
         (simp_all add: P_eq)
  } thus ?thesis by rule
qed

lemma list_assn_mono: 
  "x x'. P x x'AP' x x'  list_assn P l l' A list_assn P' l l'"
  unfolding hn_ctxt_def
  apply (induct P l l' rule: list_assn.induct)
  by (auto intro: ent_star_mono)

lemma list_assn_monot: 
  "x x'. P x x'tP' x x'  list_assn P l l' t list_assn P' l l'"
  unfolding hn_ctxt_def
  apply (induct P l l' rule: list_assn.induct)
  by (auto intro: entt_star_mono)

lemma list_match_cong[sepref_frame_match_rules]: 
  "x x'. xset l; x'set l'  hn_ctxt A x x' t hn_ctxt A' x x'   hn_ctxt (list_assn A) l l' t hn_ctxt (list_assn A') l l'"
  unfolding hn_ctxt_def
  by (induct A l l' rule: list_assn.induct) (simp_all add: entt_star_mono)

lemma list_merge_cong[sepref_frame_merge_rules]:
  assumes "x x'. xset l; x'set l'  hn_ctxt A x x' A hn_ctxt A' x x' t hn_ctxt Am x x'"
  shows "hn_ctxt (list_assn A) l l' A hn_ctxt (list_assn A') l l' t hn_ctxt (list_assn Am) l l'"
  apply (blast intro: entt_disjE list_match_cong entt_disjD1[OF assms] entt_disjD2[OF assms])
  done
  
lemma invalid_list_split: 
  "invalid_assn (list_assn A) (x#xs) (y#ys) t invalid_assn A x y * invalid_assn (list_assn A) xs ys"
  by (fastforce simp: invalid_assn_def intro!: enttI simp: mod_star_conv)

lemma entt_invalid_list: "hn_invalid (list_assn A) l l' t hn_ctxt (list_assn (invalid_assn A)) l l'"
  apply (induct A l l' rule: list_assn.induct)
  applyS simp

  subgoal
    apply1 (simp add: hn_ctxt_def cong del: invalid_assn_cong)
    apply1 (rule entt_trans[OF invalid_list_split])
    apply (rule entt_star_mono)
      applyS simp

      apply (rule entt_trans)
        applyS assumption
        applyS simp
    done
    
  applyS (simp add: hn_ctxt_def invalid_assn_def) 
  applyS (simp add: hn_ctxt_def invalid_assn_def) 
  done

lemmas invalid_list_merge[sepref_frame_merge_rules] = gen_merge_cons[OF entt_invalid_list]


lemma list_assn_comp[fcomp_norm_unfold]: "hr_comp (list_assn A) (Blist_rel) = list_assn (hr_comp A B)"
proof (intro ext)  
  { fix x l y m
    have "hr_comp (list_assn A) (Blist_rel) (x # l) (y # m) = 
      hr_comp A B x y * hr_comp (list_assn A) (Blist_rel) l m"
      by (sep_auto 
        simp: hr_comp_def list_rel_split_left_iff
        intro!: ent_ex_preI ent_iffI) (* TODO: ent_ex_preI should be applied by default, before ent_ex_postI!*)
  } note aux = this

  fix l li
  show "hr_comp (list_assn A) (Blist_rel) l li = list_assn (hr_comp A B) l li"
    apply (induction l arbitrary: li; case_tac li; intro ent_iffI)
    apply (sep_auto simp: hr_comp_def; fail)+
    by (simp_all add: aux)
qed  

lemma hn_ctxt_eq: "A x y = z  hn_ctxt A x y = z" by (simp add: hn_ctxt_def)

lemmas hn_ctxt_list = hn_ctxt_eq[of "list_assn A" for A]

lemma hn_case_list[sepref_prep_comb_rule, sepref_comb_rules]:
  fixes p p' P
  defines [simp]: "INVE  hn_invalid (list_assn P) p p'"
  assumes FR: "Γ t hn_ctxt (list_assn P) p p' * F"
  assumes Rn: "p=[]  hn_refine (hn_ctxt (list_assn P) p p' * F) f1' (hn_ctxt XX1 p p' * Γ1') R f1"
  assumes Rs: "x l x' l'.  p=x#l; p'=x'#l'   
    hn_refine (hn_ctxt P x x' * hn_ctxt (list_assn P) l l' * INVE * F) (f2' x' l') (hn_ctxt P1' x x' * hn_ctxt (list_assn P2') l l' * hn_ctxt XX2 p p' * Γ2') R (f2 x l)"
  assumes MERGE1[unfolded hn_ctxt_def]: "x x'. hn_ctxt P1' x x' A hn_ctxt P2' x x' t hn_ctxt P' x x'"  
  assumes MERGE2: "Γ1' A Γ2' t Γ'"  
  shows "hn_refine Γ (case_list f1' f2' p') (hn_ctxt (list_assn P') p p' * Γ') R (case_list$f1$(λ2x l. f2 x l)$p)"
    apply (rule hn_refine_cons_pre[OF FR])
    apply1 extract_hnr_invalids
    apply (cases p; cases p'; simp add: list_assn.simps[THEN hn_ctxt_list])
    subgoal 
      apply (rule hn_refine_cons[OF _ Rn _ entt_refl]; assumption?)
      applyS (simp add: hn_ctxt_def)

      apply (subst mult.commute, rule entt_fr_drop)
      apply (rule entt_trans[OF _ MERGE2])
      apply (simp add: ent_disjI1' ent_disjI2')
    done  

    subgoal
      apply (rule hn_refine_cons[OF _ Rs _ entt_refl]; assumption?)
      applyS (simp add: hn_ctxt_def)
      apply (rule entt_star_mono)
      apply1 (rule entt_fr_drop)
      apply (rule entt_star_mono)

      apply1 (simp add: hn_ctxt_def)
      apply1 (rule entt_trans[OF _ MERGE1])
      applyS (simp)

      apply1 (simp add: hn_ctxt_def)
      apply (rule list_assn_monot)
      apply1 (rule entt_trans[OF _ MERGE1])
      applyS (simp)

      apply1 (rule entt_trans[OF _ MERGE2])
      applyS (simp)
    done
    done

lemma hn_Nil[sepref_fr_rules]: 
  "hn_refine emp (return []) emp (list_assn P) (RETURN$[])"
  unfolding hn_refine_def
  by sep_auto

lemma hn_Cons[sepref_fr_rules]: "hn_refine (hn_ctxt P x x' * hn_ctxt (list_assn P) xs xs') 
  (return (x'#xs')) (hn_invalid P x x' * hn_invalid (list_assn P) xs xs') (list_assn P)
  (RETURN$((#) $x$xs))"
  unfolding hn_refine_def
  apply (sep_auto simp: hn_ctxt_def)
  apply (rule ent_frame_fwd[OF invalidate_clone'[of P]], frame_inference)
  apply (rule ent_frame_fwd[OF invalidate_clone'[of "list_assn P"]], frame_inference)
  apply solve_entails
  done

lemma list_assn_aux_len: 
  "list_assn P l l' = list_assn P l l' * (length l = length l')"
  apply (induct PP l l' rule: list_assn.induct)
  apply simp_all
  subgoal for a as c cs
    by (erule_tac t="list_assn P as cs" in subst[OF sym]) simp
  done

lemma list_assn_aux_eqlen_simp: 
  "vassn_tag (list_assn P l l')  length l' = length l"
  "h  (list_assn P l l')  length l' = length l"
  apply (subst (asm) list_assn_aux_len; auto simp: vassn_tag_def)+
  done

lemma hn_append[sepref_fr_rules]: "hn_refine (hn_ctxt (list_assn P) l1 l1' * hn_ctxt (list_assn P) l2 l2')
  (return (l1'@l2')) (hn_invalid (list_assn P) l1 l1' * hn_invalid (list_assn P) l2 l2') (list_assn P)
  (RETURN$((@) $l1$l2))"
  apply rule
  apply (sep_auto simp: hn_ctxt_def)
  apply (subst list_assn_aux_len)
  apply (sep_auto)
  apply (rule ent_frame_fwd[OF invalidate_clone'[of "list_assn P"]], frame_inference)
  apply (rule ent_frame_fwd[OF invalidate_clone'[of "list_assn P"]], frame_inference)
  apply solve_entails
  done

lemma list_assn_aux_cons_conv1:
  "list_assn R (a#l) m = (Ab m'. R a b * list_assn R l m' * (m=b#m'))"
  apply (cases m)
  apply sep_auto
  apply (sep_auto intro!: ent_iffI)
  done
lemma list_assn_aux_cons_conv2:
  "list_assn R l (b#m) = (Aa l'. R a b * list_assn R l' m * (l=a#l'))"
  apply (cases l)
  apply sep_auto
  apply (sep_auto intro!: ent_iffI)
  done
lemmas list_assn_aux_cons_conv = list_assn_aux_cons_conv1 list_assn_aux_cons_conv2

lemma list_assn_aux_append_conv1:
  "list_assn R (l1@l2) m = (Am1 m2. list_assn R l1 m1 * list_assn R l2 m2 * (m=m1@m2))"
  apply (induction l1 arbitrary: m)
  apply (sep_auto intro!: ent_iffI)
  apply (sep_auto intro!: ent_iffI simp: list_assn_aux_cons_conv)
  done
lemma list_assn_aux_append_conv2:
  "list_assn R l (m1@m2) = (Al1 l2. list_assn R l1 m1 * list_assn R l2 m2 * (l=l1@l2))"
  apply (induction m1 arbitrary: l)
  apply (sep_auto intro!: ent_iffI)
  apply (sep_auto intro!: ent_iffI simp: list_assn_aux_cons_conv)
  done
lemmas list_assn_aux_append_conv = list_assn_aux_append_conv1 list_assn_aux_append_conv2  

declare param_upt[sepref_import_param]
  
  
subsection ‹Sum-Type›    

fun sum_assn :: "('ai  'a  assn)  ('bi  'b  assn)  ('ai+'bi)  ('a+'b)  assn" where
  "sum_assn A B (Inl ai) (Inl a) = A ai a"
| "sum_assn A B (Inr bi) (Inr b) = B bi b"
| "sum_assn A B _ _ = false"  

notation sum_assn (infixr +a 67)
  
lemma sum_assn_pure[safe_constraint_rules]: "is_pure A; is_pure B  is_pure (sum_assn A B)"
  apply (auto simp: is_pure_iff_pure_assn)
  apply (rename_tac x x')
  apply (case_tac x; case_tac x'; simp add: pure_def)
  done
  
lemma sum_assn_id[simp]: "sum_assn id_assn id_assn = id_assn"
  apply (intro ext)
  subgoal for x y by (cases x; cases y; simp add: pure_def)
  done

lemma sum_assn_pure_conv[simp]: "sum_assn (pure A) (pure B) = pure (A,Bsum_rel)"
  apply (intro ext)
  subgoal for a b by (cases a; cases b; auto simp: pure_def)
  done
    
    
lemma sum_match_cong[sepref_frame_match_rules]: 
  "
    x y. e = Inl x; e'=Inl y  hn_ctxt A x y t hn_ctxt A' x y;
    x y. e = Inr x; e'=Inr y  hn_ctxt B x y t hn_ctxt B' x y
    hn_ctxt (sum_assn A B) e e' t hn_ctxt (sum_assn A' B') e e'"
  by (cases e; cases e'; simp add: hn_ctxt_def entt_star_mono)

lemma enum_merge_cong[sepref_frame_merge_rules]:
  assumes "x y. e=Inl x; e'=Inl y  hn_ctxt A x y A hn_ctxt A' x y t hn_ctxt Am x y"
  assumes "x y. e=Inr x; e'=Inr y  hn_ctxt B x y A hn_ctxt B' x y t hn_ctxt Bm x y"
  shows "hn_ctxt (sum_assn A B) e e' A hn_ctxt (sum_assn A' B') e e' t hn_ctxt (sum_assn Am Bm) e e'"
  apply (rule entt_disjE)
  apply (rule sum_match_cong)
  apply (rule entt_disjD1[OF assms(1)]; simp)
  apply (rule entt_disjD1[OF assms(2)]; simp)

  apply (rule sum_match_cong)
  apply (rule entt_disjD2[OF assms(1)]; simp)
  apply (rule entt_disjD2[OF assms(2)]; simp)
  done

lemma entt_invalid_sum: "hn_invalid (sum_assn A B) e e' t hn_ctxt (sum_assn (invalid_assn A) (invalid_assn B)) e e'"
  apply (simp add: hn_ctxt_def invalid_assn_def[abs_def])
  apply (rule enttI)
  apply clarsimp
  apply (cases e; cases e'; auto simp: mod_star_conv pure_def) 
  done

lemmas invalid_sum_merge[sepref_frame_merge_rules] = gen_merge_cons[OF entt_invalid_sum]

sepref_register Inr Inl  

lemma [sepref_fr_rules]: "(return o Inl,RETURN o Inl)  Ad a sum_assn A B"
  by sepref_to_hoare sep_auto
lemma [sepref_fr_rules]: "(return o Inr,RETURN o Inr)  Bd a sum_assn A B"
  by sepref_to_hoare sep_auto

sepref_register case_sum

text ‹In the monadify phase, this eta-expands to make visible all required arguments›
lemma [sepref_monadify_arity]: "case_sum  λ2f1 f2 x. SP case_sum$(λ2x. f1$x)$(λ2x. f2$x)$x"
  by simp

text ‹This determines an evaluation order for the first-order operands›  
lemma [sepref_monadify_comb]: "case_sum$f1$f2$x  (⤜) $(EVAL$x)$(λ2x. SP case_sum$f1$f2$x)" by simp

text ‹This enables translation of the case-distinction in a non-monadic context.›  
lemma [sepref_monadify_comb]: "EVAL$(case_sum$(λ2x. f1 x)$(λ2x. f2 x)$x) 
   (⤜) $(EVAL$x)$(λ2x. SP case_sum$(λ2x. EVAL $ f1 x)$(λ2x. EVAL $ f2 x)$x)"
  apply (rule eq_reflection)
  by (simp split: sum.splits)

text ‹Auxiliary lemma, to lift simp-rule over hn_ctxt›  
lemma sum_assn_ctxt: "sum_assn A B x y = z  hn_ctxt (sum_assn A B) x y = z"
  by (simp add: hn_ctxt_def)

text ‹The cases lemma first extracts the refinement for the datatype from the precondition.
  Next, it generate proof obligations to refine the functions for every case. 
  Finally the postconditions of the refinement are merged. 

  Note that we handle the
  destructed values separately, to allow reconstruction of the original datatype after the case-expression.

  Moreover, we provide (invalidated) versions of the original compound value to the cases,
  which allows access to pure compound values from inside the case.
  ›  
lemma sum_cases_hnr:
  fixes A B e e'
  defines [simp]: "INVe  hn_invalid (sum_assn A B) e e'"
  assumes FR: "Γ t hn_ctxt (sum_assn A B) e e' * F"
  assumes E1: "x1 x1a. e = Inl x1; e' = Inl x1a  hn_refine (hn_ctxt A x1 x1a * INVe * F) (f1' x1a) (hn_ctxt A' x1 x1a * hn_ctxt XX1 e e' * Γ1') R (f1 x1)"
  assumes E2: "x2 x2a. e = Inr x2; e' = Inr x2a  hn_refine (hn_ctxt B x2 x2a * INVe * F) (f2' x2a) (hn_ctxt B' x2 x2a * hn_ctxt XX2 e e' * Γ2') R (f2 x2)"
  assumes MERGE[unfolded hn_ctxt_def]: "Γ1' A Γ2' t Γ'"
  shows "hn_refine Γ (case_sum f1' f2' e') (hn_ctxt (sum_assn A' B') e e' * Γ') R (case_sum$(λ2x. f1 x)$(λ2x. f2 x)$e)"
  apply (rule hn_refine_cons_pre[OF FR])
  apply1 extract_hnr_invalids
  apply (cases e; cases e'; simp add: sum_assn.simps[THEN sum_assn_ctxt])
  subgoal
    apply (rule hn_refine_cons[OF _ E1 _ entt_refl]; assumption?)
    applyS (simp add: hn_ctxt_def) ― ‹Match precondition for case, get enum_assn› from assumption generated by extract_hnr_invalids›
    apply (rule entt_star_mono) ― ‹Split postcondition into pairs for compounds and frame, drop hn_ctxt XX›
    apply1 (rule entt_fr_drop)
    applyS (simp add: hn_ctxt_def entt_disjI1' entt_disjI2')
    apply1 (rule entt_trans[OF _ MERGE])
    applyS (simp add: entt_disjI1' entt_disjI2')
  done
  subgoal 
    apply (rule hn_refine_cons[OF _ E2 _ entt_refl]; assumption?)
    applyS (simp add: hn_ctxt_def)
    apply (rule entt_star_mono)
    apply1 (rule entt_fr_drop)
    applyS (simp add: hn_ctxt_def entt_disjI1' entt_disjI2')
    apply1 (rule entt_trans[OF _ MERGE])
    applyS (simp add: entt_disjI1' entt_disjI2')
  done    
done  

text ‹After some more preprocessing (adding extra frame-rules for non-atomic postconditions, 
  and splitting the merge-terms into binary merges), this rule can be registered›
lemmas [sepref_comb_rules] = sum_cases_hnr[sepref_prep_comb_rule]

sepref_register isl projl projr
lemma isl_hnr[sepref_fr_rules]: "(return o isl,RETURN o isl)  (sum_assn A B)k a bool_assn"
  apply sepref_to_hoare
  subgoal for a b by (cases a; cases b; sep_auto)
  done

lemma projl_hnr[sepref_fr_rules]: "(return o projl,RETURN o projl)  [isl]a (sum_assn A B)d  A"
  apply sepref_to_hoare
  subgoal for a b by (cases a; cases b; sep_auto)
  done

lemma projr_hnr[sepref_fr_rules]: "(return o projr,RETURN o projr)  [Not o isl]a (sum_assn A B)d  B"
  apply sepref_to_hoare
  subgoal for a b by (cases a; cases b; sep_auto)
  done
  
subsection ‹String Literals›  

sepref_register "PR_CONST String.empty_literal"

lemma empty_literal_hnr [sepref_import_param]:
  "(String.empty_literal, PR_CONST String.empty_literal)  Id"
  by simp

lemma empty_literal_pat [def_pat_rules]:
  "String.empty_literal  UNPROTECT String.empty_literal"
  by simp

context
  fixes b0 b1 b2 b3 b4 b5 b6 :: bool
  and s :: String.literal
begin

sepref_register "PR_CONST (String.Literal b0 b1 b2 b3 b4 b5 b6 s)"

lemma Literal_hnr [sepref_import_param]:
  "(String.Literal b0 b1 b2 b3 b4 b5 b6 s,
    PR_CONST (String.Literal b0 b1 b2 b3 b4 b5 b6 s))  Id"
  by simp

end

lemma Literal_pat [def_pat_rules]:
  "String.Literal $ b0 $ b1 $ b2 $ b3 $ b4 $ b5 $ b6 $ s 
    UNPROTECT (String.Literal $ b0 $ b1 $ b2 $ b3 $ b4 $ b5 $ b6 $ s)"
  by simp
  
end