Theory VI_Code
theory VI_Code
imports
Code_Setup
"../Value_Iteration"
"HOL-Library.Code_Target_Numeral"
begin
context MDP_Code begin
partial_function (tailrec) VI_code_aux where
"VI_code_aux v eps = (
let v' = ℒ_code v in
if check_dist v v' eps
then v'
else VI_code_aux v' eps)"
lemmas VI_code_aux.simps[code]
definition "VI_code v eps = (if l = 0 ∨ eps ≤ 0 then ℒ_code v else VI_code_aux v eps)"
lemma VI_code_aux_correct_aux:
assumes "eps > 0" "v_invar v" "v_len v = states" "l ≠ 0"
shows "V_Map.map_to_fun (VI_code_aux v eps) = MDP.value_iteration eps (V_Map.map_to_bfun v)
∧ v_len (VI_code_aux v eps) = states
∧ v_invar (VI_code_aux v eps)"
using assms
proof (induction eps "V_Map.map_to_bfun v" arbitrary: v rule: MDP.value_iteration.induct)
case (1 eps)
have *: "(check_dist v (ℒ_code v) eps) ⟷ 2 * l * dist (V_Map.map_to_bfun v) (MDP.ℒ⇩b (V_Map.map_to_bfun v)) < eps * (1 - l)"
proof (subst check_dist_correct)
have " 0 < l" using 1 MDP.zero_le_disc by linarith
thus "(dist (V_Map.map_to_bfun v) (V_Map.map_to_bfun (ℒ_code v)) < eps * (1 - l) / (2 * l)) =
(2 * l * dist (V_Map.map_to_bfun v) (MDP.ℒ⇩b (V_Map.map_to_bfun v)) < eps * (1 - l))"
by (subst pos_less_divide_eq) (fastforce simp: ℒ_code_correct' 1 algebra_simps)+
qed (auto simp: 1 intro: invar_ℒ_code)
hence **: "V_Map.map_to_fun (VI_code_aux v eps) = (MDP.value_iteration eps (V_Map.map_to_bfun (ℒ_code v)))" if "¬ (check_dist v (ℒ_code v) eps)"
using invar_ℒ_code 1 that by (auto simp: VI_code_aux.simps ℒ_code_correct')
have "V_Map.map_to_fun (VI_code_aux v eps) = (MDP.value_iteration eps (V_Map.map_to_bfun v))"
proof (cases "(check_dist v (ℒ_code v) eps)")
case True
thus ?thesis
using 1 invar_ℒ_code
by (auto simp: MDP.value_iteration.simps VI_code_aux.simps[of v] * map_to_bfun_eq_fun[symmetric] ℒ_code_correct')
next
case False
thus ?thesis
using 1 ℒ_code_correct' ** * MDP.value_iteration.simps by auto
qed
thus ?case
using 1 VI_code_aux.simps ℒ_code_correct' * invar_ℒ_code by auto
qed
lemma VI_code_aux_correct:
assumes "eps > 0" "v_invar v" "v_len v = states" "l ≠ 0"
shows "V_Map.map_to_fun (VI_code_aux v eps) = MDP.value_iteration eps (V_Map.map_to_bfun v)"
using assms VI_code_aux_correct_aux by auto
lemma VI_code_aux_keys:
assumes "eps > 0" "v_invar v" "v_len v = states" "l ≠ 0"
shows "v_len (VI_code_aux v eps) = states"
using assms VI_code_aux_correct_aux by auto
lemma VI_code_aux_invar:
assumes "eps > 0" "v_invar v" "v_len v = states" "l ≠ 0"
shows "v_invar (VI_code_aux v eps)"
using assms VI_code_aux_correct_aux by auto
lemma VI_code_correct:
assumes "eps > 0" "v_invar v" "v_len v = states"
shows "V_Map.map_to_fun (VI_code v eps) = MDP.value_iteration eps (V_Map.map_to_bfun v)"
proof (cases "l = 0")
case True
then show ?thesis
using assms invar_ℒ_code ℒ_code_correct'
unfolding VI_code_def MDP.value_iteration.simps[of _ "V_Map.map_to_bfun v"]
by (fastforce simp: map_to_bfun_eq_fun)
next
case False
then show ?thesis
using assms
by (auto simp add: VI_code_def VI_code_aux_correct)
qed
definition "VI_policy_code v eps = vi_find_policy_code (VI_code v eps)"
lemma VI_policy_code_correct:
assumes "eps > 0" "v_invar v" "v_len v = states"
shows "D_Map.map_to_fun (VI_policy_code v eps) = MDP.vi_policy' eps (V_Map.map_to_bfun v)"
proof -
have "V_Map.map_to_bfun (VI_code v eps) = (MDP.value_iteration eps (V_Map.map_to_bfun v))"
using assms VI_code_correct
by (auto simp: VI_code_aux_invar map_to_bfun_eq_fun)
moreover have "D_Map.map_to_fun (VI_policy_code v eps) = MDP.find_policy' (V_Map.map_to_bfun (VI_code v eps))"
unfolding VI_code_def VI_policy_code_def
using assms invar_ℒ_code keys_ℒ_code vi_find_policy_correct vi_find_policy_correct VI_code_aux_correct_aux assms by (cases "l = 0") auto
ultimately show ?thesis
unfolding MDP.vi_policy'_def
by presburger
qed
end
context MDP_nat_disc
begin
lemma dist_opt_bound_ℒ⇩b: "dist v ν⇩b_opt ≤ dist v (ℒ⇩b v) / (1 - l)"
using contraction_ℒ_dist
by (simp add: mult.commute mult_imp_le_div_pos)
lemma cert_ℒ⇩b:
assumes "ε ≥ 0" "dist v (ℒ⇩b v) / (1 - l) ≤ ε"
shows "dist v ν⇩b_opt ≤ ε"
using assms dist_opt_bound_ℒ⇩b order_trans by auto
definition "check_value_ℒ⇩b eps v ⟷ dist v (ℒ⇩b v) / (1 - l) ≤ eps"
definition "vi_policy_bound_error v = (
let v' = (ℒ⇩b v); err = (2 * l) * dist v v' / (1 - l) in
(err, find_policy' v'))"
lemma
assumes "vi_policy_bound_error v = (err, d)"
shows "dist (ν⇩b (mk_stationary_det d)) ν⇩b_opt ≤ err"
proof (cases "l = 0")
case True
hence "vi_policy_bound_error v = (0, find_policy' (ℒ⇩b v))"
unfolding vi_policy_bound_error_def by auto
have "ℒ⇩b v = ℒ⇩b ν⇩b_opt"
by (auto simp: ℒ⇩b.rep_eq L_def simp del: ℒ⇩b_opt intro!: bfun_eqI simp: ℒ_def) (simp add: True)
hence "ℒ⇩b v = ν⇩b_opt"
by auto
hence "ν⇩b (mk_stationary_det (find_policy' (ℒ⇩b v))) = ν⇩b_opt"
using L_ν_fix ν_improving_opt_acts conserving_imp_opt
unfolding find_policy'_def ν_conserving_def
by auto
then show ?thesis
using assms unfolding vi_policy_bound_error_def
by (auto simp: True)
next
case False
then show ?thesis
proof (cases "ℒ⇩b v = v")
case True
hence "ν⇩b (mk_stationary_det (find_policy' (ℒ⇩b v))) = ν⇩b_opt"
using L_ν_fix ν_improving_opt_acts conserving_imp_opt
unfolding find_policy'_def ν_conserving_def
by auto
then show ?thesis
using assms unfolding vi_policy_bound_error_def
by (auto simp: True)
next
case False
hence 1: "dist v (ℒ⇩b v) > 0"
by fastforce
hence "2 * l * dist v (ℒ⇩b v) > 0"
using ‹l ≠ 0› zero_le_disc by (simp add: less_le)
hence "err > 0"
using assms unfolding vi_policy_bound_error_def by auto
hence "dist (ν⇩b (mk_stationary_det (find_policy' (ℒ⇩b v)))) ν⇩b_opt < err'" if "err < err'" for err'
using that assms
unfolding vi_policy_bound_error_def
by (auto simp: pos_divide_less_eq[symmetric] intro: find_policy'_error_bound)
then show ?thesis
using assms unfolding vi_policy_bound_error_def Let_def
by force
qed
qed
end
context MDP_Code
begin
definition "vi_policy_bound_error_code v = (
let v' = (ℒ_code v);
d = if states = 0 then 0 else (MAX s ∈ {..< states}. dist (v_lookup v s) (v_lookup v' s));
err = (2 * l) * d / (1 - l) in
(err, vi_find_policy_code v'))"
lemma
assumes "v_len v = states" "v_invar v"
shows "D_Map.map_to_fun (snd (vi_policy_bound_error_code v)) = snd (MDP.vi_policy_bound_error (V_Map.map_to_bfun v))"
using assms ℒ_code_correct' invar_ℒ_code vi_find_policy_correct
by (auto simp: vi_policy_bound_error_code_def MDP.vi_policy_bound_error_def)
lemma MAX_cong:
assumes "⋀x. x ∈ X ⟹ f x = g x"
shows "(MAX x ∈ X. f x) = (MAX x ∈ X. g x)"
using assms by auto
lemma
assumes "v_len v = states" "v_invar v"
shows "(fst (vi_policy_bound_error_code v)) = fst (MDP.vi_policy_bound_error (V_Map.map_to_bfun v))"
proof-
have dist_zero_ge: "dist (apply_bfun (V_Map.map_to_bfun v) x) (apply_bfun (V_Map.map_to_bfun (ℒ_code v)) x) = 0" if "x ≥ states" for x
using assms that
by (auto simp: V_Map.map_to_bfun.rep_eq)
have univ: "UNIV = {0..<states} ∪ {states..}" by auto
let ?d = "λx. dist (apply_bfun (V_Map.map_to_bfun v) x) (apply_bfun (V_Map.map_to_bfun (ℒ_code v)) x)"
have fin: "finite (range (λx. ?d x))"
by (auto simp: dist_zero_ge univ Set.image_Un Set.image_constant[of states])
have r: "range (λx. ?d x) = ?d ` {..<states} ∪ ?d ` {states..}"
by force
hence "Sup (range ?d) = Max (range ?d)"
using fin cSup_eq_Max by blast
also have "… = (if states = 0 then (Max (?d ` {states..})) else max (Max (?d ` {..<states})) (Max (?d ` {states..})))"
using r fin by (auto intro: Max_Un)
also have "… = (if states = 0 then 0 else max (Max (?d ` {..<states})) 0)"
using dist_zero_ge
by (auto simp: Set.image_constant[of states] cSup_eq_Max[symmetric, of "(λ_. 0) ` {states..}"])
also have "… = (if states = 0 then 0 else (Max (?d ` {..<states})))"
by (auto intro!: max_absorb1 max_geI)
finally have 1: "Sup (range ?d) = (if states = 0 then 0 else (Max (?d ` {..<states})))".
thus ?thesis
unfolding MDP.vi_policy_bound_error_def vi_policy_bound_error_code_def dist_bfun_def
using assms v_lookup_map_to_bfun ℒ_code_correct' ℒ_code_correct
by fastforce
qed
end