Theory Modified_Policy_Iteration

(* Author: Maximilian Schäffeler *)

theory Modified_Policy_Iteration
  imports 
    Policy_Iteration
    Value_Iteration
begin

section ‹Modified Policy Iteration›

locale MDP_MPI = MDP_att_ℒ A K r l + MDP_act_disc arb_act A K r l 
  for A and K :: "'s :: countable × 'a :: countable  's pmf" and r l arb_act
begin

subsection ‹The Advantage Function @{term B}

definition "B v s = (d  DR. (r_dec d s + (l *R 𝒫1 d - id_blinfun) v s))"

text "The function @{const B} denotes the advantage of choosing the optimal action vs.
  the current value estimate"

lemma cSUP_plus:
  assumes "X  {}" "bdd_above (f`X)"
  shows "(x  X. f x + c) = (x  X. f x) + (c::real)"
proof (rule antisym)
  show "(xX. f x + c)   (f ` X) + c"
    using assms by (fastforce intro: cSUP_least cSUP_upper)
  show " (f ` X) + c  (xX. f x + c)"
    unfolding le_diff_eq[symmetric]
    using assms
    by (intro cSUP_least) (auto simp add: algebra_simps bdd_above_def intro!: cSUP_upper2 intro: add_left_mono)+
qed

lemma cSUP_minus:
  assumes "X  {}" "bdd_above (f`X)"
  shows "(x  X. f x - c) = (x  X. f x) - (c::real)"
  using cSUP_plus[OF assms, of "- c"] by auto

lemma B_eq_ℒ: "B v s =  v s - v s"
proof -
  have *: "B v s = (d  DR. L d v s - v s)"
    unfolding B_def L_def by (auto simp add: blinfun.bilinear_simps add_diff_eq)
  have "bdd_above ((λd. L d v s - v s) ` DR)"
    by (auto intro!: bounded_const bounded_minus_comp bounded_imp_bdd_above)
  thus ?thesis
    unfolding * ℒ_def using ex_dec by (fastforce intro!: cSUP_minus)
qed

text @{const B} is a bounded function.›

lift_definition Bb :: "('s b real)  's b real" is "B"
  unfolding B_eq_ℒ using ℒ_bfun by (auto intro: Bounded_Functions.minus_cont)

lemma Bb_eq_ℒb: "Bb v = b v - v"
  by (auto simp: b.rep_eq Bb.rep_eq B_eq_ℒ)

lemma b_eq_SUP_La': "b v s = (a  A s. La a v s)"
  using L_eq_La_det b_eq_SUP_det SUP_step_det_eq
  by auto

subsection ‹Optimization of the Value Function over Multiple Steps›

definition "U m v s = (d  DR. (νb_fin (mk_stationary d) m + ((l *R 𝒫1 d)^^m) v) s)"

text @{const U} expresses the value estimate obtained by optimizing the first @{term m} steps and 
  afterwards using the current estimate.›

lemma U_zero [simp]: "U 0 v = v"
  unfolding U_def ℒ_def by (auto simp: νb_fin.rep_eq)

lemma U_one_eq_ℒ: "U 1 v s =  v s"
  unfolding U_def ℒ_def by (auto simp: νb_fin_eq_𝒫X L_def blinfun.bilinear_simps)

lift_definition Ub :: "nat  ('s b real)  ('s b real)" is U
proof -
  fix n v
  have "norm (νb_fin (mk_stationary d) m)  (i<m. l ^ i * rM)" for d m
    using abs_ν_fin_le νb_fin.rep_eq by (auto intro!: norm_bound)
  moreover have "norm (((l *R 𝒫1 d)^^m) v)  l ^ m * norm v" for d m
    by (auto simp: 𝒫X_const[symmetric] blinfun.bilinear_simps blincomp_scaleR_right 
        intro!: boundedI order.trans[OF abs_le_norm_bfun] mult_left_mono)
  ultimately have *: "norm (νb_fin (mk_stationary d) m + ((l *R 𝒫1 d)^^m) v)  (i<m. l ^ i * rM) +  l ^ m * norm v" for d m
    using norm_triangle_mono by blast
  show "U n v  bfun"
    using ex_dec order.trans[OF abs_le_norm_bfun *]
    by (fastforce simp: U_def intro!: bfun_normI cSup_abs_le)
qed

lemma Ub_contraction: "dist (Ub m v) (Ub m u)  l ^ m * dist v u"
proof -
  have aux: "dist (Ub m v s) (Ub m u s)  l ^ m * dist v u" if le: "Ub m u s  Ub m v s" for s v u
  proof -
    let ?U = "λm v d. (νb_fin (mk_stationary d) m + ((l *R 𝒫1 d) ^^ m) v) s"
    have "Ub m v s - Ub m u s  (d  DR. ?U m v d - ?U m u d)"
      using bounded_stationary_νb_fin bounded_disc_𝒫1 le
      unfolding Ub.rep_eq U_def
      by (intro le_SUP_diff') (auto intro: bounded_plus_comp)
    also have " = (d  DR. ((l *R 𝒫1 d) ^^ m) (v - u) s)"
      by (simp add: L_def scale_right_diff_distrib blinfun.bilinear_simps)
    also have " = (d  DR. l^m * ((𝒫1 d ^^ m) (v - u) s))"
      by (simp add: blincomp_scaleR_right blinfun.scaleR_left)
    also have " = l^m * (d  DR. ((𝒫1 d ^^ m) (v - u) s))" 
      using DR_ne bounded_P bounded_disc_𝒫1' by (auto intro: bounded_SUP_mul)
    also have "  l^m * norm (d  DR. ((𝒫1 d ^^ m) (v - u) s))"
      by (simp add: mult_left_mono)
    also have "  l^m * (d  DR. norm (((𝒫1 d ^^ m) (v - u) s)))"
      using DR_ne ex_dec bounded_norm_comp bounded_disc_𝒫1'
      by (fastforce intro!: mult_left_mono)
    also have "  l^m * (d  DR. norm ((𝒫1 d ^^ m) ((v - u))))"
      using ex_dec
      by (fastforce intro!: order.trans[OF norm_blinfun] abs_le_norm_bfun mult_left_mono cSUP_mono)
    also have "  l^m * (d  DR. norm ((v - u)))"
      using norm_𝒫X_apply by (auto simp: 𝒫X_const[symmetric] cSUP_least mult_left_mono)
    also have " = l ^m * dist v u"
      by (auto simp: dist_norm)
    finally have "Ub m v s - Ub m u s  l^m * dist v u" .
    thus ?thesis
      by (simp add: dist_real_def le)
  qed
  moreover have "Ub m v s  Ub m u s  dist (Ub m v s) (Ub m u s)  l^m * dist v u" for u v s
    by (simp add: aux dist_commute)
  ultimately have "dist (Ub m v s) (Ub m u s)  l^m * dist v u" for u v s
    using linear by blast
  thus "dist (Ub m v) (Ub m u)  l^m * dist v u"
    by (simp add: dist_bound)
qed

lemma Ub_conv:
  "∃!v. Ub (Suc m) v = v" 
  "(λn. (Ub (Suc m) ^^ n) v)  (THE v. Ub (Suc m) v = v)"
proof -
  have *: "is_contraction (Ub (Suc m))"
    unfolding is_contraction_def
    using Ub_contraction[of "Suc m"] le_neq_trans[OF zero_le_disc] 
    by (cases "l = 0") (auto intro!: power_Suc_less_one intro: exI[of _ "l^(Suc m)"])
  show "∃!v. Ub (Suc m) v = v" "(λn. (Ub (Suc m) ^^ n) v)  (THE v. Ub (Suc m) v = v)"
    using banach'[OF *] by auto
qed

lemma Ub_convergent: "convergent (λn. (Ub (Suc m) ^^ n) v)"
  by (intro convergentI[OF Ub_conv(2)])

lemma Ub_mono:
  assumes "v  u" 
  shows "Ub m v  Ub m u"
proof  -
  have "Ub m v s  Ub m u s" for s
    unfolding Ub.rep_eq U_def
  proof (intro cSUP_mono, goal_cases)
    case 2
    thus ?case
      by (simp add: bounded_imp_bdd_above bounded_disc_𝒫1 bounded_plus_comp bounded_stationary_νb_fin)
  next
    case (3 n)
    thus ?case 
      using less_eq_bfunD[OF 𝒫X_mono[OF assms]]
      by (auto simp: 𝒫X_const[symmetric] blincomp_scaleR_right blinfun.scaleR_left intro!: mult_left_mono exI)
  qed auto
  thus ?thesis
    using assms by auto
qed

lemma Ub_le_ℒb: "Ub m v  (b ^^ m) v"
proof -
  have "Ub m v s = (d  DR. (L d^^ m) v s)" for m v s
    by (auto simp: L_iter Ub.rep_eq b.rep_eq U_def ℒ_def)
  thus ?thesis
    using L_iter_le_ℒb ex_dec by (fastforce intro!: cSUP_least)
qed

lemma L_iter_le_Ub: 
  assumes "d  DR" 
  shows "(L d^^m) v  Ub m v"
  using assms
  by (fastforce intro!: cSUP_upper bounded_imp_bdd_above
      simp: L_iter Ub.rep_eq U_def bounded_disc_𝒫1 bounded_plus_comp bounded_stationary_νb_fin)

lemma lim_Ub: "lim (λn. (Ub (Suc m) ^^ n) v) = νb_opt"
proof -
  have le_U: "νb_opt  Ub m νb_opt" for m
  proof -
    obtain d where d: "ν_improving νb_opt (mk_dec_det d)" "d  DD"
      using ex_improving_det by auto
    have "νb_opt = (L (mk_dec_det d) ^^ m) νb_opt"
      by (induction m) (metis L_ν_fix_iff b_opt ν_improving_imp_ℒb d(1) funpow_swap1)+
    thus ?thesis
      using d  DD by (auto intro!: order.trans[OF _ L_iter_le_Ub])
  qed
  have "Ub m νb_opt  νb_opt" for m
    using ℒ_inc_le_opt  by (auto intro!: order.trans[OF Ub_le_ℒb] simp: funpow_swap1)
  hence "Ub (Suc m) νb_opt = νb_opt"
    using le_U by (simp add: antisym)
  moreover have "(lim (λn. (Ub (Suc m) ^^n) v)) = Ub (Suc m) (lim (λn. (Ub (Suc m) ^^n) v))"
    using limI[OF Ub_conv(2)] theI'[OF Ub_conv(1)] by auto 
  ultimately show ?thesis
    using Ub_conv(1) by metis
qed

lemma Ub_tendsto: "(λn. (Ub (Suc m) ^^ n) v)  νb_opt"
  using lim_Ub Ub_convergent convergent_LIMSEQ_iff by metis

lemma Ub_fix_unique: "Ub (Suc m) v = v  v = νb_opt" 
  using theI'[OF Ub_conv(1)] Ub_conv(1)
  by (auto simp: LIMSEQ_unique[OF Ub_tendsto Ub_conv(2)[of m]])

lemma dist_Ub_opt: "dist (Ub m v) νb_opt  l^m * dist v νb_opt"
proof -
  have "dist (Ub m v) νb_opt = dist (Ub m v) (Ub m νb_opt)"
    by (metis Ub.abs_eq Ub_fix_unique U_zero apply_bfun_inverse not0_implies_Suc)
  also have "  l^m * dist v νb_opt"
    by (meson Ub_contraction)
  finally show ?thesis .
qed

subsection ‹Expressing a Single Step of Modified Policy Iteration›
text ‹The function @{term W} equals the value computed by the Modified Policy Iteration Algorithm
  in a single iteration.
  The right hand addend in the definition describes the advantage of using the optimal action for 
  the first m steps.
  ›
definition "W d m v = v + (i < m. (l *R 𝒫1 d)^^i) (Bb v)"

lemma W_eq_L_iter:
  assumes "ν_improving v d"
  shows "W d m v = (L d^^m) v"
proof -
  have "(i<m. (l *R 𝒫1 d)^^i) (b v) = (i<m. (l *R 𝒫1 d)^^i) (L d v)"
    using ν_improving_imp_ℒb assms by auto
  hence "W d m v = v + ((i<m. (l *R 𝒫1 d)^^i) (L d v)) - (i<m. (l *R 𝒫1 d)^^i) v"
    by (auto simp: W_def Bb_eq_ℒb blinfun.bilinear_simps)
  also have " = v + νb_fin (mk_stationary d) m + (i<m. ((l *R 𝒫1 d)^^i) ((l *R 𝒫1 d) v)) - (i<m. (l *R 𝒫1 d)^^i) v"
    by (auto simp: L_def νb_fin_eq blinfun.bilinear_simps scaleR_right.sum)
  also have " = v + νb_fin (mk_stationary d) m + (i<m. ((l *R 𝒫1 d)^^Suc i) v) - (i<m. (l *R 𝒫1 d)^^i) v"
    by (auto simp del: blinfunpow.simps simp: blinfunpow_assoc)
  also have " = νb_fin (mk_stationary d) m + (i<Suc m. ((l *R 𝒫1 d)^^ i) v)  - (i<m. (l *R 𝒫1 d)^^ i) v"
    by (subst sum.lessThan_Suc_shift) auto
  also have " =  νb_fin (mk_stationary d) m + ((l *R 𝒫1 d)^^m) v"
    by (simp add: blinfun.sum_left)
  also have " = (L d ^^ m) v"
    using L_iter by auto
  finally show ?thesis .
qed

lemma Ub_ge: "d  DR  Ub m u  νb_fin (mk_stationary d) m + ((l *R 𝒫1 d) ^^ m) u"
  using ν_improving_D_MR bounded_stationary_νb_fin bounded_disc_𝒫1
  by (fastforce intro!: diff_mono bounded_imp_bdd_above cSUP_upper bounded_plus_comp simp: Ub.rep_eq U_def)

lemma W_le_Ub:
  assumes "v  u" "ν_improving v d"
  shows "W d m v  Ub m u"
  using assms
  by (fastforce simp: W_eq_L_iter intro!: order.trans[OF L_iter_le_Ub Ub_mono])

lemma W_ge_ℒb:
  assumes "v  u" "0  Bb u" "ν_improving u d'"
  shows "b v  W d' (Suc m) u"
proof -
  have "b v  u + Bb u"
    using assms(1) b_mono Bb_eq_ℒb by auto
  also have "  W d' (Suc m) u"
    using L_mono ν_improving_imp_ℒb assms(3) assms 
    by (induction m) (auto simp: W_eq_L_iter Bb_eq_ℒb)
  finally show ?thesis .
qed

lemma Bb_le:
  assumes "ν_improving v d"
  shows "Bb v + (l *R 𝒫1 d - id_blinfun) (u - v)  Bb u"
proof -
  have "r_decb d + (l *R 𝒫1 d - id_blinfun) u  Bb u"
    using L_def L_le_ℒb assms by (auto simp: Bb_eq_ℒb b.rep_eq ℒ_def blinfun.bilinear_simps)
  moreover have "Bb v = r_decb d + (l *R 𝒫1 d - id_blinfun) v"
    using assms by (auto simp: Bb_eq_ℒb ν_improving_imp_ℒb[of _ d] L_def blinfun.bilinear_simps)
  ultimately show ?thesis
    by (simp add: blinfun.diff_right)
qed


subsection ‹Computing the Bellman Operator over Multiple Steps›
definition "L_pow v d m = (L (mk_dec_det d) ^^ m) v"

(*
lemma sum_telescope': "(∑i≤k. f (Suc i) - f i ) = f (Suc k) - (f 0 :: 'c :: ab_group_add)"
  using sum_telescope[of "-f" k]
  by auto
*)

(* eq 6.5.7 *)
lemma L_pow_eq:
  fixes d defines "d'  mk_dec_det d"
  assumes "ν_improving v d'" 
  shows "L_pow v d m = v + (i < m. ((l *R 𝒫1 d')^^i)) (Bb v)"
  using L_pow_def W_def W_eq_L_iter assms by presburger

lemma L_pow_eq_W:
  assumes "d  DD" 
  shows "L_pow v (policy_improvement d v) m = W (mk_dec_det (policy_improvement d v)) m v" 
  using assms policy_improvement_improving by (auto simp: W_eq_L_iter L_pow_def)


lemma find_policy'_is_dec_det: "is_dec_det (find_policy' v)"
  using find_policy'_def is_dec_det_def some_opt_acts_in_A by presburger

lemma find_policy'_improving: "ν_improving v (mk_dec_det (find_policy' v))"
  using ν_improving_opt_acts find_policy'_def by presburger

lemma L_pow_eq_W': "L_pow v (find_policy' v) m = W (mk_dec_det (find_policy' v)) m v" 
  using find_policy'_improving by (auto simp: W_eq_L_iter L_pow_def)

lemma b_W_ge:
  assumes "u  b u" "ν_improving u d"
  shows "W d m u  b (W d m u)"    
proof -
  have "0  ((l *R 𝒫1 d) ^^ m) (Bb u)"
    by (metis Bb_eq_ℒb 𝒫1_n_disc_pos assms(1) blincomp_scaleR_right diff_ge_0_iff_ge)
  also have " = ((l *R 𝒫1 d)^^0 + (i < m. (l *R 𝒫1 d)^^(Suc i))) (Bb u) - (i < m. (l *R 𝒫1 d)^^ i) (Bb u)"
    by (subst sum.lessThan_Suc_shift[symmetric]) (auto simp: blinfun.diff_left[symmetric])
  also have " = Bb u + ((l *R 𝒫1 d - id_blinfun) oL (i < m. (l *R 𝒫1 d)^^i)) (Bb u)" 
    by (auto simp: blinfun.bilinear_simps sum_subtractf)
  also have " = Bb u + (l *R 𝒫1 d - id_blinfun) (W d m u - u)"
    by (auto simp: W_def sum.lessThan_Suc[unfolded lessThan_Suc_atMost])
  also have "  Bb (W d m u)"
    using Bb_le assms(2) by blast
  finally have "0  Bb (W d m u)" .
  thus ?thesis 
    using Bb_eq_ℒb by auto
qed

lemma L_pow_ℒb_mono_inv:
  assumes "d  DD" "v  b v"
  shows "L_pow v (policy_improvement d v) m  b (L_pow v (policy_improvement d v) m)"
  using assms L_pow_eq_W b_W_ge policy_improvement_improving by auto

lemma L_pow_ℒb_mono_inv':
  assumes "v  b v"
  shows "L_pow v (find_policy' v) m  b (L_pow v (find_policy' v) m)"
  using assms L_pow_eq_W' b_W_ge find_policy'_improving by auto

subsection ‹The Modified Policy Iteration Algorithm›
context
  fixes d0 :: "'s  'a"
  fixes v0 :: "'s b real"
  fixes m :: "nat  ('s b real)  nat"
  assumes d0: "d0  DD"
begin

text ‹We first define a function that executes the algorithm for n steps.›
fun mpi :: "nat  (('s  'a) × ('s b real))" where
  "mpi 0 = (find_policy' v0, v0)" |
  "mpi (Suc n) =
  (let (d, v) = mpi n; v' = L_pow v d (Suc (m n v)) in
  (find_policy' v', v'))"

definition "mpi_val n = snd (mpi n)"
definition "mpi_pol n = fst (mpi n)"

lemma mpi_pol_zero[simp]: "mpi_pol 0 = find_policy' v0"
  unfolding mpi_pol_def 
  by auto

lemma mpi_pol_Suc: "mpi_pol (Suc n) = find_policy' (mpi_val (Suc n))"
  by (auto simp: case_prod_beta' Let_def mpi_pol_def mpi_val_def)

lemma mpi_pol_is_dec_det: "mpi_pol n  DD"
  unfolding mpi_pol_def
  using find_policy'_is_dec_det d0
  by (induction n)  (auto simp: Let_def split: prod.splits)

lemma ν_improving_mpi_pol: "ν_improving (mpi_val n) (mk_dec_det (mpi_pol n))"
  using d0 find_policy'_improving mpi_pol_is_dec_det mpi_pol_Suc
  by (cases n) (auto simp: mpi_pol_def mpi_val_def)

lemma mpi_val_zero[simp]: "mpi_val 0 = v0"
  unfolding mpi_val_def by auto

lemma mpi_val_Suc: "mpi_val (Suc n) = L_pow (mpi_val n) (mpi_pol n) (Suc (m n (mpi_val n)))"
  unfolding mpi_val_def mpi_pol_def
  by (auto simp: case_prod_beta' Let_def)

lemma mpi_val_eq: "mpi_val (Suc n) = 
  mpi_val n + (i  (m n (mpi_val n)). (l *R 𝒫1 (mk_dec_det (mpi_pol n))) ^^ i) (Bb (mpi_val n))"
  using lessThan_Suc_atMost by (auto simp: mpi_val_Suc L_pow_eq[OF ν_improving_mpi_pol])

text ‹Value Iteration is a special case of MPI where @{term "n v. m n v = 0"}.›
lemma mpi_includes_value_it: 
  assumes "n v. m n v = 0"
  shows "mpi_val (Suc n) = b (mpi_val n)"
  using assms Bb_eq_ℒb mpi_val_eq by auto

subsection ‹Convergence Proof›
text ‹We define the sequence @{term w} as an upper bound for the values of MPI.›
fun w where
  "w 0 = v0" |
  "w (Suc n) = Ub (Suc (m n (mpi_val n))) (w n)"

lemma dist_νb_opt: "dist (w (Suc n)) νb_opt  l * dist (w n) νb_opt"
  by (fastforce simp: algebra_simps intro: order.trans[OF dist_Ub_opt] mult_left_mono power_le_one
      mult_left_le_one_le order.strict_implies_order)

lemma dist_νb_opt_n: "dist (w n) νb_opt  l^n * dist v0 νb_opt"
  by (induction n) (fastforce simp: algebra_simps intro: order.trans[OF dist_νb_opt] mult_left_mono)+

lemma w_conv: "w  νb_opt"
proof -
  have "(λn. l^n * dist v0 νb_opt)  0"
    using LIMSEQ_realpow_zero by (cases "v0 = νb_opt") auto
  then show ?thesis
    by (fastforce intro: metric_LIMSEQ_I order.strict_trans1[OF dist_νb_opt_n] simp: LIMSEQ_def)
qed

text ‹MPI converges monotonically to the optimal value from below. 
  The iterates are sandwiched between @{const b} from below and @{const Ub} from above.›
theorem mpi_conv:
  assumes "v0  b v0"
  shows "mpi_val  νb_opt" and "n. mpi_val n  mpi_val (Suc n)"
proof -
  define y where "y n = (b^^n) v0" for n
  have aux: "mpi_val n  b (mpi_val n)  mpi_val n  mpi_val (Suc n)  y n  mpi_val n  mpi_val n  w n" for n
  proof (induction n)
    case 0
    show ?case
      using assms Bb_eq_ℒb
      unfolding y_def
      by (auto simp: mpi_val_eq blinfun.sum_left 𝒫1_n_disc_pos blincomp_scaleR_right sum_nonneg)
  next
    case (Suc n)
    have val_eq_W: "mpi_val (Suc n) = W (mk_dec_det (mpi_pol n)) (Suc (m n (mpi_val n))) (mpi_val n)"
      using ν_improving_mpi_pol mpi_val_Suc W_eq_L_iter L_pow_def by auto
    hence *: "mpi_val (Suc n)  b (mpi_val (Suc n))"
      using Suc.IH b_W_ge ν_improving_mpi_pol by presburger
    moreover have "mpi_val (Suc n)  mpi_val (Suc (Suc n))"
      using *
      by (simp add: Bb_eq_ℒb mpi_val_eq 𝒫1_n_disc_pos blincomp_scaleR_right blinfun.sum_left sum_nonneg)
    moreover have "mpi_val (Suc n)  w (Suc n)"
      using Suc.IH ν_improving_mpi_pol by (auto simp: val_eq_W intro: order.trans[OF _ W_le_Ub])
    moreover have "y (Suc n)  mpi_val (Suc n)"
      using Suc.IH ν_improving_mpi_pol W_ge_ℒb by (auto simp: y_def Bb_eq_ℒb val_eq_W)
    ultimately show ?case
      by auto
  qed
  thus "mpi_val n  mpi_val (Suc n)" for n
    by auto
  have "y  νb_opt"
    using b_lim y_def by presburger
  thus "mpi_val  νb_opt"
    using aux by (auto intro: tendsto_bfun_sandwich[OF _ w_conv])
qed

subsection ‹$\epsilon$-Optimality›
text ‹This gives an upper bound on the error of MPI.›
lemma mpi_pol_eps_opt:
  assumes "2 * l * dist (mpi_val n) (b (mpi_val n)) < eps * (1 - l)" "eps > 0"
  shows "dist (νb (mk_stationary_det (mpi_pol n))) (b (mpi_val n))  eps / 2"
proof -
  let ?p = "mk_stationary_det (mpi_pol n)"
  let ?d = "mk_dec_det (mpi_pol n)"
  let ?v = "mpi_val n"
  have "dist (νb ?p) (b ?v) = dist (L ?d (νb ?p)) (b ?v)"
    using L_ν_fix by force
  also have " = dist (L ?d (νb ?p)) (L ?d ?v)"
    by (metis ν_improving_imp_ℒb ν_improving_mpi_pol)
  also have "  dist (L ?d (νb ?p)) (L ?d (b ?v)) + dist (L ?d (b ?v)) (L ?d ?v)"
    using dist_triangle by blast
  also have "  l * dist (νb ?p) (b ?v) + dist (L ?d (b ?v)) (L ?d ?v)"
    using contraction_L by auto
  also have "  l * dist (νb ?p) (b ?v) + l * dist (b ?v) ?v"
    using contraction_L by auto
  finally have "dist (νb ?p) (b ?v)  l * dist (νb ?p) (b ?v) + l * dist (b ?v) ?v".
  hence *:"(1-l) * dist (νb ?p) (b ?v)  l * dist (b ?v) ?v"
    by (auto simp: left_diff_distrib)
  thus ?thesis
  proof (cases "l = 0")
    case True
    thus ?thesis
      using assms * by auto
  next
    case False
    have **: "dist (b ?v) (mpi_val n) < eps * (1 - l) / (2 * l)"
      using False le_neq_trans[OF zero_le_disc False[symmetric]] assms
      by (auto simp: dist_commute pos_less_divide_eq Groups.mult_ac(2))
    have "dist (νb ?p) (b ?v)  (l/ (1-l)) * dist (b ?v) ?v"
      using * by (auto simp: mult.commute pos_le_divide_eq)
    also have "  (l/ (1-l)) * (eps * (1 - l) / (2 * l))"
      using ** by (fastforce intro!: mult_left_mono simp: divide_nonneg_pos)
    also have " = eps / 2"
      using False disc_lt_one by (auto simp: order.strict_iff_order)
    finally show "dist (νb ?p) (b ?v)  eps / 2".    
  qed
qed

lemma mpi_pol_opt:
  assumes "2 * l * dist (mpi_val n) (b (mpi_val n)) < eps * (1 - l)" "eps > 0"
  shows "dist (νb (mk_stationary_det (mpi_pol n))) (νb_opt) < eps"
proof -
  have "dist (νb (mk_stationary_det (mpi_pol n))) (νb_opt)  eps/2 + dist (b (mpi_val n)) νb_opt"
    by (metis mpi_pol_eps_opt[OF assms] dist_commute dist_triangle_le add_right_mono)
  thus ?thesis
    using dist_ℒb_opt_eps assms by fastforce
qed

lemma mpi_val_term_ex:
  assumes "v0  b v0" "eps > 0"
  shows "n. 2 * l * dist (mpi_val n) (b (mpi_val n)) < eps * (1 - l)"
proof -
  have "(λn. dist (mpi_val n) νb_opt)  0"
    using mpi_conv(1)[OF assms(1)] tendsto_dist_iff 
    by blast
  hence "(λn. dist (mpi_val n) (b (mpi_val n)))  0"
    using dist_ℒb_lt_dist_opt
    by (auto simp: metric_LIMSEQ_I intro: tendsto_sandwich[of "λ_. 0" _ _ "λn. 2 * dist (mpi_val n) νb_opt"])
  hence "e >0. n. dist (mpi_val n) (b (mpi_val n)) < e"
    by (fastforce dest!: metric_LIMSEQ_D)
  hence "l  0  n. dist (mpi_val n) (b (mpi_val n)) < eps * (1 - l) / (2 * l)"
    by (simp add: assms order.not_eq_order_implies_strict)
  thus "n. (2 * l) * dist (mpi_val n) (b (mpi_val n)) < eps * (1 - l)"
    using assms le_neq_trans[OF zero_le_disc]
    by (cases "l = 0") (auto simp: mult.commute pos_less_divide_eq)
qed
end

subsection ‹Unbounded MPI›
context
  fixes eps δ :: real and M :: nat
begin

function (domintros) mpi_algo where "mpi_algo d v m = (
  if 2 * l * dist v (b v) <  eps * (1 - l)
  then (find_policy' v, v)
  else mpi_algo (find_policy' v) (L_pow v (find_policy' v) (Suc (m 0 v))) (λn. m (Suc n)))"
  by auto

text ‹We define a tailrecursive version of @{const mpi} which more closely resembles @{const mpi_algo}.›
fun mpi' where
  "mpi' d v 0 m = (find_policy' v, v)" |
  "mpi' d v (Suc n) m = (
  let d' = find_policy' v; v' = L_pow v d' (Suc (m 0 v)) in mpi' d' v' n (λn. m (Suc n)))"

lemma mpi_Suc':
  assumes "d  DD"
  shows "mpi v m (Suc n) = mpi (L_pow v (find_policy' v) (Suc (m 0 v))) (λa. m (Suc a)) n"
  using assms
  by (induction n rule: nat.induct) (auto simp: Let_def)

lemma
  assumes "d  DD"
  shows "mpi v m n = mpi' d v n m"
  using assms
proof (induction n arbitrary: d v m rule: nat.induct)
  case (Suc nat)
  thus ?case
    using find_policy'_is_dec_det by (fastforce simp: Let_def mpi_Suc'[OF Suc(2)])
qed auto

lemma termination_mpi_algo: 
  assumes "eps > 0" "d  DD" "v  b v"
  shows "mpi_algo_dom (d, v, m)"
proof -
  define n where "n = (LEAST n. 2 * l * dist (mpi_val v m n) (b (mpi_val v m n)) < eps * (1 - l))" (is "n = (LEAST n. ?P d v m n)")
  have least0: "n. P n  (LEAST n. P n) = (0 :: nat)  P 0"  for P
    by (metis LeastI_ex)
  from n_def assms show ?thesis
  proof (induction n arbitrary: v d m)
    case 0
    have "2 * l * dist (mpi_val v m 0) (b (mpi_val v m 0)) < eps * (1 - l)"
      using least0 mpi_val_term_ex 0 by (metis (no_types, lifting))
    thus ?case
      using 0 mpi_algo.domintros mpi_val_zero by (metis (no_types, opaque_lifting))
  next
    case (Suc n v d m)
    let ?d = "find_policy' v"
    have "Suc n = Suc (LEAST n. 2 * l * dist (mpi_val v m (Suc n)) (b (mpi_val v m (Suc n))) < eps * (1 - l))"
      using mpi_val_term_ex[OF Suc.prems(3) v  b v 0 < eps, of m] Suc.prems 
      by (subst Nat.Least_Suc[symmetric]) (auto intro: LeastI_ex)
    hence "n = (LEAST n. 2 * l * dist (mpi_val v m (Suc n)) (b (mpi_val v m (Suc n))) < eps * (1 - l))"
      by auto
    hence n_eq: "n =
    (LEAST n. 2 * l * dist (mpi_val (L_pow v ?d (Suc (m 0 v))) (λa. m (Suc a)) n) (b (mpi_val (L_pow v ?d (Suc (m 0 v))) (λa. m (Suc a)) n))
        < eps * (1 - l))"
      using Suc.prems mpi_Suc' by (auto simp: is_dec_det_pi mpi_val_def)
    have "¬ 2 * l * dist v (b v) < eps * (1 - l)"
      using Suc mpi_val_zero by force
    moreover have "mpi_algo_dom (?d, L_pow v ?d (Suc (m 0 v)), λa. m (Suc a))"
      apply (rule Suc.IH[OF n_eq 0 < eps])
          using Suc.prems is_dec_det_pi L_pow_ℒb_mono_inv' find_policy'_is_dec_det by auto
    ultimately show ?case
      using mpi_algo.domintros by blast
  qed
qed

abbreviation "mpi_alg_rec d v m  
    (if 2 * l * dist v (b v) < eps * (1 - l) then (find_policy' v, v)
     else mpi_algo (find_policy' v) (L_pow v (find_policy' v) (Suc (m 0 v)))
           (λn. m (Suc n)))"

lemma mpi_algo_def':
  assumes "d  DD" "v  b v" "eps > 0"
  shows "mpi_algo d v m = mpi_alg_rec d v m"
  using mpi_algo.psimps termination_mpi_algo assms by auto


lemma mpi_algo_def'':
  assumes "d  DD" "v  b v" "eps > 0"
  shows "mpi_algo d v m = (
  let v' = b v; d' = find_policy' v in
    if 2 * l * dist v v'  < eps * (1 - l) 
    then (d', v)
    else mpi_algo d' (L_pow v' d' ((m 0 v))) (λn. m (Suc n)))"
proof -
  have "ν_improving v (mk_dec_det (find_policy' v))"
    using ν_improving_opt_acts find_policy'_def by presburger
  hence aux: "L_pow (b v) (find_policy' v) n = L_pow v (find_policy' v) (Suc n)" for n
    using d  DD  ν_improving_imp_ℒb
    by (auto simp: funpow_swap1 L_pow_def)
  show ?thesis
    unfolding mpi_algo_def'[OF assms] Let_def aux[symmetric] by auto
qed

lemma mpi_algo_eq_mpi:
  assumes "d  DD" "v  b v" "eps > 0"
  shows "mpi_algo d v m = mpi v m (LEAST n. 2 * l * dist (mpi_val v m n) (b (mpi_val v m n)) < eps * (1 - l))"
proof -
  define n where "n = (LEAST n. 2 * l * dist (mpi_val v m n) (b (mpi_val v m n)) < eps * (1 - l))" (is "n = (LEAST n. ?P d v m n)")
  from n_def assms show ?thesis
  proof (induction n arbitrary: d v m)
    case 0
    have "?P d v m 0"
      by (metis (no_types, lifting) assms(3) LeastI_ex 0 mpi_val_term_ex)
    thus ?case
      using assms 0 by (auto simp: mpi_val_def mpi_algo_def')
  next
    case (Suc n)
    hence not0: "¬ (2 * l * dist v (b v) < eps * (1 - l))"
      using Suc(3) mpi_val_zero by auto
    obtain n' where "2 * l * dist (mpi_val v m n') (b (mpi_val v m n')) < eps * (1 - l)"
      using mpi_val_term_ex[OF Suc(3) Suc(4), of _ m] assms by blast
    hence "n = (LEAST n. ?P d v m (Suc n))"
      using Suc(2) Suc by (subst (asm) Least_Suc) auto
    hence "n = (LEAST n. ?P (find_policy' v) (L_pow v (find_policy' v) (Suc (m 0 v))) (λn. m (Suc n)) n)"
      using Suc(3)  mpi_Suc' by (auto simp: mpi_val_def)
    hence "mpi_algo d v m = mpi v m (Suc n)"
      unfolding mpi_algo_def'[OF Suc.prems(2-4)]
      using Suc(1) Suc.prems(2-4) is_dec_det_pi mpi_Suc' not0 L_pow_ℒb_mono_inv' find_policy'_is_dec_det
      by fastforce
    thus ?case
      using Suc.prems(1) by presburger
  qed
qed

lemma mpi_algo_opt: 
  assumes "v0  b v0" "eps > 0" "d  DD"
  shows "dist (νb (mk_stationary_det (fst (mpi_algo d v0 m)))) νb_opt < eps"
proof -
  let ?P = "λn. 2 * l * dist (mpi_val v0 m n) (b (mpi_val v0 m n)) <  eps * (1 - l)"
  let ?n = "Least ?P"
  have "mpi_algo d v0 m = mpi v0 m ?n" and "?P ?n"
    using mpi_algo_eq_mpi LeastI_ex[OF mpi_val_term_ex] assms by auto
  thus ?thesis
    using assms by (auto simp: mpi_pol_opt mpi_pol_def[symmetric])
qed

end

subsection ‹Initial Value Estimate @{term v0_mpi}
text ‹We define an initial estimate of the value function for which Modified Policy Iteration 
  always terminates.›

abbreviation "r_min  (s'. (a  A s'. r (s', a)))"
definition "v0_mpi s = r_min / (1 - l)"

lift_definition v0_mpib :: "'s b real" is "v0_mpi"
  by (auto simp: v0_mpi_def)

lemma v0_mpib_le_ℒb: "v0_mpib  b v0_mpib"
proof (rule less_eq_bfunI)
  fix x
  have bounded_r': "bounded ((λa. r (x, a)) ` A x)" for x
    using r_bounded'
    unfolding bounded_def
    by simp (meson UNIV_I)
  have *: "(aA x. r (x, a))  r (x,a)" if "a  A x" for  a x
    using  bounded_r' that 
    by (auto intro!: cInf_lower bounded_imp_bdd_below)
  have ****: "r (s,a)  rM" for s a
    using abs_le_iff abs_r_le_rM by blast
  have **: "bounded (range (λs'. aA s'. r (s', a)))"
    using abs_r_le_rM  ex_dec_det is_dec_det_def A_ne
    by (auto simp add: minus_le_iff abs_le_iff intro!: cINF_greatest order.trans[OF *]  boundedI[of _ "rM"])
  have "r_min  r (s, a)" if "a  A s" for s a
    using  r_bounded' that **
    by (auto intro!: bounded_imp_bdd_below cInf_lower2[OF _ *])
  
    hence "r_min  (1-l) * r (s, a) + l * r_min" if "a  A s" for s a
    using disc_lt_one zero_le_disc that by (meson order_less_imp_le order_refl segment_bound_lemma)
  hence "r_min / (1 - l)  ((1-l) * r (s, a) + l * r_min) / (1 - l)" if "a  A s" for s a
    using order_less_imp_le[OF disc_lt_one] that  by (auto intro!: divide_right_mono)
  hence "r_min / (1 - l)  r (s, a) + (l * r_min) / (1 - l)" if "a  A s" for s a
    using disc_lt_one that by (auto simp: add_divide_distrib)
  hence "r_min / (1 - l)  La (arb_act (A x)) (λs. r_min / (1 - l)) x"
    using A_ne arb_act_in by auto
  moreover have "bdd_above ((λa. La a (λs. r_min / (1 - l)) x) ` A x)"
    using r_bounded
    by (fastforce simp: bounded_def intro!: bounded_imp_bdd_above bounded_plus_comp)
  ultimately show "v0_mpib x  b v0_mpib x"
    unfolding b_eq_SUP_La' v0_mpib.rep_eq v0_mpi_def by (auto simp: A_ne intro!: cSUP_upper2)
qed

subsection ‹An Instance of Modified Policy Iteration with a Valid Conservative Initial Value Estimate›
definition "mpi_user eps m = (
  if eps  0 then undefined else mpi_algo eps (λx. arb_act (A x)) v0_mpib m)"

lemma mpi_user_eq:
  assumes "eps > 0"
  shows "mpi_user eps = mpi_alg_rec eps (λx. arb_act (A x)) v0_mpib"
  using v0_mpib_le_ℒb assms
  by (auto simp: mpi_user_def mpi_algo_def' A_ne is_dec_det_def)

lemma mpi_user_opt:
  assumes "eps > 0"
  shows "dist (νb (mk_stationary_det (fst (mpi_user eps n)))) νb_opt < eps"
  unfolding mpi_user_def using assms
  by (auto intro: mpi_algo_opt simp: is_dec_det_def A_ne v0_mpib_le_ℒb)
end


end