Theory MDP_disc

(* Author: Maximilian Schäffeler *)

theory MDP_disc
  imports 
    MDP_cont
    "HOL-Library.Omega_Words_Fun"
begin

section ‹Markov Decision Processes with Discrete State Spaces›

(* counterpart to nn_integral_stream_space *)
lemma (in prob_space) integral_stream_space:
  fixes f :: "'a stream  ('b :: {banach, second_countable_topology,real_normed_vector})"
  assumes int_f: "integrable (stream_space M) f"
  assumes [measurable]: "f  borel_measurable (stream_space M)"
  shows "(X. f X stream_space M) = (x. (X. f (x ## X) stream_space M) M)"
proof -
  interpret S: sequence_space M ..
  interpret P: pair_sigma_finite M "ΠM i::natUNIV. M" ..

  interpret P': pair_sigma_finite "ΠM i::natUNIV. M" M ..

  obtain i where "has_bochner_integral (stream_space M) f i"
    using int_f
    using integrable.cases by blast
  have "integrable S.S (λX. f (to_stream X))"
    using int_f
    by (metis integrable_distr measurable_to_stream stream_space_eq_distr)
  hence "integrable (distr (M M PiM UNIV (λi. M)) (PiM UNIV (λi. M)) 
    (λ(x, y). case_nat x y)) (λX. f (to_stream X))"
    by (auto simp: S.PiM_iter)
  moreover have "integrable (distr (M M PiM UNIV (λi. M)) (PiM UNIV (λi. M)) (λ(x, y). case_nat x y)) 
    (λX. f (to_stream X))  
    integrable (M M S.S) (λX. f (to_stream ((λ(s, ω). case_nat s ω) X)))"
    by (auto simp: integrable_distr_eq)
  ultimately have "integrable (M M S.S) (λX. f (to_stream ((λ(s, ω). case_nat s ω) X)))"
    by auto
  hence "integrable (M M (PiM UNIV (λi. M))) 
    (λX. f (to_stream ((λ(s, ω). case_nat s ω) X)))"
    by auto
  moreover have "integrable (M M (PiM UNIV (λi. M))) 
    (λX. f (to_stream ((λ(s, ω). case_nat s ω) X))) = 
      integrable (M M PiM UNIV (λi. M)) (λ(x, X). f (to_stream (case_nat x X)))"
    by (auto intro!: Bochner_Integration.integrable_cong)
  ultimately have *: "integrable (M M PiM UNIV (λi. M)) (λ(x, X). f (to_stream (case_nat x X)))"
    by auto
  have "(X. f X stream_space M) = (X. f (to_stream X) S.S)"
    by (subst stream_space_eq_distr) (simp add: integral_distr)
  also have " = (X. f (to_stream ((λ(s, ω). case_nat s ω) X)) (M M S.S))"
    by (subst S.PiM_iter[symmetric]) (simp add: integral_distr)
  also have " = (x. X. f (to_stream ((λ(s, ω). case_nat s ω) (x, X)))S.S  M)"
    using *
    by (auto simp: pair_sigma_finite.integral_fst P.pair_sigma_finite_axioms case_prod_unfold)
  also have " = (x. X. f (x ## to_stream X) S.S M)"
    by (auto intro!: integral_cong simp: to_stream_nat_case)
  also have " = (x. X. f (x ## X) distr (PiM UNIV (λi. M)) (stream_space M) to_stream M)"
    by (subst Bochner_Integration.integral_cong[OF refl]) (auto simp: integral_distr) 
  also have " = (x. X. f (x ## X) stream_space M M)"
    using stream_space_eq_distr by metis
  finally show ?thesis .
qed

lemma prefix_cons: 
  "Omega_Words_Fun.prefix (Suc n) seq = seq 0#Omega_Words_Fun.prefix n (λn. seq (Suc n))"
  by (metis map_upt_Suc subsequence_def)

lemma restrict_Suc: "restrict y {0..<Suc i} (Suc n) = (restrict (λn. y (Suc n)) {0..<i}) n"
  by auto

lemma prefix_restrict: "Omega_Words_Fun.prefix i (restrict y {0..<i}) = Omega_Words_Fun.prefix i y"
proof (induction i arbitrary: y)
  case (Suc i)
  then show ?case
    unfolding restrict_Suc prefix_cons 
    by fastforce+
qed simp

lemma prefix_measurable[measurable]: 
  "Omega_Words_Fun.prefix i  PiM {0..<i} 
  (λ_. count_space (UNIV :: ('s ::countable × 'a::countable) set)) M count_space UNIV"
proof (induction i)
  case 0
  then show ?case by simp
next
  case (Suc i)
  have aux: "(λw. (restrict w {0..<i}, w i))  PiM {0..<Suc i} (λ_. count_space UNIV) M 
    PiM {0..<i} (λ_. count_space UNIV) M (count_space UNIV)"
    by auto
  have aux': "(λ(w,wi). Omega_Words_Fun.prefix i (restrict w {0..<i})@[wi])  PiM {0..<i} 
    (λ_. count_space (UNIV :: ('s × 'a) set)) M (count_space UNIV) M count_space UNIV"
    using Suc.IH by auto
  have f_eq: "w. Omega_Words_Fun.prefix i (restrict w {0..<i}) @ [w i] = 
    (λ(w,wi). Omega_Words_Fun.prefix i w @[wi]) ((restrict w {0..<i}), w i)"
    by auto
  have "(λw:: nat  's × 'a. Omega_Words_Fun.prefix i (restrict w {0..<i}) @ [w i])  PiM {0..<Suc i} (λ_. count_space UNIV) M count_space UNIV"
    using aux aux'[unfolded prefix_restrict]
    by (subst f_eq) auto
  thus ?case
    unfolding prefix_restrict[of _ "i"]
    by auto
qed

no_notation Omega_Words_Fun.build (infixr  ## 65)

locale discrete_MDP =
  fixes A :: "'s::countable  'a::countable set" ― ‹enabled actions›
    and K :: "'s × 'a  's pmf" ― ‹MDP kernel, transition probabilities›
  assumes
    A_ne: "s. A s  {}" ― ‹set of enabled actions is nonempty›
begin

subsection ‹Policies›
text ‹Type synonym for decision rules.›
type_synonym ('c, 'd) dec = "'c  'd pmf"

definition is_dec :: "('s, 'a) dec  bool" where
  "is_dec d  s. d s  A s"

lemma is_decI[intro]: 
  "(s. set_pmf (d s)  A s)  is_dec d"
  unfolding is_dec_def
  by auto

abbreviation "DR  {d. is_dec d}"

definition is_dec_det :: "('s  'a)  bool" where
  "is_dec_det d  s. d s  A s"

abbreviation "DD  {d. is_dec_det d}"

definition "mk_dec_det d s = return_pmf (d s)"

lemma is_dec_mk_dec_det_iff [simp]: "is_dec (mk_dec_det d)  is_dec_det d"
  by (simp add: is_dec_def is_dec_det_def mk_dec_det_def)

lemma D_det_to_MR[intro]: "is_dec_det d  is_dec (mk_dec_det d)"
  by simp

text ‹
Due to the assumption @{thm A_ne}, a deterministic decision rule always exists.
It immediately follows via @{thm is_dec_mk_dec_det_iff} that a randomized decision rule also exists.
›

lemma SOME_is_dec_det: "is_dec_det (λs. SOME a. a  A s)"
  using A_ne by (simp add: is_dec_det_def some_in_eq)

lemma ex_dec_det [simp]: "d. is_dec_det d"
  using SOME_is_dec_det by blast

lemma D_det_ne [simp]: "DD  {}"
  by simp

lemma DR_ne [simp]: "DR  {}"
  using D_det_ne D_det_to_MR by blast

lemma ex_dec[intro, simp]: "d. is_dec d"
  using ex_dec_det by blast

text ‹Type synonym for policies.›
type_synonym ('c, 'd) pol = "('c × 'd) list  ('c, 'd) dec"


text ‹A policy assigns a decision rule to each observed past.›
definition is_policy :: "('s, 'a) pol  bool" where
  "is_policy p  hs. is_dec (p hs)"

abbreviation "ΠHR  {p. is_policy p}"

text ‹Deterministic policies›
definition "is_deterministic p  is_policy p  (h s. a. p h s = return_pmf a)"

definition "mk_det p h s  return_pmf (p h s)"

abbreviation "ΠHD  {p. h. p h  DD}"

text ‹Markovian policies›
definition "is_markovian p  is_policy p  (h h'. length h = length h'  p h = p h')"

definition mk_markovian :: "(nat  ('s, 'a) dec)  ('s, 'a) pol" where
  "mk_markovian p  (λh. p (length h))"

lemma is_markovian_mk_iff[simp]: "is_markovian (mk_markovian p)  (n. is_dec (p n))"
  unfolding is_markovian_def mk_markovian_def is_policy_def
  by (metis (mono_tags, opaque_lifting) Ex_list_of_length)

lemma is_markovian_mk[intro]: "n. is_dec (p n)  is_markovian (mk_markovian p)"
  unfolding is_markovian_def mk_markovian_def is_policy_def
  by auto

lemma mk_markovian_nil [simp]: "mk_markovian p [] = p 0"
  unfolding mk_markovian_def by auto

definition "mk_markovian_det p  (λh s. return_pmf (p (length h) s))"

abbreviation "ΠMD  {p. n::nat. p n  DD}"
abbreviation "ΠMR  {p. n. p n  DR}"

lemma ΠMR_imp_policies[intro]: "p  ΠMR  mk_markovian p  ΠHR"
  unfolding is_policy_def mk_markovian_def by auto

lemma ΠMD_MR_iff[simp]: "(λn. mk_dec_det (p n))  ΠMR  p  ΠMD"
  by auto

lemma ΠMD_to_MR[intro]: "p  ΠMD  (λn. mk_dec_det (p n))  ΠMR"
  by simp

lemma p_n_π_MD[intro]: "p  ΠMD  p n  DD"
  by auto

lemma p_n_π_MR[intro]: "p  ΠMR  p n  DR"
  by auto

lemma ΠMD_ne[simp]: "ΠMD  {}"
  by (auto simp: someI_ex[OF ex_dec_det] intro: exI[of _ "λn. (SOME d. is_dec_det d)"])

lemma ΠMR_ne[simp]: "ΠMR  {}"
  using ΠMD_ne by fast

lemma policies_ne[simp, intro]: "ΠHR  {}"
  using ΠMR_ne is_policy_def by auto

text ‹Stationary policies›
definition "is_stationary p  is_policy p  (h h'. p h = p h')"

lemma is_stationary_const_iff[simp]: "is_stationary (λ_. d) = is_dec d"
  unfolding is_stationary_def is_policy_def by simp

lemma is_stationary_const[intro]: "is_dec d  is_stationary (λ_. d)"
  by simp

abbreviation "mk_stationary p  mk_markovian (λ_. p)"
abbreviation "mk_stationary_det d  mk_markovian (λ_. mk_dec_det d)"

subsubsection ‹Successor Policy›
text ‹
After taking the first step in the MDP, we will know which state and which action got selected 
during the initial epoch. To obtain a policy that acts as if the current epoch was the initial one, 
we prepend the observed state-action pair to the history. The result is again a policy, 
i.e. it satisfies @{const is_policy}.
›
definition π_Suc :: "('s, 'a) pol  's × 'a  ('s, 'a) pol"
  where
    "π_Suc p sa h = p (sa#h)"

lemma is_policy_π_Suc [intro]: "is_policy p  is_policy (π_Suc p sa)"
  unfolding is_policy_def π_Suc_def by force

lemma Suc_mk_markovian[simp]: "π_Suc (mk_markovian p) x = mk_markovian (λn. p (Suc n))"
  unfolding π_Suc_def mk_markovian_def by auto

subsection ‹Stream Space of the MDP›
subsubsection ‹Initial State-Action Distribution›
text ‹
If we fix a decision rule @{term d} and an initial distribution of states @{term "S0"}, 
we obtain a distribution over state-action pairs in the following way:
First, the initial state @{term "s"} is sampled from @{term S0}, 
then an action @{term a} is selected from @{term "d s"}.
›

definition "K0 d S0 = do {
  s  S0; 
  a  d s;
  return_pmf (s,a)
}"

notation K0 (K0)

lemma K0_iff: "K0 d S0 = S0  (λs. map_pmf (λa. (s,a)) (d s))"
  by (simp add: K0_def map_pmf_def)

lemma vimage_pair[simp]: "Pair x -` {p} = (if x = fst p then {snd p} else {})"
  by auto

lemma pmf_K0 [simp]: "pmf (K0 d S0) (s,a) = pmf S0 s * pmf (d s) a"
  unfolding K0_iff pmf_bind
  by (subst integral_measure_pmf[where A = "{s}"]) (auto simp: pmf_map pmf.rep_eq split: if_splits)

lemma set_pmf_K0: "set_pmf (K0 p S0) = {(s,a). s  S0  a  p s}"
  by (auto simp add: K0_def)

lemma fst_K0[simp]: "map_pmf fst (K0 p S0) = S0"
  unfolding K0_def
  by (simp add: map_bind_pmf map_pmf_comp bind_return_pmf')

abbreviation "S  stream_space (count_space UNIV)"

text ‹We inherit the trace space from MDPs with continuous state-action spaces›
interpretation MDP_cont: MDP_cont.discrete_MDP "count_space UNIV" "count_space UNIV" A K
proof standard
  show "(λx. measure_pmf (K x))  
    count_space UNIV M count_space UNIV M prob_algebra (count_space UNIV)"
    using measurable_prob_algebraI
    by (measurable, auto simp: prob_space_measure_pmf measurable_pair_measure_countable1)+
  show "δcount_space UNIV M count_space UNIV. s. δ s  A s"
    by (auto simp: A_ne some_in_eq intro: bexI[of _ "λs. SOME a. a  A s"])
qed (auto simp: A_ne)

lemma count_space_M[simp]: "MDP_cont.M = count_space UNIV"
  by (auto simp: pair_measure_countable)

lemma space_M[simp]: "space MDP_cont.M = UNIV"
  by (auto simp: MDP_cont.space_lim_stream)

text ‹We reuse the stream space provided by @{const MDP_cont.lim_stream}
definition T :: "('s, 'a) pol  's pmf  ('s × 'a) stream measure"
  where "T p = MDP_cont.lim_stream (λn (h,s). p (Omega_Words_Fun.prefix n h) s)"

lemma sets_T[measurable_cong]: 
  "sets (T p x) = sets S"
  by (auto simp: T_def MDP_cont.sets_lim_stream)

lemma space_stream_space_ne[simp]: "space S  {}"
  by (auto simp: space_stream_space)

lemma space_T[simp]: "space (T p S0) = space S"
  by (simp add: MDP_cont.space_lim_stream T_def space_stream_space)

lemma is_policy_MDP_cont[intro]: 
  fixes p :: "('s × 'a) list  's  'a pmf"
  shows "MDP_cont.is_policy (λn (h,s). p (Omega_Words_Fun.prefix n h) s)"
  unfolding MDP_cont.is_policy_def MDP_cont.is_dec_def
  using prefix_measurable measurable_pair_swap_iff
  by (auto simp: prob_space_measure_pmf
      intro: measurable_pair_measure_countable1 measurable_prob_algebraI)

lemma prob_space_T[intro, simp]: "prob_space (T p x)"
  by (auto simp add: T_def prob_space_measure_pmf space_prob_algebra)

lemma T_subprob[simp]: 
  "T p S0  space (subprob_algebra S)"
  by (metis prob_space.M_in_subprob prob_space_T sets_T subprob_algebra_cong)

lemma T_subprob_space [simp]: "subprob_space (T p S0)"
  by (auto intro: prob_space_imp_subprob_space)

lemma K0_MDP_cont_eq: 
  "MDP_cont.K0 (λx (h,s). measure_pmf (p (Omega_Words_Fun.prefix x h) s)) (measure_pmf S0) = 
    K0 (p []) S0"
  unfolding MDP_cont.K0_def K0_def MDP_cont.K'_def map_pmf_def
  by (simp add: measure_pmf_bind return_pmf.rep_eq)

subsubsection ‹Decomposition of the Stream Space›
text ‹
The distribution of traces/walks the MDP allows should intuitively satisfy the following rule:

 select the initial state @{term s} from @{term S0}
 pass it to the decision rule @{term "p []"} to determine a distribution over actions
 select the action @{term a}
 finally pass the state-action pair @{term "(s,a)"} to the kernel @{term K} to get a new 
  distribution over states @{term s0'}

Then the iteration repeats with the updated policy @{term "π_Suc p (s,a)"}.

The result carries over from @{thm MDP_cont.lim_stream_eq}.
›

lemma T_eq:
  shows "T p S0 = do {
    sa  measure_pmf (K0 (p []) S0);
    ω  T (π_Suc p sa) (K sa);
    return S (sa ## ω)
  }"
  unfolding T_def
proof (subst MDP_cont.lim_stream_eq)
  show "MDP_cont.is_policy (λx xa. measure_pmf (case xa of (h, xa)  p (Omega_Words_Fun.prefix x h) xa))"
    by auto
qed (auto simp: space_prob_algebra prob_space_measure_pmf π_Suc_def MDP_cont.Suc_policy_def 
    prefix_cons K0_MDP_cont_eq  prod.case_distrib)

lemma T_eq_distr:
  shows "T p S0 = measure_pmf (K0 (p []) S0)  (λsa. distr (T (π_Suc p sa) (K sa)) S ((##) sa))"
  by (simp add: T_eq[symmetric] bind_return_distr'[symmetric])

text ‹
The iteration rule lets us nicely decompose integrals (expected values) over functions on traces of 
the MDP.
›
lemma integral_T:
  fixes f :: "('s × 'a) stream  real"
  assumes f_bounded: "x. ¦f x¦  B"
  assumes f: "f  borel_measurable S"
  shows "(t. f t T p x) = sa. t'. f (sa##t') T (π_Suc p sa) (K sa) K0 (p []) x"
proof -
  note T_eq_distr
  have "(t. f t T p x) = (t. f t measure_pmf (K0 (p []) x)  (λsa. distr (T (π_Suc p sa) (K sa)) (stream_space (count_space UNIV)) ((##) sa)))"
    using T_eq_distr by metis
  also have " = measure_pmf.expectation (K0 (p []) x) (λsa. LINT t'|T (π_Suc p sa) (K sa). f (sa ## t'))"
  proof (subst integral_bind[OF f f_bounded, where B' = 1], goal_cases)
    case 1
    then show ?case 
      by (auto intro!: prob_space_imp_subprob_space prob_space.prob_space_distr 
          simp: space_subprob_algebra)
  next
    case 3
    then show ?case
      by (auto intro!: prob_space.emeasure_le_1 prob_space.prob_space_distr)
  next
    case 4
    then show ?case
      by (auto simp: f integral_distr intro: Bochner_Integration.integral_cong)
  qed auto
  finally show ?thesis.
qed


lemma nn_integral_T:
  assumes f: "f  borel_measurable S"
  shows "(+t. f t T p x) = (+sa. + t'. f (sa##t') T (π_Suc p sa) (K sa) K0 (p []) x)"
  unfolding T_eq_distr[of p]
  by (subst nn_integral_bind[OF f]) 
    (auto intro!: prob_space_imp_subprob_space prob_space.prob_space_distr 
      simp: f nn_integral_distr space_subprob_algebra)

subsubsection ‹A Denotational View on the Stochastic Process›
text ‹
Many definitions on MDPs do not rely on the individual traces but only on the distribution 
of states and actions at each epoch.

We define this view on the trace space as the repeated iteration of @{const K0} and @{term K}.
It coincides with the definition of @{const T}.
›

primrec Pn :: "('s, 'a) pol  's pmf  nat  ('s × 'a) pmf" where
  "Pn p S0 0 = K0 (p []) S0"
| "Pn p S0 (Suc n) = K0 (p []) S0  (λsa. Pn (π_Suc p sa) (K sa) n)"

declare Pn.simps(2)[simp del]

lemma Pn_eq_T: "measure_pmf (Pn p S0 n) = distr (T p S0) (count_space UNIV) (λt. t !! n)"
proof (induction n arbitrary: p S0)
  case (0 p S0)
  then show ?case
    unfolding T_eq[of p]
  proof (subst distr_bind[where K = S], goal_cases)
    case 1
    then show ?case 
      by (auto intro!: prob_space_imp_subprob_space subprob_space.bind_in_space)
  next
    case 4
    then show ?case 
      by (subst bind_cong[OF refl, where g = "return (count_space UNIV)"])
        (auto intro!: bind_const' simp: distr_bind[where K = S] distr_return bind_return'' space_stream_space subprob_space_return_ne)
  qed auto
next
  case (Suc n)
  show ?case
    unfolding T_eq[of p]
  proof (subst distr_bind[where K = S], goal_cases)
    case 1
    then show ?case 
      by (auto intro!: prob_space_imp_subprob_space subprob_space.bind_in_space)[1]
  next
    case 4
    then show ?case
      by (auto simp: Pn.simps(2) measure_pmf_bind Suc bind_return_distr' distr_distr comp_def intro!: bind_cong)
  qed auto
qed

text ‹
The definition of @{const Pn} also allows us to easily prove that only enabled actions can occur in
the traces of the MDP.
›

lemma Pn_in_A: "is_policy p  (s, a)  Pn p S0 n  a  A s"
proof (induction n arbitrary: S0 p)
  case 0
  then show ?case 
    using 0 unfolding is_policy_def is_dec_def
    by (auto simp: K0_def)
next
  case (Suc n)
  then show ?case
    by (auto simp: Pn.simps(2) K0_def)
qed

lemma T_in_A:
  assumes "is_policy p"
  shows "AE t in T p S0. snd (t !! n)  A (fst (t !! n))"
proof -
  have aux: "AE t in distr (T p S0) (count_space UNIV) (λt. t !! n). snd t  A (fst t)"
    using assms Pn_eq_T[symmetric]
    by (auto simp: Pn_in_A intro!: AE_pmfI cong: AE_cong_simp)
  show ?thesis
    by (auto intro!: AE_distrD[OF _ aux])
qed

subsubsection ‹State Process›
text ‹Alongside @{const Pn}, we also define the state and action distributions as projections.›

definition "Xn p S0 n = map_pmf fst (Pn p S0 n)"

lemma X0 [simp]: "Xn p S0 0 = S0"
  using fst_K0 Xn_def by auto

lemma Xn_Suc: "Xn p S0 (Suc n) = Pn p S0 n  K"
proof (induction n arbitrary: p S0)
  case 0
  then show ?case
    by (simp add: Pn.simps(2) Xn_def map_bind_pmf)
next
  case (Suc n)
  then show ?case
    by (simp add: Pn.simps(2) Xn_def map_bind_pmf bind_assoc_pmf)
qed

lemma Pn_markovian_eq_Xn_bind: "Pn (mk_markovian p) S0 n = K0 (p n) (Xn (mk_markovian p) S0 n)"
proof (induction n arbitrary: p S0)
  case 0
  then show ?case 
    unfolding Xn_def by auto
next
  case (Suc n)
  then show ?case
    unfolding Xn_def K0_def
    by (auto intro!: bind_pmf_cong simp: Pn.simps(2) map_bind_pmf Suc bind_assoc_pmf)
qed

lemma Xn_Suc': "Xn p S0 (Suc n) = K0 (p []) S0  (λsa. Xn (π_Suc p sa) (K sa) n)"
  unfolding Xn_def by (auto simp: Pn.simps(2) map_bind_pmf)

lemma set_pmf_X0 [simp]: "set_pmf (Xn p S0 0) = S0"
  using X0 by auto

lemma set_pmf_PSuc: "set_pmf (Pn (mk_markovian p) S0 n) = 
  {(s, a). s  set_pmf (Xn (mk_markovian p) S0 n)  a  p n s}"
  using set_pmf_K0 Pn_markovian_eq_Xn_bind by auto

subsubsection ‹The Conditional Distribution of Actions›
text ‹
Actions are selected wrt. the whole history of state-action pairs encountered so far.
The following definition defines the expected action selection when only the current state is given.›
definition "Y_cond_X p S0 n x = map_pmf snd (cond_pmf (Pn p S0 n) {(s,a). s = x})"

lemma prob_K0_X [simp]: "measure_pmf.prob (K0 p S0) {(s, a). s = x} = pmf S0 x"
  unfolding K0_iff
proof (subst measure_pmf_bind, subst measure_pmf.measure_bind[of _ _ "count_space UNIV"], goal_cases)
  case 1
  then show ?case
    by (simp add: measure_pmf_in_subprob_algebra)
next
  case 3
  then show ?case 
    by (subst integral_measure_pmf_real[of "{x}"]) (auto split: if_splits)
qed simp  

lemma prob_Pn_X[simp]: "measure_pmf.prob (Pn p S0 n) {(s, a). s = x} = pmf (Xn p S0 n) x"
proof (induction n arbitrary: p S0)
  case 0
  then show ?case
    by auto
next
  case (Suc n)
  show ?case
    unfolding Xn_Suc' Pn.simps(2) measure_pmf_bind 
    using Suc
    by (simp add: measure_pmf.measure_bind[of _ _ "count_space UNIV"] K0_def
        measure_pmf_in_subprob_algebra pmf_bind)
qed

lemma pmf_Pn_pair:
  assumes "sa  set_pmf (Pn p S0 n)"
  shows "pmf (Pn p S0 n) sa = pmf (Y_cond_X p S0 n (fst sa)) (snd sa) * pmf (Xn p S0 n) (fst sa)"
proof -
  have aux: "set_pmf (Pn p S0 n)  {(s, a). s = fst sa}  {}"
    using Xn_def assms by auto
  have aux': "({(s, a). s = fst sa}  snd -` {snd sa}) = {sa}"
    by auto
  show ?thesis
    using assms
    unfolding Y_cond_X_def pmf_map cond_pmf.rep_eq[OF aux]
    by (auto simp: Xn_def pmf_eq_0_set_pmf measure_pmf.emeasure_eq_measure aux' measure_pmf_single)
qed

lemma pmf_Pn:
  assumes "x  set_pmf (Xn p S0 n)"
  shows "pmf (Pn p S0 n) (x,a) = pmf (Y_cond_X p S0 n x) a * pmf (Xn p S0 n) x"
proof -
  have aux: "set_pmf (Pn p S0 n)  {(s, a). s = x}  {}"
    using Xn_def assms by auto
  have aux': "({(s, a). s = x}  snd -` {a}) = {(x, a)}"
    by auto
  show ?thesis
    using assms
    unfolding Y_cond_X_def cond_pmf.rep_eq[OF aux] pmf_map
    by (auto simp: pmf_eq_0_set_pmf measure_pmf.emeasure_eq_measure aux' measure_pmf_single)
qed

lemma pmf_Y_cond_X:
  assumes "x  set_pmf (Xn p S0 n)"
  shows "pmf (Y_cond_X p S0 n x) a = pmf (Pn p S0 n) (x,a) / pmf (Xn p S0 n) x"
proof -
  have aux: "set_pmf (Pn p S0 n)  {(s, a). s = x}  {}"
    using Xn_def assms by auto
  have aux': "({(s, a). s = x}  snd -` {a}) = {(x, a)}"
    by auto
  show ?thesis
    using assms aux'
    unfolding Y_cond_X_def
    by (auto simp: cond_pmf.rep_eq[OF aux] pmf_map pmf_eq_0_set_pmf measure_pmf.emeasure_eq_measure 
        measure_pmf_single)
qed

lemma Y_cond_X_0[simp]:
  assumes "x  set_pmf S0"
  shows "Y_cond_X p S0 0 x = p [] x"
  by (auto intro: pmf_eqI simp: assms pmf_Y_cond_X pmf_eq_0_set_pmf)

(* eqn 5.5.3 in Puterman *)
lemma Y_cond_X_markovian[simp]:
  assumes h: "x  Xn (mk_markovian p) S0 n"
  shows "Y_cond_X (mk_markovian p) S0 n x = p n x"
  by (auto intro!: pmf_eqI simp: pmf_Y_cond_X h Pn_markovian_eq_Xn_bind pmf_eq_0_set_pmf)

lemma Pn_eq_Xn_Y_cond: "Pn p S0 n = Xn p S0 n  (λx. map_pmf (λa. (x, a)) (Y_cond_X p S0 n x))"
proof (induction n)
  case 0
  then show ?case 
    by (auto simp: K0_iff intro: bind_pmf_cong)
next
  case (Suc n)
  show ?case
  proof (intro pmf_eqI; safe)
    fix a :: 's 
    fix b :: 'a
    have aux': "pmf (Xn p S0 (Suc n)  (λx. map_pmf (Pair x) (Y_cond_X p S0 (Suc n) x))) (a,b) 
      = measure_pmf.expectation (Pn p S0 (Suc n)) (λx. 
          if fst x = a then pmf (Y_cond_X p S0 (Suc n) a) b else 0)"
      by (auto intro!: Bochner_Integration.integral_cong[OF refl] 
          simp: Xn_def bind_map_pmf pmf_map pmf_bind measure_pmf_single)
    also have " = measure_pmf.expectation (Pn p S0 (Suc n))
     (λx. indicator {(s',a'). s' = a} x * (pmf (Pn p S0 (Suc n)) (a, b) / pmf (Xn p S0 (Suc n)) a))"
    proof (intro Bochner_Integration.integral_cong_AE AE_pmfI)
      fix y
      assume h: "y  set_pmf (Pn p S0 (Suc n))"
      hence h': "fst y  set_pmf (Xn p S0 (Suc n))"
        by (metis mult_eq_0_iff pmf_Pn_pair pmf_eq_0_set_pmf)
      show "(if fst y = a then pmf (Y_cond_X p S0 (Suc n) a) b else 0) = 
        indicat_real {(s', a'). s' = a} y * 
        (pmf (Pn p S0 (Suc n)) (a, b) / pmf (Xn p S0 (Suc n)) a)"
        by (auto simp: case_prod_beta' pmf_Y_cond_X[of "fst y" p S0 "(Suc n)" b, OF h'])
    qed auto
    also have " = measure_pmf.prob (Pn p S0 (Suc n)) {(s',a'). s' = a} * 
      pmf (Pn p S0 (Suc n)) (a, b) / pmf (Xn p S0 (Suc n)) a"
      by auto
    also have " = pmf (Pn p S0 (Suc n)) (a,b)"
      using prob_Pn_X Xn_def pmf_Pn_pair pmf_eq_0_set_pmf by fastforce
    finally show "pmf (Pn p S0 (Suc n)) (a, b) = pmf (Xn p S0 (Suc n)  
      (λx. map_pmf (Pair x) (Y_cond_X p S0 (Suc n) x))) (a, b)"
      by auto
  qed
qed

lemma Pn_eq_Xn_Y_cond': 
  "Pn p S0 n = Xn p S0 n  (λs. Y_cond_X p S0 n s  (λa. return_pmf (s,a)))"
  by (metis K0_def K0_iff Pn_eq_Xn_Y_cond)

lemma Pn_markovian_Suc: "Pn (mk_markovian p) S0 (Suc n) = 
  Pn (mk_markovian p) S0 n  (λsa. K0 (p (Suc n)) (K sa))"
proof (induction n arbitrary: S0 p)
  case 0
  then show ?case
    by (auto intro: bind_pmf_cong simp: Pn.simps(2) π_Suc_def)
next
  case (Suc n)
  show ?case
    by (auto simp add: Suc bind_assoc_pmf Pn.simps(2)[of _ S0] intro: bind_pmf_cong)
qed

subsubsection ‹Action Process›
text ‹The distribution of actions.›
definition "Yn p S0 n = map_pmf snd (Pn p S0 n)"

lemma Y0: "Yn p S0 0 = S0  p []"
  by (simp add: Yn_def K0_iff map_bind_pmf map_pmf_comp)

text ‹
For markovian policies, the decision rules at each epoch are independent of each other,
hence we may express @{const Yn} solely in terms of @{const Xn} and the current decision rule.
›

lemma Yn_markovian: "Yn (mk_markovian p) S0 n = Xn (mk_markovian p) S0 n  p n"
proof (induction n arbitrary: p S0) 
  case 0
  then show ?case
    by (auto simp: Y0)
next
  case (Suc n)
  then show ?case
    by (simp add: Xn_def Yn_def map_bind_pmf Suc Pn.simps(2) bind_assoc_pmf)
qed

subsection ‹Restriction to Markovian Policies›
abbreviation "as_markovian p S0 n x  
  if x  (Xn p S0 n) then Y_cond_X p S0 n x else return_pmf (SOME a. a  A x)"

text ‹
For states which cannot occur we choose an arbitrary enabled action, as in this case we cannot make
any statements about @{const Y_cond_X} (a distribution conditioned on an event with probability 0).
›

lemma is_ΠMR_as_markovian:
  assumes p: "is_policy p" 
  shows "as_markovian p S0  ΠMR"
proof -
  have aux: "hs s. s  set_pmf (Xn p S0 hs)  set_pmf ((Pn p S0 hs))  {(s', a). s' = s}  {}"
    by (simp add: measure_pmf_zero_iff[symmetric] pmf_eq_0_set_pmf)
  thus ?thesis
    using assms A_ne Pn_in_A by (auto simp: is_dec_def some_in_eq Y_cond_X_def)
qed

lemma is_policy_as_markovian: "is_policy p  is_policy (mk_markovian (as_markovian p S0))"
  using is_ΠMR_as_markovian ΠMR_imp_policies by auto

theorem Pn_as_markovian_eq: "Pn (mk_markovian (as_markovian p S0)) S0 = Pn p S0"
proof 
  fix n show "Pn (mk_markovian (as_markovian p S0)) S0 n = Pn p S0 n"
  proof (induction n)
    case 0
    thus ?case
      by (auto intro!: map_pmf_cong bind_pmf_cong simp: K0_def)
  next
    case (Suc n)
    have "x. x  Xn p S0 (Suc n)  
      Y_cond_X (mk_markovian (as_markovian p S0)) S0 (Suc n) x = Y_cond_X p S0 (Suc n) x"
      by (auto simp: Suc.IH Xn_Suc)
    moreover have "Xn (mk_markovian (as_markovian p S0)) S0 (Suc n) = Xn p S0 (Suc n)"
      by (simp add: Xn_Suc Suc.IH)
    ultimately show "Pn (mk_markovian (as_markovian p S0)) S0 (Suc n) = Pn p S0 (Suc n)"
      by (auto intro: bind_pmf_cong simp: Pn_eq_Xn_Y_cond)
  qed
qed

subsection ‹MDPs without Initial Distribution›
text ‹
From now on, we assume a known, deterministic initial state.
All results from the previous discussion carry over as we are now in the special case
where we the initial state is of the form @{term "return_pmf s"}.
›
definition "𝒯 p s  T p (return_pmf s)"

lemma 𝒯_eq_return_distr: "𝒯 p s = 
  measure_pmf (p [] s)  (λa. distr (T (π_Suc p (s,a)) (K (s,a))) S ((##) (s,a)))"
  unfolding 𝒯_def
  by (subst T_eq_distr) (fastforce intro!: bind_distr subprob_space.subprob_space_distr 
      simp: K0_iff map_pmf_rep_eq space_subprob_algebra bind_return_pmf)+

lemma 𝒯_eq_return:
  shows "𝒯 p s = do {
    y  measure_pmf (p [] s);
    ω  T (π_Suc p (s,y)) (K (s,y)); 
    return S ((s,y) ## ω)
  }"
  by (auto simp: 𝒯_eq_return_distr bind_return_distr' prob_space.not_empty intro!: bind_cong)

lemma 𝒯_return:
  shows "T p S0 = measure_pmf S0  𝒯 p"
proof -
  have "T p S0 = measure_pmf S0  (λx. measure_pmf (map_pmf (Pair x) (p [] x)))  
  (λsa. distr (T (π_Suc p sa) (K sa)) (stream_space (count_space UNIV)) ((##) sa))"
    unfolding T_eq_distr[of p] K0_iff measure_pmf_bind
    by auto
  also have " = measure_pmf S0 
    (λx. distr (measure_pmf (p [] x)) (count_space UNIV) (Pair x) 
          (λsa. distr (T (π_Suc p sa) (K sa)) (stream_space (count_space UNIV)) ((##) sa)))"
    using measurable_measure_pmf 
    by (subst bind_assoc[where N = "count_space UNIV", where R = S])
      (fastforce intro!: prob_space_imp_subprob_space prob_space.prob_space_distr 
        simp: space_subprob_algebra prob_space_measure_pmf map_pmf_rep_eq)+
  also have " = measure_pmf S0  𝒯 p"
    by (subst bind_distr[where K  = S])
      (auto intro!: prob_space_imp_subprob_space prob_space.prob_space_distr bind_cong
        simp: space_subprob_algebra 𝒯_eq_return_distr)
  finally show ?thesis.
qed

lemma 𝒯_return_eq:
"𝒯 p s = do {
  a  measure_pmf (p [] s);
  s'  measure_pmf (K (s,a));
  w  T (π_Suc p (s,a)) (return_pmf s');
  return S ((s,a)##w)
}"
  unfolding 𝒯_eq_return
  unfolding 𝒯_return
  by (subst bind_assoc[of _ _ S _ S]) (auto simp add: 𝒯_def 𝒯_return[symmetric])

lemma 𝒯_eq:
  shows "𝒯 p s = do {
  a  measure_pmf (p [] s);
  s'  measure_pmf (K (s,a));
  w  𝒯 (π_Suc p (s,a)) s';
  return S ((s,a)##w)
}"
  by (subst 𝒯_return_eq) (auto simp add: 𝒯_def )

lemma 𝒯_prob_space[intro]: "prob_space (𝒯 p s)"
  by (metis 𝒯_def prob_space_T)

lemma 𝒯_sets[measurable_cong]: 
  "sets (𝒯 p s) = sets S"
  by (simp add: 𝒯_def sets_T)

lemma measurable_ident_Suc'[measurable]: 
  "(λx. x)  𝒯 (π_Suc p sa) s' M S"
  by (simp add: 𝒯_def)

lemma nn_integral_𝒯: 
  fixes f :: "('s × 'a) stream  real"
  assumes f[measurable]: "f  borel_measurable S"
  shows "(+t. f t 𝒯 p s) 
    = +a. +s'. +t'. f ((s,a)##t') 𝒯 (π_Suc p (s,a)) s' K (s,a) p [] s"
proof -
  have "(+t. f t 𝒯 p s) = 
  + x. + y. (f y) measure_pmf (K (s, x))  (λs'. 𝒯 (π_Suc p (s, x)) s'  (λw. return S ((s, x) ## w))) (p [] s)"
    unfolding 𝒯_eq[of p]
    by (subst nn_integral_bind[of _ S])
      (auto intro!: measure_pmf.bind_in_space subprob_space.bind_in_space simp: 𝒯_prob_space prob_space_imp_subprob_space)
  also have " = + x. + xa. + y. (f y) 𝒯 (π_Suc p (s, x)) xa  (λw. return S ((s, x) ## w))
              measure_pmf (K (s, x)) (p [] s)"
    by (subst nn_integral_bind[of _ S])
      (auto intro!: subprob_space.bind_in_space simp: 𝒯_prob_space prob_space_imp_subprob_space)
  also have " = + x. + xa. + y. (f y) distr (𝒯 (π_Suc p (s, x)) xa) S ((##) (s, x)) measure_pmf (K (s, x)) (p [] s)"
    by (auto simp add: bind_return_distr' 𝒯_prob_space prob_space.not_empty)
  also have " = + x. + xa. + xa. (f ((s, x) ## xa)) 𝒯 (π_Suc p (s, x)) xa measure_pmf (K (s, x)) (p [] s)"
    by (auto simp: nn_integral_distr)
  finally show ?thesis.
qed

lemma integral_𝒯: 
  fixes f :: "('s × 'a) stream  real"
  assumes f_bounded: "x. ¦f x¦  B"
  assumes f[measurable]: "f  borel_measurable S"
  shows "(t. f t 𝒯 p s) 
    = a. s'. t'. f ((s,a)##t') 𝒯 (π_Suc p (s,a)) s' K (s,a) p [] s"
  unfolding 𝒯_def integral_T[OF f_bounded f] K0_iff bind_return_pmf
  unfolding 𝒯_return[of "π_Suc p _"] integral_map_pmf
  using 𝒯_return[of "π_Suc p _", symmetric]
  by (subst integral_bind[OF _ f_bounded, where B' = 1, where K = S]) 
    (auto simp: 𝒯_def intro: prob_space.emeasure_le_1)

lemma integrable_𝒯_bounded[intro]:
  fixes f :: "('s × 'a) stream  'd :: {second_countable_topology,banach}"
  assumes f[measurable]: "f  borel_measurable S"
  assumes b: "bounded (range f)"
  shows "integrable (𝒯 p s) f"
  using b
  by (auto simp: prob_space.finite_measure 𝒯_prob_space bounded_iff 
      intro!: finite_measure.integrable_const_bound)

definition "Pn' p s = Pn p (return_pmf s)"
definition "Xn' p s = Xn p (return_pmf s)"
definition "Yn' p s = Yn p (return_pmf s)"
definition "K0' d s  map_pmf (λa. (s, a)) (d s)"

definition "K_st d s  d s  (λa. K (s,a))"

lemma pmf_K_st: "pmf (K_st d s) t = a. pmf (K(s, a)) t d s"
  unfolding K_st_def by (auto simp: pmf_bind)

text @{const K_st} defines the distribution over the successor states for a given decision rule and state.
It is mostly useful for markovian policies, as the information which action was selected is lost.›

lemma P0'[simp]: "Pn' p s 0 = K0' (p []) s"
  by (simp add: Pn'_def K0'_def K0_iff bind_return_pmf)

lemma X0'[simp]: "Xn' p s 0 = return_pmf s"
  using X0 Xn'_def by auto

lemma Pn_return_pmf: "S0  (λs'. Pn p (return_pmf s') n) = Pn p S0 n"
  by (induction n arbitrary: p S0)
    (auto intro: bind_pmf_cong simp add: Pn.simps(2) K0_def bind_assoc_pmf bind_return_pmf)

lemma PSuc': "Pn' p s (Suc n) = K0' (p []) s  (λsa. K sa  (λs'. Pn' (π_Suc p sa) s' n))"
  unfolding Pn'_def
  by (auto intro!: bind_pmf_cong
      simp: Pn.simps(2) Pn_return_pmf K0_iff K0'_def bind_return_pmf map_bind_pmf bind_map_pmf)

lemma PSuc'_markovian: 
  "Pn' (mk_markovian p) s (Suc n) = K_st (p 0) s  (λs'. Pn' (mk_markovian (p  Suc)) s' n)"
  unfolding PSuc'
  by (auto simp: bind_map_pmf bind_assoc_pmf comp_def K0'_def K_st_def intro!: bind_pmf_cong)

lemma Xn'_Suc: "Xn' p s (Suc n) = Pn' p s n  K"
  by (auto simp: Xn_Suc Xn'_def Pn'_def)

lemma Xn'_Pn': "Xn' p s n = map_pmf fst (Pn' p s n)"
  by (simp add: Xn_def Xn'_def Pn'_def)

lemma Suc_Xn': "Xn' p s (Suc n) = p [] s  (λa. K (s,a)  (λs'. Xn' (π_Suc p (s,a)) s' n))"
  by (auto simp: Xn'_Pn' map_bind_pmf bind_map_pmf PSuc' K0'_def)

lemma Suc_Xn'_markovian: 
  "Xn' (mk_markovian p) s (Suc n) = K_st (p 0) s  (λs'. Xn' (mk_markovian (λn. p (Suc n))) s' n)"
  by (auto simp: K_st_def bind_assoc_pmf Suc_Xn')

lemma Xn'_split: "Xn' (mk_markovian p) s (n + m) = 
  Xn' (mk_markovian p) s n  (λs. Xn' (mk_markovian (λi. p (i + n))) s m)"
  by (induction n arbitrary: p s) (auto intro!: bind_pmf_cong simp: bind_assoc_pmf bind_return_pmf Suc_Xn')

lemma Yn'_markovian: "Yn' (mk_markovian p) s n = Xn' (mk_markovian p) s n  p n"
  unfolding Yn'_def Xn'_def Yn_markovian by simp

lemma Pn'_markovian_eq_Xn'_bind: "Pn' (mk_markovian p) s n = Xn' (mk_markovian p) s n  K0' (p n)"
  unfolding Xn'_def Pn'_def K0'_def K0_iff Pn_markovian_eq_Xn_bind by simp

lemma Pn'_eq_𝒯: "measure_pmf (Pn' p s n) = distr (𝒯 p s) (count_space UNIV) (λt. t !! n)"
  by (auto simp: 𝒯_def Pn'_def Pn_eq_T)

end
end