Theory Spmf_Common

theory Spmf_Common
  imports CryptHOL.CryptHOL
begin

no_adhoc_overloading Monad_Syntax.bind bind_pmf

lemma mk_lossless_back_eq:
  "scale_spmf (weight_spmf s) (mk_lossless s) = s"
  unfolding mk_lossless_def
  unfolding scale_scale_spmf
  by (auto simp: field_simps weight_spmf_eq_0)


lemma cond_spmf_enforce:
  "cond_spmf sx (Collect A) = mk_lossless (enforce_spmf A sx)"
  unfolding enforce_spmf_def
  unfolding cond_spmf_alt
  unfolding restrict_spmf_def
  unfolding enforce_option_alt_def
  apply (rule arg_cong[where f="mk_lossless"])
  apply (rule arg_cong[where f="λx. map_pmf x sx"])
  apply (intro ext)
  apply (rule arg_cong[where f="Option.bind _"])
  apply auto
  done


definition "rel_scale_spmf s t  (mk_lossless s = mk_lossless t)"

lemma rel_scale_spmf_refl:
  "rel_scale_spmf s s"
  unfolding rel_scale_spmf_def ..

lemma rel_scale_spmf_sym:
  "rel_scale_spmf s t  rel_scale_spmf t s"
  unfolding rel_scale_spmf_def by simp

lemma rel_scale_spmf_trans:
  "rel_scale_spmf s t  rel_scale_spmf t u  rel_scale_spmf s u"
  unfolding rel_scale_spmf_def by simp

lemma rel_scale_spmf_equiv:
  "equivp rel_scale_spmf"
  using rel_scale_spmf_refl rel_scale_spmf_sym
  by (auto intro!: equivpI reflpI sympI transpI dest: rel_scale_spmf_trans)


lemma spmf_eq_iff: "p = q  (i. spmf p i = spmf q i)"
  using spmf_eqI by auto

lemma spmf_eq_iff_set:
  "set_spmf a = set_spmf b  (x. x  set_spmf b  spmf a x = spmf b x)  a = b"
  using in_set_spmf_iff_spmf spmf_eq_iff
  by (metis)

lemma rel_scale_spmf_None:
  "rel_scale_spmf s t  s = return_pmf None  t = return_pmf None"
  unfolding rel_scale_spmf_def by auto

lemma rel_scale_spmf_def_alt:
  "rel_scale_spmf s t  (k>0. s = scale_spmf k t)"
proof
  assume rel: "rel_scale_spmf s t"
  then consider (isNone) "s = return_pmf None  t = return_pmf None" | (notNone) "weight_spmf s > 0  weight_spmf t > 0"
    using rel_scale_spmf_None weight_spmf_eq_0 zero_less_measure_iff by blast
  then show "k>0. s = scale_spmf k t"
  proof cases
    case isNone
    show ?thesis
      apply (rule exI[of _ 1])
      using isNone by simp
  next
    case notNone
    have "scale_spmf (weight_spmf s) (mk_lossless s) = scale_spmf (weight_spmf s / weight_spmf t) t"
      unfolding rel[unfolded rel_scale_spmf_def]
      unfolding mk_lossless_def
      unfolding scale_scale_spmf
      by (auto simp: field_simps)
    then show "k>0. s = scale_spmf k t"
      apply (unfold mk_lossless_back_eq)
      using notNone divide_pos_pos by blast
  qed
next
  assume "k>0. s = scale_spmf k t"
  then obtain k where kpos: "k>0" and st: "s = scale_spmf k t" by blast
  then consider (isNone) "weight_spmf s = 0  weight_spmf t = 0" | (notNone) "weight_spmf s > 0  weight_spmf t > 0"
    using zero_less_measure_iff mult_pos_pos zero_less_measure_iff by (fastforce simp: weight_scale_spmf)
  then show "rel_scale_spmf s t"
  proof cases
    case isNone
    then show ?thesis
      unfolding rel_scale_spmf_def weight_spmf_eq_0 by simp
  next
    case notNone
    then show ?thesis
      unfolding rel_scale_spmf_def
      unfolding mk_lossless_def
      unfolding st
      by (cases "k < inverse (weight_spmf t)")
        (auto simp: weight_scale_spmf scale_scale_spmf field_simps)
  qed
qed

lemma rel_scale_spmf_def_alt2:
  "rel_scale_spmf s t 
    (s = return_pmf None  t = return_pmf None)
  | (weight_spmf s > 0  weight_spmf t > 0  s = scale_spmf (weight_spmf s / weight_spmf t) t)"
  (is "?lhs  ?rhs")
proof
  assume rel: ?lhs
  then consider (isNone) "s = return_pmf None  t = return_pmf None" | (notNone) "weight_spmf s > 0  weight_spmf t > 0"
    using rel_scale_spmf_None weight_spmf_eq_0 zero_less_measure_iff by blast
  thus ?rhs
  proof cases
    case notNone
    have "scale_spmf (weight_spmf s) (mk_lossless s) = scale_spmf (weight_spmf s / weight_spmf t) t"
      unfolding rel[unfolded rel_scale_spmf_def]
      unfolding mk_lossless_def
      unfolding scale_scale_spmf
      by (auto simp: field_simps)
    thus ?thesis
      apply (unfold mk_lossless_back_eq)
      using notNone by simp
  qed simp
next
  assume ?rhs
  then show ?lhs
  proof cases
    case right
    then have gt0: "weight_spmf s > 0" "weight_spmf t > 0" and st: "s = scale_spmf (weight_spmf s / weight_spmf t) t"
      by auto
    then have "(1 / weight_spmf t)  (weight_spmf s / weight_spmf t)"
      using weight_spmf_le_1 divide_le_cancel by fastforce
    then show ?thesis
      unfolding rel_scale_spmf_def mk_lossless_def
      apply (subst (3) st)
      using gt0 by (auto simp: scale_scale_spmf field_simps)
  qed (simp add: rel_scale_spmf_refl)
qed

lemma rel_scale_spmf_scale:
  "r > 0  rel_scale_spmf s t  rel_scale_spmf s (scale_spmf r t)"
  apply (unfold rel_scale_spmf_def_alt)
  by (metis rel_scale_spmf_def rel_scale_spmf_def_alt)

lemma rel_scale_spmf_mk_lossless:
  "rel_scale_spmf s t  rel_scale_spmf s (mk_lossless t)"
  unfolding rel_scale_spmf_def_alt
  unfolding mk_lossless_def
  apply (cases "weight_spmf t = 0")
  subgoal by(simp add: weight_spmf_eq_0)
  subgoal
    apply (auto simp: weight_spmf_eq_0 field_simps scale_scale_spmf)
    using rel_scale_spmf_def_alt rel_scale_spmf_def_alt2 by blast
  done

lemma rel_scale_spmf_eq_iff:
  "rel_scale_spmf s t  weight_spmf s = weight_spmf t  s = t"
  unfolding rel_scale_spmf_def_alt2 by auto

lemma rel_scale_spmf_set_restrict:
  "finite A  rel_scale_spmf (restrict_spmf (spmf_of_set A) B) (spmf_of_set (A  B))"
  apply (unfold rel_scale_spmf_def)
  apply (fold cond_spmf_alt)
  apply (subst cond_spmf_spmf_of_set)
  subgoal .
  apply (unfold mk_lossless_spmf_of_set)
  ..

lemma spmf_of_set_restrict_empty:
  "A  B = {}  restrict_spmf (spmf_of_set A) B = return_pmf None"
  unfolding spmf_of_set_def
  by simp

lemma spmf_of_set_restrict_scale:
  "finite A  restrict_spmf (spmf_of_set A) B = scale_spmf (card (AB) / card A) (spmf_of_set (AB))"
  apply (rule rel_scale_spmf_eq_iff)
  subgoal
    apply (cases "AB = {}")
    subgoal
      by (auto simp: spmf_of_set_restrict_empty intro: rel_scale_spmf_refl)
    subgoal
      apply (rule rel_scale_spmf_scale)
      subgoal
        by (metis card_gt_0_iff divide_pos_pos finite_Int inf_bot_left of_nat_0_less_iff)
      subgoal by (rule rel_scale_spmf_set_restrict)
      done
    done
  subgoal
    apply (auto simp add: weight_scale_spmf measure_spmf_of_set)
    by (smt (verit, best) card_gt_0_iff card_mono disjoint_notin1 divide_le_eq_1_pos finite_Int inf_le1 of_nat_0_less_iff of_nat_le_iff)
  done

lemma min_em2:
  "min a b = c  ac  b=c"
  unfolding min_def by auto

lemma weight_0_spmf:
  "weight_spmf s = 0  spmf s i = 0"
  using order_trans[OF spmf_le_weight, of s 0 i] by simp

lemma mk_lossless_scale_absorb:
  "r > 0  mk_lossless (scale_spmf r s) = mk_lossless s"
  apply (rule rel_scale_spmf_eq_iff)
  subgoal
    apply (rule rel_scale_spmf_trans[where t=s])
    subgoal
      apply (rule rel_scale_spmf_sym)
      apply (rule rel_scale_spmf_mk_lossless)
      apply (rule rel_scale_spmf_scale)
      subgoal .
      subgoal by (rule rel_scale_spmf_refl)
      done
    subgoal
      apply (rule rel_scale_spmf_mk_lossless)
      apply (rule rel_scale_spmf_refl)
      done
    done
  subgoal
    unfolding weight_mk_lossless
    by (auto simp flip: weight_spmf_eq_0 simp: weight_scale_spmf dest: min_em2)
  done

lemma scale_spmf_None_iff:
  "scale_spmf k s = return_pmf None  k0  s=return_pmf None"
  apply (auto simp: spmf_eq_iff spmf_scale_spmf)
  using
    inverse_nonpositive_iff_nonpositive
    weight_0_spmf
    measure_le_0_iff
  by (smt (verit))

lemma spmf_of_pmf_the:
  "lossless_spmf s  spmf_of_pmf (map_pmf the s) = s"
  unfolding lossless_spmf_conv_spmf_of_pmf by auto

lemma lossless_mk_lossless:
  "s  return_pmf None  lossless_spmf (mk_lossless s)"
  unfolding lossless_spmf_def
  unfolding weight_mk_lossless
  by simp

definition pmf_of_spmf where
  "pmf_of_spmf p = map_pmf the (mk_lossless p)"

lemma scale_weight_spmf_of_pmf:
  "p = scale_spmf (weight_spmf p) (spmf_of_pmf (pmf_of_spmf p))"
  unfolding pmf_of_spmf_def
  apply (cases "p = return_pmf None")
  subgoal by simp
  subgoal
    apply (subst mk_lossless_back_eq[of p, symmetric])
    apply (rule arg_cong[where f="scale_spmf _ "])
    apply (rule spmf_of_pmf_the[symmetric])
    by (fact lossless_mk_lossless)
  done

lemma pmf_spmf:
  "pmf_of_spmf (spmf_of_pmf p) = p"
  unfolding pmf_of_spmf_def
  unfolding lossless_spmf_spmf_of_spmf[THEN mk_lossless_lossless]
  unfolding map_the_spmf_of_pmf
  ..

lemma spmf_pmf:
  "lossless_spmf p  spmf_of_pmf (pmf_of_spmf p) = p"
  unfolding pmf_of_spmf_def
  by (simp add: spmf_of_pmf_the)

lemma pmf_of_spmf_scale_spmf:
  "r > 0  pmf_of_spmf (scale_spmf r p) = pmf_of_spmf p"
  unfolding pmf_of_spmf_def
  by (simp add: mk_lossless_scale_absorb)

lemma nonempty_spmf_weight:
  "p  return_pmf None  weight_spmf p > 0"
  apply (fold weight_spmf_eq_0)
  using dual_order.not_eq_order_implies_strict[OF _ weight_spmf_nonneg[of p]]
  by auto

lemma pmf_of_spmf_mk_lossless:
  "pmf_of_spmf (mk_lossless p) = pmf_of_spmf p"
  apply (cases "p = return_pmf None")
  subgoal by auto
  apply (unfold mk_lossless_def)
  apply (subst pmf_of_spmf_scale_spmf)
  subgoal by (simp add: nonempty_spmf_weight)
  ..

lemma spmf_pmf':
  "p  return_pmf None  spmf_of_pmf (pmf_of_spmf p) = mk_lossless p"
  apply (subst spmf_pmf[of "mk_lossless p", symmetric])
   apply (unfold pmf_of_spmf_mk_lossless)
  subgoal using lossless_mk_lossless .
  subgoal ..
  done

lemma rel_scale_spmf_cond_UNIV:
  "rel_scale_spmf s (cond_spmf s UNIV)"
  unfolding cond_spmf_UNIV
  by (rule rel_scale_spmf_mk_lossless) (rule rel_scale_spmf_refl)

lemma "set_pmf p  g  {}  pmf_prob p (f  g) = pmf_prob (cond_pmf p g) f * pmf_prob p g"
  using measure_cond_pmf
  unfolding pmf_prob_def
  by (metis Int_commute divide_eq_eq measure_measure_pmf_not_zero)

lemma bayes:
  "set_pmf p  A  {}  set_pmf p  B  {} 
   measure_pmf.prob (cond_pmf p A) B
 = measure_pmf.prob (cond_pmf p B) A * measure_pmf.prob p B / measure_pmf.prob p A"
  unfolding measure_cond_pmf
  by (subst inf_commute) (simp add: measure_pmf_zero_iff)

definition spmf_prob :: "'a spmf  'a set  real" where
  "spmf_prob p = Sigma_Algebra.measure (measure_spmf p)"

lemma "spmf_prob = measure_measure_spmf"
  unfolding spmf_prob_def measure_measure_spmf_def ..

lemma spmf_prob_pmf:
  "spmf_prob p A = pmf_prob p (Some ` A)"
  unfolding spmf_prob_def pmf_prob_def
  unfolding measure_measure_spmf_conv_measure_pmf
  ..

lemma bayes_spmf:
  "spmf_prob (cond_spmf p A) B
 = spmf_prob (cond_spmf p B) A * spmf_prob p B / spmf_prob p A"
  unfolding spmf_prob_def
  unfolding measure_cond_spmf
  by (subst inf_commute) (auto simp: measure_spmf_zero_iff)

lemma spmf_prob_pmf_of_spmf:
  "spmf_prob p A = weight_spmf p * pmf_prob (pmf_of_spmf p) A"
  apply (subst scale_weight_spmf_of_pmf)
  apply (unfold spmf_prob_def)
  apply (subst measure_spmf_scale_spmf')
  subgoal using weight_spmf_le_1 .
  by (simp add: pmf_prob_def)

lemma cond_spmf_Int:
  "cond_spmf (cond_spmf p A) B = cond_spmf p (A  B)"
  apply (rule spmf_eqI)
  apply (unfold spmf_cond_spmf)
  apply(auto simp add: measure_cond_spmf split: if_split_asm)
  using Int_lower1[THEN measure_spmf.finite_measure_mono[simplified]]  measure_le_0_iff
  by metis

lemma cond_spmf_prob:
  "spmf_prob p (A  B) = spmf_prob (cond_spmf p A) B * spmf_prob p A"
  unfolding spmf_prob_def measure_cond_spmf
  using Int_lower1[THEN measure_spmf.finite_measure_mono[simplified]]  measure_le_0_iff
  by (metis mult_eq_0_iff nonzero_eq_divide_eq)

definition "empty_spmf = return_pmf None"

lemma spmf_prob_empty:
  "spmf_prob empty_spmf A = 0"
  unfolding spmf_prob_def empty_spmf_def
  by simp

definition le_spmf :: "'a spmf  'a spmf  bool" where
  "le_spmf p q  (k1. p = scale_spmf k q)"

definition lt_spmf :: "'a spmf  'a spmf  bool" where
  "lt_spmf p q  (k<1. p = scale_spmf k q)"

lemma "class.order_bot empty_spmf le_spmf lt_spmf"
  oops

lemma spmf_prob_cond_Int:
  "spmf_prob (cond_spmf p C) (A  B)
   = spmf_prob (cond_spmf p (B  C)) A * spmf_prob (cond_spmf p C) B"
  apply (subst Int_commute[of B C])
  apply (subst Int_commute[of A B])
  apply (fold cond_spmf_Int)
  using cond_spmf_prob .

lemma cond_spmf_mk_lossless:
  "cond_spmf (mk_lossless p) A = cond_spmf p A"
  apply (fold cond_spmf_UNIV)
  apply (unfold cond_spmf_Int)
  by simp

primrec sequence_spmf :: "'a spmf list  'a list spmf" where
  "sequence_spmf [] = return_spmf []"
| "sequence_spmf (x#xs) = map_spmf (case_prod Cons) (pair_spmf x (sequence_spmf xs))"

lemma set_sequence_spmf:
  "set_spmf (sequence_spmf xs) = {l. list_all2 (λx s. x  set_spmf s) l xs}"
  by (induction xs) (auto simp: list_all2_Cons2)

lemma map_spmf_map_sequence:
  "map_spmf (map f) (sequence_spmf xs) = sequence_spmf (map (map_spmf f) xs)"
  apply (induction xs)
  subgoal by simp
  subgoal premises IH
    unfolding list.map
    unfolding sequence_spmf.simps
    apply (fold IH)
    apply (unfold pair_map_spmf)
    apply (unfold spmf.map_comp)
    by (simp add: comp_def case_prod_map_prod prod.case_distrib)
  done

lemma sequence_map_return_spmf:
  "sequence_spmf (map return_spmf xs) = return_spmf xs"
  by (induction xs) auto

lemma sequence_bind_cong:
  "xs=ys; y. length y = length ys  f y = g y  bind_spmf (sequence_spmf xs) f = bind_spmf (sequence_spmf ys) g"
  apply (rule bind_spmf_cong)
  subgoal by simp
  subgoal unfolding set_sequence_spmf list_all2_iff by simp
  done

lemma bind_spmf_sequence_replicate_cong:
  "(l. length l = n  f l = g l)
     bind_spmf (sequence_spmf (replicate n x)) f = bind_spmf (sequence_spmf (replicate n x)) g"
  by (rule bind_spmf_cong[OF refl]) (simp add: set_spmf_of_set finite_permutations set_sequence_spmf[unfolded list_all2_iff])

lemma bind_spmf_sequence_map_cong:
  "(l. length l = length x  f l = g l)
     bind_spmf (sequence_spmf (map m x)) f = bind_spmf (sequence_spmf (map m x)) g"
  by (rule bind_spmf_cong[OF refl]) (simp add: set_spmf_of_set finite_permutations set_sequence_spmf[unfolded list_all2_iff])

lemma lossless_pair_spmf_iff:
  "lossless_spmf (pair_spmf a b)  lossless_spmf a  lossless_spmf b"
  unfolding pair_spmf_alt_def
  by (auto simp: set_spmf_eq_empty)

lemma lossless_sequence_spmf:
  "(x. xset xs  lossless_spmf x)  lossless_spmf (sequence_spmf xs)"
  by (induction xs) (auto simp: lossless_pair_spmf_iff)

end