Theory Applicative_Probability_List

theory Applicative_Probability_List
imports Applicative_List Complex_Main
(* Author: Andreas Lochbihler, ETH Zurich *)

subsection ‹Probability mass functions implemented as lists with duplicates›

theory Applicative_Probability_List imports
  Applicative_List
  Complex_Main
begin

lemma sum_list_concat_map: "sum_list (concat (map f xs)) = sum_list (map (λx. sum_list (f x)) xs)"
by(induction xs) simp_all

context includes applicative_syntax begin

lemma set_ap_list [simp]: "set (f ⋄ x) = (λ(f, x). f x) ` (set f × set x)"
by(auto simp add: ap_list_def List.bind_def)

text ‹We call the implementation type ‹pfp› because it is the basis for the Haskell library
  Probability by Martin Erwig and Steve Kollmansberger (Probabilistic Functional Programming).›

typedef 'a pfp = "{xs :: ('a × real) list. (∀(_, p) ∈ set xs. p > 0) ∧ sum_list (map snd xs) = 1}"
proof
  show "[(x, 1)] ∈ ?pfp" for x by simp
qed

setup_lifting type_definition_pfp

lift_definition pure_pfp :: "'a ⇒ 'a pfp" is "λx. [(x, 1)]" by simp

lift_definition ap_pfp :: "('a ⇒ 'b) pfp ⇒ 'a pfp ⇒ 'b pfp"
is "λfs xs. [λ(f, p) (x, q). (f x, p * q)] ⋄ fs ⋄ xs"
proof safe
  fix xs :: "(('a ⇒ 'b) × real) list" and ys :: "('a × real) list"
  assume xs: "∀(x, y) ∈ set xs. 0 < y" "sum_list (map snd xs) = 1"
    and ys: "∀(x, y) ∈ set ys. 0 < y" "sum_list (map snd ys) = 1"
  let ?ap = "[λ(f, p) (x, q). (f x, p * q)] ⋄ xs ⋄ ys"
  show "0 < b" if "(a, b) ∈ set ?ap" for a b using that xs ys
    by(auto intro!: mult_pos_pos)
  show "sum_list (map snd ?ap) = 1" using xs ys
    by(simp add: ap_list_def List.bind_def map_concat o_def split_beta sum_list_concat_map sum_list_const_mult)
qed

adhoc_overloading Applicative.ap ap_pfp

applicative pfp
 for pure: pure_pfp
     ap: ap_pfp
proof -
  show "pure_pfp (λx. x) ⋄ x = x" for x :: "'a pfp"
    by transfer(simp add: ap_list_def List.bind_def)
  show "pure_pfp f ⋄ pure_pfp x = pure_pfp (f x)" for f :: "'a ⇒ 'b" and x
    by transfer (applicative_lifting; simp)
  show "pure_pfp (λg f x. g (f x)) ⋄ g ⋄ f ⋄ x = g ⋄ (f ⋄ x)"
    for g :: "('b ⇒ 'c) pfp" and f :: "('a ⇒ 'b) pfp" and x
    by transfer(applicative_lifting; clarsimp)
  show "f ⋄ pure_pfp x = pure_pfp (λf. f x) ⋄ f" for f :: "('a ⇒ 'b) pfp" and x
    by transfer(applicative_lifting; clarsimp)
qed

end

end