Theory Policy_Iteration
theory Policy_Iteration
imports "MDP-Rewards.MDP_reward"
begin
section ‹Policy Iteration›
text ‹
The Policy Iteration algorithms provides another way to find optimal policies under the expected
total reward criterion.
It differs from Value Iteration in that it continuously improves an initial guess for an optimal
decision rule. Its execution can be subdivided into two alternating steps: policy evaluation and
policy improvement.
Policy evaluation means the calculation of the value of the current decision rule.
During the improvement phase, we choose the decision rule with the maximum value for L,
while we prefer to keep the old action selection in case of ties.
›
context MDP_att_ℒ begin
definition "policy_eval d = ν⇩b (mk_stationary_det d)"
end
context MDP_act_disc
begin
definition "policy_improvement d v s = (
if is_arg_max (λa. L⇩a a (apply_bfun v) s) (λa. a ∈ A s) (d s)
then d s
else arb_act (opt_acts v s))"
definition "policy_step d = policy_improvement d (policy_eval d)"
function policy_iteration :: "('s ⇒ 'a) ⇒ ('s ⇒ 'a)" where
"policy_iteration d = (
let d' = policy_step d in
if d = d' ∨ ¬is_dec_det d then d else policy_iteration d')"
by auto
text ‹
The policy iteration algorithm as stated above does require that the supremum in @{const ℒ⇩b} is
always attained.
›
text ‹
Each policy improvement returns a valid decision rule.
›
lemma is_dec_det_pi: "is_dec_det (policy_improvement d v)"
unfolding policy_improvement_def is_dec_det_def is_arg_max_def
by (auto simp: some_opt_acts_in_A)
lemma policy_improvement_is_dec_det: "d ∈ D⇩D ⟹ policy_improvement d v ∈ D⇩D"
unfolding policy_improvement_def is_dec_det_def
using some_opt_acts_in_A
by auto
lemma policy_improvement_improving:
assumes "d ∈ D⇩D"
shows "ν_improving v (mk_dec_det (policy_improvement d v))"
proof -
have "ℒ⇩b v x = L (mk_dec_det (policy_improvement d v)) v x" for x
using is_opt_act_some
by (fastforce simp: ℒ⇩b_eq_argmax_L⇩a L_eq_L⇩a_det is_opt_act_def policy_improvement_def arg_max_SUP)
thus ?thesis
using policy_improvement_is_dec_det assms by (auto simp: ν_improving_alt)
qed
lemma eval_policy_step_L:
"is_dec_det d ⟹ L (mk_dec_det (policy_step d)) (policy_eval d) = ℒ⇩b (policy_eval d)"
by (auto simp: policy_step_def ν_improving_imp_ℒ⇩b[OF policy_improvement_improving])
text ‹ The sequence of policies generated by policy iteration has monotonically increasing
discounted reward.›
lemma policy_eval_mon:
assumes "is_dec_det d"
shows "policy_eval d ≤ policy_eval (policy_step d)"
proof -
let ?d' = "mk_dec_det (policy_step d)"
let ?dp = "mk_stationary_det d"
let ?P = "∑t. l ^ t *⇩R 𝒫⇩1 ?d' ^^ t"
have "L (mk_dec_det d) (policy_eval d) ≤ L ?d' (policy_eval d)"
using assms by (auto simp: L_le_ℒ⇩b eval_policy_step_L)
hence "policy_eval d ≤ L ?d' (policy_eval d)"
using L_ν_fix policy_eval_def by auto
hence "ν⇩b ?dp ≤ r_dec⇩b ?d' + l *⇩R 𝒫⇩1 ?d' (ν⇩b ?dp)"
unfolding policy_eval_def L_def by auto
hence "(id_blinfun - l *⇩R 𝒫⇩1 ?d') (ν⇩b ?dp) ≤ r_dec⇩b ?d'"
by (simp add: blinfun.diff_left diff_le_eq scaleR_blinfun.rep_eq)
hence "?P ((id_blinfun - l *⇩R 𝒫⇩1 ?d') (ν⇩b ?dp)) ≤ ?P (r_dec⇩b ?d')"
using lemma_6_1_2_b by auto
hence "ν⇩b ?dp ≤ ?P (r_dec⇩b ?d')"
using inv_norm_le'(2)[OF norm_𝒫⇩1_l_less] by (auto simp: blincomp_scaleR_right)
thus ?thesis
by (auto simp: policy_eval_def ν_stationary)
qed
text ‹
If policy iteration terminates, i.e. @{term "d = policy_step d"}, then it does so with optimal value.
›
lemma policy_step_eq_imp_opt:
assumes "is_dec_det d" "d = policy_step d"
shows "ν⇩b (mk_stationary_det d) = ν⇩b_opt"
using L_ν_fix assms eval_policy_step_L[unfolded policy_eval_def]
by (fastforce intro: ℒ_fix_imp_opt)
end
text ‹We prove termination of policy iteration only if both the state and action sets are finite.›
locale MDP_PI_finite = MDP_act_disc arb_act A K r l
for
A and
K :: "'s ::countable × 'a ::countable ⇒ 's pmf" and r l arb_act +
assumes fin_states: "finite (UNIV :: 's set)" and fin_actions: "⋀s. finite (A s)"
begin
text ‹If the state and action sets are both finite,
then so is the set of deterministic decision rules @{const "D⇩D"}›
lemma finite_D⇩D[simp]: "finite D⇩D"
proof -
let ?set = "{d. ∀x :: 's. (x ∈ UNIV ⟶ d x ∈ (⋃s. A s)) ∧ (x ∉ UNIV ⟶ d x = undefined)}"
have "finite (⋃s. A s)"
using fin_actions fin_states by blast
hence "finite ?set"
using fin_states by (fastforce intro: finite_set_of_finite_funs)
moreover have "D⇩D ⊆ ?set"
unfolding is_dec_det_def by auto
ultimately show ?thesis
using finite_subset by auto
qed
lemma finite_rel: "finite {(u, v). is_dec_det u ∧ is_dec_det v ∧ ν⇩b (mk_stationary_det u) >
ν⇩b (mk_stationary_det v)}"
proof-
have aux: "finite {(u, v). is_dec_det u ∧ is_dec_det v}"
by auto
show ?thesis
by (auto intro: finite_subset[OF _ aux])
qed
text ‹
This auxiliary lemma shows that policy iteration terminates if no improvement to the value of
the policy could be made, as then the policy remains unchanged.
›
lemma eval_eq_imp_policy_eq:
assumes "policy_eval d = policy_eval (policy_step d)" "is_dec_det d"
shows "d = policy_step d"
proof -
have "policy_eval d s = policy_eval (policy_step d) s" for s
using assms by auto
have "policy_eval d = L (mk_dec_det d) (policy_eval (policy_step d))"
unfolding policy_eval_def
using L_ν_fix
by (auto simp: assms(1)[symmetric, unfolded policy_eval_def])
hence "policy_eval d = ℒ⇩b (policy_eval d)"
by (metis L_ν_fix policy_eval_def assms eval_policy_step_L)
hence "L (mk_dec_det d) (policy_eval d) s = ℒ⇩b (policy_eval d) s" for s
using ‹policy_eval d = L (mk_dec_det d) (policy_eval (policy_step d))› assms(1) by auto
hence "is_arg_max (λa. L⇩a a (ν⇩b (mk_stationary (mk_dec_det d))) s) (λa. a ∈ A s) (d s)" for s
unfolding L_eq_L⇩a_det
unfolding policy_eval_def ℒ⇩b.rep_eq ℒ_eq_SUP_det SUP_step_det_eq
using assms(2) is_dec_det_def L⇩a_le
by (auto intro!: SUP_is_arg_max boundedI bounded_imp_bdd_above)
thus ?thesis
unfolding policy_eval_def policy_step_def policy_improvement_def
by auto
qed
text ‹
We are now ready to prove termination in the context of finite state-action spaces.
Intuitively, the algorithm terminates as there are only finitely many decision rules,
and in each recursive call the value of the decision rule increases.
›
termination policy_iteration
proof (relation "{(u, v). u ∈ D⇩D ∧ v ∈ D⇩D ∧ ν⇩b (mk_stationary_det u) > ν⇩b (mk_stationary_det v)}")
show "wf {(u, v). u ∈ D⇩D ∧ v ∈ D⇩D ∧ ν⇩b (mk_stationary_det v) < ν⇩b (mk_stationary_det u)}"
using finite_rel by (auto intro!: finite_acyclic_wf acyclicI_order)
next
fix d x
assume h: "x = policy_step d" "¬ (d = x ∨ ¬ is_dec_det d)"
have "is_dec_det d ⟹ ν⇩b (mk_stationary_det d) ≤ ν⇩b (mk_stationary_det (policy_step d))"
using policy_eval_mon by (simp add: policy_eval_def)
hence "is_dec_det d ⟹ d ≠ policy_step d ⟹
ν⇩b (mk_stationary_det d) < ν⇩b (mk_stationary_det (policy_step d))"
using eval_eq_imp_policy_eq policy_eval_def
by (force intro!: order.not_eq_order_implies_strict)
thus "(x, d) ∈ {(u, v). u ∈ D⇩D ∧ v ∈ D⇩D ∧ ν⇩b (mk_stationary_det v) < ν⇩b (mk_stationary_det u)}"
using is_dec_det_pi policy_step_def h by auto
qed
text ‹
The termination proof gives us access to the induction rule/simplification lemmas associated
with the @{const policy_iteration} definition.
Thus we can prove that the algorithm finds an optimal policy.
›
lemma is_dec_det_pi': "d ∈ D⇩D ⟹ is_dec_det (policy_iteration d)"
using is_dec_det_pi
by (induction d rule: policy_iteration.induct) (auto simp: Let_def policy_step_def)
lemma pi_pi[simp]: "d ∈ D⇩D ⟹ policy_step (policy_iteration d) = policy_iteration d"
using is_dec_det_pi
by (induction d rule: policy_iteration.induct) (auto simp: policy_step_def Let_def)
lemma policy_iteration_correct:
"d ∈ D⇩D ⟹ ν⇩b (mk_stationary_det (policy_iteration d)) = ν⇩b_opt"
by (induction d rule: policy_iteration.induct)
(fastforce intro!: policy_step_eq_imp_opt is_dec_det_pi' simp del: policy_iteration.simps)
end
context MDP_finite_type begin
text ‹
The following proofs concern code generation, i.e. how to represent @{const 𝒫⇩1} as a matrix.
›
sublocale MDP_att_ℒ
by (auto simp: A_ne finite_is_arg_max MDP_att_ℒ_def MDP_att_ℒ_axioms_def max_L_ex_def
has_arg_max_def MDP_reward_disc_axioms)
definition "fun_to_matrix f = matrix (λv. (χ j. f (vec_nth v) j))"
definition "Ek_mat d = fun_to_matrix (λv. ((𝒫⇩1 d) (Bfun v)))"
definition "nu_inv_mat d = fun_to_matrix ((λv. ((id_blinfun - l *⇩R 𝒫⇩1 d) (Bfun v))))"
definition "nu_mat d = fun_to_matrix (λv. ((∑i. (l *⇩R 𝒫⇩1 d) ^^ i) (Bfun v)))"
lemma apply_nu_inv_mat:
"(id_blinfun - l *⇩R 𝒫⇩1 d) v = Bfun (λi. ((nu_inv_mat d) *v (vec_lambda v)) $ i)"
proof -
have eq_onpI: "P x ⟹ eq_onp P x x" for P x
by(simp add: eq_onp_def)
have "Real_Vector_Spaces.linear (λv. vec_lambda (((id_blinfun - l *⇩R 𝒫⇩1 d) (bfun.Bfun (($) v)))))"
by (auto simp del: real_scaleR_def intro: linearI
simp: scaleR_vec_def eq_onpI plus_vec_def vec_lambda_inverse plus_bfun.abs_eq[symmetric]
scaleR_bfun.abs_eq[symmetric] blinfun.scaleR_right blinfun.add_right)
thus ?thesis
unfolding Ek_mat_def fun_to_matrix_def nu_inv_mat_def
by (auto simp: apply_bfun_inverse vec_lambda_inverse)
qed
lemma bounded_linear_vec_lambda: "bounded_linear (λx. vec_lambda (x :: 's ⇒⇩b real))"
proof (intro bounded_linear_intro)
fix x :: "'s ⇒⇩b real"
have "sqrt (∑ i ∈ UNIV . (apply_bfun x i)⇧2) ≤ (∑ i ∈ UNIV . ¦(apply_bfun x i)¦)"
using L2_set_le_sum_abs
unfolding L2_set_def
by auto
also have "(∑ i ∈ UNIV . ¦(apply_bfun x i)¦) ≤ (card (UNIV :: 's set) * (⨆xa. ¦apply_bfun x xa¦))"
by (auto intro!: cSup_upper sum_bounded_above)
finally show "norm (vec_lambda (apply_bfun x)) ≤ norm x * CARD('s)"
unfolding norm_vec_def norm_bfun_def dist_bfun_def L2_set_def
by (auto simp add: mult.commute)
qed (auto simp: plus_vec_def scaleR_vec_def)
lemma bounded_linear_vec_lambda_blinfun:
fixes f :: "('s ⇒⇩b real) ⇒⇩L ('s ⇒⇩b real)"
shows "bounded_linear (λv. vec_lambda (apply_bfun (blinfun_apply f (bfun.Bfun (($) v)))))"
using blinfun.bounded_linear_right
by (fastforce intro: bounded_linear_compose[OF bounded_linear_vec_lambda]
bounded_linear_bfun_nth bounded_linear_compose[of f])
lemma invertible_nu_inv_max: "invertible (nu_inv_mat d)"
unfolding nu_inv_mat_def fun_to_matrix_def
by (auto simp: matrix_invertible inv_norm_le' vec_lambda_inverse apply_bfun_inverse
bounded_linear.linear[OF bounded_linear_vec_lambda_blinfun]
intro!: exI[of _ "λv. (χ j. (λv. (∑i. (l *⇩R 𝒫⇩1 d) ^^ i) (Bfun v)) (vec_nth v) j)"])
end
locale MDP_ord = MDP_finite_type A K r l
for A and
K :: "'s :: {finite, wellorder} × 'a :: {finite, wellorder} ⇒ 's pmf"
and r l
begin
lemma ℒ_fin_eq_det: "ℒ v s = (⨆a ∈ A s. L⇩a a v s)"
by (simp add: SUP_step_det_eq ℒ_eq_SUP_det)
lemma ℒ⇩b_fin_eq_det: "ℒ⇩b v s = (⨆a ∈ A s. L⇩a a v s)"
by (simp add: SUP_step_det_eq ℒ⇩b.rep_eq ℒ_eq_SUP_det)
sublocale MDP_PI_finite A K r l "λX. Least (λx. x ∈ X)"
by unfold_locales (auto intro: LeastI)
end
end