Theory HMM_Implementation
section ‹Implementation›
theory HMM_Implementation
imports
Hidden_Markov_Model
"Monad_Memo_DP.State_Main"
begin
subsection ‹The Forward Algorithm›
locale HMM4 = HMM3 _ _ _ 𝒪⇩s 𝒦 for 𝒪⇩s :: "'t set" and 𝒦 :: "'s ⇒ 's pmf" +
assumes states_distinct: "distinct state_list"
context HMM3_defs
begin
context
fixes os :: "'t iarray"
begin
text ‹
Alternative definition using indices into the list of states.
The list of states is implemented as an immutable array for better performance.
›
function forward_ix_rec where
"forward_ix_rec s t_end n = (if n ≥ IArray.length os then indicator {t_end} s else
(∑t ← state_list.
ennreal (pmf (𝒪 t) (os !! n)) * ennreal (pmf (𝒦 s) t) * forward_ix_rec t t_end (n + 1)))
"
by auto
termination
by (relation "Wellfounded.measure (λ(_,_,n). IArray.length os - n)") auto
text ‹Memoization›
memoize_fun forward_ix⇩m: forward_ix_rec
with_memory dp_consistency_mapping
monadifies (state) forward_ix_rec.simps[unfolded Let_def]
term forward_ix⇩m'
memoize_correct
by memoize_prover
text ‹The main theorems generated by memoization.›
context
includes state_monad_syntax
begin
thm forward_ix⇩m'.simps forward_ix⇩m_def
thm forward_ix⇩m.memoized_correct
end
end
definition
"forward_ix os = forward_ix_rec (IArray os)"
definition
"likelihood_compute s os ≡
if s ∈ set state_list then Some (∑t ← state_list. forward s t os) else None"
end
text ‹Correctness of the alternative definition.›
lemma (in HMM3) forward_ix_drop_one:
"forward_ix (o # os) s t (n + 1) = forward_ix os s t n"
by (induction "length os - n" arbitrary: s n; simp add: forward_ix_def)
lemma (in HMM4) forward_ix_forward:
"forward_ix os s t 0 = forward s t os"
unfolding forward_ix_def
proof (induction os arbitrary: s)
case Nil
then show ?case
by simp
next
case (Cons o os)
show ?case
using forward_ix_drop_one[unfolded forward_ix_def] states_distinct
by (subst forward.simps, subst forward_ix_rec.simps)
(simp add: Cons.IH state_list_𝒮 sum_list_distinct_conv_sum_set
del: forward_ix_rec.simps forward.simps
)
qed
text ‹
Instructs the code generator to use this equation instead to execute ‹forward›.
Uses the memoized version of ‹forward_ix›.
›
lemma (in HMM4) forward_code [code]:
"forward s t os = fst (run_state (forward_ix⇩m' (IArray os) s t 0) Mapping.empty)"
by (simp only:
forward_ix_def forward_ix⇩m.memoized_correct forward_ix_forward[symmetric]
states_distinct
)
theorem (in HMM4) likelihood_compute:
"likelihood_compute s os = Some x ⟷ s ∈ 𝒮 ∧ x = likelihood s os"
unfolding likelihood_compute_def
by (auto simp: states_distinct state_list_𝒮 sum_list_distinct_conv_sum_set likelihood_forward)
subsection ‹The Viterbi Algorithm›
context HMM3_defs
begin
context
fixes os :: "'t iarray"
begin
text ‹
Alternative definition using indices into the list of states.
The list of states is implemented as an immutable array for better performance.
›
function viterbi_ix_rec where
"viterbi_ix_rec s t_end n = (if n ≥ IArray.length os then ([], indicator {t_end} s) else
fst (
argmax snd (map
(λt. let (xs, v) = viterbi_ix_rec t t_end (n + 1) in
(t # xs, ennreal (pmf (𝒪 t) (os !! n) * pmf (𝒦 s) t) * v))
state_list)))
"
by pat_completeness auto
termination
by (relation "Wellfounded.measure (λ(_,_,n). IArray.length os - n)") auto
text ‹Memoization›
memoize_fun viterbi_ix⇩m: viterbi_ix_rec
with_memory dp_consistency_mapping
monadifies (state) viterbi_ix_rec.simps[unfolded Let_def]
memoize_correct
by memoize_prover
text ‹The main theorems generated by memoization.›
context
includes state_monad_syntax
begin
thm viterbi_ix⇩m'.simps viterbi_ix⇩m_def
thm viterbi_ix⇩m.memoized_correct
end
end
definition
"viterbi_ix os = viterbi_ix_rec (IArray os)"
end
context HMM3
begin
lemma viterbi_ix_drop_one:
"viterbi_ix (o # os) s t (n + 1) = viterbi_ix os s t n"
by (induction "length os - n" arbitrary: s n; simp add: viterbi_ix_def)
lemma viterbi_ix_viterbi:
"viterbi_ix os s t 0 = viterbi s t os"
unfolding viterbi_ix_def
proof (induction os arbitrary: s)
case Nil
then show ?case
by simp
next
case (Cons o os)
show ?case
using viterbi_ix_drop_one[unfolded viterbi_ix_def]
by (subst viterbi.simps, subst viterbi_ix_rec.simps)
(simp add: Cons.IH del: viterbi_ix_rec.simps viterbi.simps)
qed
lemma viterbi_code [code]:
"viterbi s t os = fst (run_state (viterbi_ix⇩m' (IArray os) s t 0) Mapping.empty)"
by (simp only: viterbi_ix_def viterbi_ix⇩m.memoized_correct viterbi_ix_viterbi[symmetric])
end
subsection ‹Misc›
lemma pmf_of_alist_support_aux_1:
assumes "∀ (_, p) ∈ set μ. p ≥ 0"
shows "(0 :: real) ≤ (case map_of μ x of None ⇒ 0 | Some p ⇒ p)"
using assms by (auto split: option.split dest: map_of_SomeD)
lemma pmf_of_alist_support_aux_2:
assumes "∀ (_, p) ∈ set μ. p ≥ 0"
and "sum_list (map snd μ) = 1"
and "distinct (map fst μ)"
shows "∫⇧+ x. ennreal (case map_of μ x of None ⇒ 0 | Some p ⇒ p) ∂count_space UNIV = 1"
using assms
apply (subst nn_integral_count_space)
subgoal
by (rule finite_subset[where B = "fst ` set μ"];
force split: option.split_asm simp: image_iff dest: map_of_SomeD)
apply (subst sum.mono_neutral_left[where T = "fst ` set μ"])
apply blast
subgoal
by (smt ennreal_less_zero_iff map_of_eq_None_iff mem_Collect_eq option.case(1) subsetI)
subgoal
by auto
subgoal premises prems
proof -
have "(∑x = 0..<length μ. snd (μ ! x))
= sum (λ x. case map_of μ x of None ⇒ 0 | Some v ⇒ v) (fst ` set μ)"
apply (rule sym)
apply (rule sum.reindex_cong[where l = "λ i. fst (μ ! i)"])
apply (auto split: option.split)
subgoal
using prems(3) by (intro inj_onI, auto simp: distinct_conv_nth)
subgoal
by (auto simp: in_set_conv_nth rev_image_eqI)
subgoal
by (simp add: map_of_eq_None_iff)
subgoal
using map_of_eq_Some_iff[OF prems(3)]
by (metis fst_conv nth_mem option.inject prod_eqI snd_conv)
done
with prems(2) show ?thesis
by (smt pmf_of_alist_support_aux_1[OF assms(1)] atLeastLessThan_iff ennreal_1
length_map nth_map sum.cong sum_ennreal sum_list_sum_nth
)
qed
done
lemma pmf_of_alist_support:
assumes "∀ (_, p) ∈ set μ. p ≥ 0"
and "sum_list (map snd μ) = 1"
and "distinct (map fst μ)"
shows "set_pmf (pmf_of_alist μ) ⊆ fst ` set μ"
unfolding pmf_of_alist_def
apply (subst set_embed_pmf)
subgoal for x
using assms(1) by (auto split: option.split dest: map_of_SomeD)
subgoal
using pmf_of_alist_support_aux_2[OF assms] .
apply (force split: option.split_asm simp: image_iff dest: map_of_SomeD)+
done
text ‹Defining a Markov kernel from an association list.›
locale Closed_Kernel_From =
fixes K :: "('s × ('t × real) list) list"
and S :: "'t list"
assumes wellformed: "S ≠ []"
and closed: "∀ (s, μ) ∈ set K. ∀ (t, _) ∈ set μ. t ∈ set S"
and is_pmf:
"∀ (_, μ) ∈ set K. ∀ (_, p) ∈ set μ. p ≥ 0"
"∀ (_, μ) ∈ set K. distinct (map fst μ)"
"∀ (s, μ) ∈ set K. sum_list (map snd μ) = 1"
and is_unique:
"distinct (map fst K)"
begin
definition
"K' s ≡ case map_of (map (λ (s, μ). (s, PMF_Impl.pmf_of_alist μ)) K) s of
None ⇒ return_pmf (hd S) |
Some s ⇒ s"
sublocale Closed_Kernel K' "set S"
using wellformed closed is_pmf pmf_of_alist_support
unfolding K'_def by - (standard; fastforce split: option.split_asm dest: map_of_SomeD)
definition [code]:
"K1 = map_of (map (λ (s, μ). (s, map_of μ)) K)"
lemma pmf_of_alist_aux:
assumes "(s, μ) ∈ set K"
shows
"pmf (pmf_of_alist μ) t = (case map_of μ t of
None ⇒ 0
| Some p ⇒ p)"
using assms is_pmf unfolding pmf_of_alist_def
by (intro pmf_embed_pmf pmf_of_alist_support_aux_2)
(auto 4 3 split: option.split dest: map_of_SomeD)
lemma unique: "μ = μ'" if "(s, μ) ∈ set K" "(s, μ') ∈ set K"
using that is_unique
by (smt Pair_inject distinct_conv_nth fst_conv in_set_conv_nth length_map nth_map)
lemma (in -) map_of_NoneD:
"x ∉ fst ` set M" if "map_of M x = None"
using that by (auto dest: weak_map_of_SomeI)
lemma K'_code [code_post]:
"pmf (K' s) t = (case K1 s of
None ⇒ (if t = hd S then 1 else 0)
| Some μ ⇒ case μ t of
None ⇒ 0
| Some p ⇒ p
)"
unfolding K'_def K1_def
apply (clarsimp split: option.split, safe)
apply (drule map_of_SomeD, drule map_of_NoneD, force)+
apply (fastforce dest: unique map_of_SomeD simp: pmf_of_alist_aux)+
done
end
subsection ‹Executing Concrete HMMs›
locale Concrete_HMM_defs =
fixes 𝒦 :: "('s × ('s × real) list) list"
and 𝒪 :: "('s × ('t × real) list) list"
and 𝒪⇩s :: "'t list"
and 𝒦⇩s :: "'s list"
begin
definition
"𝒦' s ≡ case map_of (map (λ (s, μ). (s, PMF_Impl.pmf_of_alist μ)) 𝒦) s of
None ⇒ return_pmf (hd 𝒦⇩s) |
Some s ⇒ s"
definition
"𝒪' s ≡ case map_of (map (λ (s, μ). (s, PMF_Impl.pmf_of_alist μ)) 𝒪) s of
None ⇒ return_pmf (hd 𝒪⇩s) |
Some s ⇒ s"
end
locale Concrete_HMM = Concrete_HMM_defs +
assumes observations_wellformed': "𝒪⇩s ≠ []"
and observations_closed': "∀ (s, μ) ∈ set 𝒪. ∀ (t, _) ∈ set μ. t ∈ set 𝒪⇩s"
and observations_form_pmf':
"∀ (_, μ) ∈ set 𝒪. ∀ (_, p) ∈ set μ. p ≥ 0"
"∀ (_, μ) ∈ set 𝒪. distinct (map fst μ)"
"∀ (s, μ) ∈ set 𝒪. sum_list (map snd μ) = 1"
and observations_unique:
"distinct (map fst 𝒪)"
assumes states_wellformed: "𝒦⇩s ≠ []"
and states_closed: "∀ (s, μ) ∈ set 𝒦. ∀ (t, _) ∈ set μ. t ∈ set 𝒦⇩s"
and states_form_pmf:
"∀ (_, μ) ∈ set 𝒦. ∀ (_, p) ∈ set μ. p ≥ 0"
"∀ (_, μ) ∈ set 𝒦. distinct (map fst μ)"
"∀ (s, μ) ∈ set 𝒦. sum_list (map snd μ) = 1"
and states_unique:
"distinct (map fst 𝒦)" "distinct 𝒦⇩s"
begin
interpretation O: Closed_Kernel_From 𝒪 𝒪⇩s
rewrites "O.K' = 𝒪'"
proof -
show ‹Closed_Kernel_From 𝒪 𝒪⇩s›
using observations_wellformed' observations_closed' observations_form_pmf' observations_unique
by unfold_locales auto
show ‹Closed_Kernel_From.K' 𝒪 𝒪⇩s = 𝒪'›
unfolding Closed_Kernel_From.K'_def[OF ‹Closed_Kernel_From 𝒪 𝒪⇩s›] 𝒪'_def
by auto
qed
interpretation K: Closed_Kernel_From 𝒦 𝒦⇩s
rewrites "K.K' = 𝒦'"
proof -
show ‹Closed_Kernel_From 𝒦 𝒦⇩s›
using states_wellformed states_closed states_form_pmf states_unique by unfold_locales auto
show ‹Closed_Kernel_From.K' 𝒦 𝒦⇩s = 𝒦'›
unfolding Closed_Kernel_From.K'_def[OF ‹Closed_Kernel_From 𝒦 𝒦⇩s›] 𝒦'_def
by auto
qed
lemmas O_code = O.K'_code O.K1_def
lemmas K_code = K.K'_code K.K1_def
sublocale HMM_interp: HMM4 𝒪' "set 𝒦⇩s" 𝒦⇩s "set 𝒪⇩s" 𝒦'
using O.Closed_Kernel_axioms K.Closed_Kernel_axioms states_unique(2)
by (intro_locales; intro HMM4_axioms.intro HMM3_axioms.intro HOL.refl)
end
end