Theory Value_Iteration
theory Value_Iteration
imports "MDP-Rewards.MDP_reward"
begin
context MDP_att_ℒ
begin
section ‹Value Iteration›
text ‹
In the previous sections we derived that repeated application of @{const "ℒ⇩b"} to any bounded
function from states to the reals converges to the optimal value of the MDP @{const "ν⇩b_opt"}.
We can turn this procedure into an algorithm that computes not only an approximation of
@{const "ν⇩b_opt"} but also a policy that is arbitrarily close to optimal.
Most of the proofs rely on the assumption that the supremum in @{const "ℒ⇩b"} can always be attained.
›
text ‹
The following lemma shows that the relation we use to prove termination of the value iteration
algorithm decreases in each step.
In essence, the distance of the estimate to the optimal value decreases by a factor of at
least @{term l} per iteration.›
abbreviation "term_measure ≡ (λ(eps, v). LEAST n. (2 * l * dist ((ℒ⇩b^^(Suc n)) v) ((ℒ⇩b^^n) v) < eps * (1-l)))"
lemma Least_Suc_less:
assumes "∃n. P n" "¬P 0"
shows "Least (λn. P (Suc n)) < Least P"
using assms by (auto simp: Least_Suc)
function value_iteration :: "real ⇒ ('s ⇒⇩b real) ⇒ ('s ⇒⇩b real)" where
"value_iteration eps v =
(if 2 * l * dist v (ℒ⇩b v) < eps * (1-l) ∨ eps ≤ 0 then ℒ⇩b v else value_iteration eps (ℒ⇩b v))"
by auto
termination
proof (relation "Wellfounded.measure term_measure")
fix eps v
assume h: "¬ (2 * l * dist v (ℒ⇩b v) < eps * (1 - l) ∨ eps ≤ 0)"
show "((eps, ℒ⇩b v), eps, v) ∈ Wellfounded.measure term_measure"
proof -
have "(λn. dist ((ℒ⇩b ^^ Suc n) v) ((ℒ⇩b ^^ n) v)) ⇢ 0"
using dist_ℒ⇩b_tendsto
by (auto simp: dist_commute)
hence *: "∃n. dist ((ℒ⇩b ^^ Suc n) v) ((ℒ⇩b ^^ n) v) < eps" if "eps > 0" for eps
unfolding LIMSEQ_def using that by auto
have **: "0 < l * 2" if "0 ≠ l"
using zero_le_disc that by linarith
hence "(LEAST n. (2 * l) * dist ((ℒ⇩b ^^ (Suc (Suc n))) v) ((ℒ⇩b ^^ (Suc n)) v) < eps * (1 - l)) <
(LEAST n. (2 * l) * dist ((ℒ⇩b ^^ Suc n) v) ((ℒ⇩b ^^ n) v) < eps * (1 - l))" if "0 ≠ l"
using h *[of "eps * (1-l) / (2 * l)"] that
by (fastforce simp: ** algebra_simps dist_commute pos_less_divide_eq intro!: Least_Suc_less)
thus ?thesis
using h by (cases "l = 0") (auto simp: funpow_swap1)
qed
qed auto
text ‹
The distance between an estimate for the value and the optimal value can be bounded with respect to
the distance between the estimate and the result of applying it to @{const ℒ⇩b}
›
lemma contraction_ℒ_dist: "(1 - l) * dist v ν⇩b_opt ≤ dist v (ℒ⇩b v)"
using contraction_dist contraction_ℒ disc_lt_one zero_le_disc by fastforce
lemma dist_ℒ⇩b_opt_eps:
assumes "eps > 0" "2 * l * dist v (ℒ⇩b v) < eps * (1-l)"
shows "2 * dist (ℒ⇩b v) ν⇩b_opt < eps"
proof -
have "2 * l * dist v ν⇩b_opt * (1 - l) ≤ 2 * l * dist v (ℒ⇩b v)"
using contraction_ℒ_dist by (simp add: mult_left_mono mult.commute)
hence "2 * l * dist v ν⇩b_opt * (1 - l) < eps * (1-l)"
using assms(2) by linarith
hence "2 * l * dist v ν⇩b_opt < eps"
by force
thus "2 * dist (ℒ⇩b v) ν⇩b_opt < eps"
using contraction_ℒ[of v ν⇩b_opt] by auto
qed
lemma dist_ℒ⇩b_lt_dist_opt: "dist v (ℒ⇩b v) ≤ 2 * dist v ν⇩b_opt"
proof -
have le1: "dist v (ℒ⇩b v) ≤ dist v ν⇩b_opt + dist (ℒ⇩b v) ν⇩b_opt"
by (simp add: dist_triangle dist_commute)
have le2: "dist (ℒ⇩b v) ν⇩b_opt ≤ l * dist v ν⇩b_opt"
using ℒ⇩b_opt contraction_ℒ by metis
show ?thesis
using mult_right_mono[of l 1] disc_lt_one
by (fastforce intro!: order.trans[OF le2] order.trans[OF le1])
qed
text ‹
The estimates above allow to give a bound on the error of @{const value_iteration}.
›
declare value_iteration.simps[simp del]
lemma value_iteration_error:
assumes "eps > 0"
shows "2 * dist (value_iteration eps v) ν⇩b_opt < eps"
using assms dist_ℒ⇩b_opt_eps value_iteration.simps
by (induction eps v rule: value_iteration.induct) auto
text ‹
After the value iteration terminates, one can easily obtain a stationary deterministic
epsilon-optimal policy.
Such a policy does not exist in general, attainment of the supremum in @{const ℒ⇩b} is required.
›
definition "find_policy (v :: 's ⇒⇩b real) s = arg_max_on (λa. L⇩a a v s) (A s)"
definition "vi_policy eps v = find_policy (value_iteration eps v)"
abbreviation "vi u n ≡ (ℒ⇩b ^^ n) u"
lemma ℒ⇩b_iter_mono:
assumes "u ≤ v" shows "vi u n ≤ vi v n"
using assms ℒ⇩b_mono by (induction n) auto
lemma
assumes "vi v (Suc n) ≤ vi v n"
shows "vi v (Suc n + m) ≤ vi v (n + m)"
proof -
have "vi v (Suc n + m) = vi (vi v (Suc n)) m"
by (simp add: Groups.add_ac(2) funpow_add funpow_swap1)
also have "... ≤ vi (vi v n) m"
using ℒ⇩b_iter_mono[OF assms] by auto
also have "... = vi v (n + m)"
by (simp add: add.commute funpow_add)
finally show ?thesis .
qed
lemma
assumes "vi v n ≤ vi v (Suc n)"
shows "vi v (n + m) ≤ vi v (Suc n + m)"
proof -
have "vi v (n + m) ≤ vi (vi v n) m"
by (simp add: Groups.add_ac(2) funpow_add funpow_swap1)
also have "… ≤ vi v (Suc n + m)"
using ℒ⇩b_iter_mono[OF assms] by (auto simp only: add.commute funpow_add o_apply)
finally show ?thesis .
qed
lemma "(λn. dist (vi v (Suc n)) (vi v n)) ⇢ 0"
using dist_ℒ⇩b_tendsto[of v] by (auto simp: dist_commute)
end
context MDP_att_ℒ
begin
lemma is_arg_max_find_policy: "is_arg_max (λd. L⇩a d (apply_bfun v) s) (λd. d ∈ A s) (find_policy v s)"
using Sup_att
by (simp add: find_policy_def arg_max_on_def arg_max_def someI_ex max_L_ex_def has_arg_max_def)
text ‹
The error of the resulting policy is bounded by the distance from its value to the value computed
by the value iteration plus the error in the value iteration itself.
We show that both are less than @{term "eps / 2"} when the algorithm terminates.
›
lemma find_policy_dist_ℒ⇩b:
assumes "eps > 0" "2 * l * dist v (ℒ⇩b v) < eps * (1-l)"
shows "2 * dist (ν⇩b (mk_stationary_det (find_policy (ℒ⇩b v)))) (ℒ⇩b v) ≤ eps"
proof -
let ?d = "mk_dec_det (find_policy (ℒ⇩b v))"
let ?p = "mk_stationary ?d"
have L_eq_ℒ⇩b: "L (mk_dec_det (find_policy v)) v = ℒ⇩b v" for v
by (auto simp: L_eq_L⇩a_det ℒ⇩b_eq_argmax_L⇩a[OF is_arg_max_find_policy])
have "dist (ν⇩b ?p) (ℒ⇩b v) = dist (L ?d (ν⇩b ?p)) (ℒ⇩b v)"
using L_ν_fix by force
also have "… ≤ dist (L ?d (ν⇩b ?p)) (ℒ⇩b (ℒ⇩b v)) + dist (ℒ⇩b (ℒ⇩b v)) (ℒ⇩b v)"
using dist_triangle by blast
also have "… = dist (L ?d (ν⇩b ?p)) (L ?d (ℒ⇩b v)) + dist (ℒ⇩b (ℒ⇩b v)) (ℒ⇩b v)"
by (auto simp: L_eq_ℒ⇩b)
also have "… ≤ l * dist (ν⇩b ?p) (ℒ⇩b v) + l * dist (ℒ⇩b v) v"
using contraction_ℒ contraction_L by (fastforce intro!: add_mono)
finally have aux: "dist (ν⇩b ?p) (ℒ⇩b v) ≤ l * dist (ν⇩b ?p) (ℒ⇩b v) + l * dist (ℒ⇩b v) v" .
hence "dist (ν⇩b ?p) (ℒ⇩b v) * (1 - l) ≤ l * dist (ℒ⇩b v) v"
by (auto simp: algebra_simps)
hence *: "2 * dist (ν⇩b ?p) (ℒ⇩b v) * (1 - l) ≤ 2 * l * dist (ℒ⇩b v) v"
using zero_le_disc mult_left_mono by auto
hence "2 * dist (ν⇩b ?p) (ℒ⇩b v) * (1 - l) ≤ eps * (1 - l)"
using assms by (fastforce simp: dist_commute intro!: order.trans[OF *])
thus "2 * dist (ν⇩b ?p) (ℒ⇩b v) ≤ eps"
by auto
qed
lemma find_policy_error_bound:
assumes "eps > 0" "2 * l * dist v (ℒ⇩b v) < eps * (1-l)"
shows "dist (ν⇩b (mk_stationary_det (find_policy (ℒ⇩b v)))) ν⇩b_opt < eps"
proof -
let ?p = "mk_stationary_det (find_policy (ℒ⇩b v))"
have "dist (ν⇩b ?p) ν⇩b_opt ≤ dist (ν⇩b ?p) (ℒ⇩b v) + dist (ℒ⇩b v) ν⇩b_opt"
using dist_triangle by blast
thus ?thesis
using find_policy_dist_ℒ⇩b[OF assms] dist_ℒ⇩b_opt_eps[OF assms] by simp
qed
lemma vi_policy_opt:
assumes "0 < eps"
shows "dist (ν⇩b (mk_stationary_det (vi_policy eps v))) ν⇩b_opt < eps"
unfolding vi_policy_def
using assms
proof (induction eps v rule: value_iteration.induct)
case (1 v)
then show ?case
using find_policy_error_bound by (subst value_iteration.simps) auto
qed
lemma lemma_6_3_1_d:
assumes "eps > 0" "2 * l * dist (vi v (Suc n)) (vi v n) < eps * (1-l)"
shows "2 * dist (vi v (Suc n)) ν⇩b_opt < eps"
using dist_ℒ⇩b_opt_eps assms by (simp add: dist_commute)
end
context MDP_act_disc begin
definition "find_policy' (v :: 's ⇒⇩b real) s = arb_act (opt_acts v s)"
definition "vi_policy' eps v = find_policy' (value_iteration eps v)"
lemma is_arg_max_find_policy': "is_arg_max (λd. L⇩a d (apply_bfun v) s) (λd. d ∈ A s) (find_policy' v s)"
using is_opt_act_some by (auto simp: find_policy'_def is_opt_act_def)
lemma find_policy'_dist_ℒ⇩b:
assumes "eps > 0" "2 * l * dist v (ℒ⇩b v) < eps * (1-l)"
shows "2 * dist (ν⇩b (mk_stationary_det (find_policy' (ℒ⇩b v)))) (ℒ⇩b v) ≤ eps"
proof -
let ?d = "mk_dec_det (find_policy' (ℒ⇩b v))"
let ?p = "mk_stationary ?d"
have L_eq_ℒ⇩b: "L (mk_dec_det (find_policy' v)) v = ℒ⇩b v" for v
by (auto simp: L_eq_L⇩a_det ℒ⇩b_eq_argmax_L⇩a[OF is_arg_max_find_policy'])
have "dist (ν⇩b ?p) (ℒ⇩b v) = dist (L ?d (ν⇩b ?p)) (ℒ⇩b v)"
using L_ν_fix by force
also have "… ≤ dist (L ?d (ν⇩b ?p)) (ℒ⇩b (ℒ⇩b v)) + dist (ℒ⇩b (ℒ⇩b v)) (ℒ⇩b v)"
using dist_triangle by blast
also have "… = dist (L ?d (ν⇩b ?p)) (L ?d (ℒ⇩b v)) + dist (ℒ⇩b (ℒ⇩b v)) (ℒ⇩b v)"
by (auto simp: L_eq_ℒ⇩b)
also have "… ≤ l * dist (ν⇩b ?p) (ℒ⇩b v) + l * dist (ℒ⇩b v) v"
using contraction_ℒ contraction_L by (fastforce intro!: add_mono)
finally have aux: "dist (ν⇩b ?p) (ℒ⇩b v) ≤ l * dist (ν⇩b ?p) (ℒ⇩b v) + l * dist (ℒ⇩b v) v" .
hence "dist (ν⇩b ?p) (ℒ⇩b v) * (1 - l) ≤ l * dist (ℒ⇩b v) v"
by (auto simp: algebra_simps)
hence *: "2 * dist (ν⇩b ?p) (ℒ⇩b v) * (1 - l) ≤ 2 * l * dist (ℒ⇩b v) v"
using zero_le_disc mult_left_mono by auto
hence "2 * dist (ν⇩b ?p) (ℒ⇩b v) * (1 - l) ≤ eps * (1 - l)"
using assms by (fastforce simp: dist_commute intro!: order.trans[OF *])
thus "2 * dist (ν⇩b ?p) (ℒ⇩b v) ≤ eps"
by auto
qed
lemma find_policy'_error_bound:
assumes "eps > 0" "2 * l * dist v (ℒ⇩b v) < eps * (1-l)"
shows "dist (ν⇩b (mk_stationary_det (find_policy' (ℒ⇩b v)))) ν⇩b_opt < eps"
proof -
let ?p = "mk_stationary_det (find_policy' (ℒ⇩b v))"
have "dist (ν⇩b ?p) ν⇩b_opt ≤ dist (ν⇩b ?p) (ℒ⇩b v) + dist (ℒ⇩b v) ν⇩b_opt"
using dist_triangle by blast
thus ?thesis
using find_policy'_dist_ℒ⇩b[OF assms] dist_ℒ⇩b_opt_eps[OF assms] by simp
qed
lemma vi_policy'_opt:
assumes "eps > 0" "l > 0"
shows "dist (ν⇩b (mk_stationary_det (vi_policy' eps v))) ν⇩b_opt < eps"
unfolding vi_policy'_def
using assms
proof (induction eps v rule: value_iteration.induct)
case (1 eps v)
then show ?case
using find_policy'_error_bound by (auto simp: value_iteration.simps[of _ v])
qed
end
end