Theory HMM_Implementation

section ‹Implementation›

theory HMM_Implementation
  imports
    Hidden_Markov_Model
    "Monad_Memo_DP.State_Main"
begin

subsection ‹The Forward Algorithm›

locale HMM4 = HMM3 _ _ _ 𝒪s 𝒦 for 𝒪s :: "'t set" and 𝒦 :: "'s  's pmf" +
  assumes states_distinct: "distinct state_list"

context HMM3_defs
begin

context
  fixes os :: "'t iarray"
begin

text ‹
  Alternative definition using indices into the list of states.
  The list of states is implemented as an immutable array for better performance.
›

function forward_ix_rec where
  "forward_ix_rec s t_end n = (if n  IArray.length os then indicator {t_end} s else
    (t  state_list.
      ennreal (pmf (𝒪 t) (os !! n)) * ennreal (pmf (𝒦 s) t) * forward_ix_rec t t_end (n + 1)))
  "
  by auto
termination
  by (relation "Wellfounded.measure (λ(_,_,n). IArray.length os - n)") auto

text ‹Memoization›

memoize_fun forward_ixm: forward_ix_rec
  with_memory dp_consistency_mapping
  monadifies (state) forward_ix_rec.simps[unfolded Let_def]
  term forward_ixm'
memoize_correct
  by memoize_prover

text ‹The main theorems generated by memoization.›
context
  includes state_monad_syntax
begin
thm forward_ixm'.simps forward_ixm_def
thm forward_ixm.memoized_correct
end

end (* Fixed IArray *)

definition
  "forward_ix os = forward_ix_rec (IArray os)"

definition
  "likelihood_compute s os 
    if s  set state_list then Some (t  state_list. forward s t os) else None"

end (* HMM3 Defs *)

text ‹Correctness of the alternative definition.›

lemma (in HMM3) forward_ix_drop_one:
  "forward_ix (o # os) s t (n + 1) = forward_ix os s t n"
  by (induction "length os - n" arbitrary: s n; simp add: forward_ix_def)

lemma (in HMM4) forward_ix_forward:
  "forward_ix os s t 0 = forward s t os"
  unfolding forward_ix_def
proof (induction os arbitrary: s)
  case Nil
  then show ?case
    by simp
next
  case (Cons o os)
  show ?case
    using forward_ix_drop_one[unfolded forward_ix_def] states_distinct
    by (subst forward.simps, subst forward_ix_rec.simps)
       (simp add: Cons.IH state_list_𝒮 sum_list_distinct_conv_sum_set
             del: forward_ix_rec.simps forward.simps
       )
qed

text ‹
  Instructs the code generator to use this equation instead to execute forward›.
  Uses the memoized version of forward_ix›.
›
lemma (in HMM4) forward_code [code]:
  "forward s t os = fst (run_state (forward_ixm' (IArray os) s t 0) Mapping.empty)"
  by (simp only:
      forward_ix_def forward_ixm.memoized_correct forward_ix_forward[symmetric]
      states_distinct
     )

theorem (in HMM4) likelihood_compute:
  "likelihood_compute s os = Some x  s  𝒮  x = likelihood s os"
  unfolding likelihood_compute_def
  by (auto simp: states_distinct state_list_𝒮 sum_list_distinct_conv_sum_set likelihood_forward)


subsection ‹The Viterbi Algorithm›

context HMM3_defs
begin

context
  fixes os :: "'t iarray"
begin

text ‹
  Alternative definition using indices into the list of states.
  The list of states is implemented as an immutable array for better performance.
›

function viterbi_ix_rec where
  "viterbi_ix_rec s t_end n = (if n  IArray.length os then ([], indicator {t_end} s) else
  fst (
    argmax snd (map
      (λt. let (xs, v) = viterbi_ix_rec t t_end (n + 1) in
        (t # xs, ennreal (pmf (𝒪 t) (os !! n) * pmf (𝒦 s) t) * v))
    state_list)))
  "
  by pat_completeness auto
termination
  by (relation "Wellfounded.measure (λ(_,_,n). IArray.length os - n)") auto

text ‹Memoization›

memoize_fun viterbi_ixm: viterbi_ix_rec
  with_memory dp_consistency_mapping
  monadifies (state) viterbi_ix_rec.simps[unfolded Let_def]

memoize_correct
  by memoize_prover

text ‹The main theorems generated by memoization.›
context
  includes state_monad_syntax
begin
thm viterbi_ixm'.simps viterbi_ixm_def
thm viterbi_ixm.memoized_correct
end

end (* Fixed IArray *)

definition
  "viterbi_ix os = viterbi_ix_rec (IArray os)"

end (* HMM3 Defs *)

context HMM3
begin

lemma viterbi_ix_drop_one:
  "viterbi_ix (o # os) s t (n + 1) = viterbi_ix os s t n"
  by (induction "length os - n" arbitrary: s n; simp add: viterbi_ix_def)

lemma viterbi_ix_viterbi:
  "viterbi_ix os s t 0 = viterbi s t os"
  unfolding viterbi_ix_def
proof (induction os arbitrary: s)
  case Nil
  then show ?case
    by simp
next
  case (Cons o os)
  show ?case
    using viterbi_ix_drop_one[unfolded viterbi_ix_def]
    by (subst viterbi.simps, subst viterbi_ix_rec.simps)
       (simp add: Cons.IH del: viterbi_ix_rec.simps viterbi.simps)
qed

lemma viterbi_code [code]:
  "viterbi s t os = fst (run_state (viterbi_ixm' (IArray os) s t 0) Mapping.empty)"
  by (simp only: viterbi_ix_def viterbi_ixm.memoized_correct viterbi_ix_viterbi[symmetric])

end (* Hidden Markov Model 3 *)

subsection ‹Misc›

lemma pmf_of_alist_support_aux_1:
  assumes " (_, p)  set μ. p  0"
  shows "(0 :: real)  (case map_of μ x of None  0 | Some p  p)"
  using assms by (auto split: option.split dest: map_of_SomeD)

lemma pmf_of_alist_support_aux_2:
  assumes " (_, p)  set μ. p  0"
    and "sum_list (map snd μ) = 1"
    and "distinct (map fst μ)"
  shows "+ x. ennreal (case map_of μ x of None  0 | Some p  p) count_space UNIV = 1"
  using assms
  apply (subst nn_integral_count_space)
  subgoal
    by (rule finite_subset[where B = "fst ` set μ"];
        force split: option.split_asm simp: image_iff dest: map_of_SomeD)
  apply (subst sum.mono_neutral_left[where T = "fst ` set μ"])
     apply blast
  subgoal
    by (smt ennreal_less_zero_iff map_of_eq_None_iff mem_Collect_eq option.case(1) subsetI)
  subgoal
    by auto
  subgoal premises prems
  proof -
    have "(x = 0..<length μ. snd (μ ! x))
      = sum (λ x. case map_of μ x of None  0 | Some v  v) (fst ` set μ)"
      apply (rule sym)
      apply (rule sum.reindex_cong[where l = "λ i. fst (μ ! i)"])
        apply (auto split: option.split)
      subgoal
        using prems(3) by (intro inj_onI, auto simp: distinct_conv_nth)
      subgoal
        by (auto simp: in_set_conv_nth rev_image_eqI)
      subgoal
        by (simp add: map_of_eq_None_iff)
      subgoal
        using map_of_eq_Some_iff[OF prems(3)]
        by (metis fst_conv nth_mem option.inject prod_eqI snd_conv)
      done
    with prems(2) show ?thesis
      by (smt pmf_of_alist_support_aux_1[OF assms(1)] atLeastLessThan_iff ennreal_1
          length_map nth_map sum.cong sum_ennreal sum_list_sum_nth
          )
  qed
  done

lemma pmf_of_alist_support:
  assumes " (_, p)  set μ. p  0"
    and "sum_list (map snd μ) = 1"
    and "distinct (map fst μ)"
  shows "set_pmf (pmf_of_alist μ)  fst ` set μ"
  unfolding pmf_of_alist_def
  apply (subst set_embed_pmf)
  subgoal for x
    using assms(1) by (auto split: option.split dest: map_of_SomeD)
  subgoal
    using pmf_of_alist_support_aux_2[OF assms] .
  apply (force split: option.split_asm simp: image_iff dest: map_of_SomeD)+
  done

text ‹Defining a Markov kernel from an association list.›
locale Closed_Kernel_From =
  fixes K :: "('s × ('t × real) list) list"
    and S :: "'t list"
  assumes wellformed: "S  []"
      and closed: " (s, μ)  set K.  (t, _)  set μ. t  set S"
      and is_pmf:
        " (_, μ)  set K.  (_, p)  set μ. p  0"
        " (_, μ)  set K. distinct (map fst μ)"
        " (s, μ)  set K. sum_list (map snd μ) = 1"
      and is_unique:
        "distinct (map fst K)"
begin

definition
  "K' s  case map_of (map (λ (s, μ). (s, PMF_Impl.pmf_of_alist μ)) K) s of
  None  return_pmf (hd S) |
  Some s  s"

sublocale Closed_Kernel K' "set S"
  using wellformed closed is_pmf pmf_of_alist_support
  unfolding K'_def by - (standard; fastforce split: option.split_asm dest: map_of_SomeD)

definition [code]:
  "K1 = map_of (map (λ (s, μ). (s, map_of μ)) K)"

lemma pmf_of_alist_aux:
  assumes "(s, μ)  set K"
  shows
    "pmf (pmf_of_alist μ) t = (case map_of μ t of
      None  0
    | Some p  p)"
  using assms is_pmf unfolding pmf_of_alist_def
  by (intro pmf_embed_pmf pmf_of_alist_support_aux_2) 
     (auto 4 3 split: option.split dest: map_of_SomeD)

lemma unique: "μ = μ'" if "(s, μ)  set K" "(s, μ')  set K"
  using that is_unique
  by (smt Pair_inject distinct_conv_nth fst_conv in_set_conv_nth length_map nth_map)

lemma (in -) map_of_NoneD:
  "x  fst ` set M" if "map_of M x = None"
  using that by (auto dest: weak_map_of_SomeI)

lemma K'_code [code_post]:
  "pmf (K' s) t = (case K1 s of
      None  (if t = hd S then 1 else 0)
    | Some μ  case μ t of
        None  0
      | Some p  p
  )"
  unfolding K'_def K1_def
  apply (clarsimp split: option.split, safe)
                 apply (drule map_of_SomeD, drule map_of_NoneD, force)+
         apply (fastforce dest: unique map_of_SomeD simp: pmf_of_alist_aux)+
  done

end

subsection ‹Executing Concrete HMMs›

locale Concrete_HMM_defs =
  fixes 𝒦 :: "('s × ('s × real) list) list"
    and 𝒪 :: "('s × ('t × real) list) list"
    and 𝒪s :: "'t list"
    and 𝒦s :: "'s list"
begin

definition
  "𝒦' s  case map_of (map (λ (s, μ). (s, PMF_Impl.pmf_of_alist μ)) 𝒦) s of
    None  return_pmf (hd 𝒦s) |
    Some s  s"

definition
  "𝒪' s  case map_of (map (λ (s, μ). (s, PMF_Impl.pmf_of_alist μ)) 𝒪) s of
    None  return_pmf (hd 𝒪s) |
    Some s  s"

end

locale Concrete_HMM = Concrete_HMM_defs +
  assumes observations_wellformed': "𝒪s  []"
      and observations_closed': " (s, μ)  set 𝒪.  (t, _)  set μ. t  set 𝒪s"
      and observations_form_pmf':
        " (_, μ)  set 𝒪.  (_, p)  set μ. p  0"
        " (_, μ)  set 𝒪. distinct (map fst μ)"
        " (s, μ)  set 𝒪. sum_list (map snd μ) = 1"
      and observations_unique:
        "distinct (map fst 𝒪)"
  assumes states_wellformed: "𝒦s  []"
      and states_closed: " (s, μ)  set 𝒦.  (t, _)  set μ. t  set 𝒦s"
      and states_form_pmf:
        " (_, μ)  set 𝒦.  (_, p)  set μ. p  0"
        " (_, μ)  set 𝒦. distinct (map fst μ)"
        " (s, μ)  set 𝒦. sum_list (map snd μ) = 1"
      and states_unique:
        "distinct (map fst 𝒦)" "distinct 𝒦s"
begin

interpretation O: Closed_Kernel_From 𝒪 𝒪s
  rewrites "O.K' = 𝒪'"
proof -
  show Closed_Kernel_From 𝒪 𝒪s
    using observations_wellformed' observations_closed' observations_form_pmf' observations_unique
    by unfold_locales auto
  show Closed_Kernel_From.K' 𝒪 𝒪s = 𝒪'
    unfolding Closed_Kernel_From.K'_def[OF Closed_Kernel_From 𝒪 𝒪s] 𝒪'_def
    by auto
qed

interpretation K: Closed_Kernel_From 𝒦 𝒦s
  rewrites "K.K' = 𝒦'"
proof -
  show Closed_Kernel_From 𝒦 𝒦s
    using states_wellformed states_closed states_form_pmf states_unique by unfold_locales auto
  show Closed_Kernel_From.K' 𝒦 𝒦s = 𝒦'
    unfolding Closed_Kernel_From.K'_def[OF Closed_Kernel_From 𝒦 𝒦s] 𝒦'_def
    by auto
qed

lemmas O_code = O.K'_code O.K1_def
lemmas K_code = K.K'_code K.K1_def

sublocale HMM_interp: HMM4 𝒪' "set 𝒦s" 𝒦s "set 𝒪s" 𝒦'
  using O.Closed_Kernel_axioms K.Closed_Kernel_axioms states_unique(2)
  by (intro_locales; intro HMM4_axioms.intro HMM3_axioms.intro HOL.refl)

end (* Concrete HMM *)

end