Theory MDP_RP_Certification

section ‹Certification of Reachability Problems on MDPs›

theory MDP_RP_Certification
imports
  "../MDP_Reachability_Problem"
  "HOL-Library.IArray"
  "HOL-Library.Code_Target_Numeral"
begin

context Reachability_Problem
begin

lemma p_ub':
  fixes x
  assumes 1: "s  S" "s D. s  S1  D  K s  (tS. pmf D t * x t)  x s"
  assumes 2: "s. s  S1  x s  0  (tS2. (s, t)  (SIGMA s:S1. DK s. set_pmf D)*)"
  assumes 3: "s. s  S - S1 - S2  x s = 0"
  assumes 4: "s. s  S2  x s = 1"
  shows "enn2real (p s)  x s"
proof (rule p_ub[OF 1 _ 4])
  fix s assume "s  S" "p s = 0" with 2[of s] p_pos[of s] p_S2[of s] 3[of s] show "x s = 0"
    by (cases "x s = 0") auto
qed

lemma n_lb':
  fixes x
  assumes "wf R"
  assumes 1: "s  S" "s D. s  S1  D  K s  x s  (tS. pmf D t * x t)"
  assumes 2: "s D. s  S1  D  K s  x s  0  tD. ((t, s)  R  t  S1  x t  0)  t  S2"
  assumes 3: "s. s  S - S1 - S2  x s = 0"
  assumes 4: "s. s  S2  x s = 1"
  shows "x s  enn2real (n s)"
proof (rule n_lb[OF 1 _ 4])
  fix s assume *: "s  S" "n s = 0"
  show "x s = 0"
  proof (rule ccontr)
    assume "x s  0"
    with * n_S2[of s] n_nS12[of s] 3[of s] have "s  S1"
      by (metis DiffI zero_neq_one)
    have "0 < n s"
      by (intro n_pos[of "λs. x s  0", OF x s  0 s  S1 wf R])
         (metis zero_less_one n_S2 2)
    with n s = 0 show False by auto
  qed
qed

end

no_notation Stream.snth (infixl !! 100) ― ‹we use @{text "!!"} for IArray›

subsection ‹Computable representation›

record mdp_reachability_problem =
  state_count :: nat
  distrs :: "(nat × rat) list list iarray"
  states1 :: "bool iarray"
  states2 :: "bool iarray"

record 'a RP_sub_cert =
  solution :: "rat iarray"
  witness :: "('a × nat) iarray"

record RP_cert =
  pos_cert :: "(nat × nat) RP_sub_cert"
  neg_cert :: "nat list RP_sub_cert"

definition "sparse_mult sx y = sum_list (map (λ(n, x). x * y !! n) sx)"

primrec lookup where
  "lookup d [] x = d"
| "lookup d (y#ys) x = (if fst y = x then snd y else lookup d ys x)"

lemma lookup_eq_map_of: "lookup d xs x = (case map_of xs x of Some x  x | None  d)"
  by (induct xs) simp_all

lemma lookup_in_set:
  "distinct (map fst xs)  x  set xs  lookup d xs (fst x) = snd x"
  unfolding lookup_eq_map_of by (subst map_of_is_SomeI[where y="snd x"]) simp_all

lemma lookup_not_in_set:
  "x  fst ` set xs  lookup d xs x = d"
  unfolding lookup_eq_map_of
  by (subst map_of_eq_None_iff[of xs x, THEN iffD2]) auto

lemma lookup_nonneg:
  "(x v. (x, v)  set xs  0  v)  (0::'a::ordered_comm_monoid_add)  lookup 0 xs x"
  apply (induction xs)
  apply simp
  apply force
  done

lemma sparse_mult_eq_sum_lookup:
  fixes xs :: "(nat × 'a::comm_semiring_1) list"
  assumes "list_all (λ(n, x). n < M) xs" "distinct (map fst xs)"
  shows "sparse_mult xs y = (i<M. lookup 0 xs i * y !! i)"
proof -
  from distinct (map fst xs) have "distinct xs" "inj_on fst (set xs)"
    by (simp_all add: distinct_map)
  then have "sparse_mult xs y = (xset xs. snd x * y !! fst x)"
    by (auto intro!: sum.cong simp add: sparse_mult_def sum_list_distinct_conv_sum_set)
  also have " = (xset xs. lookup 0 xs (fst x) * y !! fst x)"
    by (intro sum.cong refl arg_cong2[where f="(*)"]) (simp add: lookup_in_set assms)
  also have " = (xfst ` set xs. lookup 0 xs x * y !! x)"
    using inj_on fst (set xs) by (simp add: sum.reindex)
  also have " = (x<M. lookup 0 xs x * y !! x)"
    using assms(1)
    by (intro sum.mono_neutral_cong_left)
       (auto simp: list_all_iff lookup_eq_map_of map_of_eq_None_iff[THEN iffD2])
  finally show ?thesis .
qed

lemma sum_list_eq_sum_lookup:
  fixes xs :: "(nat × 'a::comm_semiring_1) list"
  assumes "list_all (λ(n, x). n < M) xs" "distinct (map fst xs)"
  shows "sum_list (map snd xs) = (i<M. lookup 0 xs i)"
proof -
  from distinct (map fst xs) have "distinct xs" "inj_on fst (set xs)"
    by (simp_all add: distinct_map)
  then have "sum_list (map snd xs) = (xset xs. snd x)"
    by (auto intro!: sum.cong simp add: sparse_mult_def sum_list_distinct_conv_sum_set)
  also have " = (xset xs. lookup 0 xs (fst x))"
    by (intro sum.cong refl arg_cong2[where f="(*)"]) (simp add: lookup_in_set assms)
  also have " = (xfst ` set xs. lookup 0 xs x)"
    using inj_on fst (set xs) by (simp add: sum.reindex)
  also have " = (x<M. lookup 0 xs x)"
    using assms(1)
    by (intro sum.mono_neutral_cong_left)
       (auto simp: list_all_iff lookup_eq_map_of map_of_eq_None_iff[THEN iffD2])
  finally show ?thesis .
qed

definition
  "valid_mdp_rp mdp 
    0 < state_count mdp 
    IArray.length (distrs mdp) = state_count mdp 
    IArray.length (states1 mdp) = state_count mdp 
    IArray.length (states2 mdp) = state_count mdp 
    (i<state_count mdp. ¬ (states1 mdp !! i  states2 mdp !! i) 
      list_all (λds. distinct (map fst ds)  list_all (λ(n, x). 0  x  n < state_count mdp) ds 
                     sum_list (map snd ds) = 1) (distrs mdp !! i) 
      ¬ List.null (distrs mdp !! i))"

definition
  "valid_sub_cert mdp c ord check 
    IArray.length (witness c) = state_count mdp 
    IArray.length (solution c) = state_count mdp 
    (i<state_count mdp.
      if states2 mdp !! i then solution c !! i = 1
      else if states1 mdp !! i then 0  solution c !! i 
        (list_all (λds. ord (sparse_mult ds (solution c)) (solution c !! i)) (distrs mdp !! i)) 
        (0 < solution c !! i  check (distrs mdp !! i) (witness c !! i))
      else solution c !! i = 0)"

definition
  "valid_pos_cert mdp c 
    valid_sub_cert mdp c (≤)
      (λD ((j, a), n). j < state_count mdp  snd (witness c !! j) < n  0 < solution c !! j 
        a < length D  lookup 0 (D ! a) j  0)"

definition
  "valid_neg_cert mdp c 
    valid_sub_cert mdp c (≥)
      (λD (J, n). list_all2 (λj d. j < state_count mdp  snd (witness c !! j) < n 
        lookup 0 d j  0  0 < solution c !! j) J D)"

definition
  "valid_cert mdp c  valid_pos_cert mdp (pos_cert c)  valid_neg_cert mdp (neg_cert c)"

lemma valid_mdp_rpD_length:
  assumes "valid_mdp_rp mdp"
  shows "0 < state_count mdp" "IArray.length (distrs mdp) = state_count mdp"
    "IArray.length (states1 mdp) = state_count mdp" "IArray.length (states2 mdp) = state_count mdp"
  using assms by (auto simp: valid_mdp_rp_def)

lemma valid_mdp_rpD:
  assumes "valid_mdp_rp mdp" "i < state_count mdp"
  shows "¬ (states1 mdp !! i  states2 mdp !! i)"
    and "ds n x. ds  set (distrs mdp !! i)  (n, x)  set ds  n < state_count mdp"
    and "ds n x. ds  set (distrs mdp !! i)  (n, x)  set ds  0  x"
    and "ds. ds  set (distrs mdp !! i)  sum_list (map snd ds) = 1"
    and "ds. ds  set (distrs mdp !! i)  distinct (map fst ds)"
    and "distrs mdp !! i  []"
  using assms by (auto simp: valid_mdp_rp_def list_all_iff List.null_def elim!: allE[of _ i])

lemma valid_mdp_rp_sparse_mult:
  assumes "valid_mdp_rp mdp" "i < state_count mdp" "ds  set (distrs mdp !! i)"
  shows "sparse_mult ds y = (i<state_count mdp. lookup 0 ds i * y !! i)"
  using valid_mdp_rpD(2,5)[OF assms] by (intro sparse_mult_eq_sum_lookup) (auto simp: list_all_iff)

lemma valid_sub_certD:
  assumes "valid_mdp_rp mdp" "valid_sub_cert mdp c ord check" "i < state_count mdp"
  shows "¬ states1 mdp !! i  ¬ states2 mdp !! i  solution c !! i = 0"
    and "states2 mdp !! i  solution c !! i = 1"
    and "states1 mdp !! i  0  solution c !! i"
    and "ds. states1 mdp !! i  ds  set (distrs mdp !! i)  ord (sparse_mult ds (solution c)) (solution c !! i)"
    and "ds. states1 mdp !! i  0 < solution c !! i  check (distrs mdp !! i) (witness c !! i)"
  using assms(2,3) valid_mdp_rpD(1)[OF assms(1,3)]
  by (auto simp add: valid_sub_cert_def list_all_iff)

lemma valid_pos_certD:
  assumes "valid_mdp_rp mdp" "valid_pos_cert mdp c" "i < state_count mdp" "states1 mdp !! i"
    "0 < solution c !! i" "witness c !! i = ((j, a), n)"
  shows "snd (witness c !! j) < n  j < state_count mdp  a < length (distrs mdp !! i) 
          lookup 0 ((distrs mdp !! i) ! a) j  0  0 < solution c !! j"
  using valid_sub_certD(5)[OF assms(1) assms(2)[unfolded valid_pos_cert_def] assms(3,4)] assms(5-) by auto

lemma valid_neg_certD:
  assumes "valid_mdp_rp mdp" "valid_neg_cert mdp c" "i < state_count mdp" "states1 mdp !! i"
    "0 < solution c !! i" "witness c !! i = (js, n)"
  shows "list_all2 (λj ds. j < state_count mdp  snd (witness c !! j) < n  lookup 0 ds j  0  0 < solution c !! j) js (distrs mdp !! i)"
  using valid_sub_certD(5)[OF assms(1) assms(2)[unfolded valid_neg_cert_def] assms(3)] assms(4-) by auto

context
  fixes mdp c
  assumes rp: "valid_mdp_rp mdp"
  assumes cert: "valid_cert mdp c"
begin

interpretation pmf_as_function .

abbreviation "S  {..< state_count mdp}"
abbreviation "S1  {i. i < state_count mdp  (states1 mdp) !! i}"
abbreviation "S2  {i. i < state_count mdp  (states2 mdp) !! i}"

lift_definition K :: "nat  nat pmf set" is
  "λi. if i < state_count mdp then
     { (λj. of_rat (lookup 0 D j) :: real) | D. D  set (distrs mdp !! i) }
     else { indicator {0} }"
proof (auto split: if_split_asm simp del: IArray.sub_def)
  fix n D assume n: "n < state_count mdp" and D: "D  set (distrs mdp !! n)"
  from valid_mdp_rpD(3)[OF rp this] show nn: "i. 0  lookup 0 D i"
    by (auto simp add: lookup_eq_map_of split: option.split dest: map_of_SomeD)
  show "(+ x. ennreal (real_of_rat (lookup 0 D x)) count_space UNIV) = 1"
    using valid_mdp_rpD(2,3,4,5)[OF rp n D]
    apply (subst nn_integral_count_space'[of "{..< state_count mdp}"])
    apply (auto intro: nn lookup_not_in_set simp: of_rat_sum[symmetric] lookup_nonneg)
    apply (subst sum_list_eq_sum_lookup[symmetric])
    apply (auto simp: list_all_iff lookup_eq_map_of split: option.split)
    done
next
  show "(+ x. ennreal (indicator {0} x) count_space UNIV) = 1"
    by (subst nn_integral_count_space'[of "{0}"]) auto
qed

interpretation MDP: Reachability_Problem K S S1 S2
proof
  show "S1  S2 = {}" "S1  S" "S2  S"
    using valid_mdp_rpD(1)[OF rp] by auto
  show "finite S" "S  {}"
    using valid_mdp_rp mdp by (auto simp add: valid_mdp_rp_def)
  show "s. K s  {}"
    using valid_mdp_rpD(6)[OF rp] by transfer simp
  show "s. finite (K s)"
    by transfer simp

  fix s assume "s  S" then show "(DK s. set_pmf D)  S"
    using valid_mdp_rpD(2)[OF rp]
    by transfer (auto simp: lookup_eq_map_of split: option.splits dest!: map_of_SomeD)
qed

definition "P_max s = enn2real (MDP.p s)"
definition "P_min s = enn2real (MDP.n s)"

lemma
  assumes "i < state_count mdp"
  shows P_max: "P_max i  real_of_rat (solution (pos_cert c) !! i)" (is ?max)
    and P_min: "P_min i  real_of_rat (solution (neg_cert c) !! i)" (is ?min)
proof -
  have "valid_pos_cert mdp (pos_cert c)" "valid_neg_cert mdp (neg_cert c)"
    using valid_cert mdp c by (auto simp: valid_cert_def)
  note pos = this(1)[unfolded valid_pos_cert_def] and neg = this(2)[unfolded valid_neg_cert_def]

  let ?x = "λs. real_of_rat (solution (pos_cert c) !! s)"
  have "enn2real (MDP.p i)  ?x i"
  proof (rule MDP.p_ub')
    show "i  S" using assms by simp
  next
    fix s D assume "s  S1" "D  K s"
    then obtain j where j: "j < length (distrs mdp !! s)"
      "i. i < state_count mdp  pmf D i = real_of_rat (lookup 0 (distrs mdp !! s ! j) i)"
      by transfer (auto simp: in_set_conv_nth)
    with valid_sub_certD(4)[OF valid_mdp_rp mdp pos, of s "distrs mdp !! s ! j"] s  S1
         valid_mdp_rp_sparse_mult[OF valid_mdp_rp mdp, of s "distrs mdp !! s ! j" "solution (pos_cert c)"]
    show "(tS. pmf D t * ?x t)  ?x s"
      by (simp add: of_rat_mult[symmetric] of_rat_sum[symmetric] of_rat_less_eq j)
  next
    fix s a assume "s  S2" then show "?x s = 1"
      using valid_sub_certD[OF valid_mdp_rp mdp pos] by simp
  next
    fix s define X where "X = (SIGMA s:S1. DK s. set_pmf D)"
    assume "s  S1" "?x s  0"
    with valid_sub_certD(3)[OF rp pos, of s]
    have "0 < ?x s"
      by simp
    with sS1 show "tS2. (s, t)  X*"
    proof (induction n"snd (witness (pos_cert c) !! s)" arbitrary: s rule: less_induct)
      case (less s)
      obtain t a n where eq: "witness (pos_cert c) !! s = ((t, a), n)"
        by (metis prod.exhaust)
      from valid_pos_certD[OF rp valid_pos_cert mdp (pos_cert c) _ _ _ this] less.prems
      have ord: "snd (witness (pos_cert c) !! t) < snd (witness (pos_cert c) !! s)"
        and t: "lookup 0 (distrs mdp !! s ! a) t  0" "0 < ?x t" "tS" "a < length (distrs mdp !! s)"
        unfolding eq by auto
      with sS1 have X: "(s, t)  X"
        unfolding X_def
        by (transfer fixing: s t a c)
           (auto simp: X_def in_set_conv_nth
                 intro!: exI[of _ "λj. real_of_rat (lookup 0 (distrs mdp !! s ! a) j)"]
                         exI[of _ "distrs mdp !! s ! a"] exI[of _ a])
      show ?case
      proof cases
        assume "t  S1"
        with less.hyps[OF ord _ 0 < ?x t] X show ?thesis
          by auto
      next
        assume "t  S1"
        with valid_sub_certD[OF valid_mdp_rp mdp pos, of t] 0 < ?x t tS
        have "t  S2"
          by auto
        with X show ?thesis
          by auto
      qed
    qed
  next
    fix s assume "s  S - S1 - S2" then show "?x s = 0"
      using valid_sub_certD(1)[OF valid_mdp_rp mdp pos, of s] by simp
  qed
  then show ?max
    by (simp add: P_max_def)

  let ?x = "λs. real_of_rat (solution (neg_cert c) !! s)"
  have "?x i  enn2real (MDP.n i)"
  proof (rule MDP.n_lb')
    show "i  S" using assms by simp
  next
    fix s D assume "s  S1" "D  K s"
    then obtain j where j: "j < length (distrs mdp !! s)"
      "i. i < state_count mdp  pmf D i = real_of_rat (lookup 0 (distrs mdp !! s ! j) i)"
      by transfer (auto simp: in_set_conv_nth)
    with valid_sub_certD(4)[OF valid_mdp_rp mdp neg, of s "distrs mdp !! s ! j"] s  S1
         valid_mdp_rp_sparse_mult[OF valid_mdp_rp mdp, of s "distrs mdp !! s ! j" "solution (neg_cert c)"]
    show "?x s  (tS. pmf D t * ?x t)"
      by (simp add: of_rat_mult[symmetric] of_rat_sum[symmetric] of_rat_less_eq j)
  next
    fix s a assume "s  S2" then show "?x s = 1"
      using valid_sub_certD[OF valid_mdp_rp mdp neg] by simp
  next
    show "wf ((S × S  {(s, t). snd (witness (neg_cert c) !! t) < snd (witness (neg_cert c) !! s)})¯)" (is "wf ?F")
      using MDP.S_finite
      by (intro finite_acyclic_wf_converse acyclicI_order[where f="λs. snd (witness (neg_cert c) !! s)"]) auto

    fix s D assume 2: "s  S1" "D  K s" and "?x s  0"
    then have "0 < ?x s"
      using valid_sub_certD(3)[OF valid_mdp_rp mdp neg, of s] by auto

    from 2 obtain a where a: "a < length (distrs mdp !! s)"
      "i. i < state_count mdp  pmf D i = real_of_rat (lookup 0 (distrs mdp !! s ! a) i)"
      by transfer (auto simp: in_set_conv_nth)

    obtain js n where eq: "witness (neg_cert c) !! s = (js, n)"
      by (metis prod.exhaust)
    from valid_neg_certD[OF valid_mdp_rp mdp valid_neg_cert mdp (neg_cert c) _ _ _ eq] a s  S1 0 < ?x s
    have *: "length js = length (distrs mdp !! s)" "js ! a  S"
      "snd (witness (neg_cert c) !! (js ! a)) < snd (witness (neg_cert c) !! s)"
      "lookup 0 (distrs mdp !! s ! a) (js ! a)  0"
      "0 < ?x (js ! a)"
      unfolding eq by (auto dest: list_all2_nthD2 list_all2_lengthD)
    with a s  S1 have js_a: "js ! a  D" "(js ! a, s)  ?F"
      by (auto simp: set_pmf_iff)

    show "tD. (t, s)  ?F  t  S1  ?x t  0  t  S2"
    proof cases
      assume "js ! a  S1" with js_a 0 < ?x (js ! a) show ?thesis by auto
    next
      assume "js ! a  S1"
      with 0 < ?x (js ! a) js!a  S valid_sub_certD[OF rp neg, of "js ! a"]
      have "js ! a  S2"
        by (auto simp:  less_le)
      with js ! a  D show ?thesis
        by auto
    qed
  next
    fix s assume "s  S - S1 - S2" then show "?x s = 0"
      using valid_sub_certD(1)[OF valid_mdp_rp mdp neg, of s] by simp
  qed
  then show ?min
    by (simp add: P_min_def)
qed

end

end