Theory Pi_pmf

(*
  File:    Pi_pmf.thy
  Authors: Manuel Eberl, Max W. Haslbeck
*)
section ‹Indexed products of PMFs›
theory Pi_pmf
  imports "HOL-Probability.Probability"
begin

text ‹Conflicting notation from theoryHOL-Analysis.Infinite_Sum
no_notation Infinite_Sum.abs_summable_on (infixr abs'_summable'_on 46)

subsection ‹Definition›

text ‹
  In analogy to @{const PiM}, we define an indexed product of PMFs. In the literature, this
  is typically called taking a vector of independent random variables. Note that the components
  do not have to be identically distributed.

  The operation takes an explicit index set termA :: 'a set and a function termf :: 'a  'b pmf
  that maps each element from termA to a PMF and defines the product measure
  $\bigotimes_{i\in A} f(i)$ , which is represented as a typ('a  'b) pmf.

  Note that unlike @{const PiM}, this only works for ‹finite› index sets. It could
  be extended to countable sets and beyond, but the construction becomes somewhat more involved.
›
definition Pi_pmf :: "'a set  'b  ('a  'b pmf)  ('a  'b) pmf" where
  "Pi_pmf A dflt p =
     embed_pmf (λf. if (x. x  A  f x = dflt) then xA. pmf (p x) (f x) else 0)"

text ‹
  A technical subtlety that needs to be addressed is this: Intuitively, the functions in the
  support of a product distribution have domain A›. However, since HOL is a total logic, these
  functions must still return ‹some› value for inputs outside A›. The product measure
  @{const PiM} simply lets these functions return @{const undefined} in these cases. We chose a
  different solution here, which is to supply a default value termdflt :: 'b that is returned
  in these cases.

  As one possible application, one could model the result of n› different independent coin
  tosses as @{term "Pi_pmf {0..<n} False (λ_. bernoulli_pmf (1 / 2))"}. This returns a function
  of type typnat  bool that maps every natural number below n› to the result of the
  corresponding coin toss, and every other natural number to termFalse.
›

lemma pmf_Pi:
  assumes A: "finite A"
  shows   "pmf (Pi_pmf A dflt p) f =
             (if (x. x  A  f x = dflt) then xA. pmf (p x) (f x) else 0)"
  unfolding Pi_pmf_def
proof (rule pmf_embed_pmf, goal_cases)
  case 2
  define S where "S = {f. x. x  A  f x = dflt}"
  define B where "B = (λx. set_pmf (p x))"

  have neutral_left: "(xaA. pmf (p xa) (f xa)) = 0"
    if "f  PiE A B - (λf. restrict f A) ` S" for f
  proof -
    have "restrict (λx. if x  A then f x else dflt) A  (λf. restrict f A) ` S"
      by (intro imageI) (auto simp: S_def)
    also have "restrict (λx. if x  A then f x else dflt) A = f"
      using that by (auto simp: PiE_def Pi_def extensional_def fun_eq_iff)
    finally show ?thesis using that by blast
  qed
  have neutral_right: "(xaA. pmf (p xa) (f xa)) = 0"
    if "f  (λf. restrict f A) ` S - PiE A B" for f
  proof -
    from that obtain f' where f': "f = restrict f' A" "f'  S" by auto
    moreover from this and that have "restrict f' A  PiE A B" by simp
    then obtain x where "x  A" "pmf (p x) (f' x) = 0" by (auto simp: B_def set_pmf_eq)
    with f' and A show ?thesis by auto
  qed

  have "(λf. xA. pmf (p x) (f x)) abs_summable_on PiE A B"
    by (intro abs_summable_on_prod_PiE A) (auto simp: B_def)
  also have "?this  (λf. xA. pmf (p x) (f x)) abs_summable_on (λf. restrict f A) ` S"
    by (intro abs_summable_on_cong_neutral neutral_left neutral_right) auto
  also have "  (λf. xA. pmf (p x) (restrict f A x)) abs_summable_on S"
    by (rule abs_summable_on_reindex_iff [symmetric]) (force simp: inj_on_def fun_eq_iff S_def)
  also have "  (λf. if x. x  A  f x = dflt then xA. pmf (p x) (f x) else 0)
                          abs_summable_on UNIV"
    by (intro abs_summable_on_cong_neutral) (auto simp: S_def)
  finally have summable:  .

  have "1 = (xA. 1::real)" by simp
  also have "(xA. 1) = (xA. ayB x. pmf (p x) y)"
    unfolding B_def by (subst infsetsum_pmf_eq_1) auto
  also have "(xA. ayB x. pmf (p x) y) = (afPiE A B. xA. pmf (p x) (f x))"
    by (intro infsetsum_prod_PiE [symmetric] A) (auto simp: B_def)
  also have " = (af(λf. restrict f A) ` S. xA. pmf (p x) (f x))" using A
    by (intro infsetsum_cong_neutral neutral_left neutral_right refl)
  also have " = (afS. xA. pmf (p x) (restrict f A x))"
    by (rule infsetsum_reindex) (force simp: inj_on_def fun_eq_iff S_def)
  also have " = (afS. xA. pmf (p x) (f x))"
    by (intro infsetsum_cong) (auto simp: S_def)
  also have " = (af. if x. x  A  f x = dflt then xA. pmf (p x) (f x) else 0)"
    by (intro infsetsum_cong_neutral) (auto simp: S_def)
  also have "ennreal  = (+f. ennreal (if x. x  A  f x = dflt
                             then xA. pmf (p x) (f x) else 0) count_space UNIV)"
    by (intro nn_integral_conv_infsetsum [symmetric] summable) (auto simp: prod_nonneg)
  finally show ?case by simp
qed (auto simp: prod_nonneg)

lemma pmf_Pi':
  assumes "finite A" "x. x  A  f x = dflt"
  shows   "pmf (Pi_pmf A dflt p) f = (xA. pmf (p x) (f x))"
  using assms by (subst pmf_Pi) auto

lemma pmf_Pi_outside:
  assumes "finite A" "x. x  A  f x  dflt"
  shows   "pmf (Pi_pmf A dflt p) f = 0"
  using assms by (subst pmf_Pi) auto

lemma pmf_Pi_empty [simp]: "Pi_pmf {} dflt p = return_pmf (λ_. dflt)"
  by (intro pmf_eqI, subst pmf_Pi) (auto simp: indicator_def)

lemma set_Pi_pmf_subset: "finite A  set_pmf (Pi_pmf A dflt p)  {f. x. x  A  f x = dflt}"
  by (auto simp: set_pmf_eq pmf_Pi)

lemma Pi_pmf_cong [cong]:
  assumes "A = A'" "dflt = dflt'" "x. x  A  f x = f' x"
  shows   "Pi_pmf A dflt f = Pi_pmf A' dflt' f'"
proof -
  have "(λg. xA. pmf (f x) (g x)) = (λg. xA. pmf (f' x) (g x))"
    by (intro ext prod.cong) (auto simp: assms)
  with assms show ?thesis by (simp add: Pi_pmf_def cong: if_cong)
qed


subsection ‹Dependent product sets with a default›

text ‹
  The following describes a dependent product of sets where the functions are required to return
  the default value termdflt outside their domain, in analogy to @{const PiE}, which uses
  @{const undefined}.
›
definition PiE_dflt
  where "PiE_dflt A dflt B = {f. x. (x  A  f x  B x)  (x  A  f x = dflt)}"

lemma restrict_PiE_dflt: "(λh. restrict h A) ` PiE_dflt A dflt B = PiE A B"
proof (intro equalityI subsetI)
  fix h assume "h  (λh. restrict h A) ` PiE_dflt A dflt B"
  thus "h  PiE A B"
    by (auto simp: PiE_dflt_def)
next
  fix h assume h: "h  PiE A B"
  hence "restrict (λx. if x  A then h x else dflt) A  (λh. restrict h A) ` PiE_dflt A dflt B"
    by (intro imageI) (auto simp: PiE_def extensional_def PiE_dflt_def)
  also have "restrict (λx. if x  A then h x else dflt) A = h"
    using h by (auto simp: fun_eq_iff)
  finally show "h  (λh. restrict h A) ` PiE_dflt A dflt B" .
qed

lemma dflt_image_PiE: "(λh x. if x  A then h x else dflt) ` PiE A B = PiE_dflt A dflt B"
  (is "?f ` ?X = ?Y")
proof (intro equalityI subsetI)
  fix h assume "h  ?f ` ?X"
  thus "h  ?Y"
    by (auto simp: PiE_dflt_def PiE_def)
next
  fix h assume h: "h  ?Y"
  hence "?f (restrict h A)  ?f ` ?X"
    by (intro imageI) (auto simp: PiE_def extensional_def PiE_dflt_def)
  also have "?f (restrict h A) = h"
    using h by (auto simp: fun_eq_iff PiE_dflt_def)
  finally show "h  ?f ` ?X" .
qed

lemma finite_PiE_dflt [intro]:
  assumes "finite A" "x. x  A  finite (B x)"
  shows   "finite (PiE_dflt A d B)"
proof -
  have "PiE_dflt A d B = (λf x. if x  A then f x else d) ` PiE A B"
    by (rule dflt_image_PiE [symmetric])
  also have "finite "
    by (intro finite_imageI finite_PiE assms)
  finally show ?thesis .
qed

lemma card_PiE_dflt:
  assumes "finite A" "x. x  A  finite (B x)"
  shows   "card (PiE_dflt A d B) = (xA. card (B x))"
proof -
  from assms have "(xA. card (B x)) = card (PiE A B)"
    by (intro card_PiE [symmetric]) auto
  also have "PiE A B = (λf. restrict f A) ` PiE_dflt A d B"
    by (rule restrict_PiE_dflt [symmetric])
  also have "card  = card (PiE_dflt A d B)"
    by (intro card_image) (force simp: inj_on_def restrict_def fun_eq_iff PiE_dflt_def)
  finally show ?thesis ..
qed

lemma PiE_dflt_empty_iff [simp]: "PiE_dflt A dflt B = {}  (xA. B x = {})"
  by (simp add: dflt_image_PiE [symmetric] PiE_eq_empty_iff)

text ‹
  The probability of an independent combination of events is precisely the product
  of the probabilities of each individual event.
›
lemma measure_Pi_pmf_PiE_dflt:
  assumes [simp]: "finite A"
  shows   "measure_pmf.prob (Pi_pmf A dflt p) (PiE_dflt A dflt B) =
             (xA. measure_pmf.prob (p x) (B x))"
proof -
  define B' where "B' = (λx. B x  set_pmf (p x))"
  have "measure_pmf.prob (Pi_pmf A dflt p) (PiE_dflt A dflt B) =
          (ahPiE_dflt A dflt B. pmf (Pi_pmf A dflt p) h)"
    by (rule measure_pmf_conv_infsetsum)
  also have " = (ahPiE_dflt A dflt B. xA. pmf (p x) (h x))"
    by (intro infsetsum_cong, subst pmf_Pi') (auto simp: PiE_dflt_def)
  also have " = (ah(λh. restrict h A) ` PiE_dflt A dflt B. xA. pmf (p x) (h x))"
    by (subst infsetsum_reindex) (force simp: inj_on_def PiE_dflt_def fun_eq_iff)+
  also have "(λh. restrict h A) ` PiE_dflt A dflt B = PiE A B"
    by (rule restrict_PiE_dflt)
  also have "(ahPiE A B. xA. pmf (p x) (h x)) = (ahPiE A B'. xA. pmf (p x) (h x))"
    by (intro infsetsum_cong_neutral) (auto simp: B'_def set_pmf_eq)
  also have "(ahPiE A B'. xA. pmf (p x) (h x)) = (xA. infsetsum (pmf (p x)) (B' x))"
    by (intro infsetsum_prod_PiE) (auto simp: B'_def)
  also have " = (xA. infsetsum (pmf (p x)) (B x))"
    by (intro prod.cong infsetsum_cong_neutral) (auto simp: B'_def set_pmf_eq)
  also have " = (xA. measure_pmf.prob (p x) (B x))"
    by (subst measure_pmf_conv_infsetsum) (rule refl)
  finally show ?thesis .
qed

lemma set_Pi_pmf_subset':
  assumes "finite A"
  shows   "set_pmf (Pi_pmf A dflt p)  PiE_dflt A dflt (set_pmf  p)"
  using assms by (auto simp: set_pmf_eq pmf_Pi PiE_dflt_def)

lemma Pi_pmf_return_pmf [simp]:
  assumes "finite A"
  shows   "Pi_pmf A dflt (λx. return_pmf (f x)) = return_pmf (λx. if x  A then f x else dflt)"
proof -
  have "set_pmf (Pi_pmf A dflt (λx. return_pmf (f x))) 
          PiE_dflt A dflt (set_pmf  (λx. return_pmf (f x)))"
    by (intro set_Pi_pmf_subset' assms)
  also have "  {λx. if x  A then f x else dflt}"
    by (auto simp: PiE_dflt_def)
  finally show ?thesis
    by (simp add: set_pmf_subset_singleton)
qed

lemma Pi_pmf_return_pmf' [simp]:
  assumes "finite A"
  shows   "Pi_pmf A dflt (λ_. return_pmf dflt) = return_pmf (λ_. dflt)"
  using assms by simp

lemma measure_Pi_pmf_Pi:
  fixes t::nat
  assumes [simp]: "finite A"
  shows   "measure_pmf.prob (Pi_pmf A dflt p) (Pi A B) =
             (xA. measure_pmf.prob (p x) (B x))" (is "?lhs = ?rhs")
proof -
  have "?lhs = measure_pmf.prob (Pi_pmf A dflt p) (PiE_dflt A dflt B)"
    by (intro measure_prob_cong_0)
       (auto simp: PiE_dflt_def PiE_def intro!: pmf_Pi_outside)+
  also have " = ?rhs"
    using assms by (simp add: measure_Pi_pmf_PiE_dflt)
  finally show ?thesis
    by simp
qed


subsection ‹Common PMF operations on products›

text @{const Pi_pmf} distributes over the `bind' operation in the Giry monad:
›
lemma Pi_pmf_bind:
  assumes "finite A"
  shows   "Pi_pmf A d (λx. bind_pmf (p x) (q x)) =
             do {f  Pi_pmf A d' p; Pi_pmf A d (λx. q x (f x))}" (is "?lhs = ?rhs")
proof (rule pmf_eqI, goal_cases)
  case (1 f)
  show ?case
  proof (cases "x-A. f x  d")
    case False
    define B where "B = (λx. set_pmf (p x))"
    have [simp]: "countable (B x)" for x by (auto simp: B_def)

    {
      fix x :: 'a
      have "(λa. pmf (p x) a * 1) abs_summable_on B x"
        by (simp add: pmf_abs_summable)
      moreover have "norm (pmf (p x) a * 1)  norm (pmf (p x) a * pmf (q x a) (f x))" for a
        unfolding norm_mult by (intro mult_left_mono) (auto simp: pmf_le_1)
      ultimately have "(λa. pmf (p x) a * pmf (q x a) (f x)) abs_summable_on B x"
        by (rule abs_summable_on_comparison_test)
    } note summable = this

    have "pmf ?rhs f = (ag. pmf (Pi_pmf A d' p) g * (xA. pmf (q x (g x)) (f x)))"
      by (subst pmf_bind, subst pmf_Pi')
         (insert assms False, simp_all add: pmf_expectation_eq_infsetsum)
    also have " = (agPiE_dflt A d' B.
                      pmf (Pi_pmf A d' p) g * (xA. pmf (q x (g x)) (f x)))" unfolding B_def
      using assms by (intro infsetsum_cong_neutral) (auto simp: pmf_Pi PiE_dflt_def set_pmf_eq)
    also have " = (agPiE_dflt A d' B.
                      (xA. pmf (p x) (g x) * pmf (q x (g x)) (f x)))"
      using assms by (intro infsetsum_cong) (auto simp: pmf_Pi PiE_dflt_def prod.distrib)
    also have " = (ag(λg. restrict g A) ` PiE_dflt A d' B.
                      (xA. pmf (p x) (g x) * pmf (q x (g x)) (f x)))"
      by (subst infsetsum_reindex) (force simp: PiE_dflt_def inj_on_def fun_eq_iff)+
    also have "(λg. restrict g A) ` PiE_dflt A d' B = PiE A B"
      by (rule restrict_PiE_dflt)
    also have "(ag. (xA. pmf (p x) (g x) * pmf (q x (g x)) (f x))) =
                 (xA. aaB x. pmf (p x) a * pmf (q x a) (f x))"
      using assms summable by (subst infsetsum_prod_PiE) simp_all
    also have " = (xA. aa. pmf (p x) a * pmf (q x a) (f x))"
      by (intro prod.cong infsetsum_cong_neutral) (auto simp: B_def set_pmf_eq)
    also have " = pmf ?lhs f"
      using False assms by (subst pmf_Pi') (simp_all add: pmf_bind pmf_expectation_eq_infsetsum)
    finally show ?thesis ..
  next
    case True
    have "pmf ?rhs f =
            measure_pmf.expectation (Pi_pmf A d' p) (λx. pmf (Pi_pmf A d (λxa. q xa (x xa))) f)"
      using assms by (simp add: pmf_bind)
    also have " = measure_pmf.expectation (Pi_pmf A d' p) (λx. 0)"
      using assms True by (intro Bochner_Integration.integral_cong pmf_Pi_outside) auto
    also have " = pmf ?lhs f"
      using assms True by (subst pmf_Pi_outside) auto
    finally show ?thesis ..
  qed
qed

text ‹
  Analogously any componentwise mapping can be pulled outside the product:
›
lemma Pi_pmf_map:
  assumes [simp]: "finite A" and "f dflt = dflt'"
  shows   "Pi_pmf A dflt' (λx. map_pmf f (g x)) = map_pmf (λh. f  h) (Pi_pmf A dflt g)"
proof -
  have "Pi_pmf A dflt' (λx. map_pmf f (g x)) =
          Pi_pmf A dflt' (λx. g x  (λx. return_pmf (f x)))"
    using assms by (simp add: map_pmf_def Pi_pmf_bind)
  also have " = Pi_pmf A dflt g  (λh. return_pmf (λx. if x  A then f (h x) else dflt'))"
   by (subst Pi_pmf_bind[where d' = dflt]) auto
  also have " = map_pmf (λh. f  h) (Pi_pmf A dflt g)"
    unfolding map_pmf_def using set_Pi_pmf_subset'[of A dflt g]
    by (intro bind_pmf_cong refl arg_cong[of _ _ return_pmf])
       (auto dest: simp: fun_eq_iff PiE_dflt_def assms(2))
  finally show ?thesis .
qed

text ‹
  We can exchange the default value in a product of PMFs like this:
›
lemma Pi_pmf_default_swap:
  assumes "finite A"
  shows   "map_pmf (λf x. if x  A then f x else dflt') (Pi_pmf A dflt p) =
             Pi_pmf A dflt' p" (is "?lhs = ?rhs")
proof (rule pmf_eqI, goal_cases)
  case (1 f)
  let ?B = "(λf x. if x  A then f x else dflt') -` {f}  PiE_dflt A dflt (λ_. UNIV)"
  show ?case
  proof (cases "x-A. f x  dflt'")
    case False
    let ?f' = "λx. if x  A then f x else dflt"
    from False have "pmf ?lhs f = measure_pmf.prob (Pi_pmf A dflt p) ?B"
      using assms unfolding pmf_map
      by (intro measure_prob_cong_0) (auto simp: PiE_dflt_def pmf_Pi_outside)
    also from False have "?B = {?f'}"
      by (auto simp: fun_eq_iff PiE_dflt_def)
    also have "measure_pmf.prob (Pi_pmf A dflt p) {?f'} = pmf (Pi_pmf A dflt p) ?f'"
      by (simp add: measure_pmf_single)
    also have " = pmf ?rhs f"
      using False assms by (subst (1 2) pmf_Pi) auto
    finally show ?thesis .
  next
    case True
    have "pmf ?lhs f = measure_pmf.prob (Pi_pmf A dflt p) ?B"
      using assms unfolding pmf_map
      by (intro measure_prob_cong_0) (auto simp: PiE_dflt_def pmf_Pi_outside)
    also from True have "?B = {}" by auto
    also have "measure_pmf.prob (Pi_pmf A dflt p)  = 0"
      by simp
    also have "0 = pmf ?rhs f"
      using True assms by (intro pmf_Pi_outside [symmetric]) auto
    finally show ?thesis .
  qed
qed

text ‹
  The following rule allows reindexing the product:
›
lemma Pi_pmf_bij_betw:
  assumes "finite A" "bij_betw h A B" "x. x  A  h x  B"
  shows "Pi_pmf A dflt (λ_. f) = map_pmf (λg. g  h) (Pi_pmf B dflt (λ_. f))"
    (is "?lhs = ?rhs")
proof -
  have B: "finite B"
    using assms bij_betw_finite by auto
  have "pmf ?lhs g = pmf ?rhs g" for g
  proof (cases "a. a  A  g a = dflt")
    case True
    define h' where "h' = the_inv_into A h"
    have h': "h' (h x) = x" if "x  A" for x
      unfolding h'_def using that assms by (auto simp add: bij_betw_def the_inv_into_f_f)
    have h: "h (h' x) = x" if "x  B" for x
      unfolding h'_def using that assms f_the_inv_into_f_bij_betw by fastforce
    have "pmf ?rhs g = measure_pmf.prob (Pi_pmf B dflt (λ_. f)) ((λg. g  h) -` {g})"
      unfolding pmf_map by simp
    also have " = measure_pmf.prob (Pi_pmf B dflt (λ_. f))
                                (((λg. g  h) -` {g})  PiE_dflt B dflt (λ_. UNIV))"
      using B by (intro measure_prob_cong_0) (auto simp: PiE_dflt_def pmf_Pi_outside)
    also have " = pmf (Pi_pmf B dflt (λ_. f)) (λx. if x  B then g (h' x) else dflt)"
    proof -
      have "(if h x  B then g (h' (h x)) else dflt) = g x" for x
        using h' assms True by (cases "x  A") (auto simp add: bij_betwE)
      then have "(λg. g  h) -` {g}  PiE_dflt B dflt (λ_. UNIV) =
            {(λx. if x  B then g (h' x) else dflt)}"
        using assms h' h True unfolding PiE_dflt_def by auto
      then show ?thesis
        by (simp add: measure_pmf_single)
    qed
    also have " = pmf (Pi_pmf A dflt (λ_. f)) g"
      using B assms True  h'_def
      by (auto simp add: pmf_Pi intro!: prod.reindex_bij_betw bij_betw_the_inv_into)
    finally show ?thesis
      by simp
  next
    case False
    have "pmf ?rhs g = infsetsum (pmf (Pi_pmf B dflt (λ_. f))) ((λg. g  h) -` {g})"
      using assms by (auto simp add: measure_pmf_conv_infsetsum pmf_map)
    also have " = infsetsum (λ_. 0) ((λg x. g (h x)) -` {g})"
      using B False assms by (intro infsetsum_cong pmf_Pi_outside) fastforce+
    also have " = 0"
      by simp
    finally show ?thesis
      using assms False by (auto simp add: pmf_Pi pmf_map)
  qed
  then show ?thesis
    by (rule pmf_eqI)
qed

text ‹
  A product of uniform random choices is again a uniform distribution.
›
lemma Pi_pmf_of_set:
  assumes "finite A" "x. x  A  finite (B x)" "x. x  A  B x  {}"
  shows   "Pi_pmf A d (λx. pmf_of_set (B x)) = pmf_of_set (PiE_dflt A d B)" (is "?lhs = ?rhs")
proof (rule pmf_eqI, goal_cases)
  case (1 f)
  show ?case
  proof (cases "x. x  A  f x  d")
    case True
    hence "pmf ?lhs f = 0"
      using assms by (intro pmf_Pi_outside) (auto simp: PiE_dflt_def)
    also from True have "f  PiE_dflt A d B"
      by (auto simp: PiE_dflt_def)
    hence "0 = pmf ?rhs f"
      using assms by (subst pmf_of_set) auto
    finally show ?thesis .
  next
    case False
    hence "pmf ?lhs f = (xA. pmf (pmf_of_set (B x)) (f x))"
      using assms by (subst pmf_Pi') auto
    also have " = (xA. indicator (B x) (f x) / real (card (B x)))"
      by (intro prod.cong refl, subst pmf_of_set) (use assms False in auto)
    also have " = (xA. indicator (B x) (f x)) / real (xA. card (B x))"
      by (subst prod_dividef) simp_all
    also have "(xA. indicator (B x) (f x) :: real) = indicator (PiE_dflt A d B) f"
      using assms False by (auto simp: indicator_def PiE_dflt_def)
    also have "(xA. card (B x)) = card (PiE_dflt A d B)"
      using assms by (intro card_PiE_dflt [symmetric]) auto
    also have "indicator (PiE_dflt A d B) f /  = pmf ?rhs f"
      using assms by (intro pmf_of_set [symmetric]) auto
    finally show ?thesis .
  qed
qed


subsection ‹Merging and splitting PMF products›

text ‹
  The following lemma shows that we can add a single PMF to a product:
›
lemma Pi_pmf_insert:
  assumes "finite A" "x  A"
  shows   "Pi_pmf (insert x A) dflt p = map_pmf (λ(y,f). f(x:=y)) (pair_pmf (p x) (Pi_pmf A dflt p))"
proof (intro pmf_eqI)
  fix f
  let ?M = "pair_pmf (p x) (Pi_pmf A dflt p)"
  have "pmf (map_pmf (λ(y, f). f(x := y)) ?M) f =
          measure_pmf.prob ?M ((λ(y, f). f(x := y)) -` {f})"
    by (subst pmf_map) auto
  also have "((λ(y, f). f(x := y)) -` {f}) = (y'. {(f x, f(x := y'))})"
    by (auto simp: fun_upd_def fun_eq_iff)
  also have "measure_pmf.prob ?M  = measure_pmf.prob ?M {(f x, f(x := dflt))}"
    using assms by (intro measure_prob_cong_0) (auto simp: pmf_pair pmf_Pi split: if_splits)
  also have " = pmf (p x) (f x) * pmf (Pi_pmf A dflt p) (f(x := dflt))"
    by (simp add: measure_pmf_single pmf_pair pmf_Pi)
  also have " = pmf (Pi_pmf (insert x A) dflt p) f"
  proof (cases "y. y  insert x A  f y = dflt")
    case True
    with assms have "pmf (p x) (f x) * pmf (Pi_pmf A dflt p) (f(x := dflt)) =
                       pmf (p x) (f x) * (xaA. pmf (p xa) ((f(x := dflt)) xa))"
      by (subst pmf_Pi') auto
    also have "(xaA. pmf (p xa) ((f(x := dflt)) xa)) = (xaA. pmf (p xa) (f xa))"
      using assms by (intro prod.cong) auto
    also have "pmf (p x) (f x) *  = pmf (Pi_pmf (insert x A) dflt p) f"
      using assms True by (subst pmf_Pi') auto
    finally show ?thesis .
  qed (insert assms, auto simp: pmf_Pi)
  finally show " = pmf (map_pmf (λ(y, f). f(x := y)) ?M) f" ..
qed

lemma Pi_pmf_insert':
  assumes "finite A"  "x  A"
  shows   "Pi_pmf (insert x A) dflt p =
             do {y  p x; f  Pi_pmf A dflt p; return_pmf (f(x := y))}"
  using assms
  by (subst Pi_pmf_insert)
     (auto simp add: map_pmf_def pair_pmf_def case_prod_beta' bind_return_pmf bind_assoc_pmf)

lemma Pi_pmf_singleton:
  "Pi_pmf {x} dflt p = map_pmf (λa b. if b = x then a else dflt) (p x)"
proof -
  have "Pi_pmf {x} dflt p = map_pmf (fun_upd (λ_. dflt) x) (p x)"
    by (subst Pi_pmf_insert) (simp_all add: pair_return_pmf2 pmf.map_comp o_def)
  also have "fun_upd (λ_. dflt) x = (λz y. if y = x then z else dflt)"
    by (simp add: fun_upd_def fun_eq_iff)
  finally show ?thesis .
qed

text ‹
  Projecting a product of PMFs onto a component yields the expected result:
›
lemma Pi_pmf_component:
  assumes "finite A"
  shows   "map_pmf (λf. f x) (Pi_pmf A dflt p) = (if x  A then p x else return_pmf dflt)"
proof (cases "x  A")
  case True
  define A' where "A' = A - {x}"
  from assms and True have A': "A = insert x A'"
    by (auto simp: A'_def)
  from assms have "map_pmf (λf. f x) (Pi_pmf A dflt p) = p x" unfolding A'
    by (subst Pi_pmf_insert)
       (auto simp: A'_def pmf.map_comp o_def case_prod_unfold map_fst_pair_pmf)
  with True show ?thesis by simp
next
  case False
  have "map_pmf (λf. f x) (Pi_pmf A dflt p) = map_pmf (λ_. dflt) (Pi_pmf A dflt p)"
    using assms False set_Pi_pmf_subset[of A dflt p]
    by (intro pmf.map_cong refl) (auto simp: set_pmf_eq pmf_Pi_outside)
  with False show ?thesis by simp
qed

text ‹
  We can take merge two PMF products on disjoint sets like this:
›
lemma Pi_pmf_union:
  assumes "finite A" "finite B" "A  B = {}"
  shows   "Pi_pmf (A  B) dflt p =
             map_pmf (λ(f,g) x. if x  A then f x else g x)
             (pair_pmf (Pi_pmf A dflt p) (Pi_pmf B dflt p))" (is "_ = map_pmf (?h A) (?q A)")
  using assms(1,3)
proof (induction rule: finite_induct)
  case (insert x A)
  have "map_pmf (?h (insert x A)) (?q (insert x A)) =
          do {v  p x; (f, g)  pair_pmf (Pi_pmf A dflt p) (Pi_pmf B dflt p);
              return_pmf (λy. if y  insert x A then (f(x := v)) y else g y)}"
    by (subst Pi_pmf_insert)
       (insert insert.hyps insert.prems,
        simp_all add: pair_pmf_def map_bind_pmf bind_map_pmf bind_assoc_pmf bind_return_pmf)
  also have " = do {v  p x; (f, g)  ?q A; return_pmf ((?h A (f,g))(x := v))}"
    by (intro bind_pmf_cong refl) (auto simp: fun_eq_iff)
  also have " = do {v  p x; f  map_pmf (?h A) (?q A); return_pmf (f(x := v))}"
    by (simp add: bind_map_pmf map_bind_pmf case_prod_unfold cong: if_cong)
  also have " = do {v  p x; f  Pi_pmf (A  B) dflt p; return_pmf (f(x := v))}"
    using insert.hyps and insert.prems by (intro bind_pmf_cong insert.IH [symmetric] refl) auto
  also have " = Pi_pmf (insert x (A  B)) dflt p"
    by (subst Pi_pmf_insert)
       (insert assms insert.hyps insert.prems, auto simp: pair_pmf_def map_bind_pmf)
  also have "insert x (A  B) = insert x A  B"
    by simp
  finally show ?case ..
qed (simp_all add: case_prod_unfold map_snd_pair_pmf)

text ‹
  We can also project a product to a subset of the indices by mapping all the other
  indices to the default value:
›
lemma Pi_pmf_subset:
  assumes "finite A" "A'  A"
  shows   "Pi_pmf A' dflt p = map_pmf (λf x. if x  A' then f x else dflt) (Pi_pmf A dflt p)"
proof -
  let ?P = "pair_pmf (Pi_pmf A' dflt p) (Pi_pmf (A - A') dflt p)"
  from assms have [simp]: "finite A'"
    by (blast dest: finite_subset)
  from assms have "A = A'  (A - A')"
    by blast
  also have "Pi_pmf  dflt p = map_pmf (λ(f,g) x. if x  A' then f x else g x) ?P"
    using assms by (intro Pi_pmf_union) auto
  also have "map_pmf (λf x. if x  A' then f x else dflt)  = map_pmf fst ?P"
    unfolding map_pmf_comp o_def case_prod_unfold
    using set_Pi_pmf_subset[of A' dflt p] by (intro map_pmf_cong refl) (auto simp: fun_eq_iff)
  also have " = Pi_pmf A' dflt p"
    by (simp add: map_fst_pair_pmf)
  finally show ?thesis ..
qed

lemma Pi_pmf_subset':
  fixes f :: "'a  'b pmf"
  assumes "finite A" "B  A" "x. x  A - B  f x = return_pmf dflt"
  shows "Pi_pmf A dflt f = Pi_pmf B dflt f"
proof -
  have "Pi_pmf (B  (A - B)) dflt f =
          map_pmf (λ(f, g) x. if x  B then f x else g x)
                  (pair_pmf (Pi_pmf B dflt f) (Pi_pmf (A - B) dflt f))"
    using assms by (intro Pi_pmf_union) (auto dest: finite_subset)
  also have "Pi_pmf (A - B) dflt f = Pi_pmf (A - B) dflt (λ_. return_pmf dflt)"
    using assms by (intro Pi_pmf_cong) auto
  also have " = return_pmf (λ_. dflt)"
    using assms by simp
  also have "map_pmf (λ(f, g) x. if x  B then f x else g x)
                  (pair_pmf (Pi_pmf B dflt f) (return_pmf (λ_. dflt))) =
             map_pmf (λf x. if x  B then f x else dflt) (Pi_pmf B dflt f)"
    by (simp add: map_pmf_def pair_pmf_def bind_assoc_pmf bind_return_pmf bind_return_pmf')
  also have " = Pi_pmf B dflt f"
    using assms by (intro Pi_pmf_default_swap) (auto dest: finite_subset)
  also have "B  (A - B) = A"
    using assms by auto
  finally show ?thesis .
qed

lemma Pi_pmf_if_set:
  assumes "finite A"
  shows "Pi_pmf A dflt (λx. if b x then f x else return_pmf dflt) =
           Pi_pmf {xA. b x} dflt f"
proof -
  have "Pi_pmf A dflt (λx. if b x then f x else return_pmf dflt) =
          Pi_pmf {xA. b x} dflt (λx. if b x then f x else return_pmf dflt)"
    using assms by (intro Pi_pmf_subset') auto
  also have " = Pi_pmf {xA. b x} dflt f"
    by (intro Pi_pmf_cong) auto
  finally show ?thesis .
qed

lemma Pi_pmf_if_set':
  assumes "finite A"
  shows "Pi_pmf A dflt (λx. if b x then return_pmf dflt else f x) =
         Pi_pmf {xA. ¬b x} dflt f"
proof -
  have "Pi_pmf A dflt (λx. if b x then return_pmf dflt else  f x) =
          Pi_pmf {xA. ¬b x} dflt (λx. if b x then return_pmf dflt else  f x)"
    using assms by (intro Pi_pmf_subset') auto
  also have " = Pi_pmf {xA. ¬b x} dflt f"
    by (intro Pi_pmf_cong) auto
  finally show ?thesis .
qed

text ‹
  Lastly, we can delete a single component from a product:
›
lemma Pi_pmf_remove:
  assumes "finite A"
  shows   "Pi_pmf (A - {x}) dflt p = map_pmf (λf. f(x := dflt)) (Pi_pmf A dflt p)"
proof -
  have "Pi_pmf (A - {x}) dflt p =
          map_pmf (λf xa. if xa  A - {x} then f xa else dflt) (Pi_pmf A dflt p)"
    using assms by (intro Pi_pmf_subset) auto
  also have " = map_pmf (λf. f(x := dflt)) (Pi_pmf A dflt p)"
    using set_Pi_pmf_subset[of A dflt p] assms
    by (intro map_pmf_cong refl) (auto simp: fun_eq_iff)
  finally show ?thesis .
qed


subsection ‹Applications›

text ‹
  Choosing a subset of a set uniformly at random is equivalent to tossing a fair coin
  independently for each element and collecting all the elements that came up heads.
›
lemma pmf_of_set_Pow_conv_bernoulli:
  assumes "finite (A :: 'a set)"
  shows "map_pmf (λb. {xA. b x}) (Pi_pmf A P (λ_. bernoulli_pmf (1/2))) = pmf_of_set (Pow A)"
proof -
  have "Pi_pmf A P (λ_. bernoulli_pmf (1/2)) = pmf_of_set (PiE_dflt A P (λx. UNIV))"
    using assms by (simp add: bernoulli_pmf_half_conv_pmf_of_set Pi_pmf_of_set)
  also have "map_pmf (λb. {xA. b x})  = pmf_of_set (Pow A)"
  proof -
    have "bij_betw (λb. {x  A. b x}) (PiE_dflt A P (λ_. UNIV)) (Pow A)"
      by (rule bij_betwI[of _ _ _ "λB b. if b  A then b  B else P"]) (auto simp add: PiE_dflt_def)
    then show ?thesis
      using assms by (intro map_pmf_of_set_bij_betw) auto
  qed
  finally show ?thesis
    by simp
qed

text ‹
  A binomial distribution can be seen as the number of successes in n› independent coin tosses.
›
lemma binomial_pmf_altdef':
  fixes A :: "'a set"
  assumes "finite A" and "card A = n" and p: "p  {0..1}"
  shows   "binomial_pmf n p =
             map_pmf (λf. card {xA. f x}) (Pi_pmf A dflt (λ_. bernoulli_pmf p))" (is "?lhs = ?rhs")
proof -
  from assms have "?lhs = binomial_pmf (card A) p"
    by simp
  also have " = ?rhs"
  using assms(1)
  proof (induction rule: finite_induct)
    case empty
    with p show ?case by (simp add: binomial_pmf_0)
  next
    case (insert x A)
    from insert.hyps have "card (insert x A) = Suc (card A)"
      by simp
    also have "binomial_pmf  p = do {
                                     b  bernoulli_pmf p;
                                     f  Pi_pmf A dflt (λ_. bernoulli_pmf p);
                                     return_pmf ((if b then 1 else 0) + card {y  A. f y})
                                   }"
      using p by (simp add: binomial_pmf_Suc insert.IH bind_map_pmf)
    also have " = do {
                      b  bernoulli_pmf p;
                      f  Pi_pmf A dflt (λ_. bernoulli_pmf p);
                      return_pmf (card {y  insert x A. (f(x := b)) y})
                    }"
    proof (intro bind_pmf_cong refl, goal_cases)
      case (1 b f)
      have "(if b then 1 else 0) + card {yA. f y} = card ((if b then {x} else {})  {yA. f y})"
        using insert.hyps by auto
      also have "(if b then {x} else {})  {yA. f y} = {yinsert x A. (f(x := b)) y}"
        using insert.hyps by auto
      finally show ?case by simp
    qed
    also have " = map_pmf (λf. card {yinsert x A. f y})
                      (Pi_pmf (insert x A) dflt (λ_. bernoulli_pmf p))"
      using insert.hyps by (subst Pi_pmf_insert) (simp_all add: pair_pmf_def map_bind_pmf)
    finally show ?case .
  qed
  finally show ?thesis .
qed

end