Theory Resampling

(* Title: Resampling.thy
   Author: Andreas Lochbihler, ETH Zurich *)

theory Resampling imports

lemma ord_spmf_lossless:
  assumes "ord_spmf (=) p q" "lossless_spmf p"
  shows "p = q"
  unfolding pmf.rel_eq[symmetric] using assms(1)
  by(rule pmf.rel_mono_strong)(use assms(2) in auto elim!: ord_option.cases simp add: lossless_iff_set_pmf_None)

context notes [[function_internals]] begin

partial_function (spmf) resample :: "'a set  'a set  'a spmf" where
  "resample A B = bind_spmf (spmf_of_set A) (λx. if x  B then return_spmf x else resample A B)"


lemmas resample_fixp_induct[case_names adm bottom step] = resample.fixp_induct

  fixes A :: "'a set"
  and B :: "'a set"

interpretation loop_spmf "λx. x  B" "λ_. spmf_of_set A" .

lemma resample_conv_while: "resample A B = bind_spmf (spmf_of_set A) while"
proof(induction rule: parallel_fixp_induct_2_1[OF partial_function_definitions_spmf partial_function_definitions_spmf resample.mono while.mono resample_def while_def, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step resample' while') then show ?case by(simp add: z3_rule(33) cong del: if_cong)

  assumes A: "finite A"
    and B: "B  A" "B  {}"

private lemma A_nonempty: "A  {}"
  using B by blast

private lemma B_finite: "finite B"
  using A B by(blast intro: finite_subset)

lemma lossless_resample: "lossless_spmf (resample A B)"
proof -
  from B have [simp]: "A  B  {}" by auto
  have "lossless_spmf (while x)" for x
    by(rule termination_0_1_immediate[where p="card (A  B) / card A"])
      (simp_all add: spmf_map vimage_def measure_spmf_of_set field_simps A_nonempty A not_le card_gt_0_iff B)
  then show ?thesis by(clarsimp simp add: resample_conv_while A A_nonempty)

lemma resample_le_sample:
  "ord_spmf (=) (resample A B) (spmf_of_set B)"
proof(induction rule: resample_fixp_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step resample')
  note [simp] = B_finite A
  show ?case
  proof(rule ord_pmf_increaseI)
    fix x
    let ?f = "λx. if x  B then return_spmf x else resample' A B"
    have "spmf (bind_spmf (spmf_of_set A) ?f) x =
      (nB  (A - B). if n  B then (if n = x then 1 else 0) / card A else spmf (resample' A B) x / card A)"
      using B
      by(auto simp add: spmf_bind integral_spmf_of_set sum_divide_distrib if_distrib[where f="λp. spmf p _ / _"] cong: if_cong intro!: sum.cong split: split_indicator_asm)
    also have " = (nB. (if n = x then 1 else 0) / card A) + (nA - B. spmf (resample' A B) x / card A)"
      by(subst sum.union_disjoint)(auto)
    also have " = (if x  B then 1 / card A else 0) + card (A - B) / card A * spmf (resample' A B) x"
      by(simp cong: sum.cong add: if_distrib[where f="λx. x / _"] cong: if_cong)
    also have "  (if x  B then 1 / card A else 0) + card (A - B) / card A * spmf (spmf_of_set B) x"
      by(intro add_left_mono mult_left_mono step.IH[THEN ord_spmf_eq_leD]) simp
    also have " = spmf (spmf_of_set B) x" using B
      by(simp add: spmf_of_set field_simps A_nonempty card_Diff_subset card_mono of_nat_diff)
    finally show "spmf (bind_spmf (spmf_of_set A) ?f) x  " .
  qed simp

lemma resample_eq_sample: "resample A B = spmf_of_set B"
  using resample_le_sample lossless_resample by(rule ord_spmf_lossless)