Theory PI_Code
theory PI_Code
imports
"../Policy_Iteration_Fin"
"HOL-Library.Code_Target_Numeral"
"Jordan_Normal_Form.Matrix_Impl"
Code_Setup
begin
context MDP_Code begin
definition "policy_eval_code d =
inverse_mat (1⇩m states -
l ⋅⇩m (Matrix.mat states states (λ(s, s'). pmf (MDP_K (s, d_lookup' d s)) s'))) *⇩v
(vec states (λi. MDP_r (i, d_lookup' d i)))"
lemma d_lookup'_eq_map_to_fun: "D_Map.invar d ⟹ s ∈ D_Map.keys d ⟹ d_lookup' d s = D_Map.map_to_fun d s"
using D_Map.lookup_None_set_inorder
by (auto simp: D_Map.map_to_fun_def d_lookup'_def split: option.splits)
lemma policy_eval_correct:
assumes "D_Map.keys d = {0..<states}" "D_Map.invar d" "s < states"
shows "policy_eval_code d $v s = MDP.ν⇩b (MDP.mk_stationary_det (D_Map.map_to_fun d)) s"
unfolding policy_eval_code_def MDP.vec_ν⇩b''[OF assms(3)]
using assms d_lookup'_eq_map_to_fun
by (auto cong: vec_cong MDP.mat_cong)
definition "transition_vecs =
Matrix.vec states (λs. M.from_list (map (λ(a, _, ps). (a,
Matrix.vec states (λs'. ∑x ← ps. if fst x = s' then snd x else 0))) (a_inorder (s_lookup mdp s))))"
lemma transition_vecs_correct:
assumes "s < states" "a ∈ MDP_A s" "s' < states"
shows "M.lookup' (transition_vecs $v s) a $v s' = pmf (MDP_K (s,a)) s'"
proof -
have *: "Matrix.vec states (λs'. ∑x←snd (a_lookup' (s_lookup mdp s) a). if fst x = s' then snd x else 0) $v s' = pmf (pmf_of_list (snd (a_lookup' (s_lookup mdp s) a))) s'"
by (auto simp: pmf_pmf_of_list assms pmf_of_list_wf_mdp sum_list_map_filter')
have **: "
M.lookup' (M.from_list (map (λ(a, _, ps). (a, Matrix.vec states (λs'. ∑x←ps. if fst x = s' then snd x else 0))) (a_inorder (s_lookup mdp s)))) a $v s' =
pmf (pmf_of_list (snd (a_lookup' (s_lookup mdp s) a))) s'"
unfolding *[symmetric]
using a_map_entries_lookup[OF assms(1,2)] A_Map.distinct_inorder invar_s_lookup[OF assms(1)]
by (subst M.lookup'_from_list_distinct) (force simp: comp_def case_prod_beta A_Map.entries_def[symmetric] intro!: imageI)+
show ?thesis
unfolding transition_vecs_def MDP_K_def
using assms a_lookup_None_notin_A sa_lookup_eq(2) snd_sa_lookup'_eq
by (auto split: option.splits simp: **)
qed
lemma policy_eval_code: "policy_eval_code d =
the (mat_inverse ((1⇩m states) -
l ⋅⇩m (Matrix.mat states states (λ(s, s'). pmf (MDP_K (s, d_lookup' d s)) s')))) *⇩v
(vec states (λi. MDP_r (i, d_lookup' d i)))"
unfolding policy_eval_code_def
by (subst mat_inverse_eq_inverse_mat) (auto simp: MDP.invertible_ν⇩b_mat')
definition "one_st = 1⇩m states"
definition "k_mat d = Matrix.mat states states (λ(s, y). pmf (MDP_K (s, d_lookup' d s)) y)"
definition "k_mat' d m = (
Matrix.mat_of_row_fun states states (λi. M.lookup' (m $v i) (d_lookup' d i)))"
lemma invertible_imp_inv_ex: "invertible_mat m ⟹ ∃x∈carrier_mat (dim_row m) (dim_row m). x * m = 1⇩m (dim_row m) ∧ m * x = 1⇩m (dim_row m)"
by (metis carrier_matD(1) inverse_mat_mult inverse_mats_def invertible_inverse_mats)
lemma policy_eval_code':
fixes d
defines "m ≡ (one_st - l ⋅⇩m Matrix.mat states states (λ(s, y). pmf (MDP_K (s, d_lookup' d s)) y))"
shows "policy_eval_code d = snd (gauss_jordan m (1⇩m states)) *⇩v (vec states (λi. MDP_r (i, d_lookup' d i)))"
proof -
have m: "m ∈ carrier_mat states states"
using assms by fastforce
hence "fst (gauss_jordan m (1⇩m states)) = 1⇩m states"
using MDP.invertible_ν⇩b_mat'[of "d_lookup' d", unfolded m_def[symmetric] one_st_def[symmetric]]
using m invertible_imp_inv_ex[of m]
by (auto simp: ring_mat_simps Units_def intro!: gauss_jordan_inverse_other_direction[of _ _ states _ states])
thus ?thesis
unfolding policy_eval_code mat_inverse_def
by (auto split: if_splits simp: one_st_def m_def case_prod_beta)
qed
lemma policy_eval_code''[code]:
fixes d
defines "m ≡ (one_st - l ⋅⇩m ((k_mat d)))"
shows "policy_eval_code d = snd (gauss_jordan m one_st) *⇩v (vec states (λi. MDP_r (i, d_lookup' d i)))"
unfolding m_def policy_eval_code' k_mat_def one_st_def by (simp add: mat_code)
definition "policy_eval_code' d m = snd (gauss_jordan (one_st - l ⋅⇩m ((k_mat' d m))) one_st) *⇩v (vec states (λi. MDP_r (i, d_lookup' d i)))"
lemma dim_policy_eval_code: "dim_vec (policy_eval_code d) = states"
by (simp add: policy_eval_code_def MDP.invertible_ν⇩b_mat' inverse_mat_dims(2))
lemma policy_eval_correct':
assumes "D_Map.keys d = {0..<states}" "D_Map.invar d"
shows "vec_to_bfun (policy_eval_code d) = MDP.ν⇩b (MDP.mk_stationary_det (D_Map.map_to_fun d))"
using policy_eval_correct assms dim_policy_eval_code MDP.ν⇩b_zero_notin
by (auto simp: vec_to_bfun.rep_eq)
definition "pi_find_policy_state_code_aux' d v s = (
let (d', v') = find_policy_state_code_aux' v s in
if L⇩a_code (a_lookup' (s_lookup mdp s) d) v = v' then d else d')"
definition "pi_find_policy_code d v =
D_Map.from_list' (λs. pi_find_policy_state_code_aux' (d_lookup' d s) v s) [0..<states]"
lemma vi_find_policy_code_invar: "D_Map.invar (pi_find_policy_code d v)"
unfolding pi_find_policy_code_def by simp
lemma keys_vi_find_policy_code_aux_upt: "D_Map.keys (pi_find_policy_code d v) = {0..<states}"
unfolding pi_find_policy_code_def by simp
lemma find_policy_state_code_aux'_in_acts:
assumes "s < states" "v_len v = states" "v_invar v"
shows "fst (find_policy_state_code_aux' v s) ∈ MDP_A s"
using MDP.A_ne MDP.A_fin assms least_arg_max_prop[of "λx. x ∈ MDP_A s"]
by (fastforce simp: find_policy_state_code_aux'_eq')
lemma pi_find_policy_state_code_aux'_correct:
assumes "s < states" "D_Map.invar d" "v_len v = states" "v_invar v" "D_Map.keys d = MDP.state_space" "d_lookup' d s ∈ MDP_A s"
shows "pi_find_policy_state_code_aux' (d_lookup' d s) v s = MDP.policy_improvement (D_Map.map_to_fun d) (V_Map.map_to_bfun v) s"
proof (cases "is_arg_max (λa. MDP.L⇩a a (apply_bfun (V_Map.map_to_bfun v)) s) (λa. a ∈ MDP_A s) (D_Map.map_to_fun d s)")
case True
hence aux: "L⇩a_code (a_lookup' (s_lookup mdp s) (d_lookup' d s)) v = snd (find_policy_state_code_aux' v s)"
using MDP.A_fin
by (subst L_GS_code_correct') (fastforce intro!: Max_eqI[symmetric] simp: assms find_policy_state_code_aux'_eq' d_lookup'_eq_map_to_fun split: option.splits)+
then show ?thesis
proof -
have "pi_find_policy_state_code_aux' (d_lookup' d s) v s = d_lookup' d s"
unfolding pi_find_policy_state_code_aux'_def
by (simp add: aux case_prod_unfold)
thus ?thesis
using True
by (fastforce simp: assms MDP.policy_improvement_def d_lookup'_eq_map_to_fun split: option.splits)
qed
next
case False
hence "L⇩a_code (a_lookup' (s_lookup mdp s) (d_lookup' d s)) v < (MAX a ∈ MDP_A s. MDP.L⇩a a (apply_bfun (V_Map.map_to_bfun v)) s)"
using False assms by (auto simp: L_GS_code_correct' is_arg_max_linorder not_le map_to_fun_lookup Max_gr_iff)
thus ?thesis
unfolding pi_find_policy_state_code_aux'_def MDP.policy_improvement_def
using False
by (auto simp: assms find_policy_state_code_aux'_eq' least_arg_max_def MDP.is_opt_act_def)
qed
lemma pi_find_policy_code_correct:
assumes "v_len v = states" "D_Map.keys d = MDP.state_space" "v_invar v" "D_Map.invar d" "⋀s. s < states ⟹ d_lookup' d s ∈ MDP_A s"
shows "D_Map.map_to_fun ((pi_find_policy_code d v)) s = MDP.policy_improvement (D_Map.map_to_fun d) (V_Map.map_to_bfun v) s"
using assms
proof (cases "s < states")
case True
then show ?thesis
unfolding pi_find_policy_code_def
by (simp add: assms pi_find_policy_state_code_aux'_correct D_Map.map_to_fun_def)
next
case False
then show ?thesis
using keys_vi_find_policy_code_aux_upt assms vi_find_policy_code_invar is_arg_max_const MDP.A_outside
by (auto intro!: Least_equality simp: map_to_fun_notin MDP.policy_improvement_def MDP.is_opt_act_def)
qed
definition "eq_policy d1 d2 = (∀x<states. d_lookup d1 x = d_lookup d2 x)"
definition "policy_step_code d = (
let v = policy_eval_code d in
pi_find_policy_code d (V_Map.arr_tabulate (($v) v) states))"
definition "policy_step_code' d m = (
let v = policy_eval_code' d m in
pi_find_policy_code d (V_Map.arr_tabulate (($v) v) states))"
partial_function (tailrec) PI_code_aux' where
"PI_code_aux' d m = (
let d' = policy_step_code' d m in
if eq_policy d d'
then d
else PI_code_aux' d' m)"
partial_function (tailrec) PI_code_aux where
"PI_code_aux d = (
let d' = policy_step_code d in
if eq_policy d d'
then d
else PI_code_aux d')"
lemma fold_policy_eval_update_eq:
assumes "s < states" "D_Map.keys d = MDP.state_space" "D_Map.invar d"
shows "v_lookup (V_Map.arr_tabulate (λx. policy_eval_code d $v x) states) s = (MDP.policy_eval (D_Map.map_to_fun d) s)"
using assms
by (auto simp: v_lookup_fold policy_eval_correct MDP.policy_eval_def)
lemma fold_policy_eval_update_eq':
assumes "D_Map.keys d = MDP.state_space" "D_Map.invar d"
shows "V_Map.map_to_bfun (V_Map.arr_tabulate (λx. (policy_eval_code d $v x)) states) =
(MDP.policy_eval (D_Map.map_to_fun d))"
proof (rule bfun_eqI)
fix s
show "(V_Map.map_to_bfun (V_Map.arr_tabulate (($v) (policy_eval_code d)) states)) s =
(MDP.policy_eval (D_Map.map_to_fun d)) s"
proof (cases "s < states")
case True
then show ?thesis
by (auto simp: V_Map.map_to_bfun.rep_eq assms policy_eval_correct MDP.policy_eval_def)
next
case False
then show ?thesis
by (auto simp: MDP.policy_eval_def V_Map.map_to_bfun.rep_eq MDP.ν⇩b_zero_notin)
qed
qed
lemmas PI_code_aux.simps[code]
lemmas PI_code_aux'.simps[code]
lemmas MDP.policy_iteration.simps[simp del]
definition "is_dec_det_code d ⟷
D_Map.keys d = {0..<states} ∧ D_Map.invar d ∧ (∀s ∈ set [0..<states]. a_lookup (s_lookup mdp s) (d_lookup' d s) ≠ None)"
lemma [code]: "is_dec_det_code d ⟷
(map fst (d_inorder d)) = [0..<states] ∧ D_Map.invar d ∧ (∀s ∈ set [0..<states]. a_lookup (s_lookup mdp s) (d_lookup' d s) ≠ None)"
proof -
have "D_Map.invar d ⟹ fst ` set (d_inorder d) = {0..<n} ⟹ (map fst (d_inorder d)) = [0..<n]" for n
by (simp add: D_Map.invar_def strict_sorted_equal)
moreover have "D_Map.invar d ⟹ map fst (d_inorder d) = [0..<n] ⟹ fst ` set (d_inorder d) = {0..<n}" for n
using image_set[of "fst" "d_inorder d"]
by auto
ultimately show ?thesis
unfolding D_Map.keys_def is_dec_det_code_def
by blast
qed
definition "PI_code d0 = (if ¬ is_dec_det_code d0 then d0 else PI_code_aux d0)"
lemma k_mat_eq': "is_dec_det_code d ⟹ k_mat d = k_mat' d transition_vecs"
unfolding k_mat_def k_mat'_def Matrix.mat_eq_iff
by (auto simp: is_dec_det_code_def intro!: transition_vecs_correct[symmetric] intro: a_lookup_some_in_A)
lemma policy_eval_code_eq': "is_dec_det_code d ⟹ policy_eval_code d = policy_eval_code' d transition_vecs"
unfolding policy_eval_code'' policy_eval_code'_def
using k_mat_eq'
by force
lemma policy_step_code_eq': "is_dec_det_code d ⟹ policy_step_code d = policy_step_code' d transition_vecs"
unfolding policy_step_code_def policy_step_code'_def
using policy_eval_code_eq' by presburger
lemma policy_step_code_correct:
assumes "D_Map.keys d = MDP.state_space" "D_Map.invar d" "(⋀s. s < states ⟹ d_lookup' d s ∈ MDP_A s)"
shows "D_Map.map_to_fun (policy_step_code d) = (MDP.policy_step (D_Map.map_to_fun d))"
unfolding policy_step_code_def MDP.policy_step_def
by (auto simp: fold_policy_eval_update_eq' pi_find_policy_code_correct assms)
lemma PI_code_aux_correct_aux:
assumes "D_Map.invar d" "D_Map.keys d = {0..<states}" "(⋀s. s < states ⟹ d_lookup' d s ∈ MDP_A s)"
shows "D_Map.map_to_fun (PI_code_aux d) = MDP.policy_iteration (D_Map.map_to_fun d)
∧ (is_dec_det_code d ⟶ PI_code_aux d = PI_code_aux' d transition_vecs)"
using assms
proof (induction "(D_Map.map_to_fun d)" arbitrary: d rule: MDP.policy_iteration.induct)
case 1
show ?case
proof (cases "eq_policy d (policy_step_code d)")
case True
hence *: "D_Map.map_to_fun d s = (MDP.policy_step (D_Map.map_to_fun d)) s" for s
proof (cases "s < states")
case True
then show ?thesis
using True vi_find_policy_code_invar 1 ‹eq_policy d (policy_step_code d)›
by (auto simp: D_Map.map_to_fun_def eq_policy_def policy_step_code_correct[symmetric] policy_step_code_def)
next
case False
hence "MDP.policy_step (D_Map.map_to_fun d) s = 0"
by (auto simp: 1 MDP.policy_improvement_def is_arg_max_linorder MDP.policy_step_def MDP_A_def map_to_fun_notin)
then show ?thesis
using 1 D_Map.lookup_some_set_key False
by (auto simp: D_Map.map_to_fun_def split: option.splits)
qed
have "D_Map.map_to_fun (PI_code_aux d) = D_Map.map_to_fun d"
by (simp add: PI_code_aux.simps policy_step_code_correct True)
thus ?thesis
using * MDP.policy_iteration.simps[of "D_Map.map_to_fun d"] True
by (fastforce simp: policy_step_code_eq' PI_code_aux'.simps[of d] PI_code_aux.simps[of d])
next
case False
then obtain s where s: "s < states" "d_lookup d s ≠ d_lookup (policy_step_code d) s"
unfolding eq_policy_def policy_step_code_def
using 1(2,3) D_Map.lookup_notin_keys keys_vi_find_policy_code_aux_upt vi_find_policy_code_invar
by (auto simp: d_lookup'_def)
have invar_step: "D_Map.invar (policy_step_code d)"
by (simp add: policy_step_code_def vi_find_policy_code_invar)
have keys_step: "D_Map.keys (policy_step_code d) = D_Map.keys d"
by (simp add: 1 keys_vi_find_policy_code_aux_upt policy_step_code_def)
have *: "D_Map.map_to_fun d s ≠ (MDP.policy_step (D_Map.map_to_fun d)) s"
using D_Map.lookup_in_keys[OF invar_step] D_Map.lookup_notin_keys[OF invar_step] s(2) keys_step invar_step 1(2-4)
by (fastforce dest: D_Map.lookup_None_set_inorder[OF ‹D_Map.invar d›] D_Map.lookup_some_set_key[OF ‹D_Map.invar d›]
split: option.splits simp: policy_step_code_correct[symmetric] D_Map.map_to_fun_def)
have **: "MDP.is_dec_det (D_Map.map_to_fun d)"
using 1 by (auto simp: MDP.is_dec_det_def MDP_A_def map_to_fun_lookup map_to_fun_notin)
have lookup': "s < states ⟹ d_lookup' (policy_step_code d) s ∈ MDP_A s" for s
using 1(2-4) keys_step invar_step MDP.is_dec_det_pi
by (force simp: MDP.is_dec_det_def policy_step_code_correct d_lookup'_eq_map_to_fun MDP.policy_step_def)
have "D_Map.map_to_fun (PI_code_aux d) = D_Map.map_to_fun (PI_code_aux (policy_step_code d))"
by (simp add: PI_code_aux.simps policy_step_code_correct False)
also have "… = MDP.policy_iteration (D_Map.map_to_fun (policy_step_code d))"
using 1(2-4) * ** policy_step_code_correct lookup' invar_step keys_step
by (intro conjunct1[OF 1(1)]) (auto )
also have "… = MDP.policy_iteration (MDP.policy_step (D_Map.map_to_fun d))"
using 1 by (auto simp: policy_step_code_correct)
finally have aux: "D_Map.map_to_fun (PI_code_aux d) = MDP.policy_iteration (D_Map.map_to_fun d)"
unfolding PI_code_aux.simps[of d] PI_code_aux'.simps[of d]
using ** by (auto simp: MDP.policy_iteration.simps)
thus ?thesis
proof -
have d: "is_dec_det_code d"
unfolding is_dec_det_code_def
using 1 a_lookup_None_notin_A
by (metis atLeastLessThan_iff set_upt)
hence "is_dec_det_code (policy_step_code d)"
by (metis a_lookup_None_notin_A atLeastLessThan_iff invar_step is_dec_det_code_def keys_step lookup' set_upt)
hence "PI_code_aux (policy_step_code d) = PI_code_aux' (policy_step_code d) transition_vecs"
using * ** 1 invar_step keys_step lookup' policy_step_code_correct by metis
hence "PI_code_aux d = PI_code_aux' d transition_vecs"
unfolding PI_code_aux.simps[of d] PI_code_aux'.simps[of d]
using policy_step_code_eq'[OF d]
by auto
thus ?thesis
using ** aux
by fastforce
qed
qed
qed
lemma PI_code_correct:
assumes "D_Map.invar d" "D_Map.keys d = MDP.state_space" "(⋀s. s < states ⟹ d_lookup' d s ∈ MDP_A s)"
shows "D_Map.map_to_fun (PI_code d) = MDP.policy_iteration (D_Map.map_to_fun d)"
proof -
have "is_dec_det_code d"
unfolding is_dec_det_code_def
using a_lookup_None_notin_A assms
by (fastforce simp: not_Some_eq[symmetric])
thus ?thesis
using assms
by (auto simp: PI_code_def conjunct1[OF PI_code_aux_correct_aux])
qed
lemma [code]: "PI_code d0 = (if ¬ is_dec_det_code d0 then d0 else PI_code_aux' d0 transition_vecs)"
using conjunct2[OF PI_code_aux_correct_aux[of d0]]
unfolding PI_code_def is_dec_det_code_def
using a_lookup_some_in_A
by force
definition "d0 = D_Map.from_list' (λs. fst (hd (a_inorder (s_lookup mdp s)))) [0..<states]"
end
lemma inorder_empty: "Tree2.inorder am = [] ⟹ am = ⟨⟩"
using Tree2.inorder.elims by blast