Theory PDF_Density_Contexts

(*
  Theory: PDF_Density_Contexts.thy
  Authors: Manuel Eberl
*)

section ‹Density Contexts›

theory PDF_Density_Contexts
imports PDF_Semantics
begin

lemma measurable_proj_state_measure[measurable (raw)]:
    "i  V  (λx. x i)  measurable (state_measure V Γ) (Γ i)"
  unfolding state_measure_def by measurable

lemma measurable_dens_ctxt_fun_upd[measurable (raw)]:
  "f  N M state_measure V' Γ  V = V'  {x} 
    g  N M stock_measure (Γ x) 
    (λω. (f ω)(x := g ω))  N M state_measure V Γ"
  unfolding state_measure_def
  by (rule measurable_fun_upd[where J=V']) auto

lemma measurable_case_nat_Suc_PiM:
  "(λσ. σ  Suc)  measurable (PiM (Suc ` A) (case_nat M N)) (PiM A N)"
proof-
  have "(λσ. λxA. σ (Suc x))  measurable
      (PiM (Suc ` A) (case_nat M N)) (PiM A (λx. case_nat M N (Suc x)))" (is "?A")
    by measurable
  also have "?A  ?thesis"
    by (force intro!: measurable_cong ext simp: state_measure_def space_PiM dest: PiE_mem)
  finally show ?thesis .
qed

lemma measurable_case_nat_Suc:
  "(λσ. σ  Suc)  measurable (state_measure (Suc ` A) (case_nat t Γ)) (state_measure A Γ)"
proof-
  have "(λσ. λxA. σ (Suc x))  measurable
      (state_measure (Suc ` A) (case_nat t Γ)) (state_measure A (λi. case_nat t Γ (Suc i)))" (is "?A")
    unfolding state_measure_def by measurable
  also have "?A  ?thesis"
    by (force intro!: measurable_cong ext simp: state_measure_def space_PiM dest: PiE_mem)
  finally show ?thesis .
qed

text ‹A density context holds a set of variables @{term V}, their types (using @{term Γ}), and a
common density function @{term δ} of the finite product space of all the variables in @{term V}.
@{term δ} takes a state @{term "σ  (ΠE xV. type_universe (Γ x))"} and returns the common density
of these variables.›

type_synonym dens_ctxt = "vname set × vname set × (vname  pdf_type) × (state  ennreal)"
type_synonym expr_density = "state  val  ennreal"

definition empty_dens_ctxt :: dens_ctxt where
  "empty_dens_ctxt = ({}, {}, λ_. undefined, λ_. 1)"

definition state_measure'
    :: "vname set  vname set  (vname  pdf_type)  state  state measure" where
  "state_measure' V V' Γ ρ =
       distr (state_measure V Γ) (state_measure (VV') Γ) (λσ. merge V V' (σ, ρ))"


text ‹The marginal density of a variable @{term x} is obtained by integrating the common density
@{term δ} over all the remaining variables.›

definition marg_dens :: "dens_ctxt  vname  expr_density" where
  "marg_dens = (λ(V,V',Γ,δ) x ρ v. +σ. δ (merge V V' (σ(x := v), ρ)) state_measure (V-{x}) Γ)"

definition marg_dens2 :: "dens_ctxt  vname  vname  expr_density" where
  "marg_dens2  (λ(V,V',Γ,δ) x y ρ v.
       +σ. δ (merge V V' (σ(x := fst (extract_pair v), y := snd (extract_pair v)), ρ))
           state_measure (V-{x,y}) Γ)"

definition dens_ctxt_measure :: "dens_ctxt  state  state measure" where
  "dens_ctxt_measure  λ(V,V',Γ,δ) ρ. density (state_measure' V V' Γ ρ) δ"

definition branch_prob :: "dens_ctxt  state  ennreal" where
  "branch_prob 𝒴 ρ = emeasure (dens_ctxt_measure 𝒴 ρ) (space (dens_ctxt_measure 𝒴 ρ))"

lemma dens_ctxt_measure_nonempty[simp]:
    "space (dens_ctxt_measure 𝒴 ρ)  {}"
  unfolding dens_ctxt_measure_def state_measure'_def by (cases 𝒴) simp

lemma sets_dens_ctxt_measure_eq[measurable_cong]:
    "sets (dens_ctxt_measure (V,V',Γ,δ) ρ) = sets (state_measure (VV') Γ)"
  by (simp_all add: dens_ctxt_measure_def state_measure'_def)

lemma measurable_dens_ctxt_measure_eq:
    "measurable (dens_ctxt_measure (V,V',Γ,δ) ρ) = measurable (state_measure (VV') Γ)"
  by (intro ext measurable_cong_sets)
     (simp_all add: dens_ctxt_measure_def state_measure'_def)

lemma space_dens_ctxt_measure:
    "space (dens_ctxt_measure (V,V',Γ,δ) ρ) = space (state_measure (VV') Γ)"
  unfolding dens_ctxt_measure_def state_measure'_def by simp

definition apply_dist_to_dens :: "pdf_dist  (state  val  ennreal)  (state  val  ennreal)" where
  "apply_dist_to_dens dst f = (λρ y. +x. f ρ x * dist_dens dst x y stock_measure (dist_param_type dst))"

definition remove_var :: "state  state" where
  "remove_var σ = (λx. σ (Suc x))"

lemma measurable_remove_var[measurable]:
  "remove_var  measurable (state_measure (shift_var_set V) (case_nat t Γ)) (state_measure V Γ)"
proof-
  have "(λσ. λxV. σ (Suc x))  measurable
      (state_measure (shift_var_set V) (case_nat t Γ)) (state_measure V (λx. case_nat t Γ (Suc x)))"
    (is "?f  ?M")
    unfolding state_measure_def shift_var_set_def by measurable
  also have "x f. x  V  f  space (state_measure (shift_var_set V) (case_nat t Γ)) 
                       f (Suc x) = undefined" unfolding state_measure_def
    by (subst (asm) space_PiM, drule PiE_arb[of _ _ _ "Suc x" for x])
       (simp_all add: space_PiM shift_var_set_def inj_image_mem_iff)
  hence "?f  ?M  remove_var  ?M" unfolding remove_var_def[abs_def] state_measure_def
    by (intro measurable_cong ext) (auto simp: space_PiM intro!: sym[of _ undefined])
  finally show ?thesis by simp
qed

lemma measurable_case_nat_undefined[measurable]:
  "case_nat undefined  measurable (state_measure A Γ) (state_measure (Suc`A) (case_nat t Γ))" (is "_  ?M")
proof-
  have "(λσ. λxSuc`A. case_nat undefined σ x)  ?M" (is "?f  _")
    unfolding state_measure_def by (rule measurable_restrict) auto
  also have "?f  ?M  ?thesis"
    by (intro measurable_cong ext)
       (auto simp: state_measure_def space_PiM dest: PiE_mem split: nat.split)
  finally show ?thesis .
qed

definition insert_dens
     :: "vname set  vname set  expr_density  (state  ennreal)  state  ennreal" where
  "insert_dens V V' f δ  λσ. δ (remove_var σ) * f (remove_var σ) (σ 0)"

definition if_dens :: "(state  ennreal)  (state  val  ennreal)  bool  (state  ennreal)" where
  "if_dens δ f b  λσ. δ σ * f σ (BoolVal b)"

definition if_dens_det :: "(state  ennreal)  expr  bool  (state  ennreal)" where
  "if_dens_det δ e b  λσ. δ σ * (if expr_sem_rf σ e = BoolVal b then 1 else 0)"

lemma measurable_if_dens:
  assumes [measurable]: "δ  borel_measurable M"
  assumes [measurable]: "case_prod f  borel_measurable (M M count_space (range BoolVal))"
  shows "if_dens δ f b  borel_measurable M"
  unfolding if_dens_def by measurable

lemma measurable_if_dens_det:
  assumes e: "Γ  e : BOOL" "randomfree e" "free_vars e  V"
  assumes [measurable]: "δ  borel_measurable (state_measure V Γ)"
  shows "if_dens_det δ e b  borel_measurable (state_measure V Γ)"
unfolding if_dens_det_def
proof (intro borel_measurable_times_ennreal assms measurable_If)
  have "{x  space (state_measure V Γ). expr_sem_rf x e = BoolVal b} =
            (λσ. expr_sem_rf σ e) -` {BoolVal b}  space (state_measure V Γ)" by auto
  also have "...  sets (state_measure V Γ)"
    by (rule measurable_sets, rule measurable_expr_sem_rf[OF e]) simp_all
  finally show "{x  space (state_measure V Γ). expr_sem_rf x e = BoolVal b}
                     sets (state_measure V Γ)" .
qed simp_all

locale density_context =
  fixes V V' Γ δ
  assumes subprob_space_dens:
            "ρ. ρ  space (state_measure V' Γ)  subprob_space (dens_ctxt_measure (V,V',Γ,δ) ρ)"
      and finite_vars[simp]:     "finite V" "finite V'"
      and measurable_dens[measurable]:
                                 "δ  borel_measurable (state_measure (V  V') Γ)"
      and disjoint:              "V  V' = {}"
begin

abbreviation "𝒴  (V,V',Γ,δ)"

lemma branch_prob_altdef:
  assumes ρ: "ρ  space (state_measure V' Γ)"
  shows "branch_prob 𝒴 ρ = + x. δ (merge V V' (x, ρ)) state_measure V Γ"
proof-
  have "branch_prob 𝒴 ρ =
          + x. δ (merge V V' (x, ρ)) * indicator (space (state_measure (V  V') Γ))
                  (merge V V' (x, ρ)) state_measure V Γ"
      using ρ unfolding branch_prob_def[abs_def] dens_ctxt_measure_def state_measure'_def
      by (simp add: emeasure_density ennreal_mult'' ennreal_indicator nn_integral_distr)
  also from ρ have "... = + x. δ (merge V V' (x, ρ)) state_measure V Γ"
    by (intro nn_integral_cong) (simp split: split_indicator add: merge_in_state_measure)
  finally show ?thesis .
qed

lemma measurable_branch_prob[measurable]:
  "branch_prob 𝒴  borel_measurable (state_measure V' Γ)"
proof-
  interpret sigma_finite_measure "state_measure V Γ" by auto
  show ?thesis
    by (simp add: branch_prob_altdef cong: measurable_cong)
qed

lemma measurable_marg_dens':
  assumes [simp]: "x  V"
  shows "case_prod (marg_dens 𝒴 x)  borel_measurable (state_measure V' Γ M stock_measure (Γ x))"
proof-
  interpret sigma_finite_measure "state_measure (V - {x}) Γ"
    unfolding state_measure_def
    by (rule product_sigma_finite.sigma_finite, simp_all add: product_sigma_finite_def)
  from assms have "V = insert x (V - {x})" by blast
  hence A: "PiM V = PiM ..." by simp
  show ?thesis unfolding marg_dens_def
    by (simp add: insert_absorb)
qed

lemma insert_Diff: "insert x (A - B) = insert x A - (B - {x})"
  by auto

lemma measurable_marg_dens2':
  assumes "x  V" "y  V"
  shows "case_prod (marg_dens2 𝒴 x y) 
             borel_measurable (state_measure V' Γ M stock_measure (PRODUCT (Γ x) (Γ y)))"
proof-
  interpret sigma_finite_measure "state_measure (V - {x, y}) Γ"
    unfolding state_measure_def
    by (rule product_sigma_finite.sigma_finite, simp_all add: product_sigma_finite_def)
  have [measurable]: "V = insert x (V - {x, y})  {y}"
    using assms by blast
  show ?thesis unfolding marg_dens2_def
    by simp
qed

lemma measurable_marg_dens:
  assumes "x  V" "ρ  space (state_measure V' Γ)"
  shows "marg_dens 𝒴 x ρ  borel_measurable (stock_measure (Γ x))"
  using assms by (intro measurable_Pair_compose_split[OF measurable_marg_dens']) simp_all

lemma measurable_marg_dens2:
  assumes "x  V" "y  V" "x  y" "ρ  space (state_measure V' Γ)"
  shows "marg_dens2 𝒴 x y ρ  borel_measurable (stock_measure (PRODUCT (Γ x) (Γ y)))"
  using assms by (intro measurable_Pair_compose_split[OF measurable_marg_dens2']) simp_all

lemma measurable_state_measure_component:
    "x  V  (λσ. σ x)  measurable (state_measure V Γ) (stock_measure (Γ x))"
  unfolding state_measure_def
  by (auto intro!: measurable_component_singleton)

lemma measurable_dens_ctxt_measure_component:
    "x  V  (λσ. σ x)  measurable (dens_ctxt_measure (V,V',Γ,δ) ρ) (stock_measure (Γ x))"
  unfolding dens_ctxt_measure_def state_measure'_def state_measure_def
  by (auto intro!: measurable_component_singleton)

lemma space_dens_ctxt_measure_dens_ctxt_measure':
  assumes "x  V"
  shows "space (state_measure V Γ) =
             {σ(x := y) |σ y. σ  space (state_measure (V-{x}) Γ)  y  type_universe (Γ x)}"
proof-
  from assms have "insert x (V-{x}) = V" by auto
  hence "state_measure V Γ = PiM (insert x (V-{x})) (λy. stock_measure (Γ y))"
    unfolding state_measure_def by simp
  also have "space ... = {σ(x := y) |σ y. σ  space (state_measure (V-{x}) Γ)  y  type_universe (Γ x)}"
    unfolding state_measure_def space_PiM PiE_insert_eq
    by (simp add: image_def Bex_def) blast
  finally show ?thesis .
qed

lemma state_measure_integral_split:
  assumes "x  A" "finite A"
  assumes "f  borel_measurable (state_measure A Γ)"
  shows "(+σ. f σ state_measure A Γ) =
             (+y. +σ. f (σ(x := y)) state_measure (A-{x}) Γ stock_measure (Γ x))"
proof-
  interpret product_sigma_finite "λy. stock_measure (Γ y)"
    unfolding product_sigma_finite_def by auto
  from assms have [simp]: "insert x A = A" by auto
  have "(+σ. f σ state_measure A Γ) = (+σ. f σ ΠM vinsert x (A-{x}). stock_measure (Γ v))"
    unfolding state_measure_def by simp
  also have "... = +y. +σ. f (σ(x := y)) state_measure (A-{x}) Γ stock_measure (Γ x)"
    using assms unfolding state_measure_def
    by (subst product_nn_integral_insert_rev) simp_all
  finally show ?thesis .
qed

lemma fun_upd_in_state_measure:
  "σ  space (state_measure A Γ); y  space (stock_measure (Γ x))
      σ(x := y)  space (state_measure (insert x A) Γ)"
  unfolding state_measure_def by (auto simp: space_PiM split: if_split_asm)

lemma marg_dens_integral:
  fixes X :: "val set" assumes "x  V" and [measurable]: "X  sets (stock_measure (Γ x))"
  assumes "ρ  space (state_measure V' Γ)"
  defines "X'  (λσ. σ x) -` X  space (state_measure V Γ)"
  shows "(+ y. marg_dens 𝒴 x ρ y * indicator X y stock_measure (Γ x)) =
              (+σ. δ (merge V V' (σ,ρ)) * indicator X' σ state_measure V Γ)"
proof-
  from assms have [simp]: "insert x V = V" by auto
  interpret product_sigma_finite "λy. stock_measure (Γ y)"
    unfolding product_sigma_finite_def by auto

  have "(+σ. δ (merge V V' (σ,ρ)) * indicator X' σ state_measure V Γ) =
           + y. + σ. δ (merge V V' (σ(x := y), ρ)) * indicator X' (σ(x := y))
               state_measure (V-{x}) Γ stock_measure (Γ x)" using assms(1-3)
    by (subst state_measure_integral_split[of x]) (auto simp: X'_def)
  also have "... = + y. + σ. δ (merge V V' (σ(x := y), ρ)) * indicator X y
                      state_measure (V-{x}) Γ stock_measure (Γ x)"
    by (intro nn_integral_cong)
       (auto simp: X'_def split: split_indicator dest: fun_upd_in_state_measure)
  also have "... = (+ y. marg_dens 𝒴 x ρ y * indicator X y stock_measure (Γ x))"
    using measurable_dens_ctxt_fun_upd unfolding marg_dens_def using assms(1-3)
    by (intro nn_integral_cong) (simp split: split_indicator)
  finally show ?thesis ..
qed

lemma marg_dens2_integral:
  fixes X :: "val set"
  assumes "x  V" "y  V" "x  y" and [measurable]: "X  sets (stock_measure (PRODUCT (Γ x) (Γ y)))"
  assumes "ρ  space (state_measure V' Γ)"
  defines "X'  (λσ. <|σ x, σ y|>) -` X  space (state_measure V Γ)"
  shows "(+z. marg_dens2 𝒴 x y ρ z * indicator X z stock_measure (PRODUCT (Γ x) (Γ y))) =
              (+σ. δ (merge V V' (σ,ρ)) * indicator X' σ state_measure V Γ)"
proof-
  let ?M = "stock_measure (PRODUCT (Γ x) (Γ y))"
  let ?M' = "stock_measure (Γ x) M stock_measure (Γ y)"
  interpret product_sigma_finite "λx. stock_measure (Γ x)"
    unfolding product_sigma_finite_def by simp
  from assms have "(+ z. marg_dens2 𝒴 x y ρ z * indicator X z ?M) =
      +z. marg_dens2 𝒴 x y ρ (case_prod PairVal z) * indicator X (case_prod PairVal z) ?M'"
    by (subst nn_integral_PairVal)
       (auto simp add: split_beta' intro!: borel_measurable_times_ennreal measurable_marg_dens2)

  have V'': "V - {x, y} = V - {y} - {x}"
    by auto

  from assms have A: "V = insert y (V-{y})" by blast
  from assms have B: "insert x (V-{x,y}) = V - {y}" by blast
  from assms have X'[measurable]: "X'  sets (state_measure V Γ)" unfolding X'_def
    by (intro measurable_sets[OF _ assms(4)], unfold state_measure_def, subst stock_measure.simps)
       (rule measurable_Pair_compose_split[OF measurable_embed_measure2], rule inj_PairVal,
        erule measurable_component_singleton, erule measurable_component_singleton)

  have V[simp]: "insert y (V - {y}) = V" "insert x (V - {x, y}) = V - {y}" "insert y V = V"
    and [measurable]: "x  V - {y}"
    using assms by auto

  have "(+σ. δ (merge V V' (σ,ρ)) * indicator X' σ state_measure V Γ) =
      (+σ. δ (merge V V' (σ,ρ)) * indicator X' σ state_measure (insert y (insert x (V-{x, y}))) Γ)"
    using assms by (intro arg_cong2[where f=nn_integral] arg_cong2[where f=state_measure]) auto
  also have "... = +w. +v. +σ. δ (merge V V' (σ(x := v, y := w), ρ)) * indicator X' (σ(x := v, y := w))
      state_measure (V - {x, y}) Γ stock_measure (Γ x) stock_measure (Γ y)" (is "_ = ?I")
    unfolding state_measure_def
    using assms
    apply (subst product_nn_integral_insert_rev)
    apply (auto simp: state_measure_def[symmetric])
    apply (rule nn_integral_cong)
    apply (subst state_measure_def)
    apply (subst V(2)[symmetric])
    apply (subst product_nn_integral_insert_rev)
    apply (auto simp: state_measure_def[symmetric])
    apply measurable
    apply simp_all
    done
  also from assms(1-5)
    have "v w σ. v  space (stock_measure (Γ x))  w  space (stock_measure (Γ y))
                σ  space (state_measure (V-{x,y}) Γ)
                σ(x := v, y := w)  X'  <|v,w|>  X"
    by (simp add: X'_def space_state_measure PiE_iff extensional_def)
  hence "?I = +w. +v. +σ. δ (merge V V' (σ(x := v, y := w), ρ)) * indicator X <|v,w|>
               state_measure (V - {x,y}) Γ stock_measure (Γ x) stock_measure (Γ y)"
    by (intro nn_integral_cong) (simp split: split_indicator)
  also from assms(5)
    have "... = +w. +v. (+σ. δ (merge V V' (σ(x := v,y := w), ρ)) state_measure (V - {x,y}) Γ)
                    * indicator X <|v,w|> stock_measure (Γ x) stock_measure (Γ y)"
      using assms
      apply (simp add: ennreal_mult'' ennreal_indicator)
      by (intro nn_integral_cong nn_integral_multc) (simp_all add: )
  also have "... = +w. +v. marg_dens2 𝒴 x y ρ <|v,w|> * indicator X <|v,w|>
                       stock_measure (Γ x) stock_measure (Γ y)"
    by (intro nn_integral_cong) (simp add: marg_dens2_def)
  also from assms(4)
    have "... = +z. marg_dens2 𝒴 x y ρ (case_prod PairVal z) * indicator X (case_prod PairVal z)
                    (stock_measure (Γ x) M stock_measure (Γ y))"
      using assms
      by (subst pair_sigma_finite.nn_integral_snd[symmetric])
         (auto simp add: pair_sigma_finite_def intro!: borel_measurable_times_ennreal measurable_compose[OF _ measurable_marg_dens2])
  also have "... = +z. marg_dens2 𝒴 x y ρ z * indicator X z stock_measure (PRODUCT (Γ x) (Γ y))"
      apply (subst stock_measure.simps, subst embed_measure_eq_distr, rule inj_PairVal)
      apply (rule nn_integral_distr[symmetric], intro measurable_embed_measure2 inj_PairVal)
      apply (subst stock_measure.simps[symmetric])
      apply (intro borel_measurable_times_ennreal)
      apply simp
      apply (intro measurable_marg_dens2)
      apply (insert assms)
      apply simp_all
      done
  finally show ?thesis ..
qed


text ‹The space described by the marginal density is the same as the space obtained by projecting
@{term x} (resp. @{term x} and @{term y}) out of the common distribution of all variables.›

lemma density_marg_dens_eq:
  assumes "x  V" "ρ  space (state_measure V' Γ)"
  shows "density (stock_measure (Γ x)) (marg_dens 𝒴 x ρ) =
              distr (dens_ctxt_measure (V,V',Γ,δ) ρ) (stock_measure (Γ x)) (λσ. σ x)" (is "?M1 = ?M2")
proof (rule measure_eqI)
  fix X assume X: "X  sets ?M1"
  let ?X' = "(λσ. σ x) -` X  space (state_measure V Γ)"
  let ?X'' = "(λσ. σ x) -` X  space (state_measure (V  V') Γ)"
  from X have "emeasure ?M1 X = + σ. δ (merge V V' (σ, ρ)) * indicator ?X' σ state_measure V Γ"
    using assms measurable_marg_dens measurable_dens
    by (subst emeasure_density)
       (auto simp: emeasure_distr nn_integral_distr
        dens_ctxt_measure_def state_measure'_def emeasure_density marg_dens_integral)
  also from assms have "... = + σ. δ (merge V V' (σ, ρ)) *
                                        indicator ?X'' (merge V V' (σ,ρ)) state_measure V Γ"
    by (intro nn_integral_cong)
       (auto split: split_indicator simp: space_state_measure merge_def PiE_iff extensional_def)
  also from X and assms have "... = emeasure ?M2 X" using measurable_dens
    by (auto simp: emeasure_distr emeasure_density nn_integral_distr ennreal_indicator ennreal_mult''
                   dens_ctxt_measure_def state_measure'_def state_measure_def)
  finally show "emeasure ?M1 X = emeasure ?M2 X" .
qed simp

lemma density_marg_dens2_eq:
  assumes "x  V" "y  V" "x  y" "ρ  space (state_measure V' Γ)"
  defines "M  stock_measure (PRODUCT (Γ x) (Γ y))"
  shows "density M (marg_dens2 𝒴 x y ρ) =
              distr (dens_ctxt_measure (V,V',Γ,δ) ρ) M (λσ. <|σ x,σ y|>)" (is "?M1 = ?M2")
proof (rule measure_eqI)
  fix X assume X: "X  sets ?M1"
  let ?X' = "(λσ. <|σ x , σ y|>) -` X  space (state_measure V Γ)"
  let ?X'' = "(λσ. <|σ x , σ y|>) -` X  space (state_measure (VV') Γ)"

  from assms have meas[measurable]: "(λσ. <|σ x,σ y|>)  measurable (state_measure (V  V') Γ)
                                                        (stock_measure (PRODUCT (Γ x) (Γ y)))"
    unfolding state_measure_def
    apply (subst stock_measure.simps)
    apply (rule measurable_Pair_compose_split[OF measurable_embed_measure2[OF inj_PairVal]])
    apply (rule measurable_component_singleton, simp)+
    done
  from assms(1-4) X meas have "emeasure ?M2 X = emeasure (dens_ctxt_measure 𝒴 ρ) ?X''"
    apply (subst emeasure_distr)
    apply (subst measurable_dens_ctxt_measure_eq, unfold state_measure_def M_def)
    apply (simp_all add: space_dens_ctxt_measure state_measure_def)
    done
  also from assms(1-4) X meas
    have "... = +σ. δ (merge V V' (σ, ρ)) * indicator ?X'' (merge V V' (σ, ρ)) state_measure V Γ"
      (is "_ = ?I") unfolding dens_ctxt_measure_def state_measure'_def M_def
    by (simp add: emeasure_density nn_integral_distr ennreal_indicator ennreal_mult'')
  also from assms(1-4) X
    have "σ. σspace (state_measure V Γ)  merge V V' (σ, ρ)  ?X''  σ  ?X'"
    by (auto simp: space_state_measure merge_def PiE_iff extensional_def)
  hence "?I = +σ. δ (merge V V' (σ, ρ)) * indicator ?X' σ state_measure V Γ"
    by (intro nn_integral_cong) (simp split: split_indicator)
  also from assms X have "... = +z. marg_dens2 𝒴 x y ρ z * indicator X z M" unfolding M_def
    by (subst marg_dens2_integral) simp_all
  also from X have "... = emeasure ?M1 X"
    using assms measurable_dens unfolding M_def
    by (subst emeasure_density, intro measurable_marg_dens2) simp_all
  finally show "emeasure ?M1 X = emeasure ?M2 X" ..
qed simp

lemma measurable_insert_dens[measurable]:
  assumes Mf[measurable]: "case_prod f  borel_measurable (state_measure (V  V') Γ M stock_measure t)"
  shows "insert_dens V V' f δ
              borel_measurable (state_measure (shift_var_set (V  V')) (case_nat t Γ))"
proof-
  have "(λσ. σ 0)  measurable (state_measure (shift_var_set (V  V')) (case_nat t Γ))
                               (stock_measure (case_nat t Γ 0))" unfolding state_measure_def
    unfolding shift_var_set_def by measurable
  thus ?thesis unfolding insert_dens_def[abs_def] by simp
qed

lemma nn_integral_dens_ctxt_measure:
  assumes "ρ  space (state_measure V' Γ)"
          "f  borel_measurable (state_measure (V  V') Γ)"
  shows "(+x. f x dens_ctxt_measure (V,V',Γ,δ) ρ) =
           + x. δ (merge V V' (x, ρ)) * f (merge V V' (x, ρ)) state_measure V Γ"
  unfolding dens_ctxt_measure_def state_measure'_def using assms measurable_dens
  by (simp only: prod.case, subst nn_integral_density)
     (auto simp: nn_integral_distr state_measure_def )

lemma shift_var_set_Un[simp]: "shift_var_set V  Suc ` V' = shift_var_set (V  V')"
  unfolding shift_var_set_def by (simp add: image_Un)

lemma emeasure_dens_ctxt_measure_insert:
  fixes t f ρ
  defines "M  dens_ctxt_measure (shift_var_set V, Suc`V', case_nat t Γ, insert_dens V V' f δ) ρ"
  assumes dens: "has_parametrized_subprob_density (state_measure (VV') Γ) F (stock_measure t) f"
  assumes ρ: "ρ  space (state_measure (Suc`V') (case_nat t Γ))"
  assumes X: "X  sets M"
  shows "emeasure M X =
           + x. insert_dens V V' f δ (merge (shift_var_set V) (Suc ` V') (x, ρ)) *
                 indicator X (merge (shift_var_set V) (Suc ` V') (x, ρ))
             state_measure (shift_var_set V) (case_nat t Γ)" (is "_ = ?I")
proof-
  note [measurable] = has_parametrized_subprob_densityD(3)[OF dens]
  have [measurable]:
    "(λσ. merge (shift_var_set V) (Suc ` V') (σ, ρ))
        measurable (state_measure (shift_var_set V) (case_nat t Γ))
                    (state_measure (shift_var_set (V  V')) (case_nat t Γ))"
    using ρ unfolding state_measure_def
    by (simp del: shift_var_set_Un add: shift_var_set_Un[symmetric])
  from assms have "emeasure M X = (+x. indicator X x M)"
    by (subst nn_integral_indicator)
       (simp_all add: dens_ctxt_measure_def state_measure'_def)
  also have MI: "indicator X  borel_measurable
                     (state_measure (shift_var_set (V  V')) (case_nat t Γ))"
    using X unfolding M_def dens_ctxt_measure_def state_measure'_def by simp
  have "(+x. indicator X x M) = ?I"
    using X unfolding M_def dens_ctxt_measure_def state_measure'_def
    apply (simp only: prod.case)
    apply (subst nn_integral_density)
    apply (simp_all add: nn_integral_density nn_integral_distr MI)
    done
  finally show ?thesis .
qed

lemma merge_Suc_aux':
  "ρ  space (state_measure (Suc ` V') (case_nat t Γ)) 
    (λσ. merge V V' (σ, ρ  Suc))  measurable (state_measure V Γ) (state_measure (V  V') Γ)"
by (unfold state_measure_def,
    rule measurable_compose[OF measurable_Pair measurable_merge], simp,
    rule measurable_const, auto simp: space_PiM dest: PiE_mem)

lemma merge_Suc_aux:
  "ρ  space (state_measure (Suc ` V') (case_nat t Γ)) 
    (λσ. δ (merge V V' (σ, ρ  Suc)))  borel_measurable (state_measure V Γ)"
by (rule measurable_compose[OF _ measurable_dens], unfold state_measure_def,
    rule measurable_compose[OF measurable_Pair measurable_merge], simp,
    rule measurable_const, auto simp: space_PiM dest: PiE_mem)

lemma nn_integral_PiM_Suc:
  assumes fin: "i. sigma_finite_measure (N i)"
  assumes Mf: "f  borel_measurable (PiM V N)"
  shows "(+x. f x distr (PiM (Suc`V) (case_nat M N)) (PiM V N) (λσ. σ  Suc)) =
             (+x. f x PiM V N)"
         (is "nn_integral (?M1 V) _ = _")
using Mf
proof (induction arbitrary: f
                 rule: finite_induct[OF finite_vars(1), case_names empty insert])
  case empty
  show ?case by (auto simp add: PiM_empty nn_integral_distr intro!: nn_integral_cong)
next
  case (insert v V)
  let ?V = "insert v V" and ?M3 = "PiM (insert (Suc v) (Suc ` V)) (case_nat M N)"
  let ?M4 = "PiM (insert (Suc v) (Suc ` V)) (case_nat (count_space {}) N)"
  let ?M4' = "PiM (Suc ` V) (case_nat (count_space {}) N)"
  have A: "?M3 = ?M4" by (intro PiM_cong) auto
  interpret product_sigma_finite "case_nat (count_space {}) N"
    unfolding product_sigma_finite_def
    by (auto intro: fin sigma_finite_measure_count_space_countable split: nat.split)
  interpret sigma_finite_measure "N v" by (rule assms)
  note Mf[measurable] = insert(4)

  from insert have "(+x. f x ?M1 ?V) = +x. f (x  Suc) ?M4"
    by (subst A[symmetric], subst nn_integral_distr)
       (simp_all add: measurable_case_nat_Suc_PiM image_insert[symmetric] del: image_insert)
  also from insert have "... = +x. +y. f (x(Suc v := y)  Suc) N v ?M4'"
    apply (subst product_nn_integral_insert, simp, blast, subst image_insert[symmetric])
    apply (erule measurable_compose[OF measurable_case_nat_Suc_PiM], simp)
    done
  also have "(λx y. x(Suc v := y)  Suc) = (λx y. (x  Suc)(v := y))"
    by (intro ext) (simp add: o_def)
  also have "?M4' = PiM (Suc ` V) (case_nat M N)" by (intro PiM_cong) auto
  also from insert have "(+x. +y. f ((x  Suc)(v := y)) N v ...) =
                             (+x. +y. f (x(v := y)) N v ?M1 V)"
    by (subst nn_integral_distr)
       (simp_all add: borel_measurable_nn_integral measurable_case_nat_Suc_PiM)
  also from insert have "... = (+x. +y. f (x(v := y)) N v PiM V N)"
    by (intro insert(3)) measurable
  also from insert have "... = (+x. f x PiM ?V N)"
    by (subst product_sigma_finite.product_nn_integral_insert)
       (simp_all add: assms product_sigma_finite_def)
  finally show ?case .
qed

lemma PiM_Suc:
  assumes "i. sigma_finite_measure (N i)"
  shows "distr (PiM (Suc`V) (case_nat M N)) (PiM V N) (λσ. σ  Suc) = PiM V N" (is "?M1 = ?M2")
  by (intro measure_eqI)
     (simp_all add: nn_integral_indicator[symmetric] nn_integral_PiM_Suc assms
               del: nn_integral_indicator)

lemma distr_state_measure_Suc:
  "distr (state_measure (Suc ` V) (case_nat t Γ)) (state_measure V Γ) (λσ. σ  Suc) =
     state_measure V Γ" (is "?M1 = ?M2")
  unfolding state_measure_def
  apply (subst (2) PiM_Suc[of "λx. stock_measure (Γ x)" "stock_measure t", symmetric], simp)
  apply (intro distr_cong PiM_cong)
  apply (simp_all split: nat.split)
  done

lemma emeasure_dens_ctxt_measure_insert':
  fixes t f ρ
  defines "M  dens_ctxt_measure (shift_var_set V, Suc`V', case_nat t Γ, insert_dens V V' f δ) ρ"
  assumes dens: "has_parametrized_subprob_density (state_measure (VV') Γ) F (stock_measure t) f"
  assumes ρ: "ρ  space (state_measure (Suc`V') (case_nat t Γ))"
  assumes X: "X  sets M"
  shows "emeasure M X = +σ. δ (merge V V' (σ, ρ  Suc)) * +y. f (merge V V' (σ, ρ  Suc)) y *
                       indicator X (merge (shift_var_set V) (Suc`V') (case_nat y σ, ρ))
                  stock_measure t state_measure V Γ" (is "_ = ?I")
proof-
  let ?m = "λx y. merge (insert 0 (Suc ` V)) (Suc ` V') (x(0 := y), ρ)"
  from dens have Mf:
      "case_prod f  borel_measurable (state_measure (VV') Γ M stock_measure t)"
    by (rule has_parametrized_subprob_densityD)
  note [measurable] = Mf[unfolded state_measure_def]
  have meas_merge: "(λx. merge (shift_var_set V) (Suc`V') (x, ρ))
        measurable (state_measure (shift_var_set V) (case_nat t Γ))
                    (state_measure (shift_var_set (V  V')) (case_nat t Γ))"
    using ρ unfolding state_measure_def shift_var_set_def
    by (simp add: image_Un image_insert[symmetric] Un_insert_left[symmetric]
             del: image_insert Un_insert_left)
  note measurable_insert_dens' =
           measurable_insert_dens[unfolded shift_var_set_def state_measure_def]
  have meas_merge': "(λx. merge (shift_var_set V) (Suc ` V') (case_nat (snd x) (fst x), ρ))
        measurable (state_measure V Γ M stock_measure t)
                    (state_measure (shift_var_set (VV')) (case_nat t Γ))"
    by (rule measurable_compose[OF _ meas_merge]) simp
  have meas_integral: "(λσ. + y. δ (merge V V' (σ, ρ  Suc)) * f (merge V V' (σ, ρ  Suc)) y *
                            indicator X (merge (shift_var_set V) (Suc ` V') (case_nat y σ, ρ))
                           stock_measure t)  borel_measurable (state_measure V Γ)"
    apply (rule sigma_finite_measure.borel_measurable_nn_integral, simp)
    apply (subst measurable_split_conv, intro borel_measurable_times_ennreal)
    apply (rule measurable_compose[OF measurable_fst merge_Suc_aux[OF ρ]])
    apply (rule measurable_Pair_compose_split[OF Mf])
    apply (rule measurable_compose[OF measurable_fst merge_Suc_aux'[OF ρ]], simp)
    apply (rule measurable_compose[OF meas_merge' borel_measurable_indicator])
    apply (insert X, simp add: M_def dens_ctxt_measure_def state_measure'_def)
    done
  have meas': "x. x  space (state_measure V Γ)
                   (λy. f (merge V V' (x, ρ  Suc)) y *
                      indicator X (merge (shift_var_set V) (Suc ` V') (case_nat y x, ρ)))
                   borel_measurable (stock_measure t)" using X
    apply (intro borel_measurable_times_ennreal)
    apply (rule measurable_Pair_compose_split[OF Mf])
    apply (rule measurable_const, erule measurable_space[OF merge_Suc_aux'[OF ρ]])
    apply (simp, rule measurable_compose[OF _ borel_measurable_indicator])
    apply (rule measurable_compose[OF measurable_case_nat'])
    apply (rule measurable_ident_sets[OF refl], erule measurable_const)
    apply (rule meas_merge, simp add: M_def dens_ctxt_measure_def state_measure'_def)
    done

  have "emeasure M X =
           + x. insert_dens V V' f δ (merge (shift_var_set V) (Suc ` V') (x, ρ)) *
                 indicator X (merge (shift_var_set V) (Suc ` V') (x, ρ))
             state_measure (shift_var_set V) (case_nat t Γ)"
    using assms unfolding M_def by (intro emeasure_dens_ctxt_measure_insert)
  also have "... = +x. +y. insert_dens V V' f δ (?m x y) *
                  indicator X (?m x y) stock_measure t state_measure (Suc`V) (case_nat t Γ)"
    (is "_ = ?I") using ρ X meas_merge
    unfolding shift_var_set_def M_def dens_ctxt_measure_def state_measure'_def state_measure_def
    apply (subst product_sigma_finite.product_nn_integral_insert)
    apply (auto simp: product_sigma_finite_def) [3]
    apply (intro borel_measurable_times_ennreal)
    apply (rule measurable_compose[OF _ measurable_insert_dens'], simp)
    apply (simp_all add: measurable_compose[OF _ borel_measurable_indicator] image_Un)
    done
  also have "σ y. σ  space (state_measure (Suc`V) (case_nat t Γ)) 
                   y  space (stock_measure t) 
                   (remove_var (merge (insert 0 (Suc ` V)) (Suc ` V') (σ(0:=y), ρ))) =
                       merge V V' (σ  Suc, ρ  Suc)"
    by (auto simp: merge_def remove_var_def)
  hence "?I = +σ. +y. δ (merge V V' (σ  Suc, ρ  Suc)) * f (merge V V' (σ  Suc, ρ  Suc)) y *
                       indicator X (?m σ y)
                  stock_measure t state_measure (Suc`V) (case_nat t Γ)" (is "_ = ?I")
    by (intro nn_integral_cong)
       (auto simp: insert_dens_def inj_image_mem_iff merge_def split: split_indicator nat.split)
  also have m_eq: "x y. ?m x y = merge (shift_var_set V) (Suc`V') (case_nat y (x  Suc), ρ)"
    by (intro ext) (auto simp add: merge_def shift_var_set_def split: nat.split)
  have "?I = +σ. +y. δ (merge V V' (σ, ρ  Suc)) * f (merge V V' (σ, ρ  Suc)) y *
                       indicator X (merge (shift_var_set V) (Suc`V') (case_nat y σ, ρ))
                  stock_measure t state_measure V Γ" using ρ X
    apply (subst distr_state_measure_Suc[symmetric, of t])
    apply (subst nn_integral_distr)
    apply (rule measurable_case_nat_Suc)
    apply simp
    apply (rule meas_integral)
    apply (intro nn_integral_cong)
    apply (simp add: m_eq)
    done
  also have "... = +σ. δ (merge V V' (σ, ρ  Suc)) * +y. f (merge V V' (σ, ρ  Suc)) y *
                       indicator X (merge (shift_var_set V) (Suc`V') (case_nat y σ, ρ))
                  stock_measure t state_measure V Γ" using ρ X
    apply (intro nn_integral_cong)
    apply (subst nn_integral_cmult[symmetric])
    apply (erule meas')
    apply (simp add: mult.assoc)
    done
  finally show ?thesis .
qed


lemma density_context_insert:
  assumes dens: "has_parametrized_subprob_density (state_measure (VV') Γ) F (stock_measure t) f"
  shows "density_context (shift_var_set V) (Suc ` V') (case_nat t Γ) (insert_dens V V' f δ)"
             (is "density_context ?V ?V' ?Γ' ?δ'")
unfolding density_context_def
proof (intro allI conjI impI)
  note measurable_insert_dens[OF has_parametrized_subprob_densityD(3)[OF dens]]
  thus "insert_dens V V' f δ
           borel_measurable (state_measure (shift_var_set V  Suc ` V') (case_nat t Γ))"
    unfolding shift_var_set_def by (simp only: image_Un Un_insert_left)
next
  fix ρ assume ρ: "ρ  space (state_measure ?V' ?Γ')"
  hence ρ': "ρ  Suc  space (state_measure V' Γ)"
    by (auto simp: state_measure_def space_PiM dest: PiE_mem)
  note dens' = has_parametrized_subprob_densityD[OF dens]
  note Mf[measurable] = dens'(3)
  have M_merge: "(λx. merge (shift_var_set V) (Suc ` V') (x, ρ))
                    measurable (PiM (insert 0 (Suc ` V)) (λy. stock_measure (case_nat t Γ y)))
                                (state_measure (shift_var_set (V  V')) (case_nat t Γ))"
    using ρ by (subst shift_var_set_Un[symmetric], unfold state_measure_def)
               (simp add: shift_var_set_def del: shift_var_set_Un Un_insert_left)
  show "subprob_space (dens_ctxt_measure (?V,?V',?Γ',?δ') ρ)" (is "subprob_space ?M")
  proof (rule subprob_spaceI)
    interpret product_sigma_finite "(λy. stock_measure (case y of 0  t | Suc x  Γ x))"
      by (simp add: product_sigma_finite_def)
    have Suc_state_measure:
      "x. x  space (state_measure (Suc ` V) (case_nat t Γ)) 
              merge V V' (x  Suc, ρ  Suc)  space (state_measure (V  V') Γ)" using ρ
      by (intro merge_in_state_measure) (auto simp: state_measure_def space_PiM dest: PiE_mem)

    have S[simp]: "x X. Suc x  Suc ` X  x  X" by (rule inj_image_mem_iff) simp
    let ?M = "dens_ctxt_measure (?V,?V',?Γ',?δ') ρ"
    from ρ have "σ. σ  space (state_measure ?V ?Γ')  merge ?V ?V' (σ, ρ)  space ?M"
      by (auto simp: dens_ctxt_measure_def state_measure'_def simp del: shift_var_set_Un
               intro!: merge_in_state_measure)
    hence "emeasure ?M (space ?M) =
            +σ. insert_dens V V' f δ (merge ?V ?V' (σ, ρ)) state_measure ?V ?Γ'"
     by (subst emeasure_dens_ctxt_measure_insert[OF dens ρ], simp, intro nn_integral_cong)
        (simp split: split_indicator)
    also have "... = +σ. insert_dens V V' f δ (merge ?V ?V' (σ, ρ))
                              state_measure (insert 0 (Suc ` V)) ?Γ'"
      by (simp add: shift_var_set_def)
    also have "... = +σ. +x. insert_dens V V' f δ (merge ?V ?V' (σ(0 := x), ρ))
                       stock_measure t state_measure (Suc ` V) ?Γ'"
       unfolding state_measure_def using M_merge
       by (subst product_nn_integral_insert) auto
    also have "... = +σ. +x. δ (remove_var (merge ?V ?V' (σ(0:=x), ρ))) *
                               f (remove_var (merge ?V ?V' (σ(0:=x), ρ))) x
                        stock_measure t state_measure (Suc ` V) ?Γ'" (is "_ = ?I")
       by (intro nn_integral_cong) (auto simp: insert_dens_def merge_def shift_var_set_def)
    also have "σ x. remove_var (merge ?V ?V' (σ(0:=x), ρ)) = merge V V' (σ  Suc, ρ  Suc)"
      by (intro ext) (auto simp: remove_var_def merge_def shift_var_set_def o_def)
    hence "?I = +σ. +x. δ (merge V V' (σ  Suc, ρ  Suc)) * f (merge V V' (σ  Suc, ρ  Suc)) x
                  stock_measure t state_measure (Suc ` V) ?Γ'" by simp
    also have "... = +σ. δ (merge V V' (σ  Suc, ρ  Suc)) *
                            (+x. f (merge V V' (σ  Suc, ρ  Suc)) x stock_measure t)
                       state_measure (Suc ` V) ?Γ'" (is "_ = ?I")
      using ρ disjoint
      apply (intro nn_integral_cong nn_integral_cmult)
      apply (rule measurable_Pair_compose_split[OF Mf], rule measurable_const)
      apply (auto intro!: Suc_state_measure)
      done
    also {
      fix σ assume σ: "σ  space (state_measure (Suc ` V) ?Γ')"
      let ?σ' = "merge V V' (σ  Suc, ρ  Suc)"
      let ?N = "density (stock_measure t) (f ?σ')"
      have "(+x. f (merge V V' (σ  Suc, ρ  Suc)) x stock_measure t) = emeasure ?N (space ?N)"
        using dens'(3) Suc_state_measure[OF σ]
        by (simp_all cong: nn_integral_cong' add: emeasure_density)
      also have "?N = F ?σ'" by (subst dens') (simp_all add: Suc_state_measure σ)
      also have "subprob_space (F ?σ')" by (rule dens') (simp_all add: Suc_state_measure σ)
      hence "emeasure (F ?σ') (space (F ?σ'))  1" by (rule subprob_space.emeasure_space_le_1)
      finally have "(+x. f (merge V V' (σ  Suc, ρ  Suc)) x stock_measure t)  1" .
    }
    hence "?I  +σ. δ (merge V V' (σ  Suc, ρ  Suc)) * 1 state_measure (Suc ` V) ?Γ'"
      by (intro nn_integral_mono mult_left_mono) (simp_all add: Suc_state_measure)
    also have "... = +σ. δ (merge V V' (σ, ρ  Suc))
                       distr (state_measure (Suc ` V) ?Γ') (state_measure V Γ) (λσ. σ  Suc)"
      (is "_ = nn_integral ?N _")
      using ρ by (subst nn_integral_distr) (simp_all add: measurable_case_nat_Suc merge_Suc_aux)
    also have "?N = state_measure V Γ" by (rule distr_state_measure_Suc)
    also have "(+σ. δ (merge V V' (σ, ρ  Suc)) state_measure V Γ) =
                   (+σ. 1 dens_ctxt_measure 𝒴 (ρ  Suc))" (is "_ = nn_integral ?N _")
      by (subst nn_integral_dens_ctxt_measure) (simp_all add: ρ')
    also have "... = (+σ. indicator (space ?N) σ ?N)"
      by (intro nn_integral_cong) (simp split: split_indicator)
    also have "... = emeasure ?N (space ?N)" by simp
    also have "...  1" by (simp_all add: subprob_space.emeasure_space_le_1 subprob_space_dens ρ')
    finally show "emeasure ?M (space ?M)  1" .
  qed (simp_all add: space_dens_ctxt_measure state_measure_def space_PiM PiE_eq_empty_iff)
qed (insert disjoint, auto simp: shift_var_set_def)


lemma dens_ctxt_measure_insert:
  assumes ρ: "ρ  space (state_measure V' Γ)"
  assumes meas_M: "M  measurable (state_measure (VV') Γ) (subprob_algebra (stock_measure t))"
  assumes meas_f[measurable]: "case_prod f  borel_measurable (state_measure (VV') Γ M stock_measure t)"
  assumes has_dens: "ρ. ρ  space (state_measure (VV') Γ) 
                         has_subprob_density (M ρ) (stock_measure t) (f ρ)"
  shows "do {σ  dens_ctxt_measure (V,V',Γ,δ) ρ;
             y  M σ;
             return (state_measure (shift_var_set (V  V')) (case_nat t Γ)) (case_nat y σ)} =
         dens_ctxt_measure (shift_var_set V, Suc`V', case_nat t Γ, insert_dens V V' f δ)
                           (case_nat undefined ρ)"
         (is "bind ?N (λ_. bind _ (λ_. return ?R _)) = dens_ctxt_measure (?V,?V',?Γ',?δ') _")
proof (intro measure_eqI)
  let ?lhs = "?N  (λσ . M σ  (λy. return ?R (case_nat y σ)))"
  let ?rhs = "dens_ctxt_measure (?V,?V',?Γ',?δ') (case_nat undefined ρ)"

  have meas_f': "M g h. g  measurable M (state_measure (VV') Γ) 
                         h  measurable M (stock_measure t) 
                         (λx. f (g x) (h x))  borel_measurable M" by measurable
  have t: "t = ?Γ' 0" by simp

  have nonempty: "space ?N  {}"
      by (auto simp: dens_ctxt_measure_def state_measure'_def state_measure_def
                     space_PiM PiE_eq_empty_iff)
  have meas_N_eq: "measurable ?N = measurable (state_measure (VV') Γ)"
    by (intro ext measurable_cong_sets) (auto simp: dens_ctxt_measure_def state_measure'_def)
  have meas_M': "M  measurable ?N (subprob_algebra (stock_measure t))"
    by (subst meas_N_eq) (rule meas_M)
  have meas_N': "R. measurable (?N M R) = measurable (state_measure (VV') Γ M R)"
    by (intro ext measurable_cong_sets[OF _ refl] sets_pair_measure_cong)
       (simp_all add: dens_ctxt_measure_def state_measure'_def)
  have meas_M_eq: "ρ. ρ  space ?N  measurable (M ρ) = measurable (stock_measure t)"
    by (intro ext measurable_cong_sets sets_kernel[OF meas_M']) simp_all
  have meas_rhs: "M. measurable M ?rhs = measurable M ?R"
    by (intro ext measurable_cong_sets) (simp_all add: dens_ctxt_measure_def state_measure'_def)
  have subprob_algebra_rhs: "subprob_algebra ?rhs = subprob_algebra (state_measure (shift_var_set (VV')) ?Γ')"
    unfolding dens_ctxt_measure_def state_measure'_def by (intro subprob_algebra_cong) auto
  have nonempty': "ρ. ρ  space ?N  space (M ρ)  {}"
    by (rule subprob_space.subprob_not_empty)
       (auto dest: has_subprob_densityD has_dens simp: space_dens_ctxt_measure)
  have merge_in_space: "x. x  space (state_measure V Γ) 
                              merge V V' (x, ρ)  space (dens_ctxt_measure 𝒴 ρ)"
    by (simp add: space_dens_ctxt_measure merge_in_state_measure ρ)

  have "sets ?lhs = sets (state_measure (shift_var_set (V  V')) ?Γ')"
    using nonempty' by (subst sets_bind, subst sets_bind) auto
  thus sets_eq: "sets ?lhs = sets ?rhs"
    unfolding dens_ctxt_measure_def state_measure'_def by simp

  have meas_merge[measurable]:
    "(λσ. merge V V' (σ, ρ))  measurable (state_measure V Γ) (state_measure (V  V') Γ)"
    using ρ unfolding state_measure_def by - measurable

  fix X assume "X  sets ?lhs"
  hence X: "X  sets ?rhs" by (simp add: sets_eq)
  hence "emeasure ?lhs X = +σ. emeasure (M σ  (λy. return ?R (case_nat y σ))) X ?N"
    by (intro emeasure_bind measurable_bind[OF meas_M'])
       (simp, rule measurable_compose[OF _ return_measurable],
        simp_all add: dens_ctxt_measure_def state_measure'_def)
  also from X have "... =
    + x. δ (merge V V' (x, ρ)) * emeasure (M (merge V V' (x, ρ)) 
             (λy. return ?R (case_nat y (merge V V' (x, ρ))))) X state_measure V Γ"
    apply (subst nn_integral_dens_ctxt_measure[OF ρ])
    apply (rule measurable_emeasure_kernel[OF measurable_bind[OF meas_M]])
    apply (rule measurable_compose[OF _ return_measurable], simp)
    apply (simp_all add: dens_ctxt_measure_def state_measure'_def)
    done
  also from X have "... = +x. δ (merge V V' (x, ρ)) *
                              +y. indicator X (case_nat y (merge V V' (x, ρ)))
                                   M (merge V V' (x, ρ)) state_measure V Γ" (is "_ = ?I")
    apply (intro nn_integral_cong)
    apply (subst emeasure_bind, rule nonempty', simp add: merge_in_space)
    apply (rule measurable_compose[OF _ return_measurable], simp add: merge_in_space meas_M_eq)
    apply (simp_all add: dens_ctxt_measure_def state_measure'_def)
    done
  also have "x. x  space (state_measure V Γ) 
                  M (merge V V' (x, ρ)) = density (stock_measure t) (f (merge V V' (x, ρ)))"
    by (intro has_subprob_densityD[OF has_dens]) (simp add: merge_in_state_measure ρ)
  hence "?I = +x. δ (merge V V' (x, ρ)) *
                +y. indicator X (case_nat y (merge V V' (x, ρ)))
                density (stock_measure t) (f (merge V V' (x, ρ))) state_measure V Γ"
    by (intro nn_integral_cong) simp
  also have "... = +x. δ (merge V V' (x, ρ)) *
                     +y. f (merge V V' (x, ρ)) y * indicator X (case_nat y (merge V V' (x, ρ)))
                   stock_measure t state_measure V Γ" (is "_ = ?I") using X
    by (intro nn_integral_cong, subst nn_integral_density, simp)
       (auto simp: mult.assoc dens_ctxt_measure_def state_measure'_def
             intro!: merge_in_state_measure ρ AE_I'[of "{}"]
                     has_subprob_densityD[OF has_dens])
  also have A: "case_nat undefined ρ  Suc = ρ" by (intro ext) simp
  have B: "x y. x  space (state_measure V Γ)  y  space (stock_measure t) 
           (case_nat y (merge V V' (x, ρ))) =
           (merge (shift_var_set V) (Suc ` V') (case_nat y x, case_nat undefined ρ))"
    by (intro ext) (auto simp add: merge_def shift_var_set_def split: nat.split)
  have C: "x. x  space (state_measure V Γ) 
     (+y. f (merge V V' (x, ρ)) y * indicator X (case_nat y (merge V V' (x,ρ))) stock_measure t) =
      +y. f (merge V V' (x, ρ)) y * indicator X (merge (shift_var_set V) (Suc`V')
                 (case_nat y x,case_nat undefined ρ)) stock_measure t"
    by (intro nn_integral_cong) (simp add: B)
  have "?I = emeasure ?rhs X" using X
    apply (subst emeasure_dens_ctxt_measure_insert'[where F = M])
    apply (insert has_dens, simp add: has_parametrized_subprob_density_def)
    apply (rule measurable_space[OF measurable_case_nat_undefined ρ], simp)
    apply (intro nn_integral_cong, simp add: A C)
    done
  finally show "emeasure ?lhs X = emeasure ?rhs X" .
qed

lemma density_context_if_dens:
  assumes "has_parametrized_subprob_density (state_measure (V  V') Γ) M
               (count_space (range BoolVal)) f"
  shows "density_context V V' Γ (if_dens δ f b)"
unfolding density_context_def
proof (intro allI conjI impI subprob_spaceI)
  note D = has_parametrized_subprob_densityD[OF assms]
  from D(3) show M: "if_dens δ f b  borel_measurable (state_measure (V  V') Γ)"
    by (intro measurable_if_dens) simp_all

  fix ρ assume ρ: "ρ  space (state_measure V' Γ)"
  hence [measurable]: "(λσ. merge V V' (σ, ρ)) 
                            measurable (state_measure V Γ) (state_measure (V  V') Γ)"
    unfolding state_measure_def by simp

  {
    fix σ assume "σ  space (state_measure V Γ)"
    with ρ have σρ: "merge V V' (σ, ρ)  space (state_measure (V  V') Γ)"
      by (intro merge_in_state_measure)
    with assms have "has_subprob_density (M (merge V V' (σ, ρ))) (count_space (range BoolVal))
                         (f (merge V V' (σ, ρ)))"
      unfolding has_parametrized_subprob_density_def by auto
    with σρ have "f (merge V V' (σ, ρ)) (BoolVal b)  1" "δ (merge V V' (σ, ρ))  0"
      by (auto intro: subprob_count_space_density_le_1)
  } note dens_props = this

  from ρ interpret subprob_space "dens_ctxt_measure 𝒴 ρ" by (rule subprob_space_dens)
  let ?M = "dens_ctxt_measure (V, V', Γ, if_dens δ f b) ρ"
  have "emeasure ?M (space ?M) =
          +x. if_dens δ f b (merge V V' (x, ρ)) state_measure V Γ"
    using M ρ unfolding dens_ctxt_measure_def state_measure'_def
    by (simp only: prod.case space_density)
       (auto simp: nn_integral_distr emeasure_density cong: nn_integral_cong')
  also from ρ have "...  +x. δ (merge V V' (x, ρ)) * 1 state_measure V Γ"
    unfolding if_dens_def using dens_props
    by (intro nn_integral_mono mult_left_mono) simp_all
  also from ρ have "... = branch_prob 𝒴 ρ" by (simp add: branch_prob_altdef)
  also have "... = emeasure (dens_ctxt_measure 𝒴 ρ) (space (dens_ctxt_measure 𝒴 ρ))"
    by (simp add: branch_prob_def)
  also have "...  1" by (rule emeasure_space_le_1)
  finally show "emeasure ?M (space ?M)  1" .
qed (insert disjoint, auto)

lemma density_context_if_dens_det:
  assumes e: "Γ  e : BOOL" "randomfree e" "free_vars e  V  V'"
  shows "density_context V V' Γ (if_dens_det δ e b)"
unfolding density_context_def
proof (intro allI conjI impI subprob_spaceI)
  from assms show M: "if_dens_det δ e b  borel_measurable (state_measure (V  V') Γ)"
    by (intro measurable_if_dens_det) simp_all

  fix ρ assume ρ: "ρ  space (state_measure V' Γ)"
  hence [measurable]: "(λσ. merge V V' (σ, ρ)) 
                            measurable (state_measure V Γ) (state_measure (V  V') Γ)"
    unfolding state_measure_def by simp

  from ρ interpret subprob_space "dens_ctxt_measure 𝒴 ρ" by (rule subprob_space_dens)
  let ?M = "dens_ctxt_measure (V, V', Γ, if_dens_det δ e b) ρ"
  have "emeasure ?M (space ?M) =
          +x. if_dens_det δ e b (merge V V' (x, ρ)) state_measure V Γ"
    using M ρ unfolding dens_ctxt_measure_def state_measure'_def
    by (simp only: prod.case space_density)
       (auto simp: nn_integral_distr emeasure_density cong: nn_integral_cong')
  also from ρ have "...  +x. δ (merge V V' (x, ρ)) * 1 state_measure V Γ"
    unfolding if_dens_det_def
    by (intro nn_integral_mono mult_left_mono) (simp_all add: merge_in_state_measure)
  also from ρ have "... = branch_prob 𝒴 ρ" by (simp add: branch_prob_altdef)
  also have "... = emeasure (dens_ctxt_measure 𝒴 ρ) (space (dens_ctxt_measure 𝒴 ρ))"
    by (simp add: branch_prob_def)
  also have "...  1" by (rule emeasure_space_le_1)
  finally show "emeasure ?M (space ?M)  1" .
qed (insert disjoint assms, auto intro: measurable_if_dens_det)


lemma density_context_empty[simp]: "density_context {} (VV') Γ (λ_. 1)"
unfolding density_context_def
proof (intro allI conjI impI subprob_spaceI)
  fix ρ assume ρ: "ρ  space (state_measure (V  V') Γ)"
  let ?M = "dens_ctxt_measure ({},VV',Γ,λ_. 1) ρ"
  from ρ have "σ. merge {} (VV') (σ,ρ) = ρ"
    by (intro ext) (auto simp: merge_def state_measure_def space_PiM)
  with ρ show "emeasure ?M (space ?M)  1"
    unfolding dens_ctxt_measure_def state_measure'_def
    by (auto simp: emeasure_density emeasure_distr state_measure_def PiM_empty)
qed auto

lemma dens_ctxt_measure_bind_const:
  assumes "ρ  space (state_measure V' Γ)" "subprob_space N"
  shows "dens_ctxt_measure 𝒴 ρ  (λ_. N) = density N (λ_. branch_prob 𝒴 ρ)" (is "?M1 = ?M2")
proof (rule measure_eqI)
  have [simp]: "sets ?M1 = sets N" by (auto simp: space_subprob_algebra assms)
  thus "sets ?M1 = sets ?M2" by simp
  fix X assume X: "X  sets ?M1"
  with assms have "emeasure ?M1 X = emeasure N X * branch_prob 𝒴 ρ"
    unfolding branch_prob_def by (subst emeasure_bind_const') (auto simp: subprob_space_dens)
  also from X have "emeasure N X = +x. indicator X x N" by simp
  also from X have "... * branch_prob 𝒴 ρ = +x. branch_prob 𝒴 ρ * indicator X x N"
    by (subst nn_integral_cmult) (auto simp: branch_prob_def field_simps)
  also from X have "... = emeasure ?M2 X" by (simp add: emeasure_density)
  finally show "emeasure ?M1 X = emeasure ?M2 X" .
qed


lemma nn_integral_dens_ctxt_measure_restrict:
  assumes "ρ  space (state_measure V' Γ)" "f ρ  0"
  assumes "f  borel_measurable (state_measure V' Γ)"
  shows "(+x. f (restrict x V') dens_ctxt_measure 𝒴 ρ) = branch_prob 𝒴 ρ * f ρ"
proof-
  have "(+x. f (restrict x V') dens_ctxt_measure (V,V',Γ,δ) ρ) =
          + x. δ (merge V V' (x, ρ)) * f (restrict (merge V V' (x, ρ)) V') state_measure V Γ"
          (is "_ = ?I")
    by (subst nn_integral_dens_ctxt_measure, simp add: assms,
        rule measurable_compose[OF measurable_restrict], unfold state_measure_def,
        rule measurable_component_singleton, insert assms, simp_all add: state_measure_def)
  also from assms(1) and disjoint
    have "x. x  space (state_measure V Γ)  restrict (merge V V' (x, ρ)) V' = ρ"
    by (intro ext) (auto simp: restrict_def merge_def state_measure_def space_PiM dest: PiE_mem)
  hence "?I = + x. δ (merge V V' (x, ρ)) * f ρ state_measure V Γ"
    by (intro nn_integral_cong) simp
  also have "... = (+x. f ρ dens_ctxt_measure (V,V',Γ,δ) ρ)"
    by (subst nn_integral_dens_ctxt_measure) (simp_all add: assms)
  also have "... = f ρ * branch_prob 𝒴 ρ"
    by (subst nn_integral_const)
       (simp_all add: assms branch_prob_def)
  finally show ?thesis by (simp add: field_simps)
qed

lemma expr_sem_op_eq_distr:
  assumes "Γ  oper $$ e : t'" "free_vars e  V  V'" "ρ  space (state_measure V' Γ)"
  defines "M  dens_ctxt_measure (V,V',Γ,δ) ρ"
  shows "M  (λσ. expr_sem σ (oper $$ e)) =
             distr (M  (λσ. expr_sem σ e)) (stock_measure t') (op_sem oper)"
proof-
  from assms(1) obtain t where t1: "Γ  e : t" and t2: "op_type oper t = Some t'" by auto
  let ?N = "stock_measure t" and ?R = "subprob_algebra (stock_measure t')"

  {
    fix x assume "x  space (stock_measure t)"
    with t1 assms(2,3) have "val_type x = t"
      by (auto simp: state_measure_def space_PiM dest: PiE_mem)
    hence "return_val (op_sem oper x) = return (stock_measure t') (op_sem oper x)"
      unfolding return_val_def by (subst op_sem_val_type) (simp_all add: t2)
  } note return_op_sem = this

  from assms and t1 have M_e: "(λσ. expr_sem σ e)  measurable M (subprob_algebra (stock_measure t))"
    by (simp add: M_def measurable_dens_ctxt_measure_eq measurable_expr_sem)
  from return_op_sem
    have M_cong: "(λx. return_val (op_sem oper x))  measurable ?N ?R 
                     (λx. return (stock_measure t') (op_sem oper x))  measurable ?N ?R"
    by (intro measurable_cong) simp
  have M_ret: "(λx. return_val (op_sem oper x))  measurable (stock_measure t) ?R"
    by (subst M_cong, intro measurable_compose[OF measurable_op_sem[OF t2]] return_measurable)

  from M_e have [simp]: "sets (M  (λσ. expr_sem σ e)) = sets (stock_measure t)"
    by (intro sets_bind) (auto simp: M_def space_subprob_algebra dest!: measurable_space)
  from measurable_cong_sets[OF this refl]
    have M_op: "op_sem oper  measurable (M  (λσ. expr_sem σ e)) (stock_measure t')"
    by (auto intro!: measurable_op_sem t2)
  have [simp]: "space (M  (λσ. expr_sem σ e)) = space (stock_measure t)"
    by (rule sets_eq_imp_space_eq) simp

  from M_e and M_ret have "M  (λσ. expr_sem σ (oper $$ e)) =
                              (M  (λσ. expr_sem σ e))  (λx. return_val (op_sem oper x))"
    unfolding M_def by (subst expr_sem.simps, intro bind_assoc[symmetric]) simp_all
  also have "... = (M  (λσ. expr_sem σ e))  (λx. return (stock_measure t') (op_sem oper x))"
    by (intro bind_cong refl) (simp add: return_op_sem)
  also have "... = distr (M  (λσ. expr_sem σ e)) (stock_measure t') (op_sem oper)"
    by (subst bind_return_distr[symmetric]) (simp_all add: o_def M_op)
  finally show ?thesis .
qed

end

lemma density_context_equiv:
  assumes "σ. σ  space (state_measure (V  V') Γ)  δ σ = δ' σ"
  assumes [simp, measurable]: "δ'  borel_measurable (state_measure (V  V') Γ)"
  assumes "density_context V V' Γ δ"
  shows "density_context V V' Γ δ'"
proof (unfold density_context_def, intro conjI allI impI subprob_spaceI)
  interpret density_context V V' Γ δ by fact
  fix ρ assume ρ: "ρ  space (state_measure V' Γ)"
  let ?M = "dens_ctxt_measure (V, V', Γ, δ') ρ"
  let ?N = "dens_ctxt_measure (V, V', Γ, δ) ρ"
  from ρ have "emeasure ?M (space ?M) = +x. δ' (merge V V' (x, ρ)) state_measure V Γ"
     unfolding dens_ctxt_measure_def state_measure'_def
    apply (simp only: prod.case, subst space_density)
    apply (simp add: emeasure_density cong: nn_integral_cong')
    apply (subst nn_integral_distr, simp add: state_measure_def, simp_all)
    done
  also from ρ have "... = +x. δ (merge V V' (x, ρ)) state_measure V Γ"
    by (intro nn_integral_cong, subst assms(1)) (simp_all add: merge_in_state_measure)
  also from ρ have "... = branch_prob (V,V',Γ,δ) ρ" by (simp add: branch_prob_altdef)
  also have "... = emeasure ?N (space ?N)" by (simp add: branch_prob_def)
  also from ρ have "...  1" by (intro subprob_space.emeasure_space_le_1 subprob_space_dens)
  finally show "emeasure ?M (space ?M)  1" .
qed (insert assms, auto simp: density_context_def)

end