Theory Fin_Code
theory Fin_Code
imports
"../Backward_Induction"
Code_Setup
begin
locale MDP_nat_fin = MDP_nat + MDP_reward_fin
begin
end
locale MDP_Code_Fin = MDP_Code_raw +
R_Fin_Map : Array' "r_fin_lookup :: 'tf ⇒ nat ⇒ real" r_fin_update r_fin_len r_fin_array r_fin_list r_fin_invar +
V_Map : Array' "v_lookup :: 'tv ⇒ nat ⇒ real" v_update v_len v_array v_list v_invar +
D_Map : Array' "d_lookup :: 'td ⇒ nat ⇒ nat" d_update d_len d_array d_list d_invar +
VD_Map : Array' "vd_lookup :: 'tvd ⇒ nat ⇒ (nat × real)" vd_update vd_len vd_array vd_list vd_invar
for v_lookup v_update v_len v_array v_list v_invar
and d_lookup d_update d_len d_array d_list d_invar
and vd_lookup vd_update vd_len vd_array vd_list vd_invar
and r_fin_lookup r_fin_update r_fin_len r_fin_array r_fin_list r_fin_invar +
fixes
N_code :: nat and
r_fin_code :: 'tf
begin
definition "v_map_from_list xs = v_array xs"
definition "MDP_r_fin s = (if s ≥ states then 0 else r_fin_lookup r_fin_code s)"
lemma bounded_r_fin: "bounded (range MDP_r_fin)"
unfolding MDP_r_fin_def
by (fastforce simp add: nle_le bounded_const finite_nat_set_iff_bounded_le intro!: finite_imageI)
sublocale MDP: MDP_reward_disc "(MDP_A)" "(MDP_K)" "(MDP_r)" 0
using bounded_MDP_r
by unfold_locales auto
sublocale MDP: MDP_act "(MDP_A)" "(MDP_K)" "λX. LEAST x. x ∈ X"
using MDP.MDP_reward_disc_axioms
by unfold_locales
(auto intro: LeastI2 simp: MDP_reward_disc.max_L_ex_def has_arg_max_def finite_is_arg_max)
sublocale MDP: MDP_nat_fin "λX. LEAST x. x ∈ X" "(MDP_A)" "(MDP_K)" states "(MDP_r)" MDP_r_fin N_code
using MDP_K_closed MDP_K_comp_closed MDP_r_zero_notin_states MDP_A_outside bounded_MDP_r bounded_r_fin
by unfold_locales (auto intro: LeastI2)
sublocale V_Map: Array_real v_lookup v_update v_len v_array v_list v_invar
by unfold_locales
sublocale V_Map: Array_zero v_lookup v_update v_len v_array v_list v_invar
by unfold_locales
sublocale D_Map: Array_zero d_lookup d_update d_len d_array d_list d_invar
by unfold_locales
definition "L⇩a_code rp v = (
let (r, ps) = rp in r + (foldl (λ acc (s', p). p * v_lookup v s' + acc)) 0 ps)"
lemma L⇩a_code_correct:
assumes
"s < states"
"v_len v = states" "v_invar v"
"pmf_of_list (snd rps) = MDP_K (s, a)" "pmf_of_list_wf (snd rps)"
"fst ` set (snd rps) ⊆ {0..<states}" "fst rps = MDP_r (s, a)"
shows "L⇩a_code rps v = MDP_r (s, a) + measure_pmf.expectation (MDP_K (s,a)) (V_Map.map_to_bfun v)"
proof -
have "measure_pmf.expectation (MDP_K (s, a)) (v_lookup v) = measure_pmf.expectation (MDP_K (s, a)) (V_Map.map_to_bfun v)"
using assms MDP.K_closed
by (force simp: V_Map.map_to_bfun.rep_eq split: option.splits
intro!: Bochner_Integration.integral_cong_AE AE_pmfI)
have "foldl (λacc x. f x + acc) x xs = (∑x←xs. f x) + x" for f xs and x :: real
by (induction xs arbitrary: x) (auto simp: algebra_simps)
hence *: "(∑x←xs. f x) = foldl (λacc x. f x + acc) (0::real) xs" for f xs
by (metis add.right_neutral)
have "foldl (λacc (s', p). p * v_lookup v s' + acc) 0 (snd rps) =
measure_pmf.expectation (MDP_K (s, a)) (apply_bfun (V_Map.map_to_bfun v))"
unfolding assms(4)[symmetric]
using assms(5-7)
by (auto intro!: foldl_cong simp: pmf_of_list_expectation * V_Map.map_to_bfun.rep_eq assms(2,3))
thus ?thesis
unfolding L⇩a_code_def
by (auto simp add: assms case_prod_unfold)
qed
definition "find_policy_state_code_aux v s =
(least_arg_max_max_ne (λ(_, rsuccs).
L⇩a_code rsuccs v) ((a_inorder (s_lookup mdp s))))"
definition "find_policy_state_code_aux' v s = (
case find_policy_state_code_aux v s of ((a, _, _), v) ⇒ (a, v))"
definition "vi_find_policy_code (v::'tv) = VD_Map.arr_tabulate (λs. (find_policy_state_code_aux' v s)) states"
lemma find_policy_state_code_aux_eq:
assumes "s < states"
shows "find_policy_state_code_aux' v s = (least_arg_max_max_ne (λa.
L⇩a_code (a_lookup' (s_lookup mdp s) a) v) ((map fst (a_inorder (s_lookup mdp s)))))"
unfolding find_policy_state_code_aux'_def find_policy_state_code_aux_def
using assms A_Map.is_empty_def ne_s_lookup
by(subst least_arg_max_max_ne_app'[symmetric])
(auto simp: case_prod_unfold a_lookup'_def A_Map.entries_def A_Map.inorder_lookup_Some assms invar_s_lookup)
lemma L_GS_code_correct':
assumes "s < states" "v_len v = states" "v_invar v" "a ∈ MDP_A s"
shows "L⇩a_code (a_lookup' (s_lookup mdp s) a) v =
MDP_r(s, a) + measure_pmf.expectation (MDP_K (s,a)) (V_Map.map_to_bfun v)"
using pmf_of_list_wf_mdp assms set_list_pmf_in_states
by (intro L⇩a_code_correct)
(auto simp: fst_sa_lookup'_eq[symmetric] snd_sa_lookup'_eq)
lemma find_policy_state_code_aux'_eq':
assumes "s < states" "v_len v = states" "v_invar v"
shows "find_policy_state_code_aux' v s =
(least_arg_max (λa. MDP_r(s, a) + measure_pmf.expectation (MDP_K (s,a)) (V_Map.map_to_bfun v)) (λa. a ∈ MDP_A s),
Max ((λa. MDP_r(s, a) + measure_pmf.expectation (MDP_K (s,a)) (V_Map.map_to_bfun v)) ` (MDP_A s)))"
proof -
have "find_policy_state_code_aux' v s = least_arg_max_max_ne (λa. L⇩a_code (a_lookup' (s_lookup mdp s) a) v) (map fst (a_inorder (s_lookup mdp s)))"
using find_policy_state_code_aux_eq assms by auto
also have ‹… = (least_arg_max (λa. L⇩a_code (a_lookup' (s_lookup mdp s) a) v) (List.member (map fst (a_inorder (s_lookup mdp s)))),
MAX a∈set (map fst (a_inorder (s_lookup mdp s))). L⇩a_code (a_lookup' (s_lookup mdp s) a) v)›
using A_Map.is_empty_def assms(1) A_Map.invar_def A_inv_locale S_Map.lookup_in_list s_invar s_len A_ne_locale
by (auto simp: fold_max_eq_arg_max')
also have ‹… = (least_arg_max (λa. MDP_r(s, a) + measure_pmf.expectation (MDP_K (s,a)) (V_Map.map_to_bfun v)) (List.member (map fst (a_inorder (s_lookup mdp s)))),
MAX a∈set (map fst (a_inorder (s_lookup mdp s))). MDP_r(s, a) + measure_pmf.expectation (MDP_K (s,a)) (V_Map.map_to_bfun v))›
using assms a_inorderD(1) A_Map.keys_def MDP_A_def
by (auto intro!: least_arg_max_cong simp: L_GS_code_correct' in_set_member[symmetric])
also have ‹… = (least_arg_max (λa. MDP_r(s, a) + measure_pmf.expectation (MDP_K (s,a)) (V_Map.map_to_bfun v)) (λa. a ∈ MDP_A s),
MAX a∈MDP_A s. MDP_r(s, a) + measure_pmf.expectation (MDP_K (s,a)) (V_Map.map_to_bfun v))›
using assms A_Map.entries_def A_Map.keys_def A_Map.entries_imp_keys
by (auto intro!: least_arg_max_cong' simp: MDP_A_def in_set_member[symmetric])
finally show ?thesis.
qed
lemma vi_find_policy_code_correct:
assumes "s < states" "v_len v = states" "v_invar v"
shows "vd_lookup (vi_find_policy_code v) s =
( least_arg_max
(λa. MDP_r(s, a) + measure_pmf.expectation (MDP_K (s,a)) (V_Map.map_to_bfun v))
(λa. a ∈ MDP_A s)
, Max ((λa. MDP_r(s, a) + measure_pmf.expectation (MDP_K (s,a)) (V_Map.map_to_bfun v)) ` (MDP_A s)))"
unfolding vi_find_policy_code_def
by (auto simp: find_policy_state_code_aux'_eq' assms)
fun bw_ind_aux_code where
"bw_ind_aux_code (Suc n) last_v m_v m_d = (let
vd = vi_find_policy_code last_v;
v = V_Map.arr_tabulate (λs. snd (vd_lookup vd s)) states;
d = D_Map.arr_tabulate (λs. fst (vd_lookup vd s)) states in
bw_ind_aux_code n v (last_v # m_v) (d # m_d))" |
"bw_ind_aux_code 0 last_v m_v m_d = (last_v # m_v, m_d)"
definition "bw_ind_code = bw_ind_aux_code N_code (V_Map.arr_tabulate (r_fin_lookup r_fin_code) states) [] []"
lemma bw_ind_aux_code_fst_index: "i < length v0 ⟹ fst (bw_ind_aux_code n vl v0 d0) ! (i + n) =
(vl#v0) ! i"
by (induction n arbitrary: vl v0 d0 i) (auto simp: add_Suc[symmetric] simp del: add_Suc)
lemma bw_ind_aux_code_fst_index': "n ≤ i ⟹ fst (bw_ind_aux_code n vl v0 d0) ! i =
(vl#v0) ! (i - n)"
by (induction n arbitrary: vl v0 d0 i) auto
lemma bw_ind_aux_code_snd_index': "n ≤ i ⟹ snd (bw_ind_aux_code n vl v0 d0) ! i =
(d0) ! (i - n)"
by (induction n arbitrary: vl v0 d0 i) auto
lemma bw_ind_code_aux_correct:
fixes n vl v0 d0
defines "d ≡ snd (bw_ind_aux_code n vl v0 d0)"
defines "v ≡ fst (bw_ind_aux_code n vl v0 d0)"
assumes "v_len vl = states"
assumes "v_invar vl"
assumes "⋀s. s < states ⟹ m n s = v_lookup vl s"
assumes "s < states"
shows "(i ≤ n ⟶ v_lookup (v ! i) s = MDP.bw_ind_aux' n m i s) ∧
(i < n ⟶ d_lookup (d ! i) s = (least_arg_max
(λa. MDP_r (s, a) + measure_pmf.expectation (MDP_K (s, a)) (MDP.bw_ind_aux' n m (Suc i)))
(λa. a ∈ MDP_A s)))"
unfolding v_def d_def
using assms
proof (induction n arbitrary: m i v0 d0 vl s)
case (Suc n)
show ?case
proof (cases "i = Suc n")
case True
then show ?thesis
by (simp add: Suc bw_ind_aux_code_fst_index')
next
case False
then show ?thesis
proof (cases "i = n")
case True
thus ?thesis
using MDP_K_closed Suc.prems True
by (auto intro!: least_arg_max_cong Bochner_Integration.integral_cong_AE AE_pmfI SUP_cong AE_pmfI
simp: cSup_eq_Max[symmetric] bw_ind_aux_code_snd_index' bw_ind_aux_code_fst_index'
subset_eq V_Map.map_to_bfun.rep_eq vi_find_policy_code_correct)
next
case False
have *: "⋀s. s < states ⟹
(⨆a∈MDP_A s. MDP_r (s, a) + measure_pmf.expectation (MDP_K (s, a)) (m (Suc n))) =
v_lookup (V_Map.arr_tabulate (λs. snd (vd_lookup (vi_find_policy_code vl) s)) states) s"
using MDP.K_closed
by (auto simp: subset_eq vi_find_policy_code_correct Suc cSup_eq_Max[symmetric] V_Map.map_to_bfun.rep_eq
intro!: AE_pmfI Bochner_Integration.integral_cong_AE SUP_cong)
hence "v_lookup (fst (bw_ind_aux_code (Suc n) vl v0 d0) ! i) s = MDP.bw_ind_aux' (Suc n) m i s" if "i ≤ Suc n"
unfolding bw_ind_aux_code.simps Let_def
using ‹i ≤ Suc n› ‹i ≠ Suc n›
by (subst Suc(1)[THEN conjunct1]) (auto simp: Suc)
moreover have "d_lookup (snd (bw_ind_aux_code (Suc n) vl v0 d0) ! i) s =
least_arg_max (λa. MDP_r (s, a) + measure_pmf.expectation (MDP_K (s, a)) (MDP.bw_ind_aux' (Suc n) m (Suc i))) (λa. a ∈ MDP_A s)" if "i < Suc n"
unfolding bw_ind_aux_code.simps Let_def
using ‹i < Suc n› ‹i ≠ Suc n› * False
by (subst Suc(1)[THEN conjunct2]) (auto simp: Suc)
ultimately show ?thesis
by auto
qed
qed
qed auto
lemma bw_ind_code_correct:
defines "d ≡ snd bw_ind_code"
defines "v ≡ fst bw_ind_code"
shows "⋀n s. n ≤ N_code ⟹ s < states ⟹ v_lookup (v ! n) s = MDP.bw_ind n s"
and "⋀n. n < N_code ⟹ s < states ⟹ d_lookup (d ! n) s = MDP.bw_ind_pol_gen (λX. LEAST a. a ∈ X) n s"
proof (goal_cases)
case (1 n s)
then show ?case
unfolding MDP.bw_ind_def
by (subst bw_ind_code_aux_correct[THEN conjunct1, THEN mp, symmetric])
(auto simp add: MDP_r_fin_def bw_ind_code_def v_def )
next
case (2 n)
then show ?case
unfolding MDP.bw_ind_pol_gen_def d_def bw_ind_code_def
by (subst bw_ind_code_aux_correct[THEN conjunct2])
(auto simp: least_arg_max_def[symmetric] MDP_r_fin_def MDP.bw_ind_aux'_eq[symmetric])
qed
end