Theory Bellman_Ford

theory Bellman_Ford
imports Extended IArray Product_Lexorder RBT_Mapping State_Main Example_Misc Tracing Ground_Function
subsection ‹The Bellman-Ford Algorithm›

theory Bellman_Ford
  imports
    "HOL-Library.Extended"
    "HOL-Library.IArray"
    "HOL-Library.Code_Target_Numeral"
    "HOL-Library.Product_Lexorder"
    "HOL-Library.RBT_Mapping"
    "../state_monad/State_Main"
    "../heap_monad/Heap_Main"
    Example_Misc
    "../util/Tracing"
    "../util/Ground_Function"
begin

subsubsection ‹Misc›

instance extended :: (countable) countable
proof standard
  obtain to_nat :: "'a ⇒ nat" where "inj to_nat"
    by auto
  let ?f = "λ x. case x of Fin n ⇒ to_nat n + 2 | Pinf ⇒ 0 | Minf ⇒ 1"
  from ‹inj _ › have "inj ?f"
    by (auto simp: inj_def split: extended.split)
  then show "∃to_nat :: 'a extended ⇒ nat. inj to_nat"
    by auto
qed

instance extended :: (heap) heap ..

lemma fold_acc_preserv:
  assumes "⋀ x acc. P acc ⟹ P (f x acc)" "P acc"
  shows "P (fold f xs acc)"
  using assms(2) by (induction xs arbitrary: acc) (auto intro: assms(1))

lemma get_return:
  "run_state (State_Monad.bind State_Monad.get (λ m. State_Monad.return (f m))) m = (f m, m)"
  by (simp add: State_Monad.bind_def State_Monad.get_def)


subsubsection ‹Single-Sink Shortest Path Problem›

datatype bf_result = Path "nat list" int | No_Path | Computation_Error

context
  fixes n :: nat and W :: "nat ⇒ nat ⇒ int extended"
begin

context
  fixes t :: nat ― ‹Final node›
begin

text ‹
  The correctness proof closely follows Kleinberg ‹&› Tardos: "Algorithm Design",
  chapter "Dynamic Programming" @{cite "Kleinberg-Tardos"}
›

definition weight :: "nat list ⇒ int extended" where
  "weight xs = snd (fold (λ i (j, x). (i, W i j + x)) (rev xs) (t, 0))"

definition
  "OPT i v = (
    Min (
      {weight (v # xs) | xs. length xs + 1 ≤ i ∧ set xs ⊆ {0..n}} ∪
      {if t = v then 0 else ∞}
    )
  )"

lemma weight_Cons [simp]:
  "weight (v # w # xs) = W v w + weight (w # xs)"
  by (simp add: case_prod_beta' weight_def)

lemma weight_single [simp]:
  "weight [v] = W v t"
  by (simp add: weight_def)

(* XXX Generalize to the right type class *)
lemma Min_add_right:
  "Min S + (x :: int extended) = Min ((λ y. y + x) ` S)" (is "?A = ?B") if "finite S" "S ≠ {}"
proof -
  have "?A ≤ ?B"
    using that by (force intro: Min.boundedI add_right_mono)
  moreover have "?B ≤ ?A"
    using that by (force intro: Min.boundedI)
  ultimately show ?thesis
    by simp
qed

lemma OPT_0:
  "OPT 0 v = (if t = v then 0 else ∞)"
  unfolding OPT_def by simp

(* TODO: Move to distribution! *)
lemma Pinf_add_right[simp]:
  "∞ + x = ∞"
  by (cases x; simp)

subsubsection ‹Functional Correctness›

lemma OPT_Suc:
  "OPT (Suc i) v = min (OPT i v) (Min {OPT i w + W v w | w. w ≤ n})" (is "?lhs = ?rhs")
  if "t ≤ n"
proof -
  have fin': "finite {xs. length xs ≤ i ∧ set xs ⊆ {0..n}}" for i
    by (auto intro: finite_subset[OF _ finite_lists_length_le[OF finite_atLeastAtMost]])
  have fin: "finite {weight (v # xs) |xs. length xs ≤ i ∧ set xs ⊆ {0..n}}"
    for v i using [[simproc add: finite_Collect]] by (auto intro: finite_subset[OF _ fin'])
  have OPT_in: "OPT i v ∈
    {weight (v # xs) | xs. length xs + 1 ≤ i ∧ set xs ⊆ {0..n}} ∪
    {if t = v then 0 else ∞}"
    if "i > 0" for i v
    using that unfolding OPT_def
    by - (rule Min_in, auto 4 3 intro: finite_subset[OF _ fin, of _ v "Suc i"])

  have "OPT i v ≥ OPT (Suc i) v"
    unfolding OPT_def using fin by (auto 4 3 intro: Min_antimono)
  have subs:
    "(λy. y + W v w) ` {weight (w # xs) |xs. length xs + 1 ≤ i ∧ set xs ⊆ {0..n}}
    ⊆ {weight (v # xs) |xs. length xs + 1 ≤ Suc i ∧ set xs ⊆ {0..n}}" if ‹w ≤ n› for v w
    using ‹w ≤ n› apply clarify
    subgoal for _ _ xs
      by (rule exI[where x = "w # xs"]) (auto simp: algebra_simps)
    done
  have "OPT i t + W v t ≥ OPT (Suc i) v"
    unfolding OPT_def using subs[OF ‹t ≤ n›, of v] that
    by (subst Min_add_right)
       (auto 4 3
         simp: Bellman_Ford.weight_single
         intro: exI[where x = "[]"] finite_subset[OF _ fin[of _ "Suc i"]] intro!: Min_antimono
       )
  moreover have "OPT i w + W v w ≥ OPT (Suc i) v" if "w ≤ n" ‹w ≠ t› ‹t ≠ v› for w
    unfolding OPT_def using subs[OF ‹w ≤ n›, of v] that
    by (subst Min_add_right)
       (auto 4 3 intro: finite_subset[OF _ fin[of _ "Suc i"]] intro!: Min_antimono)
  moreover have "OPT i w + W t w ≥ OPT (Suc i) t" if "w ≤ n" ‹w ≠ t› for w
    unfolding OPT_def
    apply (subst Min_add_right)
      prefer 3
    using ‹w ≠ t›
      apply simp
      apply (cases "i = 0")
       apply (simp; fail)
    using subs[OF ‹w ≤ n›, of t]
    by (subst (2) Min_insert)
       (auto 4 4
         intro: finite_subset[OF _ fin[of _ "Suc i"]] exI[where x = "[]"] intro!: Min_antimono
       )
  ultimately have "Min {local.OPT i w + W v w |w. w ≤ n} ≥ OPT (Suc i) v"
    by (auto intro!: Min.boundedI)
  with ‹OPT i v ≥ _› have "?lhs ≤ ?rhs"
    by simp

  from OPT_in[of "Suc i" v] consider
    "OPT (Suc i) v = ∞" | "t = v" "OPT (Suc i) v = 0" |
    xs where "OPT (Suc i) v = weight (v # xs)" "length xs ≤ i" "set xs ⊆ {0..n}"
    by (auto split: if_split_asm)
  then have "?lhs ≥ ?rhs"
  proof cases
    case 1
    then show ?thesis
      by simp
  next
    case 2
    then have "OPT i v ≤ OPT (Suc i) v"
      unfolding OPT_def using [[simproc add: finite_Collect]]
      by (auto 4 4 intro: finite_subset[OF _ fin', of _ "Suc i"] intro!: Min_le)
    then show ?thesis
      by (rule min.coboundedI1)
  next
    case xs: 3
    note [simp] = xs(1)
    show ?thesis
    proof (cases "length xs = i")
      case True
      show ?thesis
      proof (cases "i = 0")
        case True
        with xs have "OPT (Suc i) v = W v t"
          by simp
        also have "W v t = OPT i t + W v t"
          unfolding OPT_def using ‹i = 0› by auto
        also have "… ≥ Min {OPT i w + W v w |w. w ≤ n}"
          using ‹t ≤ n› by (auto intro: Min_le)
        finally show ?thesis
          by (rule min.coboundedI2)
      next
        case False
        with ‹_ = i› have "xs ≠ []"
          by auto
        with xs have "weight xs ≥ OPT i (hd xs)"
          unfolding OPT_def
          by (intro Min_le[rotated] UnI1 CollectI exI[where x = "tl xs"])
             (auto 4 3 intro: finite_subset[OF _ fin, of _ "hd xs" "Suc i"] dest: list.set_sel(2))
        have "Min {OPT i w + W v w |w. w ≤ n} ≤ W v (hd xs) + OPT i (hd xs)"
          using ‹set xs ⊆ _› ‹xs ≠ []› by (force simp: add.commute intro: Min_le)
        also have "… ≤ W v (hd xs) + weight xs"
          using ‹_ ≥ OPT i (hd xs)› by (metis add_left_mono)
        also from ‹xs ≠ []› have "… = OPT (Suc i) v"
          by (cases xs) auto
        finally show ?thesis
          by (rule min.coboundedI2)
      qed
    next
      case False
      with xs have "OPT i v ≤ OPT (Suc i) v"
        by (auto 4 4 intro: Min_le finite_subset[OF _ fin, of _ v "Suc i"] simp: OPT_def)
      then show ?thesis
        by (rule min.coboundedI1)
    qed
  qed
  with ‹?lhs ≤ ?rhs› show ?thesis
    by (rule order.antisym)
qed

fun bf :: "nat ⇒ nat ⇒ int extended" where
  "bf 0 j = (if t = j then 0 else ∞)"
| "bf (Suc k) j = min_list
      (bf k j # [W j i + bf k i . i ← [0 ..< Suc n]])"

lemmas [simp del] = bf.simps
lemmas [simp] = bf.simps[unfolded min_list_fold]
thm bf.simps
thm bf.induct

lemma bf_correct:
  "OPT i j = bf i j" if ‹t ≤ n›
proof (induction i arbitrary: j)
  case 0
  then show ?case
    by (simp add: OPT_0)
next
  case (Suc i)
  have *:
    "{bf i w + W j w |w. w ≤ n} = set (map (λw. W j w + bf i w) [0..<Suc n])"
    by (fastforce simp: add.commute image_def)
  from Suc ‹t ≤ n› show ?case
    by (simp add: OPT_Suc del: upt_Suc, subst Min.set_eq_fold[symmetric], auto simp: *)
qed


subsubsection ‹Functional Memoization›

memoize_fun bfm: bf with_memory dp_consistency_mapping monadifies (state) bf.simps

text ‹Generated Definitions›
context includes state_monad_syntax begin
thm bfm'.simps bfm_def
end

text ‹Correspondence Proof›
memoize_correct
  by memoize_prover
print_theorems
lemmas [code] = bfm.memoized_correct

interpretation iterator
  "λ (x, y). x ≤ n ∧ y ≤ n"
  "λ (x, y). if y < n then (x, y + 1) else (x + 1, 0)"
  "λ (x, y). x * (n + 1) + y"
  by (rule table_iterator_up)

interpretation bottom_up: dp_consistency_iterator_empty
  "λ (_::(nat × nat, int extended) mapping). True"
  "λ (x, y). bf x y"
  "λ k. do {m ← State_Monad.get; State_Monad.return (Mapping.lookup m k :: int extended option)}"
  "λ k v. do {m ← State_Monad.get; State_Monad.set (Mapping.update k v m)}"
  "λ (x, y). x ≤ n ∧ y ≤ n"
  "λ (x, y). if y < n then (x, y + 1) else (x + 1, 0)"
  "λ (x, y). x * (n + 1) + y"
  Mapping.empty ..

definition
  "iter_bf = iter_state (λ (x, y). bfm' x y)"

lemma iter_bf_unfold[code]:
  "iter_bf = (λ (i, j).
    (if i ≤ n ∧ j ≤ n
     then do {
            bfm' i j;
            iter_bf (if j < n then (i, j + 1) else (i + 1, 0))
          }
     else State_Monad.return ()))"
  unfolding iter_bf_def by (rule ext) (safe, clarsimp simp: iter_state_unfold)

lemmas bf_memoized = bfm.memoized[OF bfm.crel]
lemmas bf_bottom_up = bottom_up.memoized[OF bfm.crel, folded iter_bf_def]

thm bfm'.simps bf_memoized


subsubsection ‹Imperative Memoization›

context
  fixes mem :: "nat ref × nat ref × int extended option array ref × int extended option array ref"
  assumes mem_is_init: "mem = result_of (init_state (n + 1) 1 0) Heap.empty"
begin

lemma [intro]:
  "dp_consistency_heap_array_pair' (n + 1) fst snd id 1 0 mem"
  by (standard; simp add: mem_is_init injective_def)

interpretation iterator
  "λ (x, y). x ≤ n ∧ y ≤ n"
  "λ (x, y). if y < n then (x, y + 1) else (x + 1, 0)"
  "λ (x, y). x * (n + 1) + y"
  by (rule table_iterator_up)

lemma [intro]:
  "dp_consistency_heap_array_pair_iterator (n + 1) fst snd id 1 0 mem
  (λ (x, y). if y < n then (x, y + 1) else (x + 1, 0))
  (λ (x, y). x * (n + 1) + y)
  (λ (x, y). x ≤ n ∧ y ≤ n)"
  by (standard; simp add: mem_is_init injective_def)

memoize_fun bfh: bf
  with_memory (default_proof) dp_consistency_heap_array_pair_iterator
  where size = "n + 1"
    and key1="fst :: nat × nat ⇒ nat" and key2="snd :: nat × nat ⇒ nat"
    and k1="1 :: nat" and k2="0 :: nat"
    and to_index = "id :: nat ⇒ nat"
    and mem = mem
    and cnt = "λ (x, y). x ≤ n ∧ y ≤ n"
    and nxt = "λ (x :: nat, y). if y < n then (x, y + 1) else (x + 1, 0)"
    and sizef = "λ (x, y). x * (n + 1) + y"
monadifies (heap) bf.simps

memoize_correct
  by memoize_prover

lemmas memoized_empty = bfh.memoized_empty[OF bfh.consistent_DP_iter_and_compute[OF bfh.crel]]
lemmas iter_heap_unfold = iter_heap_unfold

end (* Fixed Memory *)

end (* Final Node *)

end (* Bellman Ford *)



subsubsection ‹Extracting an Executable Constant for the Imperative Implementation›

ground_function (prove_termination) bfh'_impl: bfh'.simps

lemma bfh'_impl_def:
  fixes n :: nat
  fixes mem :: "nat ref × nat ref × int extended option array ref × int extended option array ref"
  assumes mem_is_init: "mem = result_of (init_state (n + 1) 1 0) Heap.empty"
  shows "bfh'_impl n w t mem = bfh' n w t mem"
proof -
  have "bfh'_impl n w t mem i j = bfh' n w t mem i j" for i j
    by (induction rule: bfh'.induct[OF mem_is_init];
        simp add: bfh'.simps[OF mem_is_init]; solve_cong simp
       )
  then show ?thesis
    by auto
qed

definition
  "iter_bf_heap n w t mem = iterator_defs.iter_heap
      (λ(x, y). x ≤ n ∧ y ≤ n)
      (λ(x, y). if y < n then (x, y + 1) else (x + 1, 0))
      (λ(x, y). bfh'_impl n w t mem x y)"

lemma iter_bf_heap_unfold[code]:
  "iter_bf_heap n w t mem = (λ (i, j).
    (if i ≤ n ∧ j ≤ n
     then do {
            bfh'_impl n w t mem i j;
            iter_bf_heap n w t mem (if j < n then (i, j + 1) else (i + 1, 0))
          }
     else Heap_Monad.return ()))"
  unfolding iter_bf_heap_def by (rule ext) (safe, simp add: iter_heap_unfold)

definition
  "bf_impl n w t i j = do {
    mem ← (init_state (n + 1) (1::nat) (0::nat) ::
      (nat ref × nat ref × int extended option array ref × int extended option array ref) Heap);
    iter_bf_heap n w t mem (0, 0);
    bfh'_impl n w t mem i j
  }"

lemma bf_impl_correct:
  "bf n w t i j = result_of (bf_impl n w t i j) Heap.empty"
  using memoized_empty[OF HOL.refl, of n w t "(i, j)"]
  by (simp add:
        execute_bind_success[OF succes_init_state] bf_impl_def bfh'_impl_def iter_bf_heap_def
      )


subsubsection ‹Test Cases›

definition
  "G1_list = [[(1 :: nat,-6 :: int), (2,4), (3,5)], [(3,10)], [(3,2)], []]"

definition
  "graph_of a i j = case_option ∞ (Fin o snd) (List.find (λ p. fst p = j) (a !! i))"

definition "test_bf = bf_impl 3 (graph_of (IArray G1_list)) 3 3 0"

code_reflect Test functions test_bf

text ‹One can see a trace of the calls to the memory in the output›
ML ‹Test.test_bf ()›

lemma bottom_up_alt[code]:
  "bf n W t i j =
     fst (run_state
      (iter_bf n W t (0, 0) ⤜ (λ_. bfm' n W t i j))
      Mapping.empty)"
  using bf_bottom_up by auto

definition
  "bf_ia n W t i j = (let W' = graph_of (IArray W) in
    fst (run_state
      (iter_bf n W' t (i, j) ⤜ (λ_. bfm' n W' t i j))
      Mapping.empty)
  )"

value "fst (run_state (bfm' 3 (graph_of (IArray G1_list)) 3 3 0) Mapping.empty)"

value "bf 3 (graph_of (IArray G1_list)) 3 3 0"

end (* Theory *)