Theory Rel_PMF_Characterisation

(* Author: Andreas Lochbihler, ETH Zurich *)

theory Rel_PMF_Characterisation imports
  Matrix_For_Marginals
begin

section ‹Characterisation of @{const rel_pmf}

proposition rel_pmf_measureI:
  fixes p :: "'a pmf" and q :: "'b pmf"
  assumes le: "A. measure (measure_pmf p) A  measure (measure_pmf q) {y. xA. R x y}"
  shows "rel_pmf R p q"
proof -
  let ?A = "set_pmf p" and ?f = "λx. ennreal (pmf p x)"
    and ?B = "set_pmf q" and ?g = "λy. ennreal (pmf q y)"
  define R' where "R' = {(x, y)?A × ?B. R x y}"

  have "(+ x?A. ?f x) = (+ y?B. ?g y)" (is "?lhs = ?rhs")
    and "(+ y?B. ?g y)  " (is ?bounded)
  proof -
    have "?lhs = (+ x. ?f x)" "?rhs = (+ y. ?g y)"
      by(auto simp add: nn_integral_count_space_indicator pmf_eq_0_set_pmf intro!: nn_integral_cong split: split_indicator)
    then show "?lhs = ?rhs" ?bounded by(simp_all add: nn_integral_pmf_eq_1)
  qed
  moreover
  have "(+ xX. ?f x)  (+ yR' `` X. ?g y)" (is "?lhs  ?rhs") if "X  set_pmf p" for X
  proof -
    have "?lhs = measure (measure_pmf p) X" 
      by(simp add: nn_integral_pmf measure_pmf.emeasure_eq_measure)
    also have "  measure (measure_pmf q) {y. xX. R x y}" by(simp add: le)
    also have " = measure (measure_pmf q) (R' `` X)" using that
      by(auto 4 3 simp add: R'_def AE_measure_pmf_iff intro!: measure_eq_AE)
    also have " = ?rhs" by(simp add: nn_integral_pmf measure_pmf.emeasure_eq_measure)
    finally show ?thesis .
  qed
  moreover have "countable ?A" "countable ?B" by simp_all
  moreover have "R'  ?A × ?B" by(auto simp add: R'_def)
  ultimately obtain h
    where supp: "x y. 0 < h x y  (x, y)  R'"
    and bound: "x y. h x y  "
    and p: "x. x  ?A  (+ y?B. h x y) = ?f x"
    and q: "y. y  ?B  (+ x?A. h x y) = ?g y"
    by(rule bounded_matrix_for_marginals_ennreal) blast+

  let ?z = "λ(x, y). enn2real (h x y)"
  define z where "z = embed_pmf ?z"
  have nonneg: "xy. 0  ?z xy" by clarsimp
  have outside: "h x y = 0" if "x  set_pmf p  y  set_pmf q  ¬ R x y" for x y
    using supp[of x y] that by(cases "h x y > 0")(auto simp add: R'_def)
  have prob: "(+ xy. ?z xy) = 1"
  proof -
    have "(+ xy. ?z xy) = (+ x. + y. (ennreal  ?z) (x, y))"
      unfolding nn_integral_fst_count_space by(simp add: split_def o_def)
    also have " = (+ x. (+y. h x y))" using bound
      by(simp add: nn_integral_count_space_reindex ennreal_enn2real_if)
    also have " = (+ x?A. (+y?B. h x y))"
      by(auto intro!: nn_integral_cong nn_integral_zero' simp add: nn_integral_count_space_indicator outside split: split_indicator)
    also have " = (+ x?A. ?f x)" by(auto simp add: p intro!: nn_integral_cong)
    also have " = (+ x. ?f x)"
      by(auto simp add: nn_integral_count_space_indicator pmf_eq_0_set_pmf intro!: nn_integral_cong split: split_indicator)
    finally show ?thesis by(simp add: nn_integral_pmf_eq_1)
  qed
  note z = nonneg prob
  have z_sel [simp]: "pmf z (x, y) = enn2real (h x y)" for x y
    by(simp add: z_def pmf_embed_pmf[OF z])

  show ?thesis
  proof
    show "R x y" if "(x, y)  set_pmf z" for x y using that
      using that outside[of x y] unfolding set_pmf_iff
      by(auto simp add: enn2real_eq_0_iff)

    show "map_pmf fst z = p"
    proof(rule pmf_eqI)
      fix x
      have "pmf (map_pmf fst z) x = (+ erange (Pair x). pmf z e)"
        by(auto simp add: ennreal_pmf_map nn_integral_measure_pmf nn_integral_count_space_indicator intro!: nn_integral_cong split: split_indicator)
      also have " = (+ y. h x y)"
        using bound by(simp add: nn_integral_count_space_reindex ennreal_enn2real_if)
      also have " = (+y?B. h x y)" using outside
        by(auto simp add: nn_integral_count_space_indicator intro!: nn_integral_cong split: split_indicator)
      also have " = ?f x" using p[of x] apply(cases "x  set_pmf p")
        by(auto simp add: set_pmf_iff AE_count_space outside intro!: nn_integral_zero')
      finally show "pmf (map_pmf fst z) x = pmf p x" by simp
    qed

    show "map_pmf snd z = q"
    proof(rule pmf_eqI)
      fix y
      have "pmf (map_pmf snd z) y = (+ erange (λx. (x, y)). pmf z e)"
        by(auto simp add: ennreal_pmf_map nn_integral_measure_pmf nn_integral_count_space_indicator intro!: nn_integral_cong split: split_indicator)
      also have " = (+ x. h x y)"
        using bound by(simp add: nn_integral_count_space_reindex ennreal_enn2real_if)
      also have " = (+x?A. h x y)" using outside
        by(auto simp add: nn_integral_count_space_indicator intro!: nn_integral_cong split: split_indicator)
      also have " = ?g y" using q[of y] apply(cases "y  set_pmf q")
        by(auto simp add: set_pmf_iff AE_count_space outside intro!: nn_integral_zero')
      finally show "pmf (map_pmf snd z) y = pmf q y" by simp
    qed
  qed
qed

subsection ‹Code generation for @{const rel_pmf}

proposition rel_pmf_measureI':
  fixes p :: "'a pmf" and q :: "'b pmf"
  assumes le: "A. A  set_pmf p  measure_pmf.prob p A  measure_pmf.prob q {y  set_pmf q. xA. R x y}"
  shows "rel_pmf R p q"
proof(rule rel_pmf_measureI)
  fix A
  let ?A = "A  set_pmf p"
  have "measure_pmf.prob p A = measure_pmf.prob p ?A" by(simp add: measure_Int_set_pmf)
  also have "  measure_pmf.prob q {y  set_pmf q. x?A. R x y}" by(rule le) simp
  also have "  measure_pmf.prob q {y. xA. R x y}"
    by(rule measure_pmf.finite_measure_mono) auto
  finally show "measure_pmf.prob p A  " .
qed

lemma rel_pmf_code [code]:
  "rel_pmf R p q 
   (let B = set_pmf q in
    APow (set_pmf p). measure_pmf.prob p A  measure_pmf.prob q (snd ` Set.filter (case_prod R) (A × B)))"
  unfolding Let_def
proof(intro iffI strip)
  have eq: "snd ` Set.filter (case_prod R) (A × set_pmf q) = {y. xA. R x y}  set_pmf q" for A
    by(auto intro: rev_image_eqI simp add: Set.filter_def)
  show "measure_pmf.prob p A  measure_pmf.prob q (snd ` Set.filter (case_prod R) (A × set_pmf q))"
    if "rel_pmf R p q" and "A  Pow (set_pmf p)" for A
    using that by(auto dest: rel_pmf_measureD simp add: eq measure_Int_set_pmf)
  show "rel_pmf R p q" if "APow (set_pmf p). measure_pmf.prob p A  measure_pmf.prob q (snd ` Set.filter (case_prod R) (A × set_pmf q))"
    using that by(intro rel_pmf_measureI')(auto intro: ord_le_eq_trans arg_cong2[where f=measure] simp add: eq)
qed

end