Theory PDF_Semantics

(*
  Theory: PDF_Semantics.thy
  Author: Manuel Eberl

  Defines the expressions of the PDF language and their typing rules and semantics
  as well as a number of standard semantics-related theorems such as substitution.
*)

theory PDF_Semantics
imports PDF_Values
begin

lemma measurable_subprob_algebra_density:
  assumes "sigma_finite_measure N"
  assumes "space N  {}"
  assumes [measurable]: "case_prod f  borel_measurable (M M N)"
  assumes "x. x  space M  (+y. f x y N)  1"
  shows "(λx. density N (f x))  measurable M (subprob_algebra N)"
proof (rule measurable_subprob_algebra)
  fix x assume "x  space M"
  with assms show "subprob_space (density N (f x))"
    by (intro subprob_spaceI) (auto simp: emeasure_density cong: nn_integral_cong')
next
  interpret sigma_finite_measure N by fact
  fix X assume X: "X  sets N"
  hence "(λx. (+y. f x y * indicator X y N))  borel_measurable M" by simp
  moreover from X and assms have
      "x. x  space M  emeasure (density N (f x)) X = (+y. f x y * indicator X y N)"
    by (simp add: emeasure_density)
  ultimately show "(λx. emeasure (density N (f x)) X)  borel_measurable M"
    by (simp only: cong: measurable_cong)
qed simp_all

section ‹Built-in Probability Distributions›

subsection ‹Bernoulli›

definition bernoulli_density :: "real  bool  ennreal" where
  "bernoulli_density p b = (if p  {0..1} then (if b then p else 1 - p) else 0)"

definition bernoulli :: "val  val measure" where
  "bernoulli p = density BOOL (bernoulli_density (extract_real p) o extract_bool)"

lemma measurable_bernoulli_density[measurable]:
  "case_prod bernoulli_density  borel_measurable (borel M count_space UNIV)"
  unfolding bernoulli_density_def[abs_def] by measurable

lemma measurable_bernoulli[measurable]: "bernoulli  measurable REAL (subprob_algebra BOOL)"
  unfolding bernoulli_def[abs_def]
  by (auto intro!: measurable_subprob_algebra_density
           simp: measurable_split_conv nn_integral_BoolVal bernoulli_density_def
             ennreal_plus[symmetric]
           simp del: ennreal_plus)

subsection ‹Uniform›

definition uniform_real_density :: "real × real  real  ennreal" where
  "uniform_real_density  λ(a,b) x. ennreal (if a < b  x  {a..b} then inverse (b - a) else 0)"

definition uniform_int_density :: "int × int  int  ennreal" where
  "uniform_int_density  λ(a,b) x. (if x  {a..b} then inverse (nat (b - a + 1)) else 0)"

lemma measurable_uniform_density_int[measurable]:
  "(case_prod uniform_int_density)
        borel_measurable ((count_space UNIV M count_space UNIV) M count_space UNIV)"
  by (simp add: pair_measure_countable)

lemma measurable_uniform_density_real[measurable]:
  "(case_prod uniform_real_density)  borel_measurable (borel M borel)"
proof-
  have "(case_prod uniform_real_density) =
            (λx. uniform_real_density (fst (fst x), snd (fst x)) (snd x))"
      by (rule ext) (simp split: prod.split)
  also have "...  borel_measurable (borel M borel)"
      unfolding uniform_real_density_def
      by (simp only: prod.case) (simp add: borel_prod[symmetric])
  finally show ?thesis .
qed

definition uniform_int :: "val  val measure" where
  "uniform_int = map_int_pair (λl u. density INTEG (uniform_int_density (l,u) o extract_int)) (λ_. undefined)"

definition uniform_real :: "val  val measure" where
  "uniform_real = map_real_pair (λl u. density REAL (uniform_real_density (l,u) o extract_real)) (λ_. undefined)"

lemma if_bounded: "(if a  i  i  b then v else 0) = (v::real) * indicator {a .. b} i"
  by auto

lemma measurable_uniform_int[measurable]:
  "uniform_int  measurable (PRODUCT INTEG INTEG) (subprob_algebra INTEG)"
  unfolding uniform_int_def
proof (rule measurable measurable_subprob_algebra_density)+
  fix x :: "int × int"

  show "integralN INTEG (uniform_int_density (fst x, snd x)  extract_int)  1"
  proof cases
    assume "fst x  snd x" then show ?thesis
      by (cases x)
         (simp add: uniform_int_density_def comp_def nn_integral_IntVal nn_integral_cmult
                    nn_integral_set_ennreal[symmetric] ennreal_of_nat_eq_real_of_nat
                    if_bounded[where 'a=int] ennreal_mult[symmetric]
               del: ennreal_plus)
  qed (simp add: uniform_int_density_def comp_def split_beta' if_bounded[where 'a=int])
qed (auto simp: comp_def)

lemma density_cong':
  "(x. x  space M  f x = g x)  density M f = density M g"
  unfolding density_def
  by (auto dest: sets.sets_into_space intro!: nn_integral_cong measure_of_eq)

lemma measurable_uniform_real[measurable]:
  "uniform_real  measurable (PRODUCT REAL REAL) (subprob_algebra REAL)"
  unfolding uniform_real_def
proof (rule measurable measurable_subprob_algebra_density)+
  fix x :: "real × real"
  obtain l u where [simp]: "x = (l, u)"
    by (cases x) auto
  show "(+y. (uniform_real_density (fst x, snd x) o extract_real) y REAL)  1"
  proof cases
    assume "l < u" then show ?thesis
      by (simp add: nn_integral_RealVal uniform_real_density_def if_bounded nn_integral_cmult
                    nn_integral_set_ennreal[symmetric] ennreal_mult[symmetric])
  qed (simp add: uniform_real_density_def comp_def)
qed (auto simp: comp_def borel_prod)

subsection ‹Gaussian›

definition gaussian_density :: "real × real  real  ennreal" where
  "gaussian_density 
      λ(m,s) x. (if s > 0 then exp (-(x - m)2 / (2 * s2)) / sqrt (2 * pi * s2) else 0)"

lemma measurable_gaussian_density[measurable]:
  "case_prod gaussian_density  borel_measurable (borel M borel)"
proof-
  have "case_prod gaussian_density =
              (λ(x,y). (if snd x > 0 then exp (-((y - fst x)^2) / (2 * snd x^2)) /
                             sqrt (2 * pi * snd x^2) else 0))"
    unfolding gaussian_density_def by (intro ext) (simp split: prod.split)
  also have "...  borel_measurable (borel M borel)"
    by (simp add: borel_prod[symmetric])
  finally show ?thesis .
qed

definition gaussian :: "val  val measure" where
  "gaussian = map_real_pair (λm s. density REAL (gaussian_density (m,s) o extract_real)) undefined"

lemma measurable_gaussian[measurable]: "gaussian  measurable (PRODUCT REAL REAL) (subprob_algebra REAL)"
  unfolding gaussian_def
proof (rule measurable measurable_subprob_algebra_density)+
  fix x :: "real × real"
  show "integralN (stock_measure REAL) (gaussian_density (fst x, snd x)  extract_real)  1"
  proof cases
    assume "snd x > 0"
    then have "integralN lborel (gaussian_density x) = (+y. normal_density (fst x) (snd x) y lborel)"
      by (auto simp add: gaussian_density_def normal_density_def split_beta' intro!: nn_integral_cong)
    also have " = 1"
      using snd x > 0
      by (subst nn_integral_eq_integral) (auto intro!: normal_density_nonneg)
    finally show ?thesis
      by (cases x) (simp add: nn_integral_RealVal comp_def)
  next
    assume "¬ snd x > 0" then show ?thesis
      by (cases x)
         (simp add: nn_integral_RealVal comp_def gaussian_density_def zero_ennreal_def[symmetric])
  qed
qed (auto simp: comp_def borel_prod)

subsection ‹Poisson›

definition poisson_density' :: "real  int  ennreal" where
  "poisson_density' rate k = pmf (poisson_pmf rate) (nat k) * indicator ({0 <..} × {0..}) (rate, k)"

lemma measurable_poisson_density'[measurable]:
    "case_prod poisson_density'  borel_measurable (borel M count_space UNIV)"
proof -
  have "case_prod poisson_density' =
    (λ(rate, k). rate ^ nat k / real_of_nat (fact (nat k)) * exp (-rate) * indicator ({0 <..} × {0..}) (rate, k))"
    by (auto split: split_indicator simp: fun_eq_iff poisson_density'_def)
  then show ?thesis
    by simp
qed

definition poisson :: "val  val measure" where
  "poisson rate = density INTEG (poisson_density' (extract_real rate) o extract_int)"

lemma measurable_poisson[measurable]: "poisson  measurable REAL (subprob_algebra INTEG)"
  unfolding poisson_def[abs_def]
proof (rule measurable measurable_subprob_algebra_density)+
  fix r :: real
  have [simp]: "nat ` {0..} = UNIV"
    by (auto simp: image_iff intro!: bexI[of _ "int x" for x])

  { assume "0 < r"
    then have "(+ x. ennreal (r ^ nat x * exp (- r) * indicator ({0<..} × {0..}) (r, x) / (fact (nat x))) count_space UNIV)
      = (+ x. ennreal (pmf (poisson_pmf r) (nat x)) count_space {0 ..})"
      by (auto intro!: nn_integral_cong simp add: nn_integral_count_space_indicator split: split_indicator)
    also have " = 1"
      using measure_pmf.emeasure_space_1[of "poisson_pmf r"]
      by (subst nn_integral_pmf') (auto simp: inj_on_def)
    finally have "(+ x. ennreal (r ^ nat x * exp (- r) * indicator ({0<..} × {0..}) (r, x) / (fact (nat x))) count_space UNIV) = 1"
      . }
  then show "integralN INTEG (poisson_density' r  extract_int)  1"
    by (cases "0 < r")
       (auto simp: nn_integral_IntVal poisson_density'_def zero_ennreal_def[symmetric])
qed (auto simp: comp_def)

section ‹Source Language Syntax and Semantics›

subsection ‹Expressions›

class expr = fixes free_vars :: "'a  vname set"

datatype pdf_dist = Bernoulli | UniformInt | UniformReal | Poisson | Gaussian

datatype pdf_operator = Fst | Snd | Add | Mult | Minus | Less | Equals | And | Not | Or | Pow |
                        Sqrt | Exp | Ln | Fact | Inverse | Pi | Cast pdf_type

datatype expr =
      Var vname
    | Val val
    | LetVar expr expr (LET _ IN _› [0, 60] 61)
    | Operator pdf_operator expr (infixl $$ 999)
    | Pair expr expr  (<_ ,  _>  [0, 60] 1000)
    | Random pdf_dist expr
    | IfThenElse expr expr expr (IF _ THEN _ ELSE _› [0, 0, 70] 71)
    | Fail pdf_type

type_synonym tyenv = "vname  pdf_type"

instantiation expr :: expr
begin

primrec free_vars_expr :: "expr  vname set" where
  "free_vars_expr (Var x) = {x}"
| "free_vars_expr (Val _) = {}"
| "free_vars_expr (LetVar e1 e2) = free_vars_expr e1  Suc -` free_vars_expr e2"
| "free_vars_expr (Operator _ e) = free_vars_expr e"
| "free_vars_expr (<e1, e2>) = free_vars_expr e1  free_vars_expr e2"
| "free_vars_expr (Random _ e) = free_vars_expr e"
| "free_vars_expr (IF b THEN e1 ELSE e2) =
       free_vars_expr b  free_vars_expr e1  free_vars_expr e2"
| "free_vars_expr (Fail _) = {}"

instance ..
end

primrec free_vars_expr_code :: "expr  vname set" where
  "free_vars_expr_code (Var x) = {x}"
| "free_vars_expr_code (Val _) = {}"
| "free_vars_expr_code (LetVar e1 e2) =
      free_vars_expr_code e1  (λx. x - 1) ` (free_vars_expr_code e2 - {0})"
| "free_vars_expr_code (Operator _ e) = free_vars_expr_code e"
| "free_vars_expr_code (<e1, e2>) = free_vars_expr_code e1  free_vars_expr_code e2"
| "free_vars_expr_code (Random _ e) = free_vars_expr_code e"
| "free_vars_expr_code (IF b THEN e1 ELSE e2) =
       free_vars_expr_code b  free_vars_expr_code e1  free_vars_expr_code e2"
| "free_vars_expr_code (Fail _) = {}"

lemma free_vars_expr_code[code]:
  "free_vars (e::expr) = free_vars_expr_code e"
proof-
  have "A. Suc -` A = (λx. x - 1) ` (A - {0})" by force
  thus ?thesis by (induction e) simp_all
qed


primrec dist_param_type where
  "dist_param_type Bernoulli = REAL"
| "dist_param_type Poisson = REAL"
| "dist_param_type Gaussian = PRODUCT REAL REAL"
| "dist_param_type UniformInt = PRODUCT INTEG INTEG"
| "dist_param_type UniformReal = PRODUCT REAL REAL"

primrec dist_result_type where
  "dist_result_type Bernoulli = BOOL"
| "dist_result_type UniformInt = INTEG"
| "dist_result_type UniformReal = REAL"
| "dist_result_type Poisson = INTEG"
| "dist_result_type Gaussian = REAL"

primrec dist_measure :: "pdf_dist  val  val measure" where
  "dist_measure Bernoulli = bernoulli"
| "dist_measure UniformInt = uniform_int"
| "dist_measure UniformReal = uniform_real"
| "dist_measure Poisson = poisson"
| "dist_measure Gaussian = gaussian"

lemma measurable_dist_measure[measurable]:
  "dist_measure d  measurable (dist_param_type d) (subprob_algebra (dist_result_type d))"
  by (cases d) simp_all

lemma sets_dist_measure[simp]:
  "val_type x = dist_param_type dst 
       sets (dist_measure dst x) = sets (stock_measure (dist_result_type dst))"
  by (rule sets_kernel[OF measurable_dist_measure]) simp

lemma space_dist_measure[simp]:
  "val_type x = dist_param_type dst 
       space (dist_measure dst x) = type_universe (dist_result_type dst)"
  by (subst space_stock_measure[symmetric]) (intro sets_eq_imp_space_eq sets_dist_measure)

primrec dist_dens :: "pdf_dist  val  val  ennreal" where
  "dist_dens Bernoulli x y = bernoulli_density (extract_real x) (extract_bool y)"
| "dist_dens UniformInt x y = uniform_int_density (extract_int_pair x) (extract_int y)"
| "dist_dens UniformReal x y = uniform_real_density (extract_real_pair x) (extract_real y)"
| "dist_dens Gaussian x y = gaussian_density (extract_real_pair x) (extract_real y)"
| "dist_dens Poisson x y = poisson_density' (extract_real x) (extract_int y)"

lemma measurable_dist_dens[measurable]:
    assumes "f  measurable M (stock_measure (dist_param_type dst))" (is "_  measurable M ?N")
    assumes "g  measurable M (stock_measure (dist_result_type dst))" (is "_  measurable M ?R")
    shows "(λx. dist_dens dst (f x) (g x))  borel_measurable M"
  apply (rule measurable_Pair_compose_split[of "dist_dens dst", OF _ assms])
  apply (subst dist_dens_def, cases dst, simp_all)
  done

lemma dist_measure_has_density:
  "v  type_universe (dist_param_type dst) 
       has_density (dist_measure dst v) (stock_measure (dist_result_type dst)) (dist_dens dst v)"
proof (intro has_densityI)
  fix v assume "v  type_universe (dist_param_type dst)"
  thus "dist_measure dst v = density (stock_measure (dist_result_type dst)) (dist_dens dst v)"
    by (cases dst)
       (auto simp: bernoulli_def uniform_int_def uniform_real_def poisson_def gaussian_def
             intro!: density_cong' elim!: PROD_E REAL_E INTEG_E)
qed simp_all

lemma subprob_space_dist_measure:
    "v  type_universe (dist_param_type dst)  subprob_space (dist_measure dst v)"
  using subprob_space_kernel[OF measurable_dist_measure, of v dst] by simp

lemma dist_measure_has_subprob_density:
  "v  type_universe (dist_param_type dst) 
       has_subprob_density (dist_measure dst v) (stock_measure (dist_result_type dst)) (dist_dens dst v)"
  unfolding has_subprob_density_def
  by (auto intro: subprob_space_dist_measure dist_measure_has_density)

lemma dist_dens_integral_space:
  assumes "v  type_universe (dist_param_type dst)"
  shows "(+u. dist_dens dst v u stock_measure (dist_result_type dst))  1"
proof-
  let ?M = "density (stock_measure (dist_result_type dst)) (dist_dens dst v)"
  from assms have "(+u. dist_dens dst v u stock_measure (dist_result_type dst)) =
                       emeasure ?M (space ?M)"
    by (subst space_density, subst emeasure_density)
       (auto intro!: measurable_dist_dens cong: nn_integral_cong')
  also have "?M = dist_measure dst v" using dist_measure_has_density[OF assms]
    by (auto dest: has_densityD)
  also from assms have "emeasure ... (space ...)  1"
    by (intro subprob_space.emeasure_space_le_1 subprob_space_dist_measure)
  finally show ?thesis .
qed


subsection ‹Typing›

primrec op_type :: "pdf_operator  pdf_type  pdf_type option" where
  "op_type Add x =
      (case x of
         PRODUCT INTEG INTEG  Some INTEG
       | PRODUCT REAL REAL  Some REAL
       | _  None)"
| "op_type Mult x =
      (case x of
         PRODUCT INTEG INTEG  Some INTEG
       | PRODUCT REAL REAL  Some REAL
       | _  None)"
| "op_type Minus x =
      (case x of
         INTEG  Some INTEG
       | REAL  Some REAL
       | _  None)"
| "op_type Equals x =
      (case x of
         PRODUCT t1 t2  if t1 = t2 then Some BOOL else None
       | _  None)"
| "op_type Less x =
      (case x of
         PRODUCT INTEG INTEG  Some BOOL
       | PRODUCT REAL REAL  Some BOOL
       | _  None)"
| "op_type (Cast t) x =
      (case (x, t) of
         (BOOL, INTEG)  Some INTEG
       | (BOOL, REAL)  Some REAL
       | (INTEG, REAL)  Some REAL
       | (REAL, INTEG)  Some INTEG
       | _  None)"
| "op_type Or x = (case x of PRODUCT BOOL BOOL  Some BOOL | _  None)"
| "op_type And x = (case x of PRODUCT BOOL BOOL  Some BOOL | _  None)"
| "op_type Not x = (case x of BOOL  Some BOOL | _  None)"
| "op_type Inverse x = (case x of REAL  Some REAL | _  None)"
| "op_type Fact x = (case x of INTEG  Some INTEG | _  None)"
| "op_type Sqrt x = (case x of REAL  Some REAL | _  None)"
| "op_type Exp x = (case x of REAL  Some REAL | _  None)"
| "op_type Ln x = (case x of REAL  Some REAL | _  None)"
| "op_type Pi x = (case x of UNIT  Some REAL | _  None)"
| "op_type Pow x = (case x of
                      PRODUCT REAL INTEG  Some REAL
                    | PRODUCT INTEG INTEG  Some INTEG
                    | _  None)"
| "op_type Fst x = (case x of PRODUCT t _   Some t | _  None)"
| "op_type Snd x = (case x of PRODUCT _ t   Some t | _  None)"


subsection ‹Semantics›

abbreviation (input) de_bruijn_insert (infixr  65) where
  "de_bruijn_insert x f  case_nat x f"

inductive expr_typing :: "tyenv  expr  pdf_type  bool" ((1_/ / (_ :/ _)) [50,0,50] 50) where
  et_var:  "Γ  Var x : Γ x"
| et_val:  "Γ  Val v : val_type v"
| et_let:  "Γ  e1 : t1  t1  Γ  e2 : t2  Γ  LetVar e1 e2 : t2"
| et_op:   "Γ  e : t  op_type oper t = Some t'  Γ  Operator oper e : t'"
| et_pair: "Γ  e1 : t1   Γ  e2 : t2   Γ  <e1, e2> : PRODUCT t1 t2"
| et_rand: "Γ  e : dist_param_type dst  Γ  Random dst e :  dist_result_type dst"
| et_if:   "Γ  b : BOOL  Γ  e1 : t  Γ  e2 : t  Γ  IF b THEN e1 ELSE e2 : t"
| et_fail: "Γ  Fail t : t"

lemma expr_typing_cong':
  "Γ  e : t  (x. x  free_vars e  Γ x = Γ' x)  Γ'  e : t"
proof (induction arbitrary: Γ' rule: expr_typing.induct)
  case (et_let Γ e1 t1 e2 t2 Γ')
  have "Γ'  e1 : t1" using et_let.prems by (intro et_let.IH(1)) auto
  moreover have "case_nat t1 Γ'  e2 : t2"
    using et_let.prems by (intro et_let.IH(2)) (auto split: nat.split)
  ultimately show ?case by (auto intro!: expr_typing.intros)
qed (auto intro!: expr_typing.intros)

lemma expr_typing_cong:
  "(x. x  free_vars e  Γ x = Γ' x)  Γ  e : t  Γ'  e : t"
  by (intro iffI) (simp_all add: expr_typing_cong')

inductive_cases expr_typing_valE[elim]:  "Γ  Val v : t"
inductive_cases expr_typing_varE[elim]:  "Γ  Var x : t"
inductive_cases expr_typing_letE[elim]:  "Γ  LetVar e1 e2 : t"
inductive_cases expr_typing_ifE[elim]:  "Γ  IfThenElse b e1 e2 : t"
inductive_cases expr_typing_opE[elim]:   "Γ  Operator oper e : t"
inductive_cases expr_typing_pairE[elim]: "Γ  <e1, e2> : t"
inductive_cases expr_typing_randE[elim]: "Γ  Random dst e : t"
inductive_cases expr_typing_failE[elim]: "Γ  Fail t : t'"

lemma expr_typing_unique:
  "Γ  e : t  Γ  e : t'  t = t'"
  apply (induction arbitrary: t' rule: expr_typing.induct)
  apply blast
  apply blast
  apply (erule expr_typing_letE, blast)
  apply (erule expr_typing_opE, simp)
  apply (erule expr_typing_pairE, blast)
  apply (erule expr_typing_randE, blast)
  apply (erule expr_typing_ifE, blast)
  apply blast
  done

fun expr_type :: "tyenv  expr  pdf_type option" where
  "expr_type Γ (Var x) = Some (Γ x)"
| "expr_type Γ (Val v) = Some (val_type v)"
| "expr_type Γ (LetVar e1 e2) =
       (case expr_type Γ e1 of
          Some t  expr_type (case_nat t Γ) e2
        | None  None)"
| "expr_type Γ (Operator oper e) =
       (case expr_type Γ e of Some t  op_type oper t | None  None)"
| "expr_type Γ (<e1, e2>) =
       (case (expr_type Γ e1, expr_type Γ e2) of
          (Some t1, Some t2)  Some (PRODUCT t1 t2)
        |  _  None)"
| "expr_type Γ (Random dst e) =
       (if expr_type Γ e = Some (dist_param_type dst) then
           Some (dist_result_type dst)
        else None)"
| "expr_type Γ (IF b THEN e1 ELSE e2) =
       (if expr_type Γ b = Some BOOL then
          (case (expr_type Γ e1, expr_type Γ e2) of
             (Some t, Some t')  if t = t' then Some t else None
           | _  None) else None)"
| "expr_type Γ (Fail t) = Some t"

lemma expr_type_Some_iff: "expr_type Γ e = Some t  Γ  e : t"
  apply rule
  apply (induction e arbitrary: Γ t,
         auto intro!: expr_typing.intros split: option.split_asm if_split_asm) []
  apply (induction rule: expr_typing.induct, auto simp del: fun_upd_apply)
  done

lemmas expr_typing_code[code_unfold] = expr_type_Some_iff[symmetric]


subsubsection ‹Countable types›

primrec countable_type :: "pdf_type  bool" where
  "countable_type UNIT = True"
| "countable_type BOOL = True"
| "countable_type INTEG = True"
| "countable_type REAL = False"
| "countable_type (PRODUCT t1 t2) = (countable_type t1  countable_type t2)"

lemma countable_type_countable[dest]:
    "countable_type t  countable (space (stock_measure t))"
  by (induction t)
     (auto simp: pair_measure_countable space_embed_measure space_pair_measure stock_measure.simps)

lemma countable_type_imp_count_space:
  "countable_type t  stock_measure t = count_space (type_universe t)"
proof (subst space_stock_measure[symmetric], induction t)
  case (PRODUCT t1 t2)
    hence countable: "countable_type t1" "countable_type t2" by simp_all
    note A = PRODUCT.IH(1)[OF countable(1)] and B = PRODUCT.IH(2)[OF countable(2)]
    show "stock_measure (PRODUCT t1 t2) = count_space (space (stock_measure (PRODUCT t1 t2)))"
      apply (subst (1 2) stock_measure.simps)
      apply (subst (1 2) A, subst (1 2) B)
      apply (subst (1 2) pair_measure_countable)
      apply (auto intro: countable_type_countable simp: countable simp del: space_stock_measure) [2]
      apply (subst (1 2) embed_measure_count_space, force intro: injI)
      apply simp
      done
qed (simp_all add: stock_measure.simps)

lemma return_val_countable:
  assumes "countable_type (val_type v)"
  shows "return_val v = density (stock_measure (val_type v)) (indicator {v})" (is "?M1 = ?M2")
proof (rule measure_eqI)
  let ?M3 = "count_space (type_universe (val_type v))"
  fix X assume asm: "X  ?M1"
  with assms have "emeasure ?M2 X = + x. indicator {v} x * indicator X x
                                              count_space (type_universe (val_type v))"
    by (simp add: return_val_def emeasure_density countable_type_imp_count_space)
  also have "(λx. indicator {v} x * indicator X x :: ennreal) = (λx. indicator (X  {v}) x)"
    by (rule ext, subst Int_commute) (simp split: split_indicator)
  also have "nn_integral ?M3 ... = emeasure ?M3 (X  {v})"
    by (subst nn_integral_indicator[symmetric]) auto
  also from asm have "... = emeasure ?M1 X" by (auto simp: return_val_def split: split_indicator)
  finally show "emeasure ?M1 X = emeasure ?M2 X" ..
qed (simp add: return_val_def)



subsection ‹Semantics›

definition bool_to_int :: "bool  int" where
  "bool_to_int b = (if b then 1 else 0)"

lemma measurable_bool_to_int[measurable]:
  "bool_to_int  measurable (count_space UNIV) (count_space UNIV)"
  by (rule measurable_count_space)

definition bool_to_real :: "bool  real" where
  "bool_to_real b = (if b then 1 else 0)"

lemma measurable_bool_to_real[measurable]:
  "bool_to_real  borel_measurable (count_space UNIV)"
  by (rule borel_measurable_count_space)

definition safe_ln :: "real  real" where
  "safe_ln x = (if x > 0 then ln x else 0)"

lemma safe_ln_gt_0[simp]: "x > 0  safe_ln x = ln x"
  by (simp add: safe_ln_def)

lemma borel_measurable_safe_ln[measurable]: "safe_ln  borel_measurable borel"
  unfolding safe_ln_def[abs_def] by simp


definition safe_sqrt :: "real  real" where
  "safe_sqrt x = (if x  0 then sqrt x else 0)"

lemma safe_sqrt_ge_0[simp]: "x  0  safe_sqrt x = sqrt x"
  by (simp add: safe_sqrt_def)

lemma borel_measurable_safe_sqrt[measurable]: "safe_sqrt  borel_measurable borel"
  unfolding safe_sqrt_def[abs_def] by simp


fun op_sem :: "pdf_operator  val  val" where
  "op_sem Add = lift_RealIntVal2 (+) (+)"
| "op_sem Mult = lift_RealIntVal2 (*) (*)"
| "op_sem Minus = lift_RealIntVal uminus uminus"
| "op_sem Equals = (λ <|v1, v2|>  BoolVal (v1 = v2))"
| "op_sem Less = lift_Comp (<) (<)"
| "op_sem Or = (λ <|BoolVal a, BoolVal b|>  BoolVal (a  b))"
| "op_sem And = (λ <|BoolVal a, BoolVal b|>  BoolVal (a  b))"
| "op_sem Not = (λ BoolVal a  BoolVal (¬a))"
| "op_sem (Cast t) = (case t of
                        INTEG  (λ BoolVal b  IntVal (bool_to_int b)
                                  | RealVal r  IntVal (floor r))
                      | REAL   (λ BoolVal b  RealVal (bool_to_real b)
                                  | IntVal i  RealVal (real_of_int i)))"
| "op_sem Inverse = lift_RealVal inverse"
| "op_sem Fact = lift_IntVal (λi::int. fact (nat i))"
| "op_sem Sqrt = lift_RealVal safe_sqrt"
| "op_sem Exp = lift_RealVal exp"
| "op_sem Ln = lift_RealVal safe_ln"
| "op_sem Pi = (λ_. RealVal pi)"
| "op_sem Pow = (λ <|RealVal x, IntVal n|>  if n < 0 then RealVal 0 else RealVal (x ^ nat n)
                 | <|IntVal x, IntVal n|>  if n < 0 then IntVal 0 else IntVal (x ^ nat n))"
| "op_sem Fst = fst  extract_pair"
| "op_sem Snd = snd  extract_pair"


text ‹The semantics of expressions. Assumes that the expression given is well-typed.›

primrec expr_sem :: "state  expr  val measure" where
  "expr_sem σ (Var x) = return_val (σ x)"
| "expr_sem σ (Val v) = return_val v"
| "expr_sem σ (LET e1 IN e2) =
      do {
        v  expr_sem σ e1;
        expr_sem (v  σ) e2
      }"
| "expr_sem σ (oper $$ e) =
      do {
        x  expr_sem σ e;
        return_val (op_sem oper x)
      }"
| "expr_sem σ <v, w> =
      do {
        x  expr_sem σ v;
        y  expr_sem σ w;
        return_val <|x, y|>
      }"
| "expr_sem σ (IF b THEN e1 ELSE e2) =
     do {
       b'  expr_sem σ b;
       if b' = TRUE then expr_sem σ e1 else expr_sem σ e2
     }"
| "expr_sem σ (Random dst e) =
     do {
       x  expr_sem σ e;
       dist_measure dst x
     }"
| "expr_sem σ (Fail t) = null_measure (stock_measure t)"

lemma expr_sem_pair_vars: "expr_sem σ <Var x, Var y> = return_val <|σ x, σ y|>"
  by (simp add: return_val_def bind_return[where N="PRODUCT (val_type (σ x)) (val_type (σ y))"]
           cong: bind_cong_simp)

text ‹
  Well-typed expressions produce a result in the measure space that corresponds to their type
›

lemma op_sem_val_type:
    "op_type oper (val_type v) = Some t'  val_type (op_sem oper v) = t'"
  by (cases oper) (auto split: val.split if_split_asm pdf_type.split_asm
                        simp: lift_RealIntVal_def lift_Comp_def
                              lift_IntVal_def lift_RealVal_def lift_RealIntVal2_def
                        elim!: PROD_E INTEG_E REAL_E)

lemma sets_expr_sem:
  "Γ  w : t  (x  free_vars w. val_type (σ x) = Γ x) 
       sets (expr_sem σ w) = sets (stock_measure t)"
proof (induction arbitrary: σ rule: expr_typing.induct)
  case (et_var Γ x σ)
  thus ?case by (simp add: return_val_def)
next
  case (et_val Γ v σ)
  thus ?case by (simp add: return_val_def)
next
  case (et_let Γ e1 t1 e2 t2 σ)
  hence "sets (expr_sem σ e1) = sets (stock_measure t1)" by simp
  from sets_eq_imp_space_eq[OF this]
    have A: "space (expr_sem σ e1) = type_universe t1" by (simp add:)
  hence B: "(SOME x. x  space (expr_sem σ e1))  space (expr_sem σ e1)" (is "?v  _")
    unfolding some_in_eq by simp
  with A et_let have "sets (expr_sem (case_nat ?v σ) e2) = sets (stock_measure t2)"
    by (intro et_let.IH(2)) (auto split: nat.split)
  with B show "sets (expr_sem σ (LetVar e1 e2)) = sets (stock_measure t2)"
    by (subst expr_sem.simps, subst bind_nonempty) auto
next
  case (et_op Γ e t oper t' σ)
  from et_op.IH[of σ] and et_op.prems
      have [simp]: "sets (expr_sem σ e) = sets (stock_measure t)" by simp
  from sets_eq_imp_space_eq[OF this]
      have [simp]: "space (expr_sem σ e) = type_universe t" by (simp add:)
  have "(SOME x. x  space (expr_sem σ e))  space (expr_sem σ e)"
    unfolding some_in_eq by simp
  with et_op show ?case by (simp add: bind_nonempty return_val_def op_sem_val_type)
next
  case (et_pair Γ e1 t1 e2 t2 σ)
  hence [simp]: "space (expr_sem σ e1) = type_universe t1"
                "space (expr_sem σ e2) = type_universe t2"
    by (simp_all add: sets_eq_imp_space_eq)
  have "(SOME x. x  space (expr_sem σ e1))  space (expr_sem σ e1)"
       "(SOME x. x  space (expr_sem σ e2))  space (expr_sem σ e2)"
    unfolding some_in_eq by simp_all
  with et_pair.hyps show ?case by (simp add: bind_nonempty return_val_def)
next
  case (et_rand Γ e dst σ)
  from et_rand.IH[of σ] et_rand.prems
  have "sets (expr_sem σ e) = sets (stock_measure (dist_param_type dst))" by simp
  from this sets_eq_imp_space_eq[OF this]
  show ?case
    apply simp_all
    apply (subst sets_bind)
    apply auto
    done
next
  case (et_if Γ b e1 t e2 σ)
  have "sets (expr_sem σ b) = sets (stock_measure BOOL)"
    using et_if.prems by (intro et_if.IH) simp
  from sets_eq_imp_space_eq[OF this]
    have "space (expr_sem σ b)  {}" by simp
  moreover have "sets (expr_sem σ e1) = sets (stock_measure t)"
                "sets (expr_sem σ e2) = sets (stock_measure t)"
    using et_if.prems by (intro et_if.IH, simp)+
  ultimately show ?case by (simp add: bind_nonempty)
qed simp_all

lemma space_expr_sem:
    "Γ  w : t  (x  free_vars w. val_type (σ x) = Γ x)
       space (expr_sem σ w) = type_universe t"
  by (subst space_stock_measure[symmetric]) (intro sets_expr_sem sets_eq_imp_space_eq)

lemma measurable_expr_sem_eq:
    "Γ  e : t  σ  space (state_measure V Γ)  free_vars e  V 
       measurable (expr_sem σ e) = measurable (stock_measure t)"
  by (intro ext measurable_cong_sets sets_expr_sem)
     (auto simp: state_measure_def space_PiM dest: PiE_mem)

lemma measurable_expr_semI:
    "Γ  e : t  σ  space (state_measure V Γ)  free_vars e  V 
       f  measurable (stock_measure t) M  f  measurable (expr_sem σ e) M"
  by (subst measurable_expr_sem_eq)

lemma expr_sem_eq_on_vars:
  "(x. xfree_vars e  σ1 x = σ2 x)  expr_sem σ1 e = expr_sem σ2 e"
proof (induction e arbitrary: σ1 σ2)
  case (LetVar e1 e2 σ1 σ2)
    from LetVar.prems have A: "expr_sem σ1 e1 = expr_sem σ2 e1" by (rule LetVar.IH(1)) simp_all
    from LetVar.prems show ?case
      by (subst (1 2) expr_sem.simps, subst A)
         (auto intro!: bind_cong LetVar.IH(2) split: nat.split)
next
  case (Operator oper e σ1 σ2)
  from Operator.IH[OF Operator.prems] show ?case by simp
next
  case (Pair e1 e2 σ1 σ2)
  from Pair.prems have "expr_sem σ1 e1 = expr_sem σ2 e1" by (intro Pair.IH) auto
  moreover from Pair.prems have "expr_sem σ1 e2 = expr_sem σ2 e2" by (intro Pair.IH) auto
  ultimately show ?case by simp
next
  case (Random dst e σ1 σ2)
  from Random.prems have A: "expr_sem σ1 e = expr_sem σ2 e" by (rule Random.IH) simp_all
  show ?case
    by (subst (1 2) expr_sem.simps, subst A) (auto intro!: bind_cong)
next
  case (IfThenElse b e1 e2 σ1 σ2)
  have A: "expr_sem σ1 b = expr_sem σ2 b"
          "expr_sem σ1 e1 = expr_sem σ2 e1"
          "expr_sem σ1 e2 = expr_sem σ2 e2"
    using IfThenElse.prems by (intro IfThenElse.IH, simp)+
  thus ?case by (simp only: expr_sem.simps A)
qed simp_all


subsection ‹Measurability›

lemma borel_measurable_eq[measurable (raw)]:
  assumes [measurable]: "f  borel_measurable M" "g  borel_measurable M"
  shows "Measurable.pred M (λx. f x = (g x::real))"
proof -
  have *: "(λx. f x = g x) = (λx. f x - g x = 0)"
    by simp
  show ?thesis
    unfolding * by measurable
qed

lemma measurable_equals:
  "(λ(x,y). x = y)  measurable (stock_measure t M stock_measure t) (count_space UNIV)"
proof (induction t)
  case REAL
  let ?f = "λx. extract_real (fst x) = extract_real (snd x)"
  show ?case
  proof (subst measurable_cong)
    fix x assume "x  space (stock_measure REAL M stock_measure REAL)"
    thus "(λ(x,y). x = y) x = ?f x"
      by (auto simp: space_pair_measure elim!: REAL_E)
  next
    show "?f  measurable (stock_measure REAL M stock_measure REAL) (count_space UNIV)"
      by measurable
  qed
next
  case (PRODUCT t1 t2)
  let ?g = "λ(x,y). x = y"
  let ?f = "λx. ?g (fst (extract_pair (fst x)), fst (extract_pair (snd x))) 
                ?g (snd (extract_pair (fst x)), snd (extract_pair (snd x)))"
  show ?case
  proof (subst measurable_cong)
    fix x assume "x  space (stock_measure (PRODUCT t1 t2) M stock_measure (PRODUCT t1 t2))"
    thus "(λ(x,y). x = y) x = ?f x"
      apply (auto simp: space_pair_measure)
      apply (elim PROD_E)
      apply simp
      done
  next
    note PRODUCT[measurable]
    show "Measurable.pred (stock_measure (PRODUCT t1 t2) M stock_measure (PRODUCT t1 t2)) ?f"
      by measurable
  qed
qed (simp_all add: pair_measure_countable stock_measure.simps)

lemma measurable_equals_stock_measure[measurable (raw)]:
  assumes "f  measurable M (stock_measure t)" "g  measurable M (stock_measure t)"
  shows "Measurable.pred M (λx. f x = g x)"
  using measurable_compose[OF measurable_Pair[OF assms] measurable_equals] by simp

lemma measurable_op_sem:
  assumes "op_type oper t = Some t'"
  shows "op_sem oper  measurable (stock_measure t) (stock_measure t')"
proof (cases oper)
  case Fst with assms show ?thesis by (simp split: pdf_type.split_asm)
next
  case Snd with assms show ?thesis by (simp split: pdf_type.split_asm)
next
  case Equals with assms show ?thesis
    by (auto intro!: val_case_stock_measurable split: if_split_asm)
next
  case Pow with assms show ?thesis
    apply (auto intro!: val_case_stock_measurable split: pdf_type.splits)
    apply (subst measurable_cong[where
      g="λ(x, n). if extract_int n < 0 then RealVal 0 else RealVal (extract_real x ^ nat (extract_int n))"])
    apply (auto simp: space_pair_measure elim!: REAL_E INTEG_E)
    done
next
  case Less with assms show ?thesis
    by (auto split: pdf_type.splits)
qed (insert assms, auto split: pdf_type.split_asm intro!: val_case_stock_measurable)

definition shift_var_set :: "vname set  vname set" where
  "shift_var_set V = insert 0 (Suc ` V)"

lemma shift_var_set_0[simp]: "0  shift_var_set V"
  by (simp add: shift_var_set_def)

lemma shift_var_set_Suc[simp]: "Suc x  shift_var_set V  x  V"
  by (auto simp add: shift_var_set_def)

lemma case_nat_update_0[simp]: "(case_nat x σ)(0 := y) = case_nat y σ"
  by (intro ext) (simp split: nat.split)

lemma case_nat_delete_var_1[simp]:
    "case_nat x (case_nat y σ)  case_nat 0 (λx. Suc (Suc x)) = case_nat x σ"
  by (intro ext) (simp split: nat.split)

lemma delete_var_1_vimage[simp]:
    "case_nat 0 (λx. Suc (Suc x)) -` (shift_var_set (shift_var_set V)) = shift_var_set V"
  by (auto simp: shift_var_set_def split: nat.split_asm)


lemma measurable_case_nat[measurable]:
  assumes "g  measurable R N" "h  measurable R (PiM V M)"
  shows "(λx. case_nat (g x) (h x))  measurable R (PiM (shift_var_set V) (case_nat N M))"
proof (rule measurable_Pair_compose_split[OF _ assms])
  have "(λ(t,f). λxshift_var_set V. case_nat t f x)
           measurable (N M PiM V M) (PiM (shift_var_set V) (case_nat N M))" (is ?P)
    unfolding shift_var_set_def
    by (subst measurable_split_conv, rule measurable_restrict) (auto split: nat.split_asm)
  also have "x f. f  space (PiM V M)  x  V  undefined = f x"
    by (rule sym, subst (asm) space_PiM, erule PiE_arb)
  hence "?P  (λ(t,f). case_nat t f)
            measurable (N M PiM V M) (PiM (shift_var_set V) (case_nat N M))" (is "_ = ?P")
    by (intro measurable_cong ext)
       (auto split: nat.split simp: inj_image_mem_iff space_pair_measure shift_var_set_def)
  finally show ?P .
qed

lemma measurable_case_nat'[measurable]:
  assumes "g  measurable R (stock_measure t)" "h  measurable R (state_measure V Γ)"
  shows "(λx. case_nat (g x) (h x)) 
           measurable R (state_measure (shift_var_set V) (case_nat t Γ))"
proof-
  have A: "(λx. stock_measure (case_nat t Γ x)) =
                 case_nat (stock_measure t) (λx. stock_measure (Γ x))"
    by (intro ext) (simp split: nat.split)
  show ?thesis using assms unfolding state_measure_def by (simp add: A)
qed

lemma case_nat_in_state_measure[intro]:
  assumes "x  type_universe t1" "σ  space (state_measure V Γ)"
  shows "case_nat x σ  space (state_measure (shift_var_set V) (case_nat t1 Γ))"
  apply (rule measurable_space[OF measurable_case_nat'])
  apply (rule measurable_ident_sets[OF refl], rule measurable_const[OF assms(2)])
  using assms
  apply simp
  done

lemma subset_shift_var_set:
    "Suc -` A  V  A  shift_var_set V"
  by (rule subsetI, rename_tac x, case_tac x) (auto simp: shift_var_set_def)

lemma measurable_expr_sem[measurable]:
  assumes "Γ  e : t" and "free_vars e  V"
  shows "(λσ. expr_sem σ e)  measurable (state_measure V Γ)
                                         (subprob_algebra (stock_measure t))"
using assms
proof (induction arbitrary: V rule: expr_typing.induct)
  case (et_var Γ x)
  have A: "(λσ. expr_sem σ (Var x)) = return_val  (λσ. σ x)" by (simp add: o_def)
  with et_var show ?case unfolding state_measure_def
    by (subst A) (rule measurable_comp[OF measurable_component_singleton], simp_all)
next
  case (et_val Γ v)
  thus ?case by (auto intro!: measurable_const subprob_space_return
                      simp: space_subprob_algebra return_val_def)
next
  case (et_let Γ e1 t1 e2 t2 V)
    have A: "(λv. stock_measure (case_nat t1 Γ v)) =
                 case_nat (stock_measure t1) (λv. stock_measure (Γ v))"
      by (rule ext) (simp split: nat.split)
    from et_let.prems and et_let.hyps show ?case
      apply (subst expr_sem.simps, intro measurable_bind)
      apply (rule et_let.IH(1), simp)
      apply (rule measurable_compose[OF _ et_let.IH(2)[of "shift_var_set V"]])
      apply (simp_all add: subset_shift_var_set)
      done
next
  case (et_op Γ e t oper t')
  thus ?case by (auto intro!: measurable_bind2 measurable_compose[OF _ measurable_return_val]
                              measurable_op_sem cong: measurable_cong)
next
  case (et_pair t t1 t2 Γ e1 e2)
  have "inj (λ(a,b). <|a, b|>)" by (auto intro: injI)
  with et_pair show ?case
      apply (subst expr_sem.simps)
      apply (rule measurable_bind, (auto) [])
      apply (rule measurable_bind[OF measurable_compose[OF measurable_fst]], (auto) [])
      apply (rule measurable_compose[OF _ measurable_return_val], simp)
      done
next
  case (et_rand Γ e dst V)
  from et_rand.prems and et_rand.hyps show ?case
    by (auto intro!: et_rand.IH measurable_compose[OF measurable_snd]
                     measurable_bind measurable_dist_measure)
next
  case (et_if Γ b e1 t e2 V)
  let ?M = "λe t. (λσ. expr_sem σ e) 
                      measurable (state_measure V Γ) (subprob_algebra (stock_measure t))"
  from et_if.prems have A[measurable]: "?M b BOOL" "?M e1 t" "?M e2 t" by (intro et_if.IH, simp)+
  show ?case by (subst expr_sem.simps, rule measurable_bind[OF A(1)]) simp_all
next
  case (et_fail Γ t V)
  show ?case
    by (auto intro!: measurable_subprob_algebra subprob_spaceI simp:)
qed


subsection ‹Randomfree expressions›

fun randomfree :: "expr  bool" where
  "randomfree (Val _) = True"
| "randomfree (Var _) = True"
| "randomfree (Pair e1 e2) = (randomfree e1  randomfree e2)"
| "randomfree (Operator _ e) = randomfree e"
| "randomfree (LetVar e1 e2) = (randomfree e1  randomfree e2)"
| "randomfree (IfThenElse b e1 e2) = (randomfree b  randomfree e1  randomfree e2)"
| "randomfree (Random _ _) = False"
| "randomfree (Fail _) = False"

primrec expr_sem_rf :: "state  expr  val" where
  "expr_sem_rf _ (Val v) = v"
| "expr_sem_rf σ (Var x) = σ x"
| "expr_sem_rf σ (<e1, e2>) = <|expr_sem_rf σ e1, expr_sem_rf σ e2|>"
| "expr_sem_rf σ (Operator oper e) = op_sem oper (expr_sem_rf σ e)"
| "expr_sem_rf σ (LetVar e1 e2) = expr_sem_rf (expr_sem_rf σ e1  σ) e2"
| "expr_sem_rf σ (IfThenElse b e1 e2) =
      (if expr_sem_rf σ b = BoolVal True then expr_sem_rf σ e1 else expr_sem_rf σ e2)"
| "expr_sem_rf _ (Random _ _) = undefined"
| "expr_sem_rf _ (Fail _) = undefined"


lemma measurable_expr_sem_rf[measurable]:
  "Γ  e : t  randomfree e  free_vars e  V 
       (λσ. expr_sem_rf σ e)  measurable (state_measure V Γ) (stock_measure t)"
proof (induction arbitrary: V rule: expr_typing.induct)
  case (et_val Γ v V)
  thus ?case by (auto intro!: measurable_const simp:)
next
  case (et_var Γ x V)
  thus ?case by (auto simp: state_measure_def intro!: measurable_component_singleton)
next
  case (et_pair Γ e1 t1 e2 t2 V)
  have "inj (λ(x,y). <|x, y|>)" by (auto intro: injI)
  with et_pair show ?case by simp
next
  case (et_op Γ e t oper t' V)
  thus ?case by (auto intro!: measurable_compose[OF _ measurable_op_sem])
next
  case (et_let Γ e1 t1 e2 t2 V)
  hence M1: "(λσ. expr_sem_rf σ e1)  measurable (state_measure V Γ) (stock_measure t1)"
    and M2: "(λσ. expr_sem_rf σ e2)  measurable (state_measure (shift_var_set V) (case_nat t1 Γ))
                                           (stock_measure t2)"
    using subset_shift_var_set
    by (auto intro!: et_let.IH(1)[of V] et_let.IH(2)[of "shift_var_set V"])
  have "(λσ. expr_sem_rf σ (LetVar e1 e2)) =
            (λσ. expr_sem_rf σ e2)  (λ(σ,y). case_nat y σ)  (λσ. (σ, expr_sem_rf σ e1))" (is "_ = ?f")
    by (intro ext) simp
  also have "?f  measurable (state_measure V Γ) (stock_measure t2)"
    apply (intro measurable_comp, rule measurable_Pair, rule measurable_ident_sets[OF refl])
    apply (rule M1, subst measurable_split_conv, rule measurable_case_nat')
    apply (rule measurable_snd, rule measurable_fst, rule M2)
    done
  finally show ?case .
qed (simp_all add: expr_sem_rf_def)

lemma expr_sem_rf_sound:
  "Γ  e : t  randomfree e  free_vars e  V  σ  space (state_measure V Γ) 
       return_val (expr_sem_rf σ e) = expr_sem σ e"
proof (induction arbitrary: V σ rule: expr_typing.induct)
  case (et_val Γ v)
  thus ?case by simp
next
  case (et_var Γ x)
 thus ?case by simp
next
  case (et_pair Γ e1 t1 e2 t2 V σ)
  let ?M = "state_measure V Γ"
  from et_pair.hyps and et_pair.prems
    have e1: "return_val (expr_sem_rf σ e1) = expr_sem σ e1" and
         e2: "return_val (expr_sem_rf σ e2) = expr_sem σ e2"
      by (auto intro!: et_pair.IH[of V])

  from e1 and et_pair.prems have "space (return_val (expr_sem_rf σ e1)) = type_universe t1"
    by (subst e1, subst space_expr_sem[OF et_pair.hyps(1)])
       (auto dest: state_measure_var_type)
  hence A: "val_type (expr_sem_rf σ e1) = t1" "expr_sem_rf σ e1  type_universe t1"
      by (auto simp add: return_val_def)
  from e2 and et_pair.prems have "space (return_val (expr_sem_rf σ e2)) = type_universe t2"
    by (subst e2, subst space_expr_sem[OF et_pair.hyps(2)])
       (auto dest: state_measure_var_type)
  hence B: "val_type (expr_sem_rf σ e2) = t2" "expr_sem_rf σ e2  type_universe t2"
      by (auto simp add: return_val_def)

  have "expr_sem σ (<e1, e2>) = expr_sem σ e1 
            (λv. expr_sem σ e2  (λw. return_val (<|v,w|>)))" by simp
  also have "expr_sem σ e1 = return (stock_measure t1) (expr_sem_rf σ e1)"
    using e1 by (simp add: et_pair.prems return_val_def A)
  also have "...  (λv. expr_sem σ e2  (λw. return_val (<|v,w|>))) =
          ...  (λv. return_val (<|v, expr_sem_rf σ e2|>))"
  proof (intro bind_cong refl)
    fix v assume "v  space (return (stock_measure t1) (expr_sem_rf σ e1))"
    hence v: "val_type v = t1" "v  type_universe t1" by (simp_all add:)
    have "expr_sem σ e2  (λw. return_val (<|v,w|>)) =
              return (stock_measure t2) (expr_sem_rf σ e2)  (λw. return_val (<|v,w|>))"
      using e2 by (simp add: et_pair.prems return_val_def B)
    also have "... = return (stock_measure t2) (expr_sem_rf σ e2) 
                         (λw. return (stock_measure (PRODUCT t1 t2)) (<|v,w|>))"
    proof (intro bind_cong refl)
      fix w assume "w  space (return (stock_measure t2) (expr_sem_rf σ e2))"
      hence w: "val_type w = t2" by (simp add:)
      thus "return_val (<|v,w|>) = return (stock_measure (PRODUCT t1 t2)) (<|v,w|>)"
        by (auto simp: return_val_def v w)
    qed
    also have "... = return_val (<|v, expr_sem_rf σ e2|>)"
      using v B
      by (subst bind_return[where N="PRODUCT t1 t2"]) (auto simp: return_val_def)
    finally show "expr_sem σ e2  (λw. return_val (<|v,w|>)) = return_val (<|v, expr_sem_rf σ e2|>)" .
  qed
  also have "(λv. <|v, expr_sem_rf σ e2|>)  measurable (stock_measure t1) (stock_measure (PRODUCT t1 t2))"
    using B by (auto intro!: injI)
  hence "return (stock_measure t1) (expr_sem_rf σ e1)  (λv. return_val (<|v, expr_sem_rf σ e2|>)) =
             return_val (<|expr_sem_rf σ e1, expr_sem_rf σ e2|>)"
    by (subst bind_return, rule measurable_compose[OF _ measurable_return_val])
       (auto simp: A)
  finally show "return_val (expr_sem_rf σ (<e1,e2>)) = expr_sem σ (<e1, e2>)" by simp
next
  case (et_if Γ b e1 t e2 V σ)
  let ?P = "λe. expr_sem σ e = return_val (expr_sem_rf σ e)"
  from et_if.prems have A: "?P b" "?P e1" "?P e2" by ((intro et_if.IH[symmetric], simp_all) [])+
  from et_if.prems and et_if.hyps have "space (expr_sem σ b) = type_universe BOOL"
    by (intro space_expr_sem) (auto simp: state_measure_var_type)
  hence [simp]: "val_type (expr_sem_rf σ b) = BOOL" by (simp add: A return_val_def)
  have B: "return_val (expr_sem_rf σ e1)  space (subprob_algebra (stock_measure t))"
          "return_val (expr_sem_rf σ e2)  space (subprob_algebra (stock_measure t))"
    using et_if.hyps and et_if.prems
    by ((subst A[symmetric], intro measurable_space[OF measurable_expr_sem], auto) [])+
  thus ?case
    by (auto simp: A bind_return_val''[where M=t])
next
  case (et_op Γ e t oper t' V)
  let ?M = "PiM V (λx. stock_measure (Γ x))"
  from et_op.prems have e: "return_val (expr_sem_rf σ e) = expr_sem σ e"
    by (intro et_op.IH[of V]) auto

  with et_op.prems have "space (return_val (expr_sem_rf σ e)) = type_universe t"
    by (subst e, subst space_expr_sem[OF et_op.hyps(1)])
       (auto dest: state_measure_var_type)
  hence A: "val_type (expr_sem_rf σ e) = t" "expr_sem_rf σ e  type_universe t"
    by (auto simp add: return_val_def)
  from et_op.prems e
    have "expr_sem σ (Operator oper e) =
                 return_val (expr_sem_rf σ e)  (λv. return_val (op_sem oper v))" by simp
  also have "... = return_val (op_sem oper (expr_sem_rf σ e))"
    by (subst return_val_def, rule bind_return,
        rule measurable_compose[OF measurable_op_sem measurable_return_val])
       (auto simp: A et_op.hyps)
  finally show "return_val (expr_sem_rf σ (Operator oper e)) = expr_sem σ (Operator oper e)" by simp
next
  case (et_let Γ e1 t1 e2 t2 V)
  let ?M = "state_measure V Γ" and ?N = "state_measure (shift_var_set V) (case_nat t1 Γ)"
  let ?σ' = "case_nat (expr_sem_rf σ e1) σ"
  from et_let.prems have e1: "return_val (expr_sem_rf σ e1) = expr_sem σ e1"
    by (auto intro!: et_let.IH(1)[of V])
  from et_let.prems have S: "space (return_val (expr_sem_rf σ e1)) = type_universe t1"
    by (subst e1, subst space_expr_sem[OF et_let.hyps(1)])
       (auto dest: state_measure_var_type)
  hence A: "val_type (expr_sem_rf σ e1) = t1" "expr_sem_rf σ e1  type_universe t1"
    by (auto simp add: return_val_def)
  with et_let.prems have e2: "σ. σ  space ?N  return_val (expr_sem_rf σ e2) = expr_sem σ e2"
    using subset_shift_var_set
    by (intro et_let.IH(2)[of "shift_var_set V"]) (auto simp del: fun_upd_apply)

  from et_let.prems have "expr_sem σ (LetVar e1 e2) =
                              return_val (expr_sem_rf σ e1)  (λv. expr_sem (case_nat v σ) e2)"
    by (simp add: e1)
  also from et_let.prems
    have "... = return_val (expr_sem_rf σ e1)  (λv. return_val (expr_sem_rf (case_nat v σ) e2))"
    by (intro bind_cong refl, subst e2) (auto simp: S)
  also from et_let have Me2[measurable]: "(λσ. expr_sem_rf σ e2)  measurable ?N (stock_measure t2)"
    using subset_shift_var_set by (intro measurable_expr_sem_rf) auto
  have "(λ(σ,y). case_nat y σ)  (λy. (σ, y))  measurable (stock_measure t1) ?N"
    using σ  space ?M by simp
  have  "return_val (expr_sem_rf σ e1)  (λv. return_val (expr_sem_rf (case_nat v σ) e2)) =
              return_val (expr_sem_rf ?σ' e2)" using σ  space ?M
  by (subst return_val_def, intro bind_return, subst A)
     (rule measurable_compose[OF _ measurable_return_val[of t2]], simp_all)
  finally show ?case by simp
qed simp_all

lemma val_type_expr_sem_rf:
  assumes "Γ  e : t" "randomfree e" "free_vars e  V" "σ  space (state_measure V Γ)"
  shows "val_type (expr_sem_rf σ e) = t"
proof-
  have "type_universe (val_type (expr_sem_rf σ e)) = space (return_val (expr_sem_rf σ e))"
    by (simp add: return_val_def)
  also from assms have "return_val (expr_sem_rf σ e) = expr_sem σ e"
    by (intro expr_sem_rf_sound) auto
  also from assms have "space ... = type_universe t"
    by (intro space_expr_sem[of Γ])
       (auto simp: state_measure_def space_PiM  dest: PiE_mem)
  finally show ?thesis by simp
qed

lemma expr_sem_rf_eq_on_vars:
  "(x. xfree_vars e  σ1 x = σ2 x)  expr_sem_rf σ1 e = expr_sem_rf σ2 e"
proof (induction e arbitrary: σ1 σ2)
  case (Operator oper e σ1 σ2)
  hence "expr_sem_rf σ1 e = expr_sem_rf σ2 e" by (intro Operator.IH) auto
  thus ?case by simp
next
  case (LetVar e1 e2 σ1 σ2)
  hence A: "expr_sem_rf σ1 e1 = expr_sem_rf σ2 e1" by (intro LetVar.IH) auto
  {
    fix y assume "y  free_vars e2"
    hence "case_nat (expr_sem_rf σ1 e1) σ1 y = case_nat (expr_sem_rf σ2 e1) σ2 y"
      using LetVar(3) by (auto simp add: A split: nat.split)
  }
  hence "expr_sem_rf (case_nat (expr_sem_rf σ1 e1) σ1) e2 =
           expr_sem_rf (case_nat (expr_sem_rf σ2 e1) σ2) e2" by (intro LetVar.IH) simp
  thus ?case by simp
next
  case (Pair e1 e2 σ1 σ2)
  have "expr_sem_rf σ1 e1 = expr_sem_rf σ2 e1" "expr_sem_rf σ1 e2 = expr_sem_rf σ2 e2"
    by (intro Pair.IH, simp add: Pair)+
  thus ?case by simp
next
  case (IfThenElse b e1 e2 σ1 σ2)
  have "expr_sem_rf σ1 b = expr_sem_rf σ2 b" "expr_sem_rf σ1 e1 = expr_sem_rf σ2 e1"
       "expr_sem_rf σ1 e2 = expr_sem_rf σ2 e2" by (intro IfThenElse.IH, simp add: IfThenElse)+
  thus ?case by simp
next
  case (Random dst e σ1 σ2)
  have "expr_sem_rf σ1 e = expr_sem_rf σ2 e" by (intro Random.IH) (simp add: Random)
  thus ?case by simp
qed auto


(*
subsection {* Substitution of free variables *}

primrec expr_subst :: "expr ⇒ expr ⇒ vname ⇒ expr" ("_⟨_'/_⟩" [1000,0,0] 999) where
  "(Val v)⟨_/_⟩ = Val v"
| "(Var y)⟨f/x⟩ = (if y = x then f else Var y)"
| "<e1,e2>⟨f/x⟩ = <e1⟨f/x⟩, e2⟨f/x⟩>"
| "(<oper> e)⟨f/x⟩ = <oper> (e⟨f/x⟩)"
| "(LET e1 IN e2)⟨f/x⟩ = LET y = e1⟨f/x⟩ IN if y = x then e2 else e2⟨f/x⟩"
| "(IF b THEN e1 ELSE e2)⟨f/x⟩ = IF b⟨f/x⟩ THEN e1⟨f/x⟩ ELSE e2⟨f/x⟩"
| "(Random dst e)⟨f/x⟩ = Random dst (e⟨f/x⟩)"
| "(Fail t)⟨f/x⟩ = Fail t"

primrec bound_vars :: "expr ⇒ vname set" where
  "bound_vars (Val _) = {}"
| "bound_vars (Var _) = {}"
| "bound_vars <e1,e2> = bound_vars e1 ∪ bound_vars e2"
| "bound_vars (<_> e) = bound_vars e"
| "bound_vars (LET x = e1 IN e2) = {x} ∪ bound_vars e1 ∪ bound_vars e2"
| "bound_vars (IF b THEN e1 ELSE e2) = bound_vars b ∪ bound_vars e1 ∪ bound_vars e2"
| "bound_vars (Random _ e) = bound_vars e"
| "bound_vars (Fail _) = {}"

lemma expr_typing_eq_on_free_vars:
  "Γ1 ⊢ e : t ⟹ (⋀x. x ∈ free_vars e ⟹ Γ1 x = Γ2 x) ⟹ Γ2 ⊢ e : t"
proof (induction arbitrary: Γ2 rule: expr_typing.induct)
  case et_let
  thus ?case by (intro expr_typing.intros) auto
qed (auto intro!: expr_typing.intros simp del: fun_upd_apply)

lemma expr_typing_subst:
  assumes "Γ ⊢ e : t1" "Γ ⊢ f : t'" "Γ x = t'" "free_vars f ∩ bound_vars e = {}"
  shows "Γ ⊢ e⟨f/x⟩ : t1"
using assms
proof (induction rule: expr_typing.induct)
  case (et_let Γ e1 t1 y e2 t2)
  from et_let.prems have A: "Γ ⊢ e1⟨f/x⟩ : t1" by (intro et_let.IH) auto
  show ?case
  proof (cases "y = x")
    assume "y ≠ x"
    from et_let.prems have "Γ(y := t1) ⊢ f : t'"
      by (intro expr_typing_eq_on_free_vars[OF `Γ ⊢ f : t'`]) auto
    moreover from `y ≠ x` have "(Γ(y := t1)) x = Γ x" by simp
    ultimately have "Γ(y := t1) ⊢ e2⟨f/x⟩ : t2" using et_let.prems
      by (intro et_let.IH) (auto simp del: fun_upd_apply)
    with A and `y ≠ x` show ?thesis by (auto intro: expr_typing.intros)
  qed (insert et_let, auto intro!: expr_typing.intros simp del: fun_upd_apply)
qed (insert assms(2), auto intro: expr_typing.intros)

lemma expr_subst_randomfree:
  assumes "Γ ⊢ f : t" "randomfree f" "free_vars f ⊆ V" "free_vars f ∩ bound_vars e = {}"
          "σ ∈ space (state_measure V Γ)"
  shows   "expr_sem σ (e⟨f/x⟩) = expr_sem (σ(x := expr_sem_rf f σ)) e"
using assms(1,3,4,5)
proof (induction e arbitrary: σ V Γ)
  case (Pair e1 e2 σ V Γ)
    from Pair.prems have "expr_sem σ (e1⟨f/x⟩) = expr_sem (σ(x := expr_sem_rf f σ)) e1"
                     and "expr_sem σ (e2⟨f/x⟩) = expr_sem (σ(x := expr_sem_rf f σ)) e2"
      by (auto intro!: Pair.IH[of Γ V σ])
    thus ?case by (simp del: fun_upd_apply)
next
  case (IfThenElse b e1 e2 σ V Γ)
    from IfThenElse.prems
      have "expr_sem σ (b⟨f/x⟩) = expr_sem (σ(x := expr_sem_rf f σ)) b"
           "expr_sem σ (e1⟨f/x⟩) = expr_sem (σ(x := expr_sem_rf f σ)) e1"
           "expr_sem σ (e2⟨f/x⟩) = expr_sem (σ(x := expr_sem_rf f σ)) e2"
      by (auto intro!: IfThenElse.IH[of Γ V σ])
    thus ?case by (simp only: expr_sem.simps expr_subst.simps)
next
  case (LetVar y e1 e2)
  from LetVar.prems have A: "expr_sem σ (e1⟨f/x⟩) = expr_sem (σ(x := expr_sem_rf f σ)) e1"
    by (intro LetVar.IH) auto
  show ?case
  proof (cases "y = x")
    assume "y = x"
    with LetVar.prems show ?case by (auto simp add: A simp del: fun_upd_apply)
  next
    assume "y ≠ x"
    {
      fix v assume "v ∈ space (expr_sem (σ(x := expr_sem_rf f σ)) e1)"
      let ?σ' = "σ(y := v)" and ?Γ' = "Γ(y := val_type v)"
      from LetVar.prems have "Γ(y := val_type v) ⊢ f : t" by (auto intro: expr_typing_eq_on_free_vars)
      moreover from LetVar.prems have "?σ' ∈ space (state_measure (insert y V) ?Γ')"
        by (auto simp: state_measure_def space_PiM split: if_split_asm)
      ultimately have "expr_sem ?σ' (e2⟨f/x⟩) = expr_sem (?σ'(x := expr_sem_rf f ?σ')) e2"
        using LetVar.prems and `y ≠ x`
        by (intro LetVar.IH(2)[of "Γ(y := val_type v)" "insert y V"]) (auto simp del: fun_upd_apply)
      also from LetVar.prems have "expr_sem_rf f ?σ' = expr_sem_rf f σ"
        by (intro expr_sem_rf_eq_on_vars) auto
      finally have "expr_sem (σ(y := v)) (e2⟨f/x⟩) = expr_sem (σ(x := expr_sem_rf f σ, y := v)) e2"
        using `y ≠ x` by (subst fun_upd_twist) (simp_all del: fun_upd_apply)
    }
    with A and `y ≠ x` show ?thesis by (auto simp del: fun_upd_apply intro!: bind_cong)
  qed
qed (simp_all add: expr_sem_rf_sound assms)

lemma stock_measure_context_upd:
  "(λy. stock_measure ((Γ(x := t)) y)) = (λy. stock_measure (Γ y))(x := stock_measure t)"
  by (intro ext) simp

lemma Let_det_eq_subst:
  assumes "Γ ⊢ LET x = f IN e : t" "randomfree f" "free_vars (LET x = f IN e) ⊆ V"
          "free_vars f ∩ bound_vars e = {}" "σ ∈ space (state_measure V Γ)"
  shows   "expr_sem σ (LET x = f IN e) = expr_sem σ (e⟨f/x⟩)"
proof-
  from assms(1) obtain t' where t1: "Γ ⊢ f : t'" and t2: "Γ(x := t') ⊢ e : t" by auto
  with assms have "expr_sem σ (LET x = f IN e) =
                       return_val (expr_sem_rf f σ) ⤜ (λv. expr_sem (σ(x := v)) e)" (is "_ = ?M")
    by (auto simp: expr_sem_rf_sound)
  also have "(λσ. expr_sem σ e) ∘ (λ(σ,v). σ(x := v)) ∘ (λv. (σ,v)) ∈
                 measurable (stock_measure ((Γ(x := t')) x)) (subprob_algebra (stock_measure t))"
    apply (intro measurable_comp, rule measurable_Pair1', rule assms)
    apply (subst fun_upd_same, unfold state_measure_def)
    apply (rule measurable_add_dim', subst stock_measure_context_upd[symmetric])
    apply (insert assms, auto intro!: measurable_expr_sem[unfolded state_measure_def] t1 t2
                              simp del: fun_upd_apply)
    done
  hence "(λv. expr_sem (σ(x := v)) e) ∈
                 measurable (stock_measure ((Γ(x := t')) x)) (subprob_algebra (stock_measure t))"
    by (simp add: o_def)
  with assms have "?M = expr_sem (σ(x := expr_sem_rf f σ)) e"
    unfolding return_val_def
    by (intro bind_return) (auto simp: val_type_expr_sem_rf[OF t1]
                                       type_universe_def simp del: type_universe_type)
  also from assms t1 t2 have "... = expr_sem σ (e⟨f/x⟩)"
    by (intro expr_subst_randomfree[symmetric]) auto
  finally show ?thesis .
qed *)

end