Theory Misc_CryptHOL

(* Title: Misc_CryptHOL.thy
  Author: Andreas Lochbihler, ETH Zurich *)

section ‹Miscellaneous library additions›

theory Misc_CryptHOL imports 
  Probabilistic_While.While_SPMF
  "HOL-Library.Rewrite"
  "HOL-Library.Simps_Case_Conv"
  "HOL-Library.Type_Length"
  "HOL-Eisbach.Eisbach"
  Coinductive.TLList
  Monad_Normalisation.Monad_Normalisation
  Monomorphic_Monad.Monomorphic_Monad
  Applicative_Lifting.Applicative
begin

hide_const (open) Henstock_Kurzweil_Integration.negligible

declare eq_on_def [simp del]

subsection ‹HOL›

lemma asm_rl_conv: "(PROP P  PROP P)  Trueprop True"
by(rule equal_intr_rule) iprover+

named_theorems if_distribs "Distributivity theorems for If"

lemma if_mono_cong: "b  x  x'; ¬ b  y  y'   If b x y  If b x' y'"
by simp

lemma if_cong_then: " b = b'; b'  t = t'; e = e'   If b t e = If b' t' e'"
by simp

lemma if_False_eq: " b  False; e = e'   If b t e = e'"
by auto

lemma imp_OO_imp [simp]: "(⟶) OO (⟶) = (⟶)"
by auto

lemma inj_on_fun_updD: " inj_on (f(x := y)) A; x  A   inj_on f A"
by(auto simp add: inj_on_def split: if_split_asm)

lemma disjoint_notin1: " A  B = {}; x  B   x  A" by auto

lemma Least_le_Least:
  fixes x :: "'a :: wellorder"
  assumes "Q x"
  and Q: "x. Q x  yx. P y"
  shows "Least P  Least Q"
  by (metis assms order_trans wellorder_Least_lemma)

lemma is_empty_image [simp]: "Set.is_empty (f ` A) = Set.is_empty A"
  by(auto simp add: Set.is_empty_def)

subsection ‹Relations›

inductive Imagep :: "('a  'b  bool)  ('a  bool)  'b  bool"
  for R P
where ImagepI: " P x; R x y   Imagep R P y"

lemma r_r_into_tranclp: " r x y; r y z   r^++ x z"
by(rule tranclp.trancl_into_trancl)(rule tranclp.r_into_trancl)

lemma transp_tranclp_id:
  assumes "transp R"
  shows "tranclp R = R"
proof(intro ext iffI)
  fix x y
  assume "R^++ x y"
  thus "R x y" by induction(blast dest: transpD[OF assms])+
qed simp

lemma transp_inv_image: "transp r  transp (λx y. r (f x) (f y))"
using trans_inv_image[where r="{(x, y). r x y}" and f = f]
by(simp add: transp_trans inv_image_def)

lemma Domainp_conversep: "Domainp R¯¯ = Rangep R"
by(auto)

lemma bi_unique_rel_set_bij_betw:
  assumes unique: "bi_unique R"
  and rel: "rel_set R A B"
  shows "f. bij_betw f A B  (xA. R x (f x))"
proof -
  from assms obtain f where f: "x. x  A  R x (f x)" and B: "x. x  A  f x  B"
    apply(atomize_elim)
    apply(fold all_conj_distrib)
    apply(subst choice_iff[symmetric])
    apply(auto dest: rel_setD1)
    done
  have "inj_on f A" by(rule inj_onI)(auto dest!: f dest: bi_uniqueDl[OF unique])
  moreover have "f ` A = B" using rel
    by(auto 4 3 intro: B dest: rel_setD2 f bi_uniqueDr[OF unique])
  ultimately have "bij_betw f A B" unfolding bij_betw_def ..
  thus ?thesis using f by blast
qed

definition restrict_relp :: "('a  'b  bool)  ('a  bool)  ('b  bool)  'a  'b  bool"
  ("_  (_  _)" [53, 54, 54] 53)
where "restrict_relp R P Q = (λx y. R x y  P x  Q y)"

lemma restrict_relp_apply [simp]: "(R  P  Q) x y  R x y  P x  Q y"
by(simp add: restrict_relp_def)

lemma restrict_relpI [intro?]: " R x y; P x; Q y   (R  P  Q) x y"
by(simp add: restrict_relp_def)

lemma restrict_relpE [elim?, cases pred]:
  assumes "(R  P  Q) x y"
  obtains (restrict_relp) "R x y" "P x" "Q y"
using assms by(simp add: restrict_relp_def)

lemma conversep_restrict_relp [simp]: "(R  P  Q)¯¯ = R¯¯  Q  P"
by(auto simp add: fun_eq_iff)

lemma restrict_relp_restrict_relp [simp]: "R  P  Q  P'  Q' = R  inf P P'  inf Q Q'"
by(auto simp add: fun_eq_iff)

lemma restrict_relp_cong:
  " P = P'; Q = Q'; x y.  P x; Q y   R x y = R' x y   R  P  Q = R'  P'  Q'"
by(auto simp add: fun_eq_iff)

lemma restrict_relp_cong_simp:
  " P = P'; Q = Q'; x y. P x =simp=> Q y =simp=> R x y = R' x y   R  P  Q = R'  P'  Q'"
by(rule restrict_relp_cong; simp add: simp_implies_def)

lemma restrict_relp_parametric [transfer_rule]:
  includes lifting_syntax shows
  "((A ===> B ===> (=)) ===> (A ===> (=)) ===> (B ===> (=)) ===> A ===> B ===> (=)) restrict_relp restrict_relp"
unfolding restrict_relp_def[abs_def] by transfer_prover

lemma restrict_relp_mono: " R  R'; P  P'; Q  Q'   R  P  Q  R'  P'  Q'"
by(simp add: le_fun_def)

lemma restrict_relp_mono': 
  " (R  P  Q) x y;  R x y; P x; Q y   R' x y &&& P' x &&& Q' y 
   (R'  P'  Q') x y"
by(auto dest: conjunctionD1 conjunctionD2)

lemma restrict_relp_DomainpD: "Domainp (R  P  Q) x  Domainp R x  P x"
by(auto simp add: Domainp.simps)

lemma restrict_relp_True: "R  (λ_. True)  (λ_. True) = R"
by(simp add: fun_eq_iff)

lemma restrict_relp_False1: "R  (λ_. False)  Q = bot"
by(simp add: fun_eq_iff)

lemma restrict_relp_False2: "R  P  (λ_. False) = bot"
by(simp add: fun_eq_iff)

definition rel_prod2 :: "('a  'b  bool)  'a  ('c × 'b)  bool"
where "rel_prod2 R a = (λ(c, b). R a b)"

lemma rel_prod2_simps [simp]: "rel_prod2 R a (c, b)  R a b"
by(simp add: rel_prod2_def)

lemma restrict_rel_prod:
  "rel_prod (R  I1  I2) (S  I1'  I2') = rel_prod R S  pred_prod I1 I1'  pred_prod I2 I2'"
by(auto simp add: fun_eq_iff)

lemma restrict_rel_prod1:
  "rel_prod (R  I1  I2) S = rel_prod R S  pred_prod I1 (λ_. True)  pred_prod I2 (λ_. True)"
by(simp add: restrict_rel_prod[symmetric] restrict_relp_True)

lemma restrict_rel_prod2:
  "rel_prod R (S  I1  I2) = rel_prod R S  pred_prod (λ_. True) I1  pred_prod (λ_. True) I2"
by(simp add: restrict_rel_prod[symmetric] restrict_relp_True)

consts relcompp_witness :: "('a  'b  bool)  ('b  'c  bool)  'a × 'c  'b"
specification (relcompp_witness)
  relcompp_witness1: "(A OO B) (fst xy) (snd xy)  A (fst xy) (relcompp_witness A B xy)"
  relcompp_witness2: "(A OO B) (fst xy) (snd xy)  B (relcompp_witness A B xy) (snd xy)"
  apply(fold all_conj_distrib)
  apply(rule choice allI)+
  by(auto intro: choice allI)

lemmas relcompp_witness[of _ _ "(x, y)" for x y, simplified] = relcompp_witness1 relcompp_witness2

hide_fact (open) relcompp_witness1 relcompp_witness2

lemma relcompp_witness_eq [simp]: "relcompp_witness (=) (=) (x, x) = x"
  using relcompp_witness(1)[of "(=)" "(=)" x x] by(simp add: eq_OO)

subsection ‹Pairs›

lemma split_apfst [simp]: "case_prod h (apfst f xy) = case_prod (h  f) xy"
by(cases xy) simp

definition corec_prod :: "('s  'a)  ('s  'b)  's  'a × 'b"
where "corec_prod f g = (λs. (f s, g s))"

lemma corec_prod_apply: "corec_prod f g s = (f s, g s)"
by(simp add: corec_prod_def)

lemma corec_prod_sel [simp]:
  shows fst_corec_prod: "fst (corec_prod f g s) = f s"
  and snd_corec_prod: "snd (corec_prod f g s) = g s"
by(simp_all add: corec_prod_apply)

lemma apfst_corec_prod [simp]: "apfst h (corec_prod f g s) = corec_prod (h  f) g s"
by(simp add: corec_prod_apply)

lemma apsnd_corec_prod [simp]: "apsnd h (corec_prod f g s) = corec_prod f (h  g) s"
by(simp add: corec_prod_apply)

lemma map_corec_prod [simp]: "map_prod f g (corec_prod h k s) = corec_prod (f  h) (g  k) s"
by(simp add: corec_prod_apply)

lemma split_corec_prod [simp]: "case_prod h (corec_prod f g s) = h (f s) (g s)"
by(simp add: corec_prod_apply)

lemma Pair_fst_Unity: "(fst x, ()) = x"
  by(cases x) simp

definition rprodl :: "('a × 'b) × 'c  'a × ('b × 'c)" where "rprodl = (λ((a, b), c). (a, (b, c)))"

lemma rprodl_simps [simp]: "rprodl ((a, b), c) = (a, (b, c))"
  by(simp add: rprodl_def)

lemma rprodl_parametric [transfer_rule]: includes lifting_syntax shows
  "(rel_prod (rel_prod A B) C ===> rel_prod A (rel_prod B C)) rprodl rprodl"
  unfolding rprodl_def by transfer_prover

definition lprodr :: "'a × ('b × 'c)  ('a × 'b) × 'c" where "lprodr = (λ(a, b, c). ((a, b), c))"

lemma lprodr_simps [simp]: "lprodr (a, b, c) = ((a, b), c)"
  by(simp add: lprodr_def)

lemma lprodr_parametric [transfer_rule]: includes lifting_syntax shows
  "(rel_prod A (rel_prod B C) ===> rel_prod (rel_prod A B) C) lprodr lprodr"
  unfolding lprodr_def by transfer_prover

lemma lprodr_inverse [simp]: "rprodl (lprodr x) = x"
  by(cases x) auto

lemma rprodl_inverse [simp]: "lprodr (rprodl x) = x"
  by(cases x) auto

lemma pred_prod_mono' [mono]:
  "pred_prod A B xy  pred_prod A' B' xy"
  if "x. A x  A' x" "y. B y  B' y"
  using that by(cases xy) auto

fun rel_witness_prod :: "('a × 'b) × ('c × 'd)  (('a × 'c) × ('b × 'd))" where
  "rel_witness_prod ((a, b), (c, d)) = ((a, c), (b, d))"

subsection ‹Sums›

lemma islE:
  assumes "isl x"
  obtains l where "x = Inl l"
using assms by(cases x) auto

lemma Inl_in_Plus [simp]: "Inl x  A <+> B  x  A"
by auto

lemma Inr_in_Plus [simp]: "Inr x  A <+> B  x  B"
by auto

lemma Inl_eq_map_sum_iff: "Inl x = map_sum f g y  (z. y = Inl z  x = f z)"
by(cases y) auto

lemma Inr_eq_map_sum_iff: "Inr x = map_sum f g y  (z. y = Inr z  x = g z)"
by(cases y) auto

lemma inj_on_map_sum [simp]:
  " inj_on f A; inj_on g B   inj_on (map_sum f g) (A <+> B)"
proof(rule inj_onI, goal_cases)
  case (1 x y)
  then show ?case by(cases x; cases y; auto simp add: inj_on_def)
qed

lemma inv_into_map_sum:
  "inv_into (A <+> B) (map_sum f g) x = map_sum (inv_into A f) (inv_into B g) x"
  if "x  f ` A <+> g ` B" "inj_on f A" "inj_on g B"
  using that by(cases rule: PlusE[consumes 1])(auto simp add: inv_into_f_eq f_inv_into_f)

fun rsuml :: "('a + 'b) + 'c  'a + ('b + 'c)" where
  "rsuml (Inl (Inl a)) = Inl a"
| "rsuml (Inl (Inr b)) = Inr (Inl b)"
| "rsuml (Inr c) = Inr (Inr c)"

fun lsumr :: "'a + ('b + 'c)  ('a + 'b) + 'c" where
  "lsumr (Inl a) = Inl (Inl a)"
| "lsumr (Inr (Inl b)) = Inl (Inr b)"
| "lsumr (Inr (Inr c)) = Inr c"

lemma rsuml_lsumr [simp]: "rsuml (lsumr x) = x"
  by(cases x rule: lsumr.cases) simp_all

lemma lsumr_rsuml [simp]: "lsumr (rsuml x) = x"
  by(cases x rule: rsuml.cases) simp_all

subsection ‹Option›

declare is_none_bind [simp]

lemma case_option_collapse: "case_option x (λ_. x) y = x"
by(simp split: option.split)

lemma indicator_single_Some: "indicator {Some x} (Some y) = indicator {x} y"
by(simp split: split_indicator)

subsubsection ‹Predicator and relator›

lemma option_pred_mono_strong:
  " pred_option P x; a.  a  set_option x; P a   P' a   pred_option P' x"
by(fact option.pred_mono_strong)

lemma option_pred_map [simp]: "pred_option P (map_option f x) = pred_option (P  f) x"
by(fact option.pred_map)

lemma option_pred_o_map [simp]: "pred_option P  map_option f = pred_option (P  f)"
by(simp add: fun_eq_iff)

lemma option_pred_bind [simp]: "pred_option P (Option.bind x f) = pred_option (pred_option P  f) x"
by(simp add: pred_option_def)

lemma pred_option_conj [simp]:
  "pred_option (λx. P x  Q x) = (λx. pred_option P x  pred_option Q x)"
by(auto simp add: pred_option_def)

lemma pred_option_top [simp]:
  "pred_option (λ_. True) = (λ_. True)"
by(fact option.pred_True)

lemma rel_option_restrict_relpI [intro?]:
  " rel_option R x y; pred_option P x; pred_option Q y   rel_option (R  P  Q) x y"
by(erule option.rel_mono_strong) simp

lemma rel_option_restrict_relpE [elim?]:
  assumes "rel_option (R  P  Q) x y"
  obtains "rel_option R x y" "pred_option P x" "pred_option Q y"
proof
  show "rel_option R x y" using assms by(auto elim!: option.rel_mono_strong)
  have "pred_option (Domainp (R  P  Q)) x" using assms by(fold option.Domainp_rel) blast
  then show "pred_option P x" by(rule option_pred_mono_strong)(blast dest!: restrict_relp_DomainpD)
  have "pred_option (Domainp (R  P  Q)¯¯) y" using assms
    by(fold option.Domainp_rel)(auto simp only: option.rel_conversep Domainp_conversep)
  then show "pred_option Q y" by(rule option_pred_mono_strong)(auto dest!: restrict_relp_DomainpD)
qed

lemma rel_option_restrict_relp_iff:
  "rel_option (R  P  Q) x y  rel_option R x y  pred_option P x  pred_option Q y"
by(blast intro: rel_option_restrict_relpI elim: rel_option_restrict_relpE)

lemma option_rel_map_restrict_relp:
  shows option_rel_map_restrict_relp1:
  "rel_option (R  P  Q) (map_option f x) = rel_option (R  f  P  f  Q) x"
  and option_rel_map_restrict_relp2:
  "rel_option (R  P  Q) x (map_option g y) = rel_option ((λx. R x  g)  P  Q  g) x y"
by(simp_all add: option.rel_map restrict_relp_def fun_eq_iff)

fun rel_witness_option :: "'a option × 'b option  ('a × 'b) option" where
  "rel_witness_option (Some x, Some y) = Some (x, y)"
| "rel_witness_option (None, None) = None"
| "rel_witness_option _ = None" ― ‹Just to make the definition complete›

lemma rel_witness_option:
  shows set_rel_witness_option: " rel_option A x y; (a, b)  set_option (rel_witness_option (x, y))   A a b"
    and map1_rel_witness_option: "rel_option A x y  map_option fst (rel_witness_option (x, y)) = x"
    and map2_rel_witness_option: "rel_option A x y  map_option snd (rel_witness_option (x, y)) = y"
  by(cases "(x, y)" rule: rel_witness_option.cases; simp; fail)+

lemma rel_witness_option1:
  assumes "rel_option A x y"
  shows "rel_option (λa (a', b). a = a'  A a' b) x (rel_witness_option (x, y))"
  using map1_rel_witness_option[OF assms, symmetric]
  unfolding option.rel_eq[symmetric] option.rel_map
  by(rule option.rel_mono_strong)(auto intro: set_rel_witness_option[OF assms])

lemma rel_witness_option2:
  assumes "rel_option A x y"
  shows "rel_option (λ(a, b') b. b = b'  A a b') (rel_witness_option (x, y)) y"
  using map2_rel_witness_option[OF assms]
  unfolding option.rel_eq[symmetric] option.rel_map
  by(rule option.rel_mono_strong)(auto intro: set_rel_witness_option[OF assms])

subsubsection ‹Orders on option›

abbreviation le_option :: "'a option  'a option  bool"
where "le_option  ord_option (=)"

lemma le_option_bind_mono:
  " le_option x y; a. a  set_option x  le_option (f a) (g a) 
   le_option (Option.bind x f) (Option.bind y g)"
by(cases x) simp_all

lemma le_option_refl [simp]: "le_option x x"
by(cases x) simp_all


lemma le_option_conv_option_ord: "le_option = option_ord"
by(auto simp add: fun_eq_iff flat_ord_def elim: ord_option.cases)

definition pcr_Some :: "('a  'b  bool)  'a  'b option  bool"
where "pcr_Some R x y  (z. y = Some z  R x z)"

lemma pcr_Some_simps [simp]: "pcr_Some R x (Some y)  R x y"
by(simp add: pcr_Some_def)

lemma pcr_SomeE [cases pred]:
  assumes "pcr_Some R x y"
  obtains (pcr_Some) z where "y = Some z" "R x z"
using assms by(auto simp add: pcr_Some_def)

subsubsection ‹Filter for option›

fun filter_option :: "('a  bool)  'a option  'a option"
where
  "filter_option P None = None"
| "filter_option P (Some x) = (if P x then Some x else None)"

lemma set_filter_option [simp]: "set_option (filter_option P x) = {y  set_option x. P y}"
by(cases x) auto

lemma filter_map_option: "filter_option P (map_option f x) = map_option f (filter_option (P  f) x)"
by(cases x) simp_all

lemma is_none_filter_option [simp]: "Option.is_none (filter_option P x)  Option.is_none x  ¬ P (the x)"
by(cases x) simp_all

lemma filter_option_eq_Some_iff [simp]: "filter_option P x = Some y  x = Some y  P y"
by(cases x) auto

lemma Some_eq_filter_option_iff [simp]: "Some y = filter_option P x  x = Some y  P y"
by(cases x) auto

lemma filter_conv_bind_option: "filter_option P x = Option.bind x (λy. if P y then Some y else None)"
by(cases x) simp_all

subsubsection ‹Assert for option›

primrec assert_option :: "bool  unit option" where
  "assert_option True = Some ()"
| "assert_option False = None"

lemma set_assert_option_conv: "set_option (assert_option b) = (if b then {()} else {})"
by(simp)

lemma in_set_assert_option [simp]: "x  set_option (assert_option b)  b"
by(cases b) simp_all


subsubsection ‹Join on options›

definition join_option :: "'a option option  'a option"
where "join_option x = (case x of Some y  y | None  None)"

simps_of_case join_simps [simp, code]: join_option_def

lemma set_join_option [simp]: "set_option (join_option x) = (set_option ` set_option x)"
by(cases x)(simp_all)

lemma in_set_join_option: "x  set_option (join_option (Some (Some x)))"
by simp

lemma map_join_option: "map_option f (join_option x) = join_option (map_option (map_option f) x)"
by(cases x) simp_all

lemma bind_conv_join_option: "Option.bind x f = join_option (map_option f x)"
by(cases x) simp_all

lemma join_conv_bind_option: "join_option x = Option.bind x id"
by(cases x) simp_all

lemma join_option_parametric [transfer_rule]:
  includes lifting_syntax shows
  "(rel_option (rel_option R) ===> rel_option R) join_option join_option"
unfolding join_conv_bind_option[abs_def] by transfer_prover

lemma join_option_eq_Some [simp]: "join_option x = Some y  x = Some (Some y)"
by(cases x) simp_all

lemma Some_eq_join_option [simp]: "Some y = join_option x  x = Some (Some y)"
by(cases x) auto

lemma join_option_eq_None: "join_option x = None  x = None  x = Some None"
by(cases x) simp_all

lemma None_eq_join_option: "None = join_option x  x = None  x = Some None"
by(cases x) auto

subsubsection ‹Zip on options›

function zip_option :: "'a option  'b option  ('a × 'b) option"
where
  "zip_option (Some x) (Some y) = Some (x, y)"
| "zip_option _ None = None"
| "zip_option None _ = None"
by pat_completeness auto
termination by lexicographic_order

lemma zip_option_eq_Some_iff [iff]:
  "zip_option x y = Some (a, b)  x = Some a  y = Some b"
by(cases "(x, y)" rule: zip_option.cases) simp_all

lemma set_zip_option [simp]:
  "set_option (zip_option x y) = set_option x × set_option y"
by auto

lemma zip_map_option1: "zip_option (map_option f x) y = map_option (apfst f) (zip_option x y)"
by(cases "(x, y)" rule: zip_option.cases) simp_all

lemma zip_map_option2: "zip_option x (map_option g y) = map_option (apsnd g) (zip_option x y)"
by(cases "(x, y)" rule: zip_option.cases) simp_all

lemma map_zip_option:
  "map_option (map_prod f g) (zip_option x y) = zip_option (map_option f x) (map_option g y)"
by(simp add: zip_map_option1 zip_map_option2 option.map_comp apfst_def apsnd_def o_def prod.map_comp)

lemma zip_conv_bind_option:
  "zip_option x y = Option.bind x (λx. Option.bind y (λy. Some (x, y)))"
by(cases "(x, y)" rule: zip_option.cases) simp_all

lemma zip_option_parametric [transfer_rule]:
  includes lifting_syntax shows
  "(rel_option R ===> rel_option Q ===> rel_option (rel_prod R Q)) zip_option zip_option"
unfolding zip_conv_bind_option[abs_def] by transfer_prover

lemma rel_option_eqI [simp]: "rel_option (=) x x"
by(simp add: option.rel_eq)

subsubsection ‹Binary supremum on @{typ "'a option"}

primrec sup_option :: "'a option  'a option  'a option"
where
  "sup_option x None = x"
| "sup_option x (Some y) = (Some y)"

lemma sup_option_idem [simp]: "sup_option x x = x"
by(cases x) simp_all

lemma sup_option_assoc: "sup_option (sup_option x y) z = sup_option x (sup_option y z)"
by(cases z) simp_all

lemma sup_option_left_idem: "sup_option x (sup_option x y) = sup_option x y"
by(rewrite sup_option_assoc[symmetric])(simp)

lemmas sup_option_ai = sup_option_assoc sup_option_left_idem

lemma sup_option_None [simp]: "sup_option None y = y"
by(cases y) simp_all

subsubsection ‹Restriction on @{typ "'a option"}

primrec (transfer) enforce_option :: "('a  bool)  'a option  'a option" where
  "enforce_option P (Some x) = (if P x then Some x else None)"
| "enforce_option P None = None"

lemma set_enforce_option [simp]: "set_option (enforce_option P x) = {a  set_option x. P a}"
  by(cases x) auto

lemma enforce_map_option: "enforce_option P (map_option f x) = map_option f (enforce_option (P  f) x)"
  by(cases x) auto

lemma enforce_bind_option [simp]:
  "enforce_option P (Option.bind x f) = Option.bind x (enforce_option P  f)"
  by(cases x) auto

lemma enforce_option_alt_def:
  "enforce_option P x = Option.bind x (λa. Option.bind (assert_option (P a)) (λ_ :: unit. Some a))"
  by(cases x) simp_all

lemma enforce_option_eq_None_iff [simp]:
  "enforce_option P x = None  (a. x = Some a  ¬ P a)"
  by(cases x) auto

lemma enforce_option_eq_Some_iff [simp]:
  "enforce_option P x = Some y  x = Some y  P y"
  by(cases x) auto

lemma Some_eq_enforce_option_iff [simp]:
  "Some y = enforce_option P x  x = Some y  P y"
  by(cases x) auto

lemma enforce_option_top [simp]: "enforce_option  = id"
  by(rule ext; rename_tac x; case_tac x; simp)

lemma enforce_option_K_True [simp]: "enforce_option (λ_. True) x = x"
  by(cases x) simp_all

lemma enforce_option_bot [simp]: "enforce_option  = (λ_. None)"
  by(simp add: fun_eq_iff)

lemma enforce_option_K_False [simp]: "enforce_option (λ_. False) x = None"
  by simp

lemma enforce_pred_id_option: "pred_option P x  enforce_option P x = x"
  by(cases x) auto

subsubsection ‹Maps›

lemma map_add_apply: "(m1 ++ m2) x = sup_option (m1 x) (m2 x)"
by(simp add: map_add_def split: option.split)

lemma map_le_map_upd2: " f m g; y'. f x = Some y'  y' = y   f m g(x  y)"
by(cases "x  dom f")(auto simp add: map_le_def Ball_def)

lemma eq_None_iff_not_dom: "f x = None  x  dom f"
by auto

lemma card_ran_le_dom: "finite (dom m)  card (ran m)  card (dom m)"
by(simp add: ran_alt_def card_image_le)

lemma dom_subset_ran_iff:
  assumes "finite (ran m)"
  shows "dom m  ran m  dom m = ran m"
proof
  assume le: "dom m  ran m"
  then have "card (dom m)  card (ran m)" by(simp add: card_mono assms)
  moreover have "card (ran m)  card (dom m)" by(simp add: finite_subset[OF le assms] card_ran_le_dom)
  ultimately show "dom m = ran m" using card_subset_eq[OF assms le] by simp
qed simp

text ‹
  We need a polymorphic constant for the empty map such that transfer_prover›
  can use a custom transfer rule for @{const Map.empty}
definition Map_empty where [simp]: "Map_empty  Map.empty"

lemma map_le_Some1D: " m m m'; m x = Some y   m' x = Some y"
by(auto simp add: map_le_def Ball_def)

lemma map_le_fun_upd2: " f m g; x  dom f   f m g(x := y)"
by(auto simp add: map_le_def)

lemma map_eqI: "xdom m  dom m'. m x = m' x  m = m'"
by(auto simp add: fun_eq_iff domIff intro: option.expand)


subsection ‹Countable›

lemma countable_lfp:
  assumes step: "Y. countable Y  countable (F Y)"
  and cont: "Order_Continuity.sup_continuous F"
  shows "countable (lfp F)"
by(subst sup_continuous_lfp[OF cont])(simp add: countable_funpow[OF step])

lemma countable_lfp_apply:
  assumes step: "Y x. (x. countable (Y x))  countable (F Y x)"
  and cont: "Order_Continuity.sup_continuous F"
  shows "countable (lfp F x)"
proof -
  { fix n
    have "x. countable ((F ^^ n) bot x)"
      by(induct n)(auto intro: step) }
  thus ?thesis using cont by(simp add: sup_continuous_lfp)
qed


subsection ‹ Extended naturals ›

lemma idiff_enat_eq_enat_iff: "x - enat n = enat m  (k. x = enat k  k - n = m)"
  by (cases x) simp_all

lemma eSuc_SUP: "A  {}  eSuc ( (f ` A)) = (xA. eSuc (f x))"
  by (subst eSuc_Sup) (simp_all add: image_comp)

lemma ereal_of_enat_1: "ereal_of_enat 1 = ereal 1"
  by (simp add: one_enat_def)

lemma ennreal_real_conv_ennreal_of_enat: "ennreal (real n) = ennreal_of_enat n"
  by (simp add: ennreal_of_nat_eq_real_of_nat)

lemma enat_add_sub_same2: "b    a + b - b = (a :: enat)"
  by (cases a; cases b) simp_all

lemma enat_sub_add: "y  x  x - y + z = x + z - (y :: enat)"
  by (cases x; cases y; cases z) simp_all

lemma SUP_enat_eq_0_iff [simp]: " (f ` A) = (0 :: enat)  (xA. f x = 0)"
  by (simp add: bot_enat_def [symmetric])

lemma SUP_enat_add_left:
  assumes "I  {}"
  shows "(SUP iI. f i + c :: enat) = (SUP iI. f i) + c" (is "?lhs = ?rhs")
proof(cases "c", rule antisym)
  case (enat n)
  show "?lhs  ?rhs" by(auto 4 3 intro: SUP_upper intro: SUP_least)
  have "(SUP iI. f i)  ?lhs - c" using enat 
    by(auto simp add: enat_add_sub_same2 intro!: SUP_least order_trans[OF _ SUP_upper[THEN enat_minus_mono1]])
  note add_right_mono[OF this, of c]
  also have " + c  ?lhs" using assms
    by(subst enat_sub_add)(auto intro: SUP_upper2 simp add: enat_add_sub_same2 enat)
  finally show "?rhs  ?lhs" .
qed(simp add: assms SUP_constant)

lemma SUP_enat_add_right:
  assumes "I  {}"
  shows "(SUP iI. c + f i :: enat) = c + (SUP iI. f i)"
using SUP_enat_add_left[OF assms, of f c]
by(simp add: add.commute)

lemma iadd_SUP_le_iff: "n + (SUP xA. f x :: enat)  y  (if A = {} then n  y else xA. n + f x  y)"
by(simp add: bot_enat_def SUP_enat_add_right[symmetric] SUP_le_iff)

lemma SUP_iadd_le_iff: "(SUP xA. f x :: enat) + n  y  (if A = {} then n  y else xA. f x + n  y)"
using iadd_SUP_le_iff[of n f A y] by(simp add: add.commute)


subsection ‹Extended non-negative reals›

lemma (in finite_measure) nn_integral_indicator_neq_infty: 
  "f -` A  sets M  (+ x. indicator A (f x) M)  "
unfolding ennreal_indicator[symmetric]
apply(rule integrableD)
apply(rule integrable_const_bound[where B=1])
apply(simp_all add: indicator_vimage[symmetric])
done

lemma (in finite_measure) nn_integral_indicator_neq_top: 
  "f -` A  sets M  (+ x. indicator A (f x) M)  "
by(drule nn_integral_indicator_neq_infty) simp

lemma nn_integral_indicator_map:
  assumes [measurable]: "f  measurable M N" "{xspace N. P x}  sets N"
  shows "(+x. indicator {xspace N. P x} (f x) M) = emeasure M {xspace M. P (f x)}"
  using assms(1)[THEN measurable_space] 
  by (subst nn_integral_indicator[symmetric])
     (auto intro!: nn_integral_cong split: split_indicator simp del: nn_integral_indicator)


subsection ‹BNF material›

lemma transp_rel_fun: " is_equality Q; transp R   transp (rel_fun Q R)"
by(rule transpI)(auto dest: transpD rel_funD simp add: is_equality_def)

lemma rel_fun_inf: "inf (rel_fun Q R) (rel_fun Q R') = rel_fun Q (inf R R')"
by(rule antisym)(auto elim: rel_fun_mono dest: rel_funD)

lemma reflp_fun1: includes lifting_syntax shows " is_equality A; reflp B   reflp (A ===> B)"
by(simp add: reflp_def rel_fun_def is_equality_def)

lemma type_copy_id': "type_definition (λx. x) (λx. x) UNIV"
by unfold_locales simp_all

lemma type_copy_id: "type_definition id id UNIV"
by(simp add: id_def type_copy_id')

lemma GrpE [cases pred]:
  assumes "BNF_Def.Grp A f x y"
  obtains (Grp) "y = f x" "x  A"
using assms
by(simp add: Grp_def)

lemma rel_fun_Grp_copy_Abs:
  includes lifting_syntax
  assumes "type_definition Rep Abs A"
  shows "rel_fun (BNF_Def.Grp A Abs) (BNF_Def.Grp B g) = BNF_Def.Grp {f. f ` A  B} (Rep ---> g)"
proof -
  interpret type_definition Rep Abs A by fact
  show ?thesis
    by(auto simp add: rel_fun_def Grp_def fun_eq_iff Abs_inverse Rep_inverse intro!: Rep)
qed

lemma rel_set_Grp:
  "rel_set (BNF_Def.Grp A f) = BNF_Def.Grp {B. B  A} (image f)"
by(auto simp add: rel_set_def BNF_Def.Grp_def fun_eq_iff)

lemma rel_set_comp_Grp:
  "rel_set R = (BNF_Def.Grp {x. x  {(x, y). R x y}} ((`) fst))¯¯ OO BNF_Def.Grp {x. x  {(x, y). R x y}} ((`) snd)"
apply(auto 4 4 del: ext intro!: ext simp add: BNF_Def.Grp_def intro!: rel_setI intro: rev_bexI)
apply(simp add: relcompp_apply)
subgoal for A B
  apply(rule exI[where x="A × B  {(x, y). R x y}"])
  apply(auto 4 3 dest: rel_setD1 rel_setD2 intro: rev_image_eqI)
  done
done

lemma Domainp_Grp: "Domainp (BNF_Def.Grp A f) = (λx. x  A)"
by(auto simp add: fun_eq_iff Grp_def)

lemma pred_prod_conj [simp]:
  shows pred_prod_conj1: "P Q R. pred_prod (λx. P x  Q x) R = (λx. pred_prod P R x  pred_prod Q R x)"
  and pred_prod_conj2: "P Q R. pred_prod P (λx. Q x  R x) = (λx. pred_prod P Q x  pred_prod P R x)"
by(auto simp add: pred_prod.simps)

lemma pred_sum_conj [simp]:
  shows pred_sum_conj1: "P Q R. pred_sum (λx. P x  Q x) R = (λx. pred_sum P R x  pred_sum Q R x)"
  and pred_sum_conj2: "P Q R. pred_sum P (λx. Q x  R x) = (λx. pred_sum P Q x  pred_sum P R x)"
by(auto simp add: pred_sum.simps fun_eq_iff)

lemma pred_list_conj [simp]: "list_all (λx. P x  Q x) = (λx. list_all P x  list_all Q x)"
by(auto simp add: list_all_def)

lemma pred_prod_top [simp]:
  "pred_prod (λ_. True) (λ_. True) = (λ_. True)"
by(simp add: pred_prod.simps fun_eq_iff)

lemma rel_fun_conversep: includes lifting_syntax shows
  "(A^--1 ===> B^--1) = (A ===> B)^--1"
by(auto simp add: rel_fun_def fun_eq_iff)

lemma left_unique_Grp [iff]:
  "left_unique (BNF_Def.Grp A f)  inj_on f A"
unfolding Grp_def left_unique_def by(auto simp add: inj_on_def)

lemma right_unique_Grp [simp, intro!]: "right_unique (BNF_Def.Grp A f)"
by(simp add: Grp_def right_unique_def)

lemma bi_unique_Grp [iff]:
  "bi_unique (BNF_Def.Grp A f)  inj_on f A"
by(simp add: bi_unique_alt_def)

lemma left_total_Grp [iff]:
  "left_total (BNF_Def.Grp A f)  A = UNIV"
by(auto simp add: left_total_def Grp_def)

lemma right_total_Grp [iff]:
  "right_total (BNF_Def.Grp A f)  f ` A = UNIV"
by(auto simp add: right_total_def BNF_Def.Grp_def image_def)

lemma bi_total_Grp [iff]:
  "bi_total (BNF_Def.Grp A f)  A = UNIV  surj f"
by(auto simp add: bi_total_alt_def)

lemma left_unique_vimage2p [simp]:
  " left_unique P; inj f   left_unique (BNF_Def.vimage2p f g P)"
unfolding vimage2p_Grp by(intro left_unique_OO) simp_all

lemma right_unique_vimage2p [simp]:
  " right_unique P; inj g   right_unique (BNF_Def.vimage2p f g P)"
unfolding vimage2p_Grp by(intro right_unique_OO) simp_all

lemma bi_unique_vimage2p [simp]:
  " bi_unique P; inj f; inj g   bi_unique (BNF_Def.vimage2p f g P)"
unfolding bi_unique_alt_def by simp

lemma left_total_vimage2p [simp]:
  " left_total P; surj g   left_total (BNF_Def.vimage2p f g P)"
unfolding vimage2p_Grp by(intro left_total_OO) simp_all

lemma right_total_vimage2p [simp]:
  " right_total P; surj f   right_total (BNF_Def.vimage2p f g P)"
unfolding vimage2p_Grp by(intro right_total_OO) simp_all

lemma bi_total_vimage2p [simp]:
  " bi_total P; surj f; surj g   bi_total (BNF_Def.vimage2p f g P)"
unfolding bi_total_alt_def by simp

lemma vimage2p_eq [simp]:
  "inj f  BNF_Def.vimage2p f f (=) = (=)"
by(auto simp add: vimage2p_def fun_eq_iff inj_on_def)

lemma vimage2p_conversep: "BNF_Def.vimage2p f g R^--1 = (BNF_Def.vimage2p g f R)^--1"
by(simp add: vimage2p_def fun_eq_iff)

lemma rel_fun_refl: " A  (=); (=)  B   (=)  rel_fun A B"
  by(subst fun.rel_eq[symmetric])(rule fun_mono)

lemma rel_fun_mono_strong:
  " rel_fun A B f g; A'  A; x y.  x  f ` {x. Domainp A' x}; y  g ` {x. Rangep A' x}; B x y   B' x y   rel_fun A' B' f g"
  by(auto simp add: rel_fun_def) fastforce

lemma rel_fun_refl_strong: 
  assumes "A  (=)" "x. x  f ` {x. Domainp A x}  B x x"
  shows "rel_fun A B f f"
proof -
  have "rel_fun (=) (=) f f" by(simp add: rel_fun_eq)
  then show ?thesis using assms(1)
    by(rule rel_fun_mono_strong) (auto intro: assms(2))
qed

lemma Grp_iff: "BNF_Def.Grp B g x y  y = g x  x  B" by(simp add: Grp_def)

lemma Rangep_Grp: "Rangep (BNF_Def.Grp A f) = (λx. x  f ` A)"
  by(auto simp add: fun_eq_iff Grp_iff)

lemma rel_fun_Grp:
  "rel_fun (BNF_Def.Grp UNIV h)¯¯ (BNF_Def.Grp A g) = BNF_Def.Grp {f. f ` range h  A} (map_fun h g)"
  by(auto simp add: rel_fun_def fun_eq_iff Grp_iff)

subsection ‹Transfer and lifting material›

context includes lifting_syntax begin

lemma monotone_parametric [transfer_rule]:
  assumes [transfer_rule]: "bi_total A"
  shows "((A ===> A ===> (=)) ===> (B ===> B ===> (=)) ===> (A ===> B) ===> (=)) monotone monotone"
unfolding monotone_def[abs_def] by transfer_prover

lemma fun_ord_parametric [transfer_rule]:
  assumes [transfer_rule]: "bi_total C"
  shows "((A ===> B ===> (=)) ===> (C ===> A) ===> (C ===> B) ===> (=)) fun_ord fun_ord"
unfolding fun_ord_def[abs_def] by transfer_prover

lemma Plus_parametric [transfer_rule]:
  "(rel_set A ===> rel_set B ===> rel_set (rel_sum A B)) (<+>) (<+>)"
unfolding Plus_def[abs_def] by transfer_prover

lemma pred_fun_parametric [transfer_rule]:
  assumes [transfer_rule]: "bi_total A"
  shows "((A ===> (=)) ===> (B ===> (=)) ===> (A ===> B) ===> (=)) pred_fun pred_fun"
unfolding pred_fun_def by(transfer_prover)

lemma rel_fun_eq_OO: "((=) ===> A) OO ((=) ===> B) = ((=) ===> A OO B)"
by(clarsimp simp add: rel_fun_def fun_eq_iff relcompp.simps) metis

end

lemma Quotient_set_rel_eq:
  includes lifting_syntax
  assumes "Quotient R Abs Rep T"
  shows "(rel_set T ===> rel_set T ===> (=)) (rel_set R) (=)"
proof(rule rel_funI iffI)+
  fix A B C D
  assume AB: "rel_set T A B" and CD: "rel_set T C D"
  have *: "x y. R x y = (T x (Abs x)  T y (Abs y)  Abs x = Abs y)"
    "a b. T a b  Abs a = b"
    using assms unfolding Quotient_alt_def by simp_all

  { assume [simp]: "B = D"
    thus "rel_set R A C"
      by(auto 4 4 intro!: rel_setI dest: rel_setD1[OF AB, simplified] rel_setD2[OF AB, simplified] rel_setD2[OF CD] rel_setD1[OF CD] simp add: * elim!: rev_bexI)
  next
    assume AC: "rel_set R A C"
    show "B = D"
      apply safe
       apply(drule rel_setD2[OF AB], erule bexE)
       apply(drule rel_setD1[OF AC], erule bexE)
       apply(drule rel_setD1[OF CD], erule bexE)
       apply(simp add: *)
      apply(drule rel_setD2[OF CD], erule bexE)
      apply(drule rel_setD2[OF AC], erule bexE)
      apply(drule rel_setD1[OF AB], erule bexE)
      apply(simp add: *)
      done
  }
qed

lemma Domainp_eq: "Domainp (=) = (λ_. True)"
by(simp add: Domainp.simps fun_eq_iff)

lemma rel_fun_eq_onpI: "eq_onp (pred_fun P Q) f g  rel_fun (eq_onp P) (eq_onp Q) f g"
by(auto simp add: eq_onp_def rel_fun_def)

lemma bi_unique_eq_onp: "bi_unique (eq_onp P)"
by(simp add: bi_unique_def eq_onp_def)

lemma rel_fun_eq_conversep: includes lifting_syntax shows "(A¯¯ ===> (=)) = (A ===> (=))¯¯"
by(auto simp add: fun_eq_iff rel_fun_def)

lemma rel_fun_comp:
  "f g h. rel_fun A B (f  g) h = rel_fun A (λx. B (f x)) g h"
  "f g h. rel_fun A B f (g  h) = rel_fun A (λx y. B x (g y)) f h"
  by(auto simp add: rel_fun_def)

lemma rel_fun_map_fun1: "rel_fun (BNF_Def.Grp UNIV h)¯¯ A f g  rel_fun (=) A (map_fun h id f) g"
  by(auto simp add: rel_fun_def Grp_def)

lemma map_fun2_id: "map_fun f g x = g  map_fun f id x"
  by(simp add: map_fun_def o_assoc)

lemma map_fun_id2_in: "map_fun g h f = map_fun g id (h  f)"
  by(simp add: map_fun_def)

lemma Domainp_rel_fun_le: "Domainp (rel_fun A B)  pred_fun (Domainp A) (Domainp B)"
  by(auto dest: rel_funD)

definition rel_witness_fun :: "('a  'b  bool)  ('b  'c  bool)  ('a  'd) × ('c  'e)  ('b  'd × 'e)" where
  "rel_witness_fun A A' = (λ(f, g) b. (f (THE a. A a b), g (THE c. A' b c)))"

lemma
  assumes fg: "rel_fun (A OO A') B f g"
    and A: "left_unique A" "right_total A"
    and A': "right_unique A'" "left_total A'"
  shows rel_witness_fun1: "rel_fun A (λx (x', y). x = x'  B x' y) f (rel_witness_fun A A' (f, g))"
    and rel_witness_fun2: "rel_fun A' (λ(x, y') y. y = y'  B x y') (rel_witness_fun A A' (f, g)) g"
proof (goal_cases)
  case 1
  have "A x y  f x = f (THE a. A a y)  B (f (THE a. A a y)) (g (The (A' y)))" for x y 
    by(rule left_totalE[OF A'(2)]; erule meta_allE[of _ y]; erule exE; frule (1) fg[THEN rel_funD, OF relcomppI])
      (auto intro!: arg_cong[where f=f] arg_cong[where f=g] rel_funI the_equality the_equality[symmetric] dest: left_uniqueD[OF A(1)] right_uniqueD[OF A'(1)] elim!: arg_cong2[where f=B, THEN iffD2, rotated -1])

  with 1 show ?case by(clarsimp simp add: rel_fun_def rel_witness_fun_def)
next
  case 2
  have "A' x y  g y = g (The (A' x))  B (f (THE a. A a x)) (g (The (A' x)))" for x y
    by(rule right_totalE[OF A(2), of x]; frule (1) fg[THEN rel_funD, OF relcomppI])
      (auto intro!: arg_cong[where f=f] arg_cong[where f=g] rel_funI the_equality the_equality[symmetric] dest: left_uniqueD[OF A(1)] right_uniqueD[OF A'(1)] elim!: arg_cong2[where f=B, THEN iffD2, rotated -1])

  with 2 show ?case by(clarsimp simp add: rel_fun_def rel_witness_fun_def)    
qed

lemma rel_witness_fun_eq [simp]: "rel_witness_fun (=) (=) (f, g) = (λx. (f x, g x))"
  by(simp add: rel_witness_fun_def)

subsection ‹Arithmetic›

lemma abs_diff_triangle_ineq2: "¦a - b :: _ :: ordered_ab_group_add_abs¦  ¦a - c¦ + ¦c - b¦"
by(rule order_trans[OF _ abs_diff_triangle_ineq]) simp

lemma (in ordered_ab_semigroup_add) add_left_mono_trans:
  " x  a + b; b  c   x  a + c"
by(erule order_trans)(rule add_left_mono)

lemma of_nat_le_one_cancel_iff [simp]:
  fixes n :: nat shows "real n  1  n  1"
by linarith

lemma (in linordered_semidom) mult_right_le: "c  1  0  a  c * a  a"
by(subst mult.commute)(rule mult_left_le)

subsection ‹Chain-complete partial orders and partial_function›

lemma fun_ordD: "fun_ord ord f g  ord (f x) (g x)"
by(simp add: fun_ord_def)

lemma parallel_fixp_induct_strong:
  assumes ccpo1: "class.ccpo luba orda (mk_less orda)"
  and ccpo2: "class.ccpo lubb ordb (mk_less ordb)"
  and adm: "ccpo.admissible (prod_lub luba lubb) (rel_prod orda ordb) (λx. P (fst x) (snd x))"
  and f: "monotone orda orda f"
  and g: "monotone ordb ordb g"
  and bot: "P (luba {}) (lubb {})"
  and step: "x y.  orda x (ccpo.fixp luba orda f); ordb y (ccpo.fixp lubb ordb g); P x y   P (f x) (g y)"
  shows "P (ccpo.fixp luba orda f) (ccpo.fixp lubb ordb g)"
proof -
  let ?P="λx y. orda x (ccpo.fixp luba orda f)  ordb y (ccpo.fixp lubb ordb g)  P x y"
  show ?thesis using ccpo1 ccpo2 _ f g
  proof(rule parallel_fixp_induct[where P="?P", THEN conjunct2, THEN conjunct2])
    note [cont_intro] = 
      admissible_leI[OF ccpo1] ccpo.mcont_const[OF ccpo1]
      admissible_leI[OF ccpo2] ccpo.mcont_const[OF ccpo2]
    show "ccpo.admissible (prod_lub luba lubb) (rel_prod orda ordb) (λxy. ?P (fst xy) (snd xy))"
      using adm by simp
    show "?P (luba {}) (lubb {})" using bot by(auto intro: ccpo.ccpo_Sup_least ccpo1 ccpo2 chain_empty)
    show "?P (f x) (g y)" if "?P x y" for x y using that
      apply(subst ccpo.fixp_unfold[OF ccpo1 f])
      apply(subst ccpo.fixp_unfold[OF ccpo2 g])
      apply(auto intro: step monotoneD[OF f] monotoneD[OF g])
      done
  qed
qed

lemma parallel_fixp_induct_strong_uc:
  assumes a: "partial_function_definitions orda luba"
  and b: "partial_function_definitions ordb lubb"
  and F: "x. monotone (fun_ord orda) orda (λf. U1 (F (C1 f)) x)"
  and G: "y. monotone (fun_ord ordb) ordb (λg. U2 (G (C2 g)) y)"
  and eq1: "f  C1 (ccpo.fixp (fun_lub luba) (fun_ord orda) (λf. U1 (F (C1 f))))"
  and eq2: "g  C2 (ccpo.fixp (fun_lub lubb) (fun_ord ordb) (λg. U2 (G (C2 g))))"
  and inverse: "f. U1 (C1 f) = f"
  and inverse2: "g. U2 (C2 g) = g"
  and adm: "ccpo.admissible (prod_lub (fun_lub luba) (fun_lub lubb)) (rel_prod (fun_ord orda) (fun_ord ordb)) (λx. P (fst x) (snd x))"
  and bot: "P (λ_. luba {}) (λ_. lubb {})"
  and step: "f' g'.  x. orda (U1 f' x) (U1 f x); y. ordb (U2 g' y) (U2 g y); P (U1 f') (U2 g')   P (U1 (F f')) (U2 (G g'))"
  shows "P (U1 f) (U2 g)"
apply(unfold eq1 eq2 inverse inverse2)
apply(rule parallel_fixp_induct_strong[OF partial_function_definitions.ccpo[OF a] partial_function_definitions.ccpo[OF b] adm])
using F apply(simp add: monotone_def fun_ord_def)
using G apply(simp add: monotone_def fun_ord_def)
apply(simp add: fun_lub_def bot)
apply(rule step; simp add: inverse inverse2 eq1 eq2 fun_ordD)
done

lemmas parallel_fixp_induct_strong_1_1 = parallel_fixp_induct_strong_uc[
  of _ _ _ _ "λx. x" _ "λx. x" "λx. x" _ "λx. x",
  OF _ _ _ _ _ _ refl refl]

lemmas parallel_fixp_induct_strong_2_2 = parallel_fixp_induct_strong_uc[
  of _ _ _ _ "case_prod" _ "curry" "case_prod" _ "curry",
  where P="λf g. P (curry f) (curry g)",
  unfolded case_prod_curry curry_case_prod curry_K,
  OF _ _ _ _ _ _ refl refl,
  split_format (complete), unfolded prod.case]
  for P

lemma fixp_induct_option': ― ‹Stronger induction rule›
  fixes F :: "'c  'c" and
    U :: "'c  'b  'a option" and
    C :: "('b  'a option)  'c" and
    P :: "'b  'a  bool"
  assumes mono: "x. mono_option (λf. U (F (C f)) x)"
  assumes eq: "f  C (ccpo.fixp (fun_lub (flat_lub None)) (fun_ord option_ord) (λf. U (F (C f))))"
  assumes inverse2: "f. U (C f) = f"
  assumes step: "g x y.  x y. U g x = Some y  P x y; U (F g) x = Some y; x. option_ord (U g x) (U f x)   P x y"
  assumes defined: "U f x = Some y"
  shows "P x y"
using step defined option.fixp_strong_induct_uc[of U F C, OF mono eq inverse2 option_admissible, of P]
unfolding fun_lub_def flat_lub_def fun_ord_def
by(simp (no_asm_use)) blast

declaration Partial_Function.init "option'" @{term option.fixp_fun}
  @{term option.mono_body} @{thm option.fixp_rule_uc} @{thm option.fixp_induct_uc}
  (SOME @{thm fixp_induct_option'})

lemma bot_fun_least [simp]: "(λ_. bot :: 'a :: order_bot)  x"
by(fold bot_fun_def) simp

lemma fun_ord_conv_rel_fun: "fun_ord = rel_fun (=)"
by(simp add: fun_ord_def fun_eq_iff rel_fun_def)

inductive finite_chains :: "('a  'a  bool)  bool"
  for ord
where finite_chainsI: "(Y. Complete_Partial_Order.chain ord Y  finite Y)  finite_chains ord"

lemma finite_chainsD: " finite_chains ord; Complete_Partial_Order.chain ord Y   finite Y"
by(rule finite_chains.cases)

lemma finite_chains_flat_ord [simp, intro!]: "finite_chains (flat_ord x)"
proof
  fix Y
  assume chain: "Complete_Partial_Order.chain (flat_ord x) Y"
  show "finite Y"
  proof(cases "y  Y. y  x")
    case True
    then obtain y where y: "y  Y" and yx: "y  x" by blast
    hence "Y  {x, y}" by(auto dest: chainD[OF chain] simp add: flat_ord_def)
    thus ?thesis by(rule finite_subset) simp
  next
    case False
    hence "Y  {x}" by auto
    thus ?thesis by(rule finite_subset) simp
  qed
qed    

lemma mcont_finite_chains:
  assumes finite: "finite_chains ord"
  and mono: "monotone ord ord' f"
  and ccpo: "class.ccpo lub ord (mk_less ord)"
  and ccpo': "class.ccpo lub' ord' (mk_less ord')"
  shows "mcont lub ord lub' ord' f"
proof(intro mcontI contI)
  fix Y
  assume chain: "Complete_Partial_Order.chain ord Y" and Y: "Y  {}"
  from finite chain have fin: "finite Y" by(rule finite_chainsD)
  from ccpo chain fin Y have lub: "lub Y  Y" by(rule ccpo.in_chain_finite)

  interpret ccpo': ccpo lub' ord' "mk_less ord'" by(rule ccpo')

  have chain': "Complete_Partial_Order.chain ord' (f ` Y)" using chain
    by(rule chain_imageI)(rule monotoneD[OF mono])

  have "ord' (f (lub Y)) (lub' (f ` Y))" using chain'
    by(rule ccpo'.ccpo_Sup_upper)(simp add: lub)
  moreover
  have "ord' (lub' (f ` Y)) (f (lub Y))" using chain'
    by(rule ccpo'.ccpo_Sup_least)(blast intro: monotoneD[OF mono] ccpo.ccpo_Sup_upper[OF ccpo chain])
  ultimately show "f (lub Y) = lub' (f ` Y)" by(rule ccpo'.order.antisym)
qed(fact mono)  

lemma rel_fun_curry: includes lifting_syntax shows
  "(A ===> B ===> C) f g  (rel_prod A B ===> C) (case_prod f) (case_prod g)"
by(auto simp add: rel_fun_def)

lemma (in ccpo) Sup_image_mono:
  assumes ccpo: "class.ccpo luba orda lessa"
  and mono: "monotone orda (≤) f"
  and chain: "Complete_Partial_Order.chain orda A"
  and "A  {}"
  shows "Sup (f ` A)  (f (luba A))"
proof(rule ccpo_Sup_least)
  from chain show "Complete_Partial_Order.chain (≤) (f ` A)"
    by(rule chain_imageI)(rule monotoneD[OF mono])
  fix x
  assume "x  f ` A"
  then obtain y where "x = f y" "y  A" by blast
  from y  A have "orda y (luba A)" by(rule ccpo.ccpo_Sup_upper[OF ccpo chain])
  hence "f y  f (luba A)" by(rule monotoneD[OF mono])
  thus "x  f (luba A)" using x = f y by simp
qed

lemma (in ccpo) admissible_le_mono:
  assumes "monotone (≤) (≤) f"
  shows "ccpo.admissible Sup (≤) (λx. x  f x)"
proof(rule ccpo.admissibleI)
  fix Y
  assume chain: "Complete_Partial_Order.chain (≤) Y"
    and Y: "Y  {}"
    and le [rule_format]: "xY. x  f x"
  have "Y  (f ` Y)" using chain
    by(rule ccpo_Sup_least)(rule order_trans[OF le]; blast intro!: ccpo_Sup_upper chain_imageI[OF chain] intro: monotoneD[OF assms])
  also have "  f (Y)"
    by(rule Sup_image_mono[OF _ assms chain Y, where lessa="(<)"]) unfold_locales
  finally show "Y  " .
qed

lemma (in ccpo) fixp_induct_strong2:
  assumes adm: "ccpo.admissible Sup (≤) P"
  and mono: "monotone (≤) (≤) f"
  and bot: "P ({})"
  and step: "x.  x  ccpo_class.fixp f; x  f x; P x   P (f x)"
  shows "P (ccpo_class.fixp f)"
proof(rule fixp_strong_induct[where P="λx. x  f x  P x", THEN conjunct2])
  show "ccpo.admissible Sup (≤) (λx. x  f x  P x)"
    using admissible_le_mono adm by(rule admissible_conj)(rule mono)
next
  show "{}  f ({})  P ({})"
    by(auto simp add: bot chain_empty intro: ccpo_Sup_least)
next
  fix x
  assume "x  ccpo_class.fixp f" "x  f x  P x"
  thus "f x  f (f x)  P (f x)"
    by(auto dest: monotoneD[OF mono] intro: step)
qed(rule mono)

context partial_function_definitions begin

lemma fixp_induct_strong2_uc:
  fixes F :: "'c  'c"
    and U :: "'c  'b  'a"
    and C :: "('b  'a)  'c"
    and P :: "('b  'a)  bool"
  assumes mono: "x. mono_body (λf. U (F (C f)) x)"
    and eq: "f  C (fixp_fun (λf. U (F (C f))))"
    and inverse: "f. U (C f) = f"
    and adm: "ccpo.admissible lub_fun le_fun P"
    and bot: "P (λ_. lub {})"
    and step: "f'.  le_fun (U f') (U f); le_fun (U f') (U (F f')); P (U f')   P (U (F f'))"
  shows "P (U f)"
unfolding eq inverse
apply (rule ccpo.fixp_induct_strong2[OF ccpo adm])
apply (insert mono, auto simp: monotone_def fun_ord_def bot fun_lub_def)[2]
apply (rule_tac f'5="C x" in step)
apply (simp_all add: inverse eq)
done

end

lemmas parallel_fixp_induct_2_4 = parallel_fixp_induct_uc[
  of _ _ _ _ "case_prod" _ "curry" "λf. case_prod (case_prod (case_prod f))" _ "λf. curry (curry (curry f))",
  where P="λf g. P (curry f) (curry (curry (curry g)))",
  unfolded case_prod_curry curry_case_prod curry_K,
  OF _ _ _ _ _ _ refl refl]
  for P
  
lemma (in ccpo) fixp_greatest:
  assumes f: "monotone (≤) (≤) f"
    and ge: "y. f y  y  x  y"
  shows "x  ccpo.fixp Sup (≤) f"
  by(rule ge)(simp add: fixp_unfold[OF f, symmetric])

lemma fixp_rolling:
  assumes "class.ccpo lub1 leq1 (mk_less leq1)"
    and "class.ccpo lub2 leq2 (mk_less leq2)"
    and f: "monotone leq1 leq2 f"
    and g: "monotone leq2 leq1 g"
  shows "ccpo.fixp lub1 leq1 (λx. g (f x)) = g (ccpo.fixp lub2 leq2 (λx. f (g x)))"
proof -
  interpret c1: ccpo lub1 leq1 "mk_less leq1" by fact
  interpret c2: ccpo lub2 leq2 "mk_less leq2" by fact
  show ?thesis
  proof(rule c1.order.antisym)
    have fg: "monotone leq2 leq2 (λx. f (g x))" using f g by(rule monotone2monotone) simp_all
    have gf: "monotone leq1 leq1 (λx. g (f x))" using g f by(rule monotone2monotone) simp_all
    show "leq1 (c1.fixp (λx. g (f x))) (g (c2.fixp (λx. f (g x))))" using gf
      by(rule c1.fixp_lowerbound)(subst (2) c2.fixp_unfold[OF fg], simp)
    show "leq1 (g (c2.fixp (λx. f (g x)))) (c1.fixp (λx. g (f x)))" using gf
    proof(rule c1.fixp_greatest)
      fix u
      assume u: "leq1 (g (f u)) u"
      have "leq1 (g (c2.fixp (λx. f (g x)))) (g (f u))"
        by(intro monotoneD[OF g] c2.fixp_lowerbound[OF fg] monotoneD[OF f u])
      then show "leq1 (g (c2.fixp (λx. f (g x)))) u" using u by(rule c1.order_trans)
    qed
  qed
qed

lemma fixp_lfp_parametric_eq:
  includes lifting_syntax
  assumes f: "x. lfp.mono_body (λf. F f x)"
  and g: "x. lfp.mono_body (λf. G f x)"
  and param: "((A ===> (=)) ===> A ===> (=)) F G"
  shows "(A ===> (=)) (lfp.fixp_fun F) (lfp.fixp_fun G)"
using f g
proof(rule parallel_fixp_induct_1_1[OF complete_lattice_partial_function_definitions complete_lattice_partial_function_definitions _ _ reflexive reflexive, where P="(A ===> (=))"])
  show "ccpo.admissible (prod_lub lfp.lub_fun lfp.lub_fun) (rel_prod lfp.le_fun lfp.le_fun) (λx. (A ===> (=)) (fst x) (snd x))"
    unfolding rel_fun_def by simp
  show "(A ===> (=)) (λ_. {}) (λ_. {})" by auto
  show "(A ===> (=)) (F f) (G g)" if "(A ===> (=)) f g" for f g
    using that by(rule rel_funD[OF param])
qed

lemma mono2mono_map_option[THEN option.mono2mono, simp, cont_intro]:
  shows monotone_map_option: "monotone option_ord option_ord (map_option f)"
by(rule monotoneI)(auto simp add: flat_ord_def)

lemma mcont2mcont_map_option[THEN option.mcont2mcont, simp, cont_intro]:
  shows mcont_map_option: "mcont (flat_lub None) option_ord (flat_lub None) option_ord (map_option f)"
by(rule mcont_finite_chains[OF _ _ flat_interpretation[THEN ccpo] flat_interpretation[THEN ccpo]]) simp_all

lemma mono2mono_set_option [THEN lfp.mono2mono]:
  shows monotone_set_option: "monotone option_ord (⊆) set_option"
by(auto intro!: monotoneI simp add: option_ord_Some1_iff)

lemma mcont2mcont_set_option [THEN lfp.mcont2mcont, cont_intro, simp]:
  shows mcont_set_option: "mcont (flat_lub None) option_ord Union (⊆) set_option"
by(rule mcont_finite_chains)(simp_all add: monotone_set_option ccpo option.partial_function_definitions_axioms)

lemma eadd_gfp_partial_function_mono [partial_function_mono]:
  " monotone (fun_ord (≥)) (≥) f; monotone (fun_ord (≥)) (≥) g 
   monotone (fun_ord (≥)) (≥) (λx. f x + g x :: enat)"
by(rule mono2mono_gfp_eadd)

lemma map_option_mono [partial_function_mono]:
  "mono_option B  mono_option (λf. map_option g (B f))"
unfolding map_conv_bind_option by(rule bind_mono) simp_all


subsection ‹Folding over finite sets›

lemma (in comp_fun_commute) fold_invariant_remove [consumes 1, case_names start step]:
  assumes fin: "finite A"
  and start: "I A s"
  and step: "x s A'.  x  A'; I A' s; A'  A   I (A' - {x}) (f x s)"
  shows "I {} (Finite_Set.fold f s A)"
proof -
  define A' where "A' == A"
  with fin start have "finite A'" "A'  A" "I A' s" by simp_all
  thus "I {} (Finite_Set.fold f s A')"
  proof(induction arbitrary: s)
    case empty thus ?case by simp
  next
    case (insert x A')
    let ?A' = "insert x A'"
    have "x  ?A'" "I ?A' s" "?A'  A" using insert by auto
    hence "I (?A' - {x}) (f x s)" by(rule step)
    with insert have "A'  A" "I A' (f x s)" by auto
    hence "I {} (Finite_Set.fold f (f x s) A')" by(rule insert.IH)
    thus ?case using insert by(simp add: fold_insert2 del: fold_insert)
  qed
qed

lemma (in comp_fun_commute) fold_invariant_insert [consumes 1, case_names start step]:
  assumes fin: "finite A"
  and start: "I {} s"
  and step: "x s A'.  I A' s; x  A'; x  A; A'  A   I (insert x A') (f x s)"
  shows "I A (Finite_Set.fold f s A)"
using fin start
proof(rule fold_invariant_remove[where I="λA'. I (A - A')" and A=A and s=s, simplified])
  fix x s A'
  assume *: "x  A'" "I (A - A') s" "A'  A"
  hence "x  A - A'" "x  A" "A - A'  A" by auto
  with I (A - A') s have "I (insert x (A - A')) (f x s)" by(rule step)
  also have "insert x (A - A') = A - (A' - {x})" using * by auto
  finally show "I  (f x s)" .
qed

lemma (in comp_fun_idem) fold_set_union:
  assumes "finite A" "finite B"
  shows "Finite_Set.fold f z (A  B) = Finite_Set.fold f (Finite_Set.fold f z A) B"
using assms(2,1) by induction simp_all


subsection ‹Parametrisation of transfer rules›

attribute_setup transfer_parametric = Attrib.thm >> (fn parametricity =>
    Thm.rule_attribute [] (fn context => fn transfer_rule =>
      let
        val ctxt = Context.proof_of context;
        val thm' = Lifting_Term.parametrize_transfer_rule ctxt transfer_rule
      in Lifting_Def.generate_parametric_transfer_rule ctxt thm' parametricity
      end
      handle Lifting_Term.MERGE_TRANSFER_REL msg => error (Pretty.string_of msg)
      )) "combine transfer rule with parametricity theorem"

subsection ‹Lists›

lemma nth_eq_tlI: "xs ! n = z  (x # xs) ! Suc n = z"
by simp

lemma list_all2_append':
  "length us = length vs  list_all2 P (xs @ us) (ys @ vs)  list_all2 P xs ys  list_all2 P us vs"
by(auto simp add: list_all2_append1 list_all2_append2 dest: list_all2_lengthD)

definition disjointp :: "('a  bool) list  bool"
where "disjointp xs = disjoint_family_on (λn. {x. (xs ! n) x}) {0..<length xs}"

lemma disjointpD:
  " disjointp xs; (xs ! n) x; (xs ! m) x; n < length xs; m < length xs   n = m"
by(auto 4 3 simp add: disjointp_def disjoint_family_on_def)

lemma disjointpD':
  " disjointp xs; P x; Q x; xs ! n = P; xs ! m = Q; n < length xs; m < length xs   n = m"
by(auto 4 3 simp add: disjointp_def disjoint_family_on_def)

lemma wf_strict_prefix: "wfP strict_prefix"
proof -
  from wf have "wf (inv_image {(x, y). x < y} length)" by(rule wf_inv_image)
  moreover have "{(x, y). strict_prefix x y}  inv_image {(x, y). x < y} length" by(auto intro: prefix_length_less)
  ultimately show ?thesis unfolding wfP_def by(rule wf_subset)
qed

lemma strict_prefix_setD:
  "strict_prefix xs ys  set xs  set ys"
  by(auto simp add: strict_prefix_def prefix_def)

subsubsection ‹List of a given length›

inductive_set nlists :: "'a set  nat  'a list set" for A n
where nlists: " set xs  A; length xs = n   xs  nlists A n"
hide_fact (open) nlists

lemma nlists_alt_def: "nlists A n = {xs. set xs  A  length xs = n}"
by(auto simp add: nlists.simps)

lemma nlists_empty: "nlists {} n = (if n = 0 then {[]} else {})"
by(auto simp add: nlists_alt_def)

lemma nlists_empty_gt0 [simp]: "n > 0  nlists {} n = {}"
by(simp add: nlists_empty)

lemma nlists_0 [simp]: "nlists A 0 = {[]}"
by(auto simp add: nlists_alt_def)

lemma Cons_in_nlists_Suc [simp]: "x # xs  nlists A (Suc n)  x  A  xs  nlists A n"
by(simp add: nlists_alt_def)

lemma Nil_in_nlists [simp]: "[]  nlists A n  n = 0"
by(auto simp add: nlists_alt_def)

lemma Cons_in_nlists_iff: "x # xs  nlists A n  (n'. n = Suc n'  x  A  xs  nlists A n')"
by(cases n) simp_all

lemma in_nlists_Suc_iff: "xs  nlists A (Suc n)  (x xs'. xs = x # xs'  x  A  xs'  nlists A n)"
by(cases xs) simp_all

lemma nlists_Suc: "nlists A (Suc n) = (xA. (#) x ` nlists A n)"
by(auto 4 3 simp add: in_nlists_Suc_iff intro: rev_image_eqI)

lemma replicate_in_nlists [simp, intro]: "x  A  replicate n x  nlists A n"
by(simp add: nlists_alt_def set_replicate_conv_if)

lemma nlists_eq_empty_iff [simp]: "nlists A n = {}  n > 0  A = {}"
using replicate_in_nlists by(cases n)(auto)

lemma finite_nlists [simp]: "finite A  finite (nlists A n)"
by(induction n)(simp_all add: nlists_Suc)

lemma finite_nlistsD: 
  assumes "finite (nlists A n)"
  shows "finite A  n = 0"
proof(rule disjCI)
  assume "n  0"
  then obtain n' where n: "n = Suc n'" by(cases n)auto
  then have "A = hd ` nlists A n" by(auto 4 4 simp add: nlists_Suc intro: rev_image_eqI rev_bexI)
  also have "finite " using assms ..
  finally show "finite A" .
qed

lemma finite_nlists_iff: "finite (nlists A n)  finite A  n = 0"
by(auto dest: finite_nlistsD)

lemma card_nlists: "card (nlists A n) = card A ^ n"
proof(induction n)
  case (Suc n)
  have "card (xA. (#) x ` nlists A n) = card A * card (nlists A n)"
  proof(cases "finite A")
    case True
    then show ?thesis by(subst card_UN_disjoint)(auto simp add: card_image inj_on_def)
  next
    case False
    hence "¬ finite (xA. (#) x ` nlists A n)"
      unfolding nlists_Suc[symmetric] by(auto dest: finite_nlistsD)
    then show ?thesis using False by simp
  qed
  then show ?case using Suc.IH by(simp add: nlists_Suc)
qed simp

lemma in_nlists_UNIV: "xs  nlists UNIV n  length xs = n"
by(simp add: nlists_alt_def)

subsubsection ‹ The type of lists of a given length ›

typedef (overloaded) ('a, 'b :: len0) nlist = "nlists (UNIV :: 'a set) (LENGTH('b))"
proof
  show "replicate LENGTH('b) undefined  ?nlist" by simp
qed

setup_lifting type_definition_nlist

subsection ‹Streams and infinite lists›

primrec sprefix :: "'a list  'a stream  bool" where
  sprefix_Nil: "sprefix [] ys = True"
| sprefix_Cons: "sprefix (x # xs) ys  x = shd ys  sprefix xs (stl ys)"

lemma sprefix_append: "sprefix (xs @ ys) zs  sprefix xs zs  sprefix ys (sdrop (length xs) zs)"
by(induct xs arbitrary: zs) simp_all

lemma sprefix_stake_same [simp]: "sprefix (stake n xs) xs"
by(induct n arbitrary: xs) simp_all

lemma sprefix_same_imp_eq:
  assumes "sprefix xs ys" "sprefix xs' ys"
  and "length xs = length xs'"
  shows "xs = xs'"
using assms(3,1,2) by(induct arbitrary: ys rule: list_induct2) auto

lemma sprefix_shift_same [simp]:
  "sprefix xs (xs @- ys)"
by(induct xs) simp_all

lemma sprefix_shift [simp]:
  "length xs  length ys  sprefix xs (ys @- zs)  prefix xs ys"
by(induct xs arbitrary: ys)(simp, case_tac ys, auto)

lemma prefixeq_stake2 [simp]: "prefix xs (stake n ys)  length xs  n  sprefix xs ys"
proof(induct xs arbitrary: n ys)
  case (Cons x xs)
  thus ?case by(cases ys n rule: stream.exhaust[case_product nat.exhaust]) auto
qed simp

lemma tlength_eq_infinity_iff: "tlength xs =   ¬ tfinite xs"
including tllist.lifting by transfer(simp add: llength_eq_infty_conv_lfinite)

subsection ‹Monomorphic monads›

context includes lifting_syntax begin
local_setup Local_Theory.map_background_naming (Name_Space.mandatory_path "monad")

definition bind_option :: "'m fail  'a option  ('a  'm)  'm"
where "bind_option fail x f = (case x of None  fail | Some x'  f x')" for fail

simps_of_case bind_option_simps [simp]: bind_option_def

lemma bind_option_parametric [transfer_rule]:
  "(M ===> rel_option B ===> (B ===> M) ===> M) bind_option bind_option"
unfolding bind_option_def by transfer_prover

lemma bind_option_K:
  "monad. (x = None  m = fail)  bind_option fail x (λ_. m) = m"
by(cases x) simp_all

end

lemma bind_option_option [simp]: "monad.bind_option None = Option.bind"
by(simp add: monad.bind_option_def fun_eq_iff split: option.split)

context monad_fail_hom begin

lemma hom_bind_option: "h (monad.bind_option fail1 x f) = monad.bind_option fail2 x (h  f)"
by(cases x)(simp_all)

end

lemma bind_option_set [simp]: "monad.bind_option fail_set = (λx f.  (f ` set_option x))"
by(simp add: monad.bind_option_def fun_eq_iff split: option.split)

lemma run_bind_option_stateT [simp]:
  "more. run_state (monad.bind_option (fail_state fail) x f) s = 
  monad.bind_option fail x (λy. run_state (f y) s)"
by(cases x) simp_all

lemma run_bind_option_envT [simp]:
  "more. run_env (monad.bind_option (fail_env fail) x f) s = 
  monad.bind_option fail x (λy. run_env (f y) s)"
by(cases x) simp_all


subsection ‹Measures›

declare sets_restrict_space_count_space [measurable_cong]

lemma (in sigma_algebra) sets_Collect_countable_Ex1:
  "(i :: 'i :: countable. {x  Ω. P i x}  M)  {x  Ω. ∃!i. P i x}  M"
using sets_Collect_countable_Ex1'[of "UNIV :: 'i set"] by simp

lemma pred_countable_Ex1 [measurable]:
  "(i :: _ :: countable. Measurable.pred M (λx. P i x))
   Measurable.pred M (λx. ∃!i. P i x)"
unfolding pred_def by(rule sets.sets_Collect_countable_Ex1)

lemma measurable_snd_count_space [measurable]: 
  "A  B  snd  measurable (M1 M count_space A) (count_space B)"
by(auto simp add: measurable_def space_pair_measure snd_vimage_eq_Times Times_Int_Times)

lemma integrable_scale_measure [simp]:
  " integrable M f; r <    integrable (scale_measure r M) f" 
  for f :: "'a  'b::{banach, second_countable_topology}"
  by(auto simp add: integrable_iff_bounded nn_integral_scale_measure ennreal_mult_less_top)

lemma integral_scale_measure:
  assumes "integrable M f" "r < "
  shows "integralL (scale_measure r M) f = enn2real r * integralL M f"
  using assms
  apply(subst (1 2) real_lebesgue_integral_def)
    apply(simp_all add: nn_integral_scale_measure ennreal_enn2real_if)
  by(auto simp add: ennreal_mult_less_top ennreal_less_top_iff ennreal_mult_eq_top_iff enn2real_mult right_diff_distrib elim!: integrableE)

subsection ‹Sequence space›

lemma (in sequence_space) nn_integral_split:
  assumes f[measurable]: "f  borel_measurable S"
  shows "(+ω. f ω S) = (+ω. (+ω'. f (comb_seq i ω ω') S) S)"
by (subst PiM_comb_seq[symmetric, where i=i])
   (simp add: nn_integral_distr P.nn_integral_fst[symmetric])

lemma (in sequence_space) prob_Collect_split:
  assumes f[measurable]: "{xspace S. P x}  sets S"
  shows "𝒫(x in S. P x) = (+x. 𝒫(x' in S. P (comb_seq i x x')) S)"
proof -
  have "𝒫(x in S. P x) = (+x. (+x'. indicator {xspace S. P x} (comb_seq i x x') S) S)"
    using nn_integral_split[of "indicator {xspace S. P x}"] by (auto simp: emeasure_eq_measure)
  also have " = (+x. 𝒫(x' in S. P (comb_seq i x x')) S)"
    by (intro nn_integral_cong) (auto simp: emeasure_eq_measure nn_integral_indicator_map)
  finally show ?thesis .
qed

subsection ‹Probability mass functions›

lemma measure_map_pmf_conv_distr:
  "measure_pmf (map_pmf f p) = distr (measure_pmf p) (count_space UNIV) f"
by(fact map_pmf_rep_eq)

abbreviation coin_pmf :: "bool pmf" where "coin_pmf  pmf_of_set UNIV"

text ‹The rule @{thm [source] rel_pmf_bindI} is not complete as a program logic.›
notepad begin
  define x where "x = pmf_of_set {True, False}"
  define y where "y = pmf_of_set {True, False}"
  define f where "f x = pmf_of_set {True, False}" for x :: bool
  define g :: "bool  bool pmf" where "g = return_pmf"
  define P :: "bool  bool  bool" where "P = (=)"
  have "rel_pmf P (bind_pmf x f) (bind_pmf y g)"
    by(simp add: P_def f_def[abs_def] g_def y_def bind_return_pmf' pmf.rel_eq)
  have "¬ R x y" if "x y. R x y  rel_pmf P (f x) (g y)" for R x y
    ― ‹Only the empty relation satisfies @{thm [source] rel_pmf_bindI}'s second premise.›
  proof
    assume "R x y"
    hence "rel_pmf P (f x) (g y)" by(rule that)
    thus False by(auto simp add: P_def f_def g_def rel_pmf_return_pmf2)
  qed
  define R where "R x y = False" for x y :: bool
  have "¬ rel_pmf R x y" by(simp add: R_def[abs_def])
end

lemma pred_rel_pmf:
  " pred_pmf P p; rel_pmf R p q   pred_pmf (Imagep R P) q"
unfolding pred_pmf_def
apply(rule ballI)
apply(unfold rel_pmf.simps)
apply(erule exE conjE)+
apply hypsubst
apply(unfold pmf.set_map)
apply(erule imageE, hypsubst)
apply(drule bspec)
 apply(erule rev_image_eqI)
 apply(rule refl)
apply(erule Imagep.intros)
apply(erule allE)+
 apply(erule mp)
apply(unfold prod.collapse)
apply assumption
done

lemma pmf_rel_mono': " rel_pmf P x y; P  Q   rel_pmf Q x y"
by(drule pmf.rel_mono) (auto)

lemma rel_pmf_eqI [simp]: "rel_pmf (=) x x"
by(simp add: pmf.rel_eq)

lemma rel_pmf_bind_reflI:
  "(x. x  set_pmf p  rel_pmf R (f x) (g x))
   rel_pmf R (bind_pmf p f) (bind_pmf p g)"
by(rule rel_pmf_bindI[where R="λx y. x = y  x  set_pmf p"])(auto intro: rel_pmf_reflI)

lemma pmf_pred_mono_strong:
  " pred_pmf P p; a.  a  set_pmf p; P a   P' a   pred_pmf P' p"
by(simp add: pred_pmf_def)

lemma rel_pmf_restrict_relpI [intro?]:
  " rel_pmf R x y; pred_pmf P x; pred_pmf Q y   rel_pmf (R  P  Q) x y"
by(erule pmf.rel_mono_strong)(simp add: pred_pmf_def)

lemma rel_pmf_restrict_relpE [elim?]:
  assumes "rel_pmf (R  P  Q) x y"
  obtains "rel_pmf R x y" "pred_pmf P x" "pred_pmf Q y"
proof
  show "rel_pmf R x y" using assms by(auto elim!: pmf.rel_mono_strong)
  have "pred_pmf (Domainp (R  P  Q)) x" using assms by(fold pmf.Domainp_rel) blast
  then show "pred_pmf P x" by(rule pmf_pred_mono_strong)(blast dest!: restrict_relp_DomainpD)
  have "pred_pmf (Domainp (R  P  Q)¯¯) y" using assms
    by(fold pmf.Domainp_rel)(auto simp only: pmf.rel_conversep Domainp_conversep)
  then show "pred_pmf Q y" by(rule pmf_pred_mono_strong)(auto dest!: restrict_relp_DomainpD)
qed

lemma rel_pmf_restrict_relp_iff:
  "rel_pmf (R  P  Q) x y  rel_pmf R x y  pred_pmf P x  pred_pmf Q y"
by(blast intro: rel_pmf_restrict_relpI elim: rel_pmf_restrict_relpE)

lemma rel_pmf_OO_trans [trans]:
  " rel_pmf R p q; rel_pmf S q r   rel_pmf (R OO S) p r"
unfolding pmf.rel_compp by blast

lemma pmf_pred_map [simp]: "pred_pmf P (map_pmf f p) = pred_pmf (P  f) p"
by(simp add: pred_pmf_def)

lemma pred_pmf_bind [simp]: "pred_pmf P (bind_pmf p f) = pred_pmf (pred_pmf P  f) p"
by(simp add: pred_pmf_def)

lemma pred_pmf_return [simp]: "pred_pmf P (return_pmf x) = P x"
by(simp add: pred_pmf_def)

lemma pred_pmf_of_set [simp]: " finite A; A  {}   pred_pmf P (pmf_of_set A) = Ball A P"
by(simp add: pred_pmf_def)

lemma pred_pmf_of_multiset [simp]: "M  {#}  pred_pmf P (pmf_of_multiset M) = Ball (set_mset M) P"
by(simp add: pred_pmf_def)

lemma pred_pmf_cond [simp]:
  "set_pmf p  A  {}  pred_pmf P (cond_pmf p A) = pred_pmf (λx. x  A  P x) p"
by(auto simp add: pred_pmf_def)

lemma pred_pmf_pair [simp]:
  "pred_pmf P (pair_pmf p q) = pred_pmf (λx. pred_pmf (P  Pair x) q) p"
by(simp add: pred_pmf_def)

lemma pred_pmf_join [simp]: "pred_pmf P (join_pmf p) = pred_pmf (pred_pmf P) p"
by(simp add: pred_pmf_def)

lemma pred_pmf_bernoulli [simp]: " 0 < p; p < 1   pred_pmf P (bernoulli_pmf p) = All P"
by(simp add: pred_pmf_def)

lemma pred_pmf_geometric [simp]: " 0 < p; p < 1   pred_pmf P (geometric_pmf p) = All P"
by(simp add: pred_pmf_def set_pmf_geometric)

lemma pred_pmf_poisson [simp]: "0 < rate  pred_pmf P (poisson_pmf rate) = All P"
by(simp add: pred_pmf_def)

lemma pmf_rel_map_restrict_relp: 
  shows pmf_rel_map_restrict_relp1: "rel_pmf (R  P  Q) (map_pmf f p) = rel_pmf (R  f  P  f  Q) p"
  and pmf_rel_map_restrict_relp2: "rel_pmf (R  P  Q) p (map_pmf g q) = rel_pmf ((λx. R x  g)  P  Q  g) p q"
by(simp_all add: pmf.rel_map restrict_relp_def fun_eq_iff)

lemma pred_pmf_conj [simp]: "pred_pmf (λx. P x  Q x) = (λx. pred_pmf P x  pred_pmf Q x)"
by(auto simp add: pred_pmf_def)

lemma pred_pmf_top [simp]:
  "pred_pmf (λ_. True) = (λ_. True)"
by(simp add: pred_pmf_def)

lemma rel_pmf_of_setI:
  assumes A: "A  {}" "finite A"
  and B: "B  {}" "finite B"
  and card: "X. X  A  card B * card X  card A * card {yB. xX. R x y}"
  shows "rel_pmf R (pmf_of_set A) (pmf_of_set B)"
apply(rule rel_pmf_measureI)
using assms
apply(clarsimp simp add: measure_pmf_of_set card_gt_0_iff field_simps of_nat_mult[symmetric] simp del: of_nat_mult)
apply(subst mult.commute)
apply(erule meta_allE)
apply(erule meta_impE)
 prefer 2
 apply(erule order_trans)
apply(auto simp add: card_gt_0_iff intro: card_mono)
done

consts rel_witness_pmf :: "('a  'b  bool)  'a pmf × 'b pmf  ('a × 'b) pmf"
specification (rel_witness_pmf)
  set_rel_witness_pmf': "rel_pmf A (fst xy) (snd xy)  set_pmf (rel_witness_pmf A xy)  {(a, b). A a b}"
  map1_rel_witness_pmf': "rel_pmf A (fst xy) (snd xy)  map_pmf fst (rel_witness_pmf A xy) = fst xy"
  map2_rel_witness_pmf': "rel_pmf A (fst xy) (snd xy)  map_pmf snd (rel_witness_pmf A xy) = snd xy"
  apply(fold all_conj_distrib imp_conjR)
  apply(rule choice allI)+
  apply(unfold pmf.in_rel)
  by blast

lemmas set_rel_witness_pmf = set_rel_witness_pmf'[of _ "(x, y)" for x y, simplified]
lemmas map1_rel_witness_pmf = map1_rel_witness_pmf'[of _ "(x, y)" for x y, simplified]
lemmas map2_rel_witness_pmf = map2_rel_witness_pmf'[of _ "(x, y)" for x y, simplified]
lemmas rel_witness_pmf = set_rel_witness_pmf map1_rel_witness_pmf map2_rel_witness_pmf

lemma rel_witness_pmf1:
  assumes "rel_pmf A p q" 
  shows "rel_pmf (λa (a', b). a = a'  A a' b) p (rel_witness_pmf A (p, q))"
  using map1_rel_witness_pmf[OF assms, symmetric]
  unfolding pmf.rel_eq[symmetric] pmf.rel_map
  by(rule pmf.rel_mono_strong)(auto dest: set_rel_witness_pmf[OF assms, THEN subsetD])

lemma rel_witness_pmf2:
  assumes "rel_pmf A p q" 
  shows "rel_pmf (λ(a, b') b. b = b'  A a b') (rel_witness_pmf A (p, q)) q"
  using map2_rel_witness_pmf[OF assms]
  unfolding pmf.rel_eq[symmetric] pmf.rel_map
  by(rule pmf.rel_mono_strong)(auto dest: set_rel_witness_pmf[OF assms, THEN subsetD])

lemma cond_pmf_of_set:
  assumes fin: "finite A" and nonempty: "A  B  {}"
  shows "cond_pmf (pmf_of_set A) B = pmf_of_set (A  B)" (is "?lhs = ?rhs")
proof(rule pmf_eqI)
  from nonempty have A: "A  {}" by auto
  show "pmf ?lhs x = pmf ?rhs x" for x
    by(subst pmf_cond; clarsimp simp add: fin A nonempty measure_pmf_of_set split: split_indicator)
qed

lemma pair_pmf_of_set:
  assumes A: "finite A" "A  {}"
    and B: "finite B" "B  {}"
  shows "pair_pmf (pmf_of_set A) (pmf_of_set B) = pmf_of_set (A × B)"
  by(rule pmf_eqI)(clarsimp simp add: pmf_pair assms split: split_indicator)

lemma emeasure_cond_pmf:
  fixes p A
  defines "q  cond_pmf p A"
  assumes "set_pmf p  A  {}"
  shows "emeasure (measure_pmf q) B = emeasure (measure_pmf p) (A  B) / emeasure (measure_pmf p) A"
proof -
  note [transfer_rule] = cond_pmf.transfer[OF assms(2), folded q_def]
  interpret pmf_as_measure .
  show ?thesis by transfer simp
qed

lemma measure_cond_pmf:
  "measure (measure_pmf (cond_pmf p A)) B = measure (measure_pmf p) (A  B) / measure (measure_pmf p) A"
  if "set_pmf p  A  {}"
  using emeasure_cond_pmf[OF that, of B] that 
  by(auto simp add: measure_pmf.emeasure_eq_measure measure_pmf_posI divide_ennreal)

lemma emeasure_measure_pmf_zero_iff: "emeasure (measure_pmf p) s = 0  set_pmf p  s = {}" (is "?lhs = ?rhs")
proof -
  have "?lhs  (AE x in measure_pmf p. x  s)"
    by(subst AE_iff_measurable)(auto)
  also have " = ?rhs" by(auto simp add: AE_measure_pmf_iff)
  finally show ?thesis .
qed

subsection ‹Subprobability mass functions›

lemma ord_spmf_return_spmf1: "ord_spmf R (return_spmf x) p  lossless_spmf p  (yset_spmf p. R x y)"
by(auto simp add: rel_pmf_return_pmf1 ord_option.simps in_set_spmf lossless_iff_set_pmf_None Ball_def) (metis option.exhaust)

lemma ord_spmf_conv:
  "ord_spmf R = rel_spmf R OO ord_spmf (=)"
apply(subst pmf.rel_compp[symmetric])
apply(rule arg_cong[where f="rel_pmf"])  
apply(rule ext)+
apply(auto elim!: ord_option.cases option.rel_cases intro: option.rel_intros)
done

lemma ord_spmf_expand:
  "NO_MATCH (=) R  ord_spmf R = rel_spmf R OO ord_spmf (=)"
by(rule ord_spmf_conv)

lemma ord_spmf_eqD_measure: "ord_spmf (=) p q  measure (measure_spmf p) A  measure (measure_spmf q) A"
by(drule ord_spmf_eqD_measure_spmf)(simp add: le_measure measure_spmf.emeasure_eq_measure)

lemma ord_spmf_measureD:
  assumes "ord_spmf R p q"
  shows "measure (measure_spmf p) A  measure (measure_spmf q) {y. xA. R x y}"
    (is "?lhs  ?rhs")
proof -
  from assms obtain p' where *: "rel_spmf R p p'" and **: "ord_spmf (=) p' q"
    by(auto simp add: ord_spmf_expand)
  have "?lhs  measure (measure_spmf p') {y. xA. R x y}" using * by(rule rel_spmf_measureD)
  also have "  ?rhs" using ** by(rule ord_spmf_eqD_measure)
  finally show ?thesis .
qed

lemma ord_spmf_bind_pmfI1:
  "(x. x  set_pmf p  ord_spmf R (f x) q)  ord_spmf R (bind_pmf p f) q"
  apply(rewrite at "ord_spmf _ _ " bind_return_pmf[symmetric, where f="λ_ :: unit. q"])
  apply(rule rel_pmf_bindI[where R="λx y. x  set_pmf p"])
  apply(simp_all add: rel_pmf_return_pmf2)
  done
  
lemma ord_spmf_bind_spmfI1:
  "(x. x  set_spmf p  ord_spmf R (f x) q)  ord_spmf R (bind_spmf p f) q"
unfolding bind_spmf_def by(rule ord_spmf_bind_pmfI1)(auto split: option.split simp add: in_set_spmf)

lemma spmf_of_set_empty: "spmf_of_set {} = return_pmf None"
by(simp add: spmf_of_set_def)

lemma rel_spmf_of_setI:
  assumes card: "X. X  A  card B * card X  card A * card {yB. xX. R x y}"
  and eq: "(finite A  A  {})  (finite B  B  {})"
  shows "rel_spmf R (spmf_of_set A) (spmf_of_set B)"
using eq by(clarsimp simp add: spmf_of_set_def card rel_pmf_of_setI simp del: spmf_of_pmf_pmf_of_set cong: conj_cong)

lemmas map_bind_spmf = map_spmf_bind_spmf

lemma nn_integral_measure_spmf_conv_measure_pmf:
  assumes [measurable]: "f  borel_measurable (count_space UNIV)"
  shows "nn_integral (measure_spmf p) f = nn_integral (restrict_space (measure_pmf p) (range Some)) (f  the)"
by(simp add: measure_spmf_def nn_integral_distr o_def)

lemma nn_integral_spmf_neq_infinity: "(+ x. spmf p x count_space UNIV)  "
using nn_integral_measure_spmf[where f="λ_. 1", of p, symmetric] by simp

lemma return_pmf_bind_option:
  "return_pmf (Option.bind x f) = bind_spmf (return_pmf x) (return_pmf  f)"
by(cases x) simp_all

lemma rel_spmf_pos_distr: "rel_spmf A OO rel_spmf B  rel_spmf (A OO B)"
unfolding option.rel_compp pmf.rel_compp ..

lemma rel_spmf_OO_trans [trans]:
  " rel_spmf R p q; rel_spmf S q r   rel_spmf (R OO S) p r"
by(rule rel_spmf_pos_distr[THEN predicate2D]) auto

lemma map_spmf_eq_map_spmf_iff: "map_spmf f p = map_spmf g q  rel_spmf (λx y. f x = g y) p q"
by(simp add: spmf_rel_eq[symmetric] spmf_rel_map)

lemma map_spmf_eq_map_spmfI: "rel_spmf (λx y. f x = g y) p q  map_spmf f p = map_spmf g q"
by(simp add: map_spmf_eq_map_spmf_iff)

lemma spmf_rel_mono_strong:
  "rel_spmf A f g; x y.  x  set_spmf f; y  set_spmf g; A x y   B x y   rel_spmf B f g"
apply(erule pmf.rel_mono_strong)
apply(erule option.rel_mono_strong)
by(clarsimp simp add: in_set_spmf)

lemma set_spmf_eq_empty: "set_spmf p = {}  p = return_pmf None"
by auto (metis restrict_spmf_empty restrict_spmf_trivial)


lemma measure_pair_spmf_times:
  "measure (measure_spmf (pair_spmf p q)) (A × B) = measure (measure_spmf p) A * measure (measure_spmf q) B"
proof -
  have "emeasure (measure_spmf (pair_spmf p q)) (A × B) = (+ x. ennreal (spmf (pair_spmf p q) x) * indicator (A × B) x count_space UNIV)"
    by(simp add: nn_integral_spmf[symmetric] nn_integral_count_space_indicator)
  also have " = (+ x. (+ y. (ennreal (spmf p x) * indicator A x) * (ennreal (spmf q y) * indicator B y) count_space UNIV) count_space UNIV)"
    by(subst nn_integral_fst_count_space[symmetric])(auto intro!: nn_integral_cong split: split_indicator simp add: ennreal_mult)
  also have " = (+ x. ennreal (spmf p x) * indicator A x * emeasure (measure_spmf q) B count_space UNIV)"
    by(simp add: nn_integral_cmult nn_integral_spmf[symmetric] nn_integral_count_space_indicator)
  also have " = emeasure (measure_spmf p) A * emeasure (measure_spmf q) B"
    by(simp add: nn_integral_multc)(simp add: nn_integral_spmf[symmetric] nn_integral_count_space_indicator)
  finally show ?thesis by(simp add: measure_spmf.emeasure_eq_measure ennreal_mult[symmetric])
qed

lemma lossless_spmfD_set_spmf_nonempty: "lossless_spmf p  set_spmf p  {}"
using set_pmf_not_empty[of p] by(auto simp add: set_spmf_def bind_UNION lossless_iff_set_pmf_None)

lemma set_spmf_return_pmf: "set_spmf (return_pmf x) = set_option x"
by(cases x) simp_all

lemma bind_spmf_pmf_assoc: "bind_spmf (bind_pmf p f) g = bind_pmf p (λx. bind_spmf (f x) g)"
by(simp add: bind_spmf_def bind_assoc_pmf)

lemma bind_spmf_of_set:  " finite A; A  {}   bind_spmf (spmf_of_set A) f = bind_pmf (pmf_of_set A) f"
by(simp add: spmf_of_set_def del: spmf_of_pmf_pmf_of_set)

lemma bind_spmf_map_pmf:
  "bind_spmf (map_pmf f p) g = bind_pmf p (λx. bind_spmf (return_pmf (f x)) g)"
by(simp add: map_pmf_def bind_spmf_def bind_assoc_pmf)

lemma rel_spmf_eqI [simp]: "rel_spmf (=) x x"
by(simp add: option.rel_eq)

lemma set_spmf_map_pmf: "set_spmf (map_pmf f p) = (xset_pmf p. set_option (f x))" (* Move up *)
by(simp add: set_spmf_def bind_UNION)

lemma ord_spmf_return_spmf [simp]: "ord_spmf (=) (return_spmf x) p  p = return_spmf x"
proof -
  have "p = return_spmf x  ord_spmf (=) (return_spmf x) p" by simp
  thus ?thesis
    by (metis (no_types) ord_option_eq_simps(2) rel_pmf_return_pmf1 rel_pmf_return_pmf2 spmf.leq_antisym)
qed

declare
  set_bind_spmf [simp]
  set_spmf_return_pmf [simp]

lemma bind_spmf_pmf_commute:
  "bind_spmf p (λx. bind_pmf q (f x)) = bind_pmf q (λy. bind_spmf p (λx. f x y))"
unfolding bind_spmf_def 
by(subst bind_commute_pmf)(auto intro: bind_pmf_cong[OF refl] split: option.split)

lemma return_pmf_map_option_conv_bind:
  "return_pmf (map_option f x) = bind_spmf (return_pmf x) (return_spmf  f)"
by(cases x) simp_all

lemma lossless_return_pmf_iff [simp]: "lossless_spmf (return_pmf x)  x  None"
by(cases x) simp_all

lemma lossless_map_pmf: "lossless_spmf (map_pmf f p)  (x  set_pmf p. f x  None)"
using image_iff by(fastforce simp add: lossless_iff_set_pmf_None)

lemma bind_pmf_spmf_assoc:
  "g None = return_pmf None
   bind_pmf (bind_spmf p f) g = bind_spmf p (λx. bind_pmf (f x) g)"
by(auto simp add: bind_spmf_def bind_assoc_pmf bind_return_pmf fun_eq_iff intro!: arg_cong2[where f=bind_pmf] split: option.split)

abbreviation pred_spmf :: "('a  bool)  'a spmf  bool"
where "pred_spmf P  pred_pmf (pred_option P)"

lemma pred_spmf_def: "pred_spmf P p  (xset_spmf p. P x)"
by(auto simp add: pred_pmf_def pred_option_def set_spmf_def)

lemma spmf_pred_mono_strong:
  " pred_spmf P p; a.  a  set_spmf p; P a   P' a   pred_spmf P' p"
by(simp add: pred_spmf_def)

lemma spmf_Domainp_rel: "Domainp (rel_spmf R) = pred_spmf (Domainp R)"
by(simp add: pmf.Domainp_rel option.Domainp_rel)

lemma rel_spmf_restrict_relpI [intro?]:
  " rel_spmf R p q; pred_spmf P p; pred_spmf Q q   rel_spmf (R  P  Q) p q"
by(erule spmf_rel_mono_strong)(simp add: pred_spmf_def)

lemma rel_spmf_restrict_relpE [elim?]:
  assumes "rel_spmf (R  P  Q) x y"
  obtains "rel_spmf R x y" "pred_spmf P x" "pred_spmf Q y"
proof
  show "rel_spmf R x y" using assms by(auto elim!: spmf_rel_mono_strong)
  have "pred_spmf (Domainp (R  P  Q)) x" using assms by(fold spmf_Domainp_rel) blast
  then show "pred_spmf P x" by(rule spmf_pred_mono_strong)(blast dest!: restrict_relp_DomainpD)
  have "pred_spmf (Domainp (R  P  Q)¯¯) y" using assms
    by(fold spmf_Domainp_rel)(auto simp only: spmf_rel_conversep Domainp_conversep)
  then show "pred_spmf Q y" by(rule spmf_pred_mono_strong)(auto dest!: restrict_relp_DomainpD)
qed

lemma rel_spmf_restrict_relp_iff:
  "rel_spmf (R  P  Q) x y  rel_spmf R x y  pred_spmf P x  pred_spmf Q y"
by(blast intro: rel_spmf_restrict_relpI elim: rel_spmf_restrict_relpE)

lemma spmf_pred_map: "pred_spmf P (map_spmf f p) = pred_spmf (P  f) p"
by(simp)

lemma pred_spmf_bind [simp]: "pred_spmf P (bind_spmf p f) = pred_spmf (pred_spmf P  f) p"
by(simp add: pred_spmf_def bind_UNION)

lemma pred_spmf_return: "pred_spmf P (return_spmf x) = P x"
by simp

lemma pred_spmf_return_pmf_None: "pred_spmf P (return_pmf None)"
by simp

lemma pred_spmf_spmf_of_pmf [simp]: "pred_spmf P (spmf_of_pmf p) = pred_pmf P p"
unfolding pred_spmf_def by(simp add: pred_pmf_def)

lemma pred_spmf_of_set [simp]: "pred_spmf P (spmf_of_set A) = (finite A  Ball A P)"
by(auto simp add: pred_spmf_def set_spmf_of_set)

lemma pred_spmf_assert_spmf [simp]: "pred_spmf P (assert_spmf b) = (b  P ())"
by(cases b) simp_all

lemma pred_spmf_pair [simp]:
  "pred_spmf P (pair_spmf p q) = pred_spmf (λx. pred_spmf (P  Pair x) q) p"
by(simp add: pred_spmf_def)

lemma set_spmf_try [simp]:
  "set_spmf (try_spmf p q) = set_spmf p  (if lossless_spmf p then {} else set_spmf q)"
by(auto simp add: try_spmf_def set_spmf_bind_pmf in_set_spmf lossless_iff_set_pmf_None split: option.splits)(metis option.collapse)

lemma try_spmf_bind_out1:
  "(x. lossless_spmf (f x))  bind_spmf (TRY p ELSE q) f = TRY (bind_spmf p f) ELSE (bind_spmf q f)"
  apply(clarsimp simp add: bind_spmf_def try_spmf_def bind_assoc_pmf bind_return_pmf intro!: bind_pmf_cong[OF refl] split: option.split)
  apply(rewrite in " = _" bind_return_pmf'[symmetric])
  apply(rule bind_pmf_cong[OF refl])
  apply(clarsimp split: option.split simp add: lossless_iff_set_pmf_None)
  done

lemma pred_spmf_try [simp]:
  "pred_spmf P (try_spmf p q) = (pred_spmf P p  (¬ lossless_spmf p  pred_spmf P q))"
by(auto simp add: pred_spmf_def)

lemma pred_spmf_cond [simp]:
  "pred_spmf P (cond_spmf p A) = pred_spmf (λx. x  A  P x) p"
by(auto simp add: pred_spmf_def)

lemma spmf_rel_map_restrict_relp: 
  shows spmf_rel_map_restrict_relp1: "rel_spmf (R  P  Q) (map_spmf f p) = rel_spmf (R  f  P  f  Q) p"
  and spmf_rel_map_restrict_relp2: "rel_spmf (R  P  Q) p (map_spmf g q) = rel_spmf ((λx. R x  g)  P  Q  g) p q"
by(simp_all add: spmf_rel_map restrict_relp_def)

lemma pred_spmf_conj: "pred_spmf (λx. P x  Q x) = (λx. pred_spmf P x  pred_spmf Q x)"
by simp

lemma spmf_of_pmf_parametric [transfer_rule]: 
  includes lifting_syntax shows
  "(rel_pmf A ===> rel_spmf A) spmf_of_pmf spmf_of_pmf"
unfolding spmf_of_pmf_def[abs_def] by transfer_prover

lemma mono2mono_return_pmf[THEN spmf.mono2mono, simp, cont_intro]: (* Move to SPMF *)
  shows monotone_return_pmf: "monotone option_ord (ord_spmf (=)) return_pmf"
by(rule monotoneI)(auto simp add: flat_ord_def)

lemma mcont2mcont_return_pmf[THEN spmf.mcont2mcont, simp, cont_intro]:  (* Move to SPMF *)
  shows mcont_return_pmf: "mcont (flat_lub None) option_ord lub_spmf (ord_spmf (=)) return_pmf"
by(rule mcont_finite_chains[OF _ _ flat_interpretation[THEN ccpo] ccpo_spmf]) simp_all

lemma pred_spmf_top: (* Move up *)
  "pred_spmf (λ_. True) = (λ_. True)"
by(simp)

lemma rel_spmf_restrict_relpI' [intro?]:
  " rel_spmf (λx y. P x  Q y  R x y) p q; pred_spmf P p; pred_spmf Q q   rel_spmf (R  P  Q) p q"
by(erule spmf_rel_mono_strong)(simp add: pred_spmf_def)

lemma set_spmf_map_pmf_MATCH [simp]:
  assumes "NO_MATCH (map_option g) f"
  shows "set_spmf (map_pmf f p) = (xset_pmf p. set_option (f x))"
by(rule set_spmf_map_pmf)

lemma rel_spmf_bindI':
  " rel_spmf A p q; x y.  A x y; x  set_spmf p; y  set_spmf q   rel_spmf B (f x) (g y) 
   rel_spmf B (p  f) (q  g)"
apply(rule rel_spmf_bindI[where R="λx y. A x y  x  set_spmf p  y  set_spmf q"])
 apply(erule spmf_rel_mono_strong; simp)
apply simp
done

definition rel_witness_spmf :: "('a  'b  bool)  'a spmf × 'b spmf  ('a × 'b) spmf" where
  "rel_witness_spmf A = map_pmf rel_witness_option  rel_witness_pmf (rel_option A)"

lemma assumes "rel_spmf A p q"
  shows rel_witness_spmf1: "rel_spmf (λa (a', b). a = a'  A a' b) p (rel_witness_spmf A (p, q))"
    and rel_witness_spmf2: "rel_spmf (λ(a, b') b. b = b'  A a b') (rel_witness_spmf A (p, q)) q"
  by(auto simp add: pmf.rel_map rel_witness_spmf_def intro: pmf.rel_mono_strong[OF rel_witness_pmf1[OF assms]] rel_witness_option1 pmf.rel_mono_strong[OF rel_witness_pmf2[OF assms]] rel_witness_option2)

lemma weight_assert_spmf [simp]: "weight_spmf (assert_spmf b) = indicator {True} b"
  by(simp split: split_indicator)

definition enforce_spmf :: "('a  bool)  'a spmf  'a spmf" where
  "enforce_spmf P = map_pmf (enforce_option P)"

lemma enforce_spmf_parametric [transfer_rule]: includes lifting_syntax shows
  "((A ===> (=)) ===> rel_spmf A ===> rel_spmf A) enforce_spmf enforce_spmf"
  unfolding enforce_spmf_def by transfer_prover

lemma enforce_return_spmf [simp]:
  "enforce_spmf P (return_spmf x) = (if P x then return_spmf x else return_pmf None)"
  by(simp add: enforce_spmf_def)

lemma enforce_return_pmf_None [simp]:
  "enforce_spmf P (return_pmf None) = return_pmf None"
  by(simp add: enforce_spmf_def)

lemma enforce_map_spmf:
  "enforce_spmf P (map_spmf f p) = map_spmf f (enforce_spmf (P  f) p)"
  by(simp add: enforce_spmf_def pmf.map_comp o_def enforce_map_option)

lemma enforce_bind_spmf [simp]:
  "enforce_spmf P (bind_spmf p f) = bind_spmf p (enforce_spmf P  f)"
  by(auto simp add: enforce_spmf_def bind_spmf_def map_bind_pmf intro!: bind_pmf_cong split: option.split)

lemma set_enforce_spmf [simp]: "set_spmf (enforce_spmf P p) = {a  set_spmf p. P a}"
  by(auto simp add: enforce_spmf_def in_set_spmf)

lemma enforce_spmf_alt_def:
  "enforce_spmf P p = bind_spmf p (λa. bind_spmf (assert_spmf (P a)) (λ_ :: unit. return_spmf a))"
  by(auto simp add: enforce_spmf_def assert_spmf_def map_pmf_def bind_spmf_def bind_return_pmf intro!: bind_pmf_cong split: option.split)

lemma bind_enforce_spmf [simp]:
  "bind_spmf (enforce_spmf P p) f = bind_spmf p (λx. if P x then f x else return_pmf None)"
  by(auto simp add: enforce_spmf_alt_def assert_spmf_def intro!: bind_spmf_cong)

lemma weight_enforce_spmf:
  "weight_spmf (enforce_spmf P p) = weight_spmf p - measure (measure_spmf p) {x. ¬ P x}" (is "?lhs = ?rhs")
proof -
  have "?lhs = LINT x|measure_spmf p. indicator {x. P x} x"
    by(auto simp add: enforce_spmf_alt_def weight_bind_spmf o_def simp del: Bochner_Integration.integral_indicator intro!: Bochner_Integration.integral_cong split: split_indicator)
  also have " = ?rhs"
    by(subst measure_spmf.finite_measure_Diff[symmetric])(auto simp add: space_measure_spmf intro!: arg_cong2[where f=measure])
  finally show ?thesis .
qed

lemma lossless_enforce_spmf [simp]:
  "lossless_spmf (enforce_spmf P p)  lossless_spmf p  set_spmf p  {x. P x}"
  by(auto simp add: enforce_spmf_alt_def)

lemma enforce_spmf_top [simp]: "enforce_spmf  = id"
  by(simp add: enforce_spmf_def)

lemma enforce_spmf_K_True [simp]: "enforce_spmf (λ_. True) p = p"
  using enforce_spmf_top[THEN fun_cong, of p] by(simp add: top_fun_def)

lemma enforce_spmf_bot [simp]: "enforce_spmf  = (λ_. return_pmf None)"
  by(simp add: enforce_spmf_def fun_eq_iff)

lemma enforce_spmf_K_False [simp]: "enforce_spmf (λ_. False) p = return_pmf None"
  using enforce_spmf_bot[THEN fun_cong, of p] by(simp add: bot_fun_def)

lemma enforce_pred_id_spmf: "enforce_spmf P p = p" if "pred_spmf P p"
proof -
  have "enforce_spmf P p = map_pmf id p" using that
    by(auto simp add: enforce_spmf_def enforce_pred_id_option simp del: map_pmf_id intro!: pmf.map_cong_pred[OF refl] elim!: pmf_pred_mono_strong)
  then show ?thesis by simp
qed

lemma map_the_spmf_of_pmf [simp]: "map_pmf the (spmf_of_pmf p) = p"
  by(simp add: spmf_of_pmf_def pmf.map_comp o_def)

lemma bind_bind_conv_pair_spmf:
  "bind_spmf p (λx. bind_spmf q (f x)) = bind_spmf (pair_spmf p q) (λ(x, y). f x y)"
  by(simp add: pair_spmf_alt_def)

lemma cond_spmf_spmf_of_set:
  "cond_spmf (spmf_of_set A) B = spmf_of_set (A  B)" if "finite A"
  by(rule spmf_eqI)(auto simp add: spmf_of_set measure_spmf_of_set that split: split_indicator)

lemma pair_spmf_of_set:
  "pair_spmf (spmf_of_set A) (spmf_of_set B) = spmf_of_set (A × B)"
  by(rule spmf_eqI)(clarsimp simp add: spmf_of_set card_cartesian_product split: split_indicator)

lemma emeasure_cond_spmf:
  "emeasure (measure_spmf (cond_spmf p A)) B = emeasure (measure_spmf p) (A  B) / emeasure (measure_spmf p) A"
  apply(clarsimp simp add: cond_spmf_def emeasure_measure_spmf_conv_measure_pmf emeasure_measure_pmf_zero_iff set_pmf_Int_Some split!: if_split)
   apply blast
  apply(subst (asm) emeasure_cond_pmf)
  by(auto simp add: set_pmf_Int_Some image_Int)

lemma measure_cond_spmf:
  "measure (measure_spmf (cond_spmf p A)) B = measure (measure_spmf p) (A  B) / measure (measure_spmf p) A"
  apply(clarsimp simp add: cond_spmf_def measure_measure_spmf_conv_measure_pmf measure_pmf_zero_iff set_pmf_Int_Some split!: if_split)
  apply(subst (asm) measure_cond_pmf)
  by(auto simp add: image_Int set_pmf_Int_Some)


lemma lossless_cond_spmf [simp]: "lossless_spmf (cond_spmf p A)  set_spmf p  A  {}"
  by(clarsimp simp add: cond_spmf_def lossless_iff_set_pmf_None set_pmf_Int_Some)

lemma measure_spmf_eq_density: "measure_spmf p = density (count_space UNIV) (spmf p)"
  by(rule measure_eqI)(simp_all add: emeasure_density nn_integral_spmf[symmetric] nn_integral_count_space_indicator)

lemma integral_measure_spmf:
  fixes f :: "'a  'b::{banach, second_countable_topology}"
  assumes A: "finite A"
  shows "(a. a  set_spmf M  f a  0  a  A)  (LINT x|measure_spmf M. f x) = (aA. spmf M a *R f a)"
  unfolding measure_spmf_eq_density
  apply (simp add: integral_density)
  apply (subst lebesgue_integral_count_space_finite_support)
  by (auto intro!: finite_subset[OF _ finite A] sum.mono_neutral_left simp: spmf_eq_0_set_spmf)


lemma image_set_spmf_eq:
  "f ` set_spmf p = g ` set_spmf q" if "ASSUMPTION (map_spmf f p = map_spmf g q)"
  using that[unfolded ASSUMPTION_def, THEN arg_cong[where f=set_spmf]] by simp

lemma map_spmf_const: "map_spmf (λ_. x) p = scale_spmf (weight_spmf p) (return_spmf x)"
  by(simp add: map_spmf_conv_bind_spmf bind_spmf_const)

lemma cond_return_pmf [simp]: "cond_pmf (return_pmf x) A = return_pmf x" if "x  A"
  using that by(intro pmf_eqI)(auto simp add: pmf_cond split: split_indicator)

lemma cond_return_spmf [simp]: "cond_spmf (return_spmf x) A = (if x  A then return_spmf x else return_pmf None)"
  by(simp add: cond_spmf_def)

lemma measure_range_Some_eq_weight:
  "measure (measure_pmf p) (range Some) = weight_spmf p"
  by (simp add: measure_measure_spmf_conv_measure_pmf space_measure_spmf)

lemma restrict_spmf_eq_return_pmf_None [simp]:
  "restrict_spmf p A = return_pmf None  set_spmf p  A = {}"
  by(auto 4 3 simp add: restrict_spmf_def map_pmf_eq_return_pmf_iff bind_UNION in_set_spmf bind_eq_None_conv option.the_def dest: bspec split: if_split_asm option.split_asm)

definition mk_lossless :: "'a spmf  'a spmf" where
  "mk_lossless p = scale_spmf (inverse (weight_spmf p)) p"

lemma mk_lossless_idem [simp]: "mk_lossless (mk_lossless p) = mk_lossless p"
  by(simp add: mk_lossless_def weight_scale_spmf min_def max_def inverse_eq_divide) 

lemma mk_lossless_return [simp]: "mk_lossless (return_pmf x) = return_pmf x"
  by(cases x)(simp_all add: mk_lossless_def)

lemma mk_lossless_map [simp]: "mk_lossless (map_spmf f p) = map_spmf f (mk_lossless p)"
  by(simp add: mk_lossless_def map_scale_spmf)

lemma spmf_mk_lossless [simp]: "spmf (mk_lossless p) x = spmf p x / weight_spmf p"
  by(simp add: mk_lossless_def spmf_scale_spmf inverse_eq_divide max_def)

lemma set_spmf_mk_lossless [simp]: "set_spmf (mk_lossless p) = set_spmf p"
  by(simp add: mk_lossless_def set_scale_spmf measure_spmf_zero_iff zero_less_measure_iff)

lemma mk_lossless_lossless [simp]: "lossless_spmf p  mk_lossless p = p"
  by(simp add: mk_lossless_def lossless_weight_spmfD)

lemma mk_lossless_eq_return_pmf_None [simp]: "mk_lossless p = return_pmf None  p = return_pmf None"
proof -
  have aux: "weight_spmf p = 0  spmf p i = 0" for i
    by(rule antisym, rule order_trans[OF spmf_le_weight]) (auto intro!: order_trans[OF spmf_le_weight])

  have[simp]: " spmf (scale_spmf (inverse (weight_spmf p)) p) = spmf (return_pmf None)  spmf p i = 0" for i
    by(drule fun_cong[where x=i]) (auto simp add: aux spmf_scale_spmf max_def)

  show ?thesis by(auto simp add: mk_lossless_def intro: spmf_eqI)
qed

lemma return_pmf_None_eq_mk_lossless [simp]: "return_pmf None = mk_lossless p  p = return_pmf None"
  by(metis mk_lossless_eq_return_pmf_None)

lemma mk_lossless_spmf_of_set [simp]: "mk_lossless (spmf_of_set A) = spmf_of_set A"
  by(simp add: spmf_of_set_def del: spmf_of_pmf_pmf_of_set)

lemma weight_mk_lossless: "weight_spmf (mk_lossless p) = (if p = return_pmf None then 0 else 1)"
  by(simp add: mk_lossless_def weight_scale_spmf min_def max_def inverse_eq_divide weight_spmf_eq_0)

lemma mk_lossless_parametric [transfer_rule]: includes lifting_syntax shows
  "(rel_spmf A ===> rel_spmf A) mk_lossless mk_lossless"
  by(simp add: mk_lossless_def rel_fun_def rel_spmf_weightD rel_spmf_scaleI)

lemma rel_spmf_mk_losslessI:
  "rel_spmf A p q  rel_spmf A (mk_lossless p) (mk_lossless q)"
  by(rule mk_lossless_parametric[THEN rel_funD])

lemma rel_spmf_restrict_spmfI:
  "rel_spmf (λx y. (x  A  y  B  R x y)  x  A  y  B) p q
    rel_spmf R (restrict_spmf p A) (restrict_spmf q B)"
  by(auto simp add: restrict_spmf_def pmf.rel_map elim!: option.rel_cases pmf.rel_mono_strong)

lemma cond_spmf_alt: "cond_spmf p A = mk_lossless (restrict_spmf p A)"
proof(cases "set_spmf p  A = {}")
  case True
  then show ?thesis by(simp add: cond_spmf_def measure_spmf_zero_iff)
next
  case False
  show ?thesis
    by(rule spmf_eqI)(simp add: False cond_spmf_def pmf_cond set_pmf_Int_Some image_iff measure_measure_spmf_conv_measure_pmf[symmetric] spmf_scale_spmf max_def inverse_eq_divide)
qed

lemma cond_spmf_bind:
  "cond_spmf (bind_spmf p f) A = mk_lossless (p  (λx. f x  A))"
  by(simp add: cond_spmf_alt restrict_bind_spmf scale_bind_spmf)

lemma cond_spmf_UNIV [simp]: "cond_spmf p UNIV = mk_lossless p"
  by(clarsimp simp add: cond_spmf_alt)

lemma cond_pmf_singleton:
  "cond_pmf p A = return_pmf x" if "set_pmf p  A = {x}"
proof -
  have[simp]: "set_pmf p  A = {x}  x  A  measure_pmf.prob p A = pmf p x"
    by(auto simp add: measure_pmf_single[symmetric] AE_measure_pmf_iff intro!: measure_pmf.finite_measure_eq_AE)

  have "pmf (cond_pmf p A) i = pmf (return_pmf x) i" for i
    using that by(auto simp add: pmf_cond measure_pmf_zero_iff pmf_eq_0_set_pmf split: split_indicator)

  then show ?thesis by(rule pmf_eqI)
qed


definition cond_spmf_fst :: "('a × 'b) spmf  'a  'b spmf" where
  "cond_spmf_fst p a = map_spmf snd (cond_spmf p ({a} × UNIV))"

lemma cond_spmf_fst_return_spmf [simp]:
  "cond_spmf_fst (return_spmf (x, y)) x = return_spmf y"
  by(simp add: cond_spmf_fst_def)

lemma cond_spmf_fst_map_Pair [simp]: "cond_spmf_fst (map_spmf (Pair x) p) x = mk_lossless p"
  by(clarsimp simp add: cond_spmf_fst_def spmf.map_comp o_def)

lemma cond_spmf_fst_map_Pair' [simp]: "cond_spmf_fst (map_spmf (λy. (x, f y)) p) x = map_spmf f (mk_lossless p)"
  by(subst spmf.map_comp[where f="Pair x", symmetric, unfolded o_def]) simp

lemma cond_spmf_fst_eq_return_None [simp]: "cond_spmf_fst p x = return_pmf None  x  fst ` set_spmf p"
  by(auto 4 4 simp add: cond_spmf_fst_def map_pmf_eq_return_pmf_iff in_set_spmf[symmetric] dest: bspec[where x="Some _"] intro: ccontr rev_image_eqI)

lemma cond_spmf_fst_map_Pair1:
  "cond_spmf_fst (map_spmf (λx. (f x, g x)) p) (f x) = return_spmf (g (inv_into (set_spmf p) f (f x)))"
  if "x  set_spmf p" "inj_on f (set_spmf p)"
proof -
  let ?foo="λy. map_option (λx. (f x, g x)) -` Some ` ({f y} × UNIV)"
  have[simp]: "y  set_spmf p  f x = f y  set_pmf p  (?foo y)  {}" for y
    by(auto simp add: vimage_def image_def in_set_spmf)

  have[simp]: "y  set_spmf p  f x = f y   map_spmf snd (map_spmf (λx. (f x, g x)) (cond_pmf p (?foo y))) = return_spmf (g x)" for y
    using that by(subst cond_pmf_singleton[where x="Some x"]) (auto simp add: in_set_spmf elim: inj_onD)

  show ?thesis
    using that
    by(auto simp add: cond_spmf_fst_def cond_spmf_def)
      (erule notE, subst cond_map_pmf, simp_all)
qed

lemma lossless_cond_spmf_fst [simp]: "lossless_spmf (cond_spmf_fst p x)  x  fst ` set_spmf p"
  by(auto simp add: cond_spmf_fst_def intro: rev_image_eqI)

lemma cond_spmf_fst_inverse:
  "bind_spmf (map_spmf fst p) (λx. map_spmf (Pair x) (cond_spmf_fst p x)) = p"
  (is "?lhs = ?rhs")
proof(rule spmf_eqI)
  fix i :: "'a × 'b"
  have *: "({x} × UNIV  (Pair x  snd) -` {i}) = (if x = fst i then {i} else {})" for x by(cases i)auto
  have "spmf ?lhs i = LINT x|measure_spmf (map_spmf fst p). spmf (map_spmf (Pair x  snd) (cond_spmf p ({x} × UNIV))) i"
    by(auto simp add: spmf_bind spmf.map_comp[symmetric] cond_spmf_fst_def intro!: integral_cong_AE)
  also have " = LINT x|measure_spmf (map_spmf fst p). measure (measure_spmf (cond_spmf p ({x} × UNIV))) ((Pair x  snd) -` {i})"
    by(rule integral_cong_AE)(auto simp add: spmf_map)
  also have " = LINT x|measure_spmf (map_spmf fst p). measure (measure_spmf p) ({x} × UNIV  (Pair x  snd) -` {i}) /
       measure (measure_spmf p) ({x} × UNIV)"
    by(rule integral_cong_AE; clarsimp simp add: measure_cond_spmf)
  also have " = spmf (map_spmf fst p) (fst i) * spmf p i / measure (measure_spmf p) ({fst i} × UNIV)"
    by(simp add: * if_distrib[where f="measure (measure_spmf _)"] cong: if_cong)
      (subst integral_measure_spmf[where A="{fst i}"]; auto split: if_split_asm simp add: spmf_conv_measure_spmf)
  also have " = spmf p i"
    by(clarsimp simp add: spmf_map vimage_fst)(metis (no_types, lifting) Int_insert_left_if1 in_set_spmf_iff_spmf insertI1 insert_UNIV insert_absorb insert_not_empty measure_spmf_zero_iff mem_Sigma_iff prod.collapse)
  finally show "spmf ?lhs i = spmf ?rhs i" .
qed

subsubsection ‹Embedding of @{typ "'a option"} into @{typ "'a spmf"}

text ‹This theoretically follows from the embedding between @{typ "_ id"} into @{typ "_ prob"} and the isomorphism
  between @{typ "(_, _ prob) optionT"} and @{typ "_ spmf"}, but we would only get the monomorphic
  version via this connection. So we do it directly.
›

lemma bind_option_spmf_monad [simp]: "monad.bind_option (return_pmf None) x = bind_spmf (return_pmf x)"
by(cases x)(simp_all add: fun_eq_iff)

locale option_to_spmf begin

text ‹
  We have to get the embedding into the lifting package such that we can use the parametrisation of transfer rules.
›

definition the_pmf :: "'a pmf  'a" where "the_pmf p = (THE x. p = return_pmf x)"

lemma the_pmf_return [simp]: "the_pmf (return_pmf x) = x"
by(simp add: the_pmf_def)

lemma type_definition_option_spmf: "type_definition return_pmf the_pmf {x. y :: 'a option. x = return_pmf y}"
by unfold_locales(auto)

context begin
private setup_lifting type_definition_option_spmf
abbreviation cr_spmf_option where "cr_spmf_option  cr_option"
abbreviation pcr_spmf_option where "pcr_spmf_option  pcr_option"
lemmas Quotient_spmf_option = Quotient_option
  and cr_spmf_option_def = cr_option_def
  and pcr_spmf_option_bi_unique = option.bi_unique
  and Domainp_pcr_spmf_option = option.domain
  and Domainp_pcr_spmf_option_eq = option.domain_eq
  and Domainp_pcr_spmf_option_par = option.domain_par
  and Domainp_pcr_spmf_option_left_total = option.domain_par_left_total
  and pcr_spmf_option_left_unique = option.left_unique
  and pcr_spmf_option_cr_eq = option.pcr_cr_eq
  and pcr_spmf_option_return_pmf_transfer = option.rep_transfer
  and pcr_spmf_option_right_total = option.right_total
  and pcr_spmf_option_right_unique = option.right_unique
  and pcr_spmf_option_def = pcr_option_def
bundle spmf_option_lifting = [[Lifting.lifting_restore_internal "Misc_CryptHOL.option.lifting"]]
end


context includes lifting_syntax begin

lemma return_option_spmf_transfer [transfer_parametric return_spmf_parametric, transfer_rule]:
  "((=) ===> cr_spmf_option) return_spmf Some"
by(rule rel_funI)(simp add: cr_spmf_option_def)

lemma map_option_spmf_transfer [transfer_parametric map_spmf_parametric, transfer_rule]:
  "(((=) ===> (=)) ===> cr_spmf_option ===> cr_spmf_option) map_spmf map_option"
unfolding rel_fun_eq by(auto simp add: rel_fun_def cr_spmf_option_def)

lemma fail_option_spmf_transfer [transfer_parametric return_spmf_None_parametric, transfer_rule]:
  "cr_spmf_option (return_pmf None) None"
by(simp add: cr_spmf_option_def)

lemma bind_option_spmf_transfer [transfer_parametric bind_spmf_parametric, transfer_rule]:
  "(cr_spmf_option ===> ((=) ===> cr_spmf_option) ===> cr_spmf_option) bind_spmf Option.bind"
apply(clarsimp simp add: rel_fun_def cr_spmf_option_def)
subgoal for x f g by(cases x; simp)
done

lemma set_option_spmf_transfer [transfer_parametric set_spmf_parametric, transfer_rule]:
  "(cr_spmf_option ===> rel_set (=)) set_spmf set_option"
by(clarsimp simp add: rel_fun_def cr_spmf_option_def rel_set_eq)

lemma rel_option_spmf_transfer [transfer_parametric rel_spmf_parametric, transfer_rule]:
  "(((=) ===> (=) ===> (=)) ===> cr_spmf_option ===> cr_spmf_option ===> (=)) rel_spmf rel_option"
unfolding rel_fun_eq by(simp add: rel_fun_def cr_spmf_option_def)

end

end

locale option_le_spmf begin

text ‹
  Embedding where only successful computations in the option monad are related to Dirac spmf.
›

definition cr_option_le_spmf :: "'a option  'a spmf  bool"
where "cr_option_le_spmf x p  ord_spmf (=) (return_pmf x) p"

context includes lifting_syntax begin

lemma return_option_le_spmf_transfer [transfer_rule]:
  "((=) ===> cr_option_le_spmf) (λx. x) return_pmf"
by(rule rel_funI)(simp add: cr_option_le_spmf_def ord_option_reflI)

lemma map_option_le_spmf_transfer [transfer_rule]:
  "(((=) ===> (=)) ===> cr_option_le_spmf ===> cr_option_le_spmf) map_option map_spmf"
unfolding rel_fun_eq
apply(clarsimp simp add: rel_fun_def cr_option_le_spmf_def rel_pmf_return_pmf1 ord_option_map1 ord_option_map2)
subgoal for f x p y by(cases x; simp add: ord_option_reflI)
done

lemma bind_option_le_spmf_transfer [transfer_rule]:
  "(cr_option_le_spmf ===> ((=) ===> cr_option_le_spmf) ===> cr_option_le_spmf) Option.bind bind_spmf"
apply(clarsimp simp add: rel_fun_def cr_option_le_spmf_def)
subgoal for x p f g by(cases x; auto 4 3 simp add: rel_pmf_return_pmf1 set_pmf_bind_spmf)
done

end

end

interpretation rel_spmf_characterisation by unfold_locales(rule rel_pmf_measureI)

lemma if_distrib_bind_spmf1 [if_distribs]:
  "bind_spmf (if b then x else y) f = (if b then bind_spmf x f else bind_spmf y f)"
by simp

lemma if_distrib_bind_spmf2 [if_distribs]:
  "bind_spmf x (λy. if b then f y else g y) = (if b then bind_spmf x f else bind_spmf x g)"
by simp

lemma rel_spmf_if_distrib [if_distribs]:
  "rel_spmf R (if b then x else y) (if b then x' else y') 
  (b  rel_spmf R x x')  (¬ b  rel_spmf R y y')"
by(simp)

lemma if_distrib_map_spmf [if_distribs]:
  "map_spmf f (if b then p else q) = (if b then map_spmf f p else map_spmf f q)"
by simp

lemma if_distrib_restrict_spmf1 [if_distribs]:
  "restrict_spmf (if b then p else q) A = (if b then restrict_spmf p A else restrict_spmf q A)"
by simp

end