Theory Modified_Policy_Iteration
theory Modified_Policy_Iteration
imports
Policy_Iteration
Value_Iteration
begin
section ‹Modified Policy Iteration›
locale MDP_MPI = MDP_att_ℒ A K r l + MDP_act_disc arb_act A K r l
for A and K :: "'s :: countable × 'a :: countable ⇒ 's pmf" and r l arb_act
begin
subsection ‹The Advantage Function @{term B}›
definition "B v s = (⨆d ∈ D⇩R. (r_dec d s + (l *⇩R 𝒫⇩1 d - id_blinfun) v s))"
text "The function @{const B} denotes the advantage of choosing the optimal action vs.
the current value estimate"
lemma cSUP_plus:
assumes "X ≠ {}" "bdd_above (f`X)"
shows "(⨆x ∈ X. f x + c) = (⨆x ∈ X. f x) + (c::real)"
proof (rule antisym)
show "(⨆x∈X. f x + c) ≤ ⨆ (f ` X) + c"
using assms by (fastforce intro: cSUP_least cSUP_upper)
show "⨆ (f ` X) + c ≤ (⨆x∈X. f x + c)"
unfolding le_diff_eq[symmetric]
using assms
by (intro cSUP_least) (auto simp add: algebra_simps bdd_above_def intro!: cSUP_upper2 intro: add_left_mono)+
qed
lemma cSUP_minus:
assumes "X ≠ {}" "bdd_above (f`X)"
shows "(⨆x ∈ X. f x - c) = (⨆x ∈ X. f x) - (c::real)"
using cSUP_plus[OF assms, of "- c"] by auto
lemma B_eq_ℒ: "B v s = ℒ v s - v s"
proof -
have *: "B v s = (⨆d ∈ D⇩R. L d v s - v s)"
unfolding B_def L_def by (auto simp add: blinfun.bilinear_simps add_diff_eq)
have "bdd_above ((λd. L d v s - v s) ` D⇩R)"
by (auto intro!: bounded_const bounded_minus_comp bounded_imp_bdd_above)
thus ?thesis
unfolding * ℒ_def using ex_dec by (fastforce intro!: cSUP_minus)
qed
text ‹@{const B} is a bounded function.›
lift_definition B⇩b :: "('s ⇒⇩b real) ⇒ 's ⇒⇩b real" is "B"
unfolding B_eq_ℒ using ℒ_bfun by (auto intro: Bounded_Functions.minus_cont)
lemma B⇩b_eq_ℒ⇩b: "B⇩b v = ℒ⇩b v - v"
by (auto simp: ℒ⇩b.rep_eq B⇩b.rep_eq B_eq_ℒ)
lemma ℒ⇩b_eq_SUP_L⇩a': "ℒ⇩b v s = (⨆a ∈ A s. L⇩a a v s)"
using L_eq_L⇩a_det ℒ⇩b_eq_SUP_det SUP_step_det_eq
by auto
subsection ‹Optimization of the Value Function over Multiple Steps›
definition "U m v s = (⨆d ∈ D⇩R. (ν⇩b_fin (mk_stationary d) m + ((l *⇩R 𝒫⇩1 d)^^m) v) s)"
text ‹@{const U} expresses the value estimate obtained by optimizing the first @{term m} steps and
afterwards using the current estimate.›
lemma U_zero [simp]: "U 0 v = v"
unfolding U_def ℒ_def by (auto simp: ν⇩b_fin.rep_eq)
lemma U_one_eq_ℒ: "U 1 v s = ℒ v s"
unfolding U_def ℒ_def by (auto simp: ν⇩b_fin_eq_𝒫⇩X L_def blinfun.bilinear_simps)
lift_definition U⇩b :: "nat ⇒ ('s ⇒⇩b real) ⇒ ('s ⇒⇩b real)" is U
proof -
fix n v
have "norm (ν⇩b_fin (mk_stationary d) m) ≤ (∑i<m. l ^ i * r⇩M)" for d m
using abs_ν_fin_le ν⇩b_fin.rep_eq by (auto intro!: norm_bound)
moreover have "norm (((l *⇩R 𝒫⇩1 d)^^m) v) ≤ l ^ m * norm v" for d m
by (auto simp: 𝒫⇩X_const[symmetric] blinfun.bilinear_simps blincomp_scaleR_right
intro!: boundedI order.trans[OF abs_le_norm_bfun] mult_left_mono)
ultimately have *: "norm (ν⇩b_fin (mk_stationary d) m + ((l *⇩R 𝒫⇩1 d)^^m) v) ≤ (∑i<m. l ^ i * r⇩M) + l ^ m * norm v" for d m
using norm_triangle_mono by blast
show "U n v ∈ bfun"
using ex_dec order.trans[OF abs_le_norm_bfun *]
by (fastforce simp: U_def intro!: bfun_normI cSup_abs_le)
qed
lemma U⇩b_contraction: "dist (U⇩b m v) (U⇩b m u) ≤ l ^ m * dist v u"
proof -
have aux: "dist (U⇩b m v s) (U⇩b m u s) ≤ l ^ m * dist v u" if le: "U⇩b m u s ≤ U⇩b m v s" for s v u
proof -
let ?U = "λm v d. (ν⇩b_fin (mk_stationary d) m + ((l *⇩R 𝒫⇩1 d) ^^ m) v) s"
have "U⇩b m v s - U⇩b m u s ≤ (⨆d ∈ D⇩R. ?U m v d - ?U m u d)"
using bounded_stationary_ν⇩b_fin bounded_disc_𝒫⇩1 le
unfolding U⇩b.rep_eq U_def
by (intro le_SUP_diff') (auto intro: bounded_plus_comp)
also have "… = (⨆d ∈ D⇩R. ((l *⇩R 𝒫⇩1 d) ^^ m) (v - u) s)"
by (simp add: L_def scale_right_diff_distrib blinfun.bilinear_simps)
also have "… = (⨆d ∈ D⇩R. l^m * ((𝒫⇩1 d ^^ m) (v - u) s))"
by (simp add: blincomp_scaleR_right blinfun.scaleR_left)
also have "… = l^m * (⨆d ∈ D⇩R. ((𝒫⇩1 d ^^ m) (v - u) s))"
using D⇩R_ne bounded_P bounded_disc_𝒫⇩1' by (auto intro: bounded_SUP_mul)
also have "… ≤ l^m * norm (⨆d ∈ D⇩R. ((𝒫⇩1 d ^^ m) (v - u) s))"
by (simp add: mult_left_mono)
also have "… ≤ l^m * (⨆d ∈ D⇩R. norm (((𝒫⇩1 d ^^ m) (v - u) s)))"
using D⇩R_ne ex_dec bounded_norm_comp bounded_disc_𝒫⇩1'
by (fastforce intro!: mult_left_mono)
also have "… ≤ l^m * (⨆d ∈ D⇩R. norm ((𝒫⇩1 d ^^ m) ((v - u))))"
using ex_dec
by (fastforce intro!: order.trans[OF norm_blinfun] abs_le_norm_bfun mult_left_mono cSUP_mono)
also have "… ≤ l^m * (⨆d ∈ D⇩R. norm ((v - u)))"
using norm_𝒫⇩X_apply by (auto simp: 𝒫⇩X_const[symmetric] cSUP_least mult_left_mono)
also have "… = l ^m * dist v u"
by (auto simp: dist_norm)
finally have "U⇩b m v s - U⇩b m u s ≤ l^m * dist v u" .
thus ?thesis
by (simp add: dist_real_def le)
qed
moreover have "U⇩b m v s ≤ U⇩b m u s ⟹ dist (U⇩b m v s) (U⇩b m u s) ≤ l^m * dist v u" for u v s
by (simp add: aux dist_commute)
ultimately have "dist (U⇩b m v s) (U⇩b m u s) ≤ l^m * dist v u" for u v s
using linear by blast
thus "dist (U⇩b m v) (U⇩b m u) ≤ l^m * dist v u"
by (simp add: dist_bound)
qed
lemma U⇩b_conv:
"∃!v. U⇩b (Suc m) v = v"
"(λn. (U⇩b (Suc m) ^^ n) v) ⇢ (THE v. U⇩b (Suc m) v = v)"
proof -
have *: "is_contraction (U⇩b (Suc m))"
unfolding is_contraction_def
using U⇩b_contraction[of "Suc m"] le_neq_trans[OF zero_le_disc]
by (cases "l = 0") (auto intro!: power_Suc_less_one intro: exI[of _ "l^(Suc m)"])
show "∃!v. U⇩b (Suc m) v = v" "(λn. (U⇩b (Suc m) ^^ n) v) ⇢ (THE v. U⇩b (Suc m) v = v)"
using banach'[OF *] by auto
qed
lemma U⇩b_convergent: "convergent (λn. (U⇩b (Suc m) ^^ n) v)"
by (intro convergentI[OF U⇩b_conv(2)])
lemma U⇩b_mono:
assumes "v ≤ u"
shows "U⇩b m v ≤ U⇩b m u"
proof -
have "U⇩b m v s ≤ U⇩b m u s" for s
unfolding U⇩b.rep_eq U_def
proof (intro cSUP_mono, goal_cases)
case 2
thus ?case
by (simp add: bounded_imp_bdd_above bounded_disc_𝒫⇩1 bounded_plus_comp bounded_stationary_ν⇩b_fin)
next
case (3 n)
thus ?case
using less_eq_bfunD[OF 𝒫⇩X_mono[OF assms]]
by (auto simp: 𝒫⇩X_const[symmetric] blincomp_scaleR_right blinfun.scaleR_left intro!: mult_left_mono exI)
qed auto
thus ?thesis
using assms by auto
qed
lemma U⇩b_le_ℒ⇩b: "U⇩b m v ≤ (ℒ⇩b ^^ m) v"
proof -
have "U⇩b m v s = (⨆d ∈ D⇩R. (L d^^ m) v s)" for m v s
by (auto simp: L_iter U⇩b.rep_eq ℒ⇩b.rep_eq U_def ℒ_def)
thus ?thesis
using L_iter_le_ℒ⇩b ex_dec by (fastforce intro!: cSUP_least)
qed
lemma L_iter_le_U⇩b:
assumes "d ∈ D⇩R"
shows "(L d^^m) v ≤ U⇩b m v"
using assms
by (fastforce intro!: cSUP_upper bounded_imp_bdd_above
simp: L_iter U⇩b.rep_eq U_def bounded_disc_𝒫⇩1 bounded_plus_comp bounded_stationary_ν⇩b_fin)
lemma lim_U⇩b: "lim (λn. (U⇩b (Suc m) ^^ n) v) = ν⇩b_opt"
proof -
have le_U: "ν⇩b_opt ≤ U⇩b m ν⇩b_opt" for m
proof -
obtain d where d: "ν_improving ν⇩b_opt (mk_dec_det d)" "d ∈ D⇩D"
using ex_improving_det by auto
have "ν⇩b_opt = (L (mk_dec_det d) ^^ m) ν⇩b_opt"
by (induction m) (metis L_ν_fix_iff ℒ⇩b_opt ν_improving_imp_ℒ⇩b d(1) funpow_swap1)+
thus ?thesis
using ‹d ∈ D⇩D› by (auto intro!: order.trans[OF _ L_iter_le_U⇩b])
qed
have "U⇩b m ν⇩b_opt ≤ ν⇩b_opt" for m
using ℒ_inc_le_opt by (auto intro!: order.trans[OF U⇩b_le_ℒ⇩b] simp: funpow_swap1)
hence "U⇩b (Suc m) ν⇩b_opt = ν⇩b_opt"
using le_U by (simp add: antisym)
moreover have "(lim (λn. (U⇩b (Suc m) ^^n) v)) = U⇩b (Suc m) (lim (λn. (U⇩b (Suc m) ^^n) v))"
using limI[OF U⇩b_conv(2)] theI'[OF U⇩b_conv(1)] by auto
ultimately show ?thesis
using U⇩b_conv(1) by metis
qed
lemma U⇩b_tendsto: "(λn. (U⇩b (Suc m) ^^ n) v) ⇢ ν⇩b_opt"
using lim_U⇩b U⇩b_convergent convergent_LIMSEQ_iff by metis
lemma U⇩b_fix_unique: "U⇩b (Suc m) v = v ⟷ v = ν⇩b_opt"
using theI'[OF U⇩b_conv(1)] U⇩b_conv(1)
by (auto simp: LIMSEQ_unique[OF U⇩b_tendsto U⇩b_conv(2)[of m]])
lemma dist_U⇩b_opt: "dist (U⇩b m v) ν⇩b_opt ≤ l^m * dist v ν⇩b_opt"
proof -
have "dist (U⇩b m v) ν⇩b_opt = dist (U⇩b m v) (U⇩b m ν⇩b_opt)"
by (metis U⇩b.abs_eq U⇩b_fix_unique U_zero apply_bfun_inverse not0_implies_Suc)
also have "… ≤ l^m * dist v ν⇩b_opt"
by (meson U⇩b_contraction)
finally show ?thesis .
qed
subsection ‹Expressing a Single Step of Modified Policy Iteration›
text ‹The function @{term W} equals the value computed by the Modified Policy Iteration Algorithm
in a single iteration.
The right hand addend in the definition describes the advantage of using the optimal action for
the first m steps.
›
definition "W d m v = v + (∑i < m. (l *⇩R 𝒫⇩1 d)^^i) (B⇩b v)"
lemma W_eq_L_iter:
assumes "ν_improving v d"
shows "W d m v = (L d^^m) v"
proof -
have "(∑i<m. (l *⇩R 𝒫⇩1 d)^^i) (ℒ⇩b v) = (∑i<m. (l *⇩R 𝒫⇩1 d)^^i) (L d v)"
using ν_improving_imp_ℒ⇩b assms by auto
hence "W d m v = v + ((∑i<m. (l *⇩R 𝒫⇩1 d)^^i) (L d v)) - (∑i<m. (l *⇩R 𝒫⇩1 d)^^i) v"
by (auto simp: W_def B⇩b_eq_ℒ⇩b blinfun.bilinear_simps)
also have "… = v + ν⇩b_fin (mk_stationary d) m + (∑i<m. ((l *⇩R 𝒫⇩1 d)^^i) ((l *⇩R 𝒫⇩1 d) v)) - (∑i<m. (l *⇩R 𝒫⇩1 d)^^i) v"
by (auto simp: L_def ν⇩b_fin_eq blinfun.bilinear_simps scaleR_right.sum)
also have "… = v + ν⇩b_fin (mk_stationary d) m + (∑i<m. ((l *⇩R 𝒫⇩1 d)^^Suc i) v) - (∑i<m. (l *⇩R 𝒫⇩1 d)^^i) v"
by (auto simp del: blinfunpow.simps simp: blinfunpow_assoc)
also have "… = ν⇩b_fin (mk_stationary d) m + (∑i<Suc m. ((l *⇩R 𝒫⇩1 d)^^ i) v) - (∑i<m. (l *⇩R 𝒫⇩1 d)^^ i) v"
by (subst sum.lessThan_Suc_shift) auto
also have "… = ν⇩b_fin (mk_stationary d) m + ((l *⇩R 𝒫⇩1 d)^^m) v"
by (simp add: blinfun.sum_left)
also have "… = (L d ^^ m) v"
using L_iter by auto
finally show ?thesis .
qed
lemma U⇩b_ge: "d ∈ D⇩R ⟹ U⇩b m u ≥ ν⇩b_fin (mk_stationary d) m + ((l *⇩R 𝒫⇩1 d) ^^ m) u"
using ν_improving_D_MR bounded_stationary_ν⇩b_fin bounded_disc_𝒫⇩1
by (fastforce intro!: diff_mono bounded_imp_bdd_above cSUP_upper bounded_plus_comp simp: U⇩b.rep_eq U_def)
lemma W_le_U⇩b:
assumes "v ≤ u" "ν_improving v d"
shows "W d m v ≤ U⇩b m u"
using assms
by (fastforce simp: W_eq_L_iter intro!: order.trans[OF L_iter_le_U⇩b U⇩b_mono])
lemma W_ge_ℒ⇩b:
assumes "v ≤ u" "0 ≤ B⇩b u" "ν_improving u d'"
shows "ℒ⇩b v ≤ W d' (Suc m) u"
proof -
have "ℒ⇩b v ≤ u + B⇩b u"
using assms(1) ℒ⇩b_mono B⇩b_eq_ℒ⇩b by auto
also have "… ≤ W d' (Suc m) u"
using L_mono ν_improving_imp_ℒ⇩b assms(3) assms
by (induction m) (auto simp: W_eq_L_iter B⇩b_eq_ℒ⇩b)
finally show ?thesis .
qed
lemma B⇩b_le:
assumes "ν_improving v d"
shows "B⇩b v + (l *⇩R 𝒫⇩1 d - id_blinfun) (u - v) ≤ B⇩b u"
proof -
have "r_dec⇩b d + (l *⇩R 𝒫⇩1 d - id_blinfun) u ≤ B⇩b u"
using L_def L_le_ℒ⇩b assms by (auto simp: B⇩b_eq_ℒ⇩b ℒ⇩b.rep_eq ℒ_def blinfun.bilinear_simps)
moreover have "B⇩b v = r_dec⇩b d + (l *⇩R 𝒫⇩1 d - id_blinfun) v"
using assms by (auto simp: B⇩b_eq_ℒ⇩b ν_improving_imp_ℒ⇩b[of _ d] L_def blinfun.bilinear_simps)
ultimately show ?thesis
by (simp add: blinfun.diff_right)
qed
subsection ‹Computing the Bellman Operator over Multiple Steps›
definition "L_pow v d m = (L (mk_dec_det d) ^^ m) v"
lemma L_pow_eq:
fixes d defines "d' ≡ mk_dec_det d"
assumes "ν_improving v d'"
shows "L_pow v d m = v + (∑i < m. ((l *⇩R 𝒫⇩1 d')^^i)) (B⇩b v)"
using L_pow_def W_def W_eq_L_iter assms by presburger
lemma L_pow_eq_W:
assumes "d ∈ D⇩D"
shows "L_pow v (policy_improvement d v) m = W (mk_dec_det (policy_improvement d v)) m v"
using assms policy_improvement_improving by (auto simp: W_eq_L_iter L_pow_def)
lemma find_policy'_is_dec_det: "is_dec_det (find_policy' v)"
using find_policy'_def is_dec_det_def some_opt_acts_in_A by presburger
lemma find_policy'_improving: "ν_improving v (mk_dec_det (find_policy' v))"
using ν_improving_opt_acts find_policy'_def by presburger
lemma L_pow_eq_W': "L_pow v (find_policy' v) m = W (mk_dec_det (find_policy' v)) m v"
using find_policy'_improving by (auto simp: W_eq_L_iter L_pow_def)
lemma ℒ⇩b_W_ge:
assumes "u ≤ ℒ⇩b u" "ν_improving u d"
shows "W d m u ≤ ℒ⇩b (W d m u)"
proof -
have "0 ≤ ((l *⇩R 𝒫⇩1 d) ^^ m) (B⇩b u)"
by (metis B⇩b_eq_ℒ⇩b 𝒫⇩1_n_disc_pos assms(1) blincomp_scaleR_right diff_ge_0_iff_ge)
also have "… = ((l *⇩R 𝒫⇩1 d)^^0 + (∑i < m. (l *⇩R 𝒫⇩1 d)^^(Suc i))) (B⇩b u) - (∑i < m. (l *⇩R 𝒫⇩1 d)^^ i) (B⇩b u)"
by (subst sum.lessThan_Suc_shift[symmetric]) (auto simp: blinfun.diff_left[symmetric])
also have "… = B⇩b u + ((l *⇩R 𝒫⇩1 d - id_blinfun) o⇩L (∑i < m. (l *⇩R 𝒫⇩1 d)^^i)) (B⇩b u)"
by (auto simp: blinfun.bilinear_simps sum_subtractf)
also have "… = B⇩b u + (l *⇩R 𝒫⇩1 d - id_blinfun) (W d m u - u)"
by (auto simp: W_def sum.lessThan_Suc[unfolded lessThan_Suc_atMost])
also have "… ≤ B⇩b (W d m u)"
using B⇩b_le assms(2) by blast
finally have "0 ≤ B⇩b (W d m u)" .
thus ?thesis
using B⇩b_eq_ℒ⇩b by auto
qed
lemma L_pow_ℒ⇩b_mono_inv:
assumes "d ∈ D⇩D" "v ≤ ℒ⇩b v"
shows "L_pow v (policy_improvement d v) m ≤ ℒ⇩b (L_pow v (policy_improvement d v) m)"
using assms L_pow_eq_W ℒ⇩b_W_ge policy_improvement_improving by auto
lemma L_pow_ℒ⇩b_mono_inv':
assumes "v ≤ ℒ⇩b v"
shows "L_pow v (find_policy' v) m ≤ ℒ⇩b (L_pow v (find_policy' v) m)"
using assms L_pow_eq_W' ℒ⇩b_W_ge find_policy'_improving by auto
subsection ‹The Modified Policy Iteration Algorithm›
context
fixes d0 :: "'s ⇒ 'a"
fixes v0 :: "'s ⇒⇩b real"
fixes m :: "nat ⇒ ('s ⇒⇩b real) ⇒ nat"
assumes d0: "d0 ∈ D⇩D"
begin
text ‹We first define a function that executes the algorithm for n steps.›
fun mpi :: "nat ⇒ (('s ⇒ 'a) × ('s ⇒⇩b real))" where
"mpi 0 = (find_policy' v0, v0)" |
"mpi (Suc n) =
(let (d, v) = mpi n; v' = L_pow v d (Suc (m n v)) in
(find_policy' v', v'))"
definition "mpi_val n = snd (mpi n)"
definition "mpi_pol n = fst (mpi n)"
lemma mpi_pol_zero[simp]: "mpi_pol 0 = find_policy' v0"
unfolding mpi_pol_def
by auto
lemma mpi_pol_Suc: "mpi_pol (Suc n) = find_policy' (mpi_val (Suc n))"
by (auto simp: case_prod_beta' Let_def mpi_pol_def mpi_val_def)
lemma mpi_pol_is_dec_det: "mpi_pol n ∈ D⇩D"
unfolding mpi_pol_def
using find_policy'_is_dec_det d0
by (induction n) (auto simp: Let_def split: prod.splits)
lemma ν_improving_mpi_pol: "ν_improving (mpi_val n) (mk_dec_det (mpi_pol n))"
using d0 find_policy'_improving mpi_pol_is_dec_det mpi_pol_Suc
by (cases n) (auto simp: mpi_pol_def mpi_val_def)
lemma mpi_val_zero[simp]: "mpi_val 0 = v0"
unfolding mpi_val_def by auto
lemma mpi_val_Suc: "mpi_val (Suc n) = L_pow (mpi_val n) (mpi_pol n) (Suc (m n (mpi_val n)))"
unfolding mpi_val_def mpi_pol_def
by (auto simp: case_prod_beta' Let_def)
lemma mpi_val_eq: "mpi_val (Suc n) =
mpi_val n + (∑i ≤ (m n (mpi_val n)). (l *⇩R 𝒫⇩1 (mk_dec_det (mpi_pol n))) ^^ i) (B⇩b (mpi_val n))"
using lessThan_Suc_atMost by (auto simp: mpi_val_Suc L_pow_eq[OF ν_improving_mpi_pol])
text ‹Value Iteration is a special case of MPI where @{term "∀n v. m n v = 0"}.›
lemma mpi_includes_value_it:
assumes "∀n v. m n v = 0"
shows "mpi_val (Suc n) = ℒ⇩b (mpi_val n)"
using assms B⇩b_eq_ℒ⇩b mpi_val_eq by auto
subsection ‹Convergence Proof›
text ‹We define the sequence @{term w} as an upper bound for the values of MPI.›
fun w where
"w 0 = v0" |
"w (Suc n) = U⇩b (Suc (m n (mpi_val n))) (w n)"
lemma dist_ν⇩b_opt: "dist (w (Suc n)) ν⇩b_opt ≤ l * dist (w n) ν⇩b_opt"
by (fastforce simp: algebra_simps intro: order.trans[OF dist_U⇩b_opt] mult_left_mono power_le_one
mult_left_le_one_le order.strict_implies_order)
lemma dist_ν⇩b_opt_n: "dist (w n) ν⇩b_opt ≤ l^n * dist v0 ν⇩b_opt"
by (induction n) (fastforce simp: algebra_simps intro: order.trans[OF dist_ν⇩b_opt] mult_left_mono)+
lemma w_conv: "w ⇢ ν⇩b_opt"
proof -
have "(λn. l^n * dist v0 ν⇩b_opt) ⇢ 0"
using LIMSEQ_realpow_zero by (cases "v0 = ν⇩b_opt") auto
then show ?thesis
by (fastforce intro: metric_LIMSEQ_I order.strict_trans1[OF dist_ν⇩b_opt_n] simp: LIMSEQ_def)
qed
text ‹MPI converges monotonically to the optimal value from below.
The iterates are sandwiched between @{const ℒ⇩b} from below and @{const U⇩b} from above.›
theorem mpi_conv:
assumes "v0 ≤ ℒ⇩b v0"
shows "mpi_val ⇢ ν⇩b_opt" and "⋀n. mpi_val n ≤ mpi_val (Suc n)"
proof -
define y where "y n = (ℒ⇩b^^n) v0" for n
have aux: "mpi_val n ≤ ℒ⇩b (mpi_val n) ∧ mpi_val n ≤ mpi_val (Suc n) ∧ y n ≤ mpi_val n ∧ mpi_val n ≤ w n" for n
proof (induction n)
case 0
show ?case
using assms B⇩b_eq_ℒ⇩b
unfolding y_def
by (auto simp: mpi_val_eq blinfun.sum_left 𝒫⇩1_n_disc_pos blincomp_scaleR_right sum_nonneg)
next
case (Suc n)
have val_eq_W: "mpi_val (Suc n) = W (mk_dec_det (mpi_pol n)) (Suc (m n (mpi_val n))) (mpi_val n)"
using ν_improving_mpi_pol mpi_val_Suc W_eq_L_iter L_pow_def by auto
hence *: "mpi_val (Suc n) ≤ ℒ⇩b (mpi_val (Suc n))"
using Suc.IH ℒ⇩b_W_ge ν_improving_mpi_pol by presburger
moreover have "mpi_val (Suc n) ≤ mpi_val (Suc (Suc n))"
using *
by (simp add: B⇩b_eq_ℒ⇩b mpi_val_eq 𝒫⇩1_n_disc_pos blincomp_scaleR_right blinfun.sum_left sum_nonneg)
moreover have "mpi_val (Suc n) ≤ w (Suc n)"
using Suc.IH ν_improving_mpi_pol by (auto simp: val_eq_W intro: order.trans[OF _ W_le_U⇩b])
moreover have "y (Suc n) ≤ mpi_val (Suc n)"
using Suc.IH ν_improving_mpi_pol W_ge_ℒ⇩b by (auto simp: y_def B⇩b_eq_ℒ⇩b val_eq_W)
ultimately show ?case
by auto
qed
thus "mpi_val n ≤ mpi_val (Suc n)" for n
by auto
have "y ⇢ ν⇩b_opt"
using ℒ⇩b_lim y_def by presburger
thus "mpi_val ⇢ ν⇩b_opt"
using aux by (auto intro: tendsto_bfun_sandwich[OF _ w_conv])
qed
subsection ‹$\epsilon$-Optimality›
text ‹This gives an upper bound on the error of MPI.›
lemma mpi_pol_eps_opt:
assumes "2 * l * dist (mpi_val n) (ℒ⇩b (mpi_val n)) < eps * (1 - l)" "eps > 0"
shows "dist (ν⇩b (mk_stationary_det (mpi_pol n))) (ℒ⇩b (mpi_val n)) ≤ eps / 2"
proof -
let ?p = "mk_stationary_det (mpi_pol n)"
let ?d = "mk_dec_det (mpi_pol n)"
let ?v = "mpi_val n"
have "dist (ν⇩b ?p) (ℒ⇩b ?v) = dist (L ?d (ν⇩b ?p)) (ℒ⇩b ?v)"
using L_ν_fix by force
also have "… = dist (L ?d (ν⇩b ?p)) (L ?d ?v)"
by (metis ν_improving_imp_ℒ⇩b ν_improving_mpi_pol)
also have "… ≤ dist (L ?d (ν⇩b ?p)) (L ?d (ℒ⇩b ?v)) + dist (L ?d (ℒ⇩b ?v)) (L ?d ?v)"
using dist_triangle by blast
also have "… ≤ l * dist (ν⇩b ?p) (ℒ⇩b ?v) + dist (L ?d (ℒ⇩b ?v)) (L ?d ?v)"
using contraction_L by auto
also have "… ≤ l * dist (ν⇩b ?p) (ℒ⇩b ?v) + l * dist (ℒ⇩b ?v) ?v"
using contraction_L by auto
finally have "dist (ν⇩b ?p) (ℒ⇩b ?v) ≤ l * dist (ν⇩b ?p) (ℒ⇩b ?v) + l * dist (ℒ⇩b ?v) ?v".
hence *:"(1-l) * dist (ν⇩b ?p) (ℒ⇩b ?v) ≤ l * dist (ℒ⇩b ?v) ?v"
by (auto simp: left_diff_distrib)
thus ?thesis
proof (cases "l = 0")
case True
thus ?thesis
using assms * by auto
next
case False
have **: "dist (ℒ⇩b ?v) (mpi_val n) < eps * (1 - l) / (2 * l)"
using False le_neq_trans[OF zero_le_disc False[symmetric]] assms
by (auto simp: dist_commute pos_less_divide_eq Groups.mult_ac(2))
have "dist (ν⇩b ?p) (ℒ⇩b ?v) ≤ (l/ (1-l)) * dist (ℒ⇩b ?v) ?v"
using * by (auto simp: mult.commute pos_le_divide_eq)
also have "… ≤ (l/ (1-l)) * (eps * (1 - l) / (2 * l))"
using ** by (fastforce intro!: mult_left_mono simp: divide_nonneg_pos)
also have "… = eps / 2"
using False disc_lt_one by (auto simp: order.strict_iff_order)
finally show "dist (ν⇩b ?p) (ℒ⇩b ?v) ≤ eps / 2".
qed
qed
lemma mpi_pol_opt:
assumes "2 * l * dist (mpi_val n) (ℒ⇩b (mpi_val n)) < eps * (1 - l)" "eps > 0"
shows "dist (ν⇩b (mk_stationary_det (mpi_pol n))) (ν⇩b_opt) < eps"
proof -
have "dist (ν⇩b (mk_stationary_det (mpi_pol n))) (ν⇩b_opt) ≤ eps/2 + dist (ℒ⇩b (mpi_val n)) ν⇩b_opt"
by (metis mpi_pol_eps_opt[OF assms] dist_commute dist_triangle_le add_right_mono)
thus ?thesis
using dist_ℒ⇩b_opt_eps assms by fastforce
qed
lemma mpi_val_term_ex:
assumes "v0 ≤ ℒ⇩b v0" "eps > 0"
shows "∃n. 2 * l * dist (mpi_val n) (ℒ⇩b (mpi_val n)) < eps * (1 - l)"
proof -
have "(λn. dist (mpi_val n) ν⇩b_opt) ⇢ 0"
using mpi_conv(1)[OF assms(1)] tendsto_dist_iff
by blast
hence "(λn. dist (mpi_val n) (ℒ⇩b (mpi_val n))) ⇢ 0"
using dist_ℒ⇩b_lt_dist_opt
by (auto simp: metric_LIMSEQ_I intro: tendsto_sandwich[of "λ_. 0" _ _ "λn. 2 * dist (mpi_val n) ν⇩b_opt"])
hence "∀e >0. ∃n. dist (mpi_val n) (ℒ⇩b (mpi_val n)) < e"
by (fastforce dest!: metric_LIMSEQ_D)
hence "l ≠ 0 ⟹ ∃n. dist (mpi_val n) (ℒ⇩b (mpi_val n)) < eps * (1 - l) / (2 * l)"
by (simp add: assms order.not_eq_order_implies_strict)
thus "∃n. (2 * l) * dist (mpi_val n) (ℒ⇩b (mpi_val n)) < eps * (1 - l)"
using assms le_neq_trans[OF zero_le_disc]
by (cases "l = 0") (auto simp: mult.commute pos_less_divide_eq)
qed
end
subsection ‹Unbounded MPI›
context
fixes eps δ :: real and M :: nat
begin
function (domintros) mpi_algo where "mpi_algo d v m = (
if 2 * l * dist v (ℒ⇩b v) < eps * (1 - l)
then (find_policy' v, v)
else mpi_algo (find_policy' v) (L_pow v (find_policy' v) (Suc (m 0 v))) (λn. m (Suc n)))"
by auto
text ‹We define a tailrecursive version of @{const mpi} which more closely resembles @{const mpi_algo}.›
fun mpi' where
"mpi' d v 0 m = (find_policy' v, v)" |
"mpi' d v (Suc n) m = (
let d' = find_policy' v; v' = L_pow v d' (Suc (m 0 v)) in mpi' d' v' n (λn. m (Suc n)))"
lemma mpi_Suc':
assumes "d ∈ D⇩D"
shows "mpi v m (Suc n) = mpi (L_pow v (find_policy' v) (Suc (m 0 v))) (λa. m (Suc a)) n"
using assms
by (induction n rule: nat.induct) (auto simp: Let_def)
lemma
assumes "d ∈ D⇩D"
shows "mpi v m n = mpi' d v n m"
using assms
proof (induction n arbitrary: d v m rule: nat.induct)
case (Suc nat)
thus ?case
using find_policy'_is_dec_det by (fastforce simp: Let_def mpi_Suc'[OF Suc(2)])
qed auto
lemma termination_mpi_algo:
assumes "eps > 0" "d ∈ D⇩D" "v ≤ ℒ⇩b v"
shows "mpi_algo_dom (d, v, m)"
proof -
define n where "n = (LEAST n. 2 * l * dist (mpi_val v m n) (ℒ⇩b (mpi_val v m n)) < eps * (1 - l))" (is "n = (LEAST n. ?P d v m n)")
have least0: "∃n. P n ⟹ (LEAST n. P n) = (0 :: nat) ⟹ P 0" for P
by (metis LeastI_ex)
from n_def assms show ?thesis
proof (induction n arbitrary: v d m)
case 0
have "2 * l * dist (mpi_val v m 0) (ℒ⇩b (mpi_val v m 0)) < eps * (1 - l)"
using least0 mpi_val_term_ex 0 by (metis (no_types, lifting))
thus ?case
using 0 mpi_algo.domintros mpi_val_zero by (metis (no_types, opaque_lifting))
next
case (Suc n v d m)
let ?d = "find_policy' v"
have "Suc n = Suc (LEAST n. 2 * l * dist (mpi_val v m (Suc n)) (ℒ⇩b (mpi_val v m (Suc n))) < eps * (1 - l))"
using mpi_val_term_ex[OF Suc.prems(3) ‹v ≤ ℒ⇩b v› ‹0 < eps›, of m] Suc.prems
by (subst Nat.Least_Suc[symmetric]) (auto intro: LeastI_ex)
hence "n = (LEAST n. 2 * l * dist (mpi_val v m (Suc n)) (ℒ⇩b (mpi_val v m (Suc n))) < eps * (1 - l))"
by auto
hence n_eq: "n =
(LEAST n. 2 * l * dist (mpi_val (L_pow v ?d (Suc (m 0 v))) (λa. m (Suc a)) n) (ℒ⇩b (mpi_val (L_pow v ?d (Suc (m 0 v))) (λa. m (Suc a)) n))
< eps * (1 - l))"
using Suc.prems mpi_Suc' by (auto simp: is_dec_det_pi mpi_val_def)
have "¬ 2 * l * dist v (ℒ⇩b v) < eps * (1 - l)"
using Suc mpi_val_zero by force
moreover have "mpi_algo_dom (?d, L_pow v ?d (Suc (m 0 v)), λa. m (Suc a))"
apply (rule Suc.IH[OF n_eq ‹0 < eps›])
using Suc.prems is_dec_det_pi L_pow_ℒ⇩b_mono_inv' find_policy'_is_dec_det by auto
ultimately show ?case
using mpi_algo.domintros by blast
qed
qed
abbreviation "mpi_alg_rec d v m ≡
(if 2 * l * dist v (ℒ⇩b v) < eps * (1 - l) then (find_policy' v, v)
else mpi_algo (find_policy' v) (L_pow v (find_policy' v) (Suc (m 0 v)))
(λn. m (Suc n)))"
lemma mpi_algo_def':
assumes "d ∈ D⇩D" "v ≤ ℒ⇩b v" "eps > 0"
shows "mpi_algo d v m = mpi_alg_rec d v m"
using mpi_algo.psimps termination_mpi_algo assms by auto
lemma mpi_algo_def'':
assumes "d ∈ D⇩D" "v ≤ ℒ⇩b v" "eps > 0"
shows "mpi_algo d v m = (
let v' = ℒ⇩b v; d' = find_policy' v in
if 2 * l * dist v v' < eps * (1 - l)
then (d', v)
else mpi_algo d' (L_pow v' d' ((m 0 v))) (λn. m (Suc n)))"
proof -
have "ν_improving v (mk_dec_det (find_policy' v))"
using ν_improving_opt_acts find_policy'_def by presburger
hence aux: "L_pow (ℒ⇩b v) (find_policy' v) n = L_pow v (find_policy' v) (Suc n)" for n
using ‹d ∈ D⇩D› ν_improving_imp_ℒ⇩b
by (auto simp: funpow_swap1 L_pow_def)
show ?thesis
unfolding mpi_algo_def'[OF assms] Let_def aux[symmetric] by auto
qed
lemma mpi_algo_eq_mpi:
assumes "d ∈ D⇩D" "v ≤ ℒ⇩b v" "eps > 0"
shows "mpi_algo d v m = mpi v m (LEAST n. 2 * l * dist (mpi_val v m n) (ℒ⇩b (mpi_val v m n)) < eps * (1 - l))"
proof -
define n where "n = (LEAST n. 2 * l * dist (mpi_val v m n) (ℒ⇩b (mpi_val v m n)) < eps * (1 - l))" (is "n = (LEAST n. ?P d v m n)")
from n_def assms show ?thesis
proof (induction n arbitrary: d v m)
case 0
have "?P d v m 0"
by (metis (no_types, lifting) assms(3) LeastI_ex 0 mpi_val_term_ex)
thus ?case
using assms 0 by (auto simp: mpi_val_def mpi_algo_def')
next
case (Suc n)
hence not0: "¬ (2 * l * dist v (ℒ⇩b v) < eps * (1 - l))"
using Suc(3) mpi_val_zero by auto
obtain n' where "2 * l * dist (mpi_val v m n') (ℒ⇩b (mpi_val v m n')) < eps * (1 - l)"
using mpi_val_term_ex[OF Suc(3) Suc(4), of _ m] assms by blast
hence "n = (LEAST n. ?P d v m (Suc n))"
using Suc(2) Suc by (subst (asm) Least_Suc) auto
hence "n = (LEAST n. ?P (find_policy' v) (L_pow v (find_policy' v) (Suc (m 0 v))) (λn. m (Suc n)) n)"
using Suc(3) mpi_Suc' by (auto simp: mpi_val_def)
hence "mpi_algo d v m = mpi v m (Suc n)"
unfolding mpi_algo_def'[OF Suc.prems(2-4)]
using Suc(1) Suc.prems(2-4) is_dec_det_pi mpi_Suc' not0 L_pow_ℒ⇩b_mono_inv' find_policy'_is_dec_det
by fastforce
thus ?case
using Suc.prems(1) by presburger
qed
qed
lemma mpi_algo_opt:
assumes "v0 ≤ ℒ⇩b v0" "eps > 0" "d ∈ D⇩D"
shows "dist (ν⇩b (mk_stationary_det (fst (mpi_algo d v0 m)))) ν⇩b_opt < eps"
proof -
let ?P = "λn. 2 * l * dist (mpi_val v0 m n) (ℒ⇩b (mpi_val v0 m n)) < eps * (1 - l)"
let ?n = "Least ?P"
have "mpi_algo d v0 m = mpi v0 m ?n" and "?P ?n"
using mpi_algo_eq_mpi LeastI_ex[OF mpi_val_term_ex] assms by auto
thus ?thesis
using assms by (auto simp: mpi_pol_opt mpi_pol_def[symmetric])
qed
end
subsection ‹Initial Value Estimate @{term v0_mpi}›
text ‹We define an initial estimate of the value function for which Modified Policy Iteration
always terminates.›
abbreviation "r_min ≡ (⨅s'. (⨅a ∈ A s'. r (s', a)))"
definition "v0_mpi s = r_min / (1 - l)"
lift_definition v0_mpi⇩b :: "'s ⇒⇩b real" is "v0_mpi"
by (auto simp: v0_mpi_def)
lemma v0_mpi⇩b_le_ℒ⇩b: "v0_mpi⇩b ≤ ℒ⇩b v0_mpi⇩b"
proof (rule less_eq_bfunI)
fix x
have bounded_r': "bounded ((λa. r (x, a)) ` A x)" for x
using r_bounded'
unfolding bounded_def
by simp (meson UNIV_I)
have *: "(⨅a∈A x. r (x, a)) ≤ r (x,a)" if "a ∈ A x" for a x
using bounded_r' that
by (auto intro!: cInf_lower bounded_imp_bdd_below)
have ****: "r (s,a) ≤ r⇩M" for s a
using abs_le_iff abs_r_le_r⇩M by blast
have **: "bounded (range (λs'. ⨅a∈A s'. r (s', a)))"
using abs_r_le_r⇩M ex_dec_det is_dec_det_def A_ne
by (auto simp add: minus_le_iff abs_le_iff intro!: cINF_greatest order.trans[OF *] boundedI[of _ "r⇩M"])
have "r_min ≤ r (s, a)" if "a ∈ A s" for s a
using r_bounded' that **
by (auto intro!: bounded_imp_bdd_below cInf_lower2[OF _ *])
hence "r_min ≤ (1-l) * r (s, a) + l * r_min" if "a ∈ A s" for s a
using disc_lt_one zero_le_disc that by (meson order_less_imp_le order_refl segment_bound_lemma)
hence "r_min / (1 - l) ≤ ((1-l) * r (s, a) + l * r_min) / (1 - l)" if "a ∈ A s" for s a
using order_less_imp_le[OF disc_lt_one] that by (auto intro!: divide_right_mono)
hence "r_min / (1 - l) ≤ r (s, a) + (l * r_min) / (1 - l)" if "a ∈ A s" for s a
using disc_lt_one that by (auto simp: add_divide_distrib)
hence "r_min / (1 - l) ≤ L⇩a (arb_act (A x)) (λs. r_min / (1 - l)) x"
using A_ne arb_act_in by auto
moreover have "bdd_above ((λa. L⇩a a (λs. r_min / (1 - l)) x) ` A x)"
using r_bounded
by (fastforce simp: bounded_def intro!: bounded_imp_bdd_above bounded_plus_comp)
ultimately show "v0_mpi⇩b x ≤ ℒ⇩b v0_mpi⇩b x"
unfolding ℒ⇩b_eq_SUP_L⇩a' v0_mpi⇩b.rep_eq v0_mpi_def by (auto simp: A_ne intro!: cSUP_upper2)
qed
subsection ‹An Instance of Modified Policy Iteration with a Valid Conservative Initial Value Estimate›
definition "mpi_user eps m = (
if eps ≤ 0 then undefined else mpi_algo eps (λx. arb_act (A x)) v0_mpi⇩b m)"
lemma mpi_user_eq:
assumes "eps > 0"
shows "mpi_user eps = mpi_alg_rec eps (λx. arb_act (A x)) v0_mpi⇩b"
using v0_mpi⇩b_le_ℒ⇩b assms
by (auto simp: mpi_user_def mpi_algo_def' A_ne is_dec_det_def)
lemma mpi_user_opt:
assumes "eps > 0"
shows "dist (ν⇩b (mk_stationary_det (fst (mpi_user eps n)))) ν⇩b_opt < eps"
unfolding mpi_user_def using assms
by (auto intro: mpi_algo_opt simp: is_dec_det_def A_ne v0_mpi⇩b_le_ℒ⇩b)
end
end