Theory Geometric

```(* Title: Geometric.thy
Author: Andreas Lochbihler, ETH Zurich *)

subsection ‹The geometric distribution›

theory Geometric imports
Bernoulli
While_SPMF
begin

text ‹
We define the geometric distribution as a least fixpoint, which is more elegant than
as a loop. To prove probabilistic termination, we prove it equivalent to a loop and use
the proof rules for probabilistic termination.
›

context notes [[function_internals]] begin
partial_function (spmf) geometric_spmf :: "real ⇒ nat spmf" where
"geometric_spmf p = do {
b ← bernoulli p;
if b then return_spmf 0 else map_spmf ((+) 1) (geometric_spmf p)
}"
end

lemma geometric_spmf_fixp_induct [case_names adm bottom step]:
and "P (λgeometric_spmf. return_pmf None)"
and "⋀geometric_spmf'. P geometric_spmf' ⟹ P (λp. bernoulli p ⤜ (λb. if b then return_spmf 0 else map_spmf ((+) 1) (geometric_spmf' p)))"
shows "P geometric_spmf"
using assms by(rule geometric_spmf.fixp_induct)

lemma spmf_geometric_nonpos: "p ≤ 0 ⟹ geometric_spmf p = return_pmf None"
by(induction rule: geometric_spmf_fixp_induct) simp_all

lemma spmf_geometric_ge_1: "1 ≤ p ⟹ geometric_spmf p = return_spmf 0"

context
fixes p :: real
and body :: "bool × nat ⇒ (bool × nat) spmf"
defines [simp]: "body ≡ λ(b, x). map_spmf (λb'. (¬ b', x + (if b' then 0 else 1))) (bernoulli p)"
begin

interpretation loop_spmf fst body
rewrites "body ≡ λ(b, x). map_spmf (λb'. (¬ b', x + (if b' then 0 else 1))) (bernoulli p)"
by(fact body_def)

lemma geometric_spmf_conv_while:
shows "geometric_spmf p = map_spmf snd (while (True, 0))"
proof -
have "map_spmf ((+) x) (geometric_spmf p) = map_spmf snd (while (True, x))" (is "?lhs = ?rhs") for x
proof(rule spmf.leq_antisym)
show "ord_spmf (=) ?lhs ?rhs"
proof(induction arbitrary: x rule: geometric_spmf_fixp_induct)
case adm show ?case by simp
case bottom show ?case by simp
case (step geometric')
show ?case using step.IH[of "Suc x"]
apply(rewrite while.simps)
apply(clarsimp simp add: map_spmf_bind_spmf bind_map_spmf intro!: ord_spmf_bind_reflI)
apply(rewrite while.simps)
done
qed
have "ord_spmf (=) ?rhs ?lhs"
and "ord_spmf (=) (map_spmf snd (while (False, x))) (return_spmf x)"
proof(induction arbitrary: x and x rule: while_fixp_induct)
case adm show ?case by simp
case bottom case 1 show ?case by simp
case bottom case 2 show ?case by simp
next
case (step while')
case 1 show ?case using step.IH(1)[of "Suc x"] step.IH(2)[of x]
by(rewrite geometric_spmf.simps)(clarsimp simp add: map_spmf_bind_spmf bind_map_spmf spmf.map_comp o_def intro!: ord_spmf_bind_reflI)
case 2 show ?case by simp
qed
then show "ord_spmf (=) ?rhs ?lhs" by -
qed
from this[of 0] show ?thesis by(simp cong: map_spmf_cong)
qed

lemma lossless_geometric [simp]: "lossless_spmf (geometric_spmf p) ⟷ p > 0"
proof(cases "0 < p ∧ p < 1")
case True
let ?body = "λ(b, x :: nat). map_spmf (λb'. (¬ b', x + (if b' then 0 else 1))) (bernoulli p)"
have "lossless_spmf (while (True, 0))"
proof(rule termination_0_1_immediate)
have "{x. x} = {True}" by auto
then show "p ≤ spmf (map_spmf fst (?body s)) False" for s :: "bool × nat" using True
by(cases s)(simp add: spmf.map_comp o_def spmf_map vimage_def spmf_conv_measure_spmf[symmetric])
show "0 < p" using True by simp
qed(clarsimp)
with True show ?thesis by(simp add: geometric_spmf_conv_while)

end

lemma spmf_geometric:
assumes p: "0 < p" "p < 1"
shows "spmf (geometric_spmf p) n = (1 - p) ^ n * p" (is "?lhs n = ?rhs n")
proof(rule spmf_ub_tight)
fix n
have "ennreal (?lhs n) ≤ ennreal (?rhs n)" using p
proof(induction arbitrary: n rule: geometric_spmf_fixp_induct)
case adm show ?case by(rule cont_intro)+
case bottom show ?case by simp
case (step geometric_spmf')
then show ?case
by(cases n)(simp_all add: ennreal_spmf_bind nn_integral_measure_spmf UNIV_bool nn_integral_count_space_finite ennreal_mult spmf_map vimage_def mult.assoc spmf_conv_measure_spmf[symmetric] mult_mono split: split_indicator)
qed
then show "?lhs n ≤ ?rhs n" using p by(simp)
next
have "(∑i. ennreal (p * (1 - p) ^ i)) = ennreal (p * (1 / (1 - (1 - p))))" using p
by (intro suminf_ennreal_eq sums_mult geometric_sums) auto
then show "(∑⇧+ x. ennreal ((1 - p) ^ x * p)) = weight_spmf (geometric_spmf p)"
using lossless_geometric[of p] p unfolding lossless_spmf_def