Theory Bellman_Ford

subsection ‹The Bellman-Ford Algorithm›

theory Bellman_Ford
  imports
    "HOL-Library.IArray"
    "HOL-Library.Code_Target_Numeral"
    "HOL-Library.Product_Lexorder"
    "HOL-Library.RBT_Mapping"
    "../heap_monad/Heap_Main"
    Example_Misc
    "../util/Tracing"
    "../util/Ground_Function"
begin

subsubsection ‹Misc›

lemma nat_le_cases:
  fixes n :: nat
  assumes "i  n"
  obtains "i < n" | "i = n"
  using assms by (cases "i = n") auto

context dp_consistency_iterator
begin

lemma crel_vs_iterate_state:
  "crel_vs (=) () (iter_state f x)" if "((=) ===>T R) g f"
  by (metis crel_vs_iterate_state iter_state_iterate_state that)

lemma consistent_crel_vs_iterate_state:
  "crel_vs (=) () (iter_state f x)" if "consistentDP f"
  using consistentDP_def crel_vs_iterate_state that by simp

end

instance extended :: (countable) countable
proof standard
  obtain to_nat :: "'a  nat" where "inj to_nat"
    by auto
  let ?f = "λ x. case x of Fin n  to_nat n + 2 | Pinf  0 | Minf  1"
  from inj _ have "inj ?f"
    by (auto simp: inj_def split: extended.split)
  then show "to_nat :: 'a extended  nat. inj to_nat"
    by auto
qed

instance extended :: (heap) heap ..

instantiation "extended" :: (conditionally_complete_lattice) complete_lattice
begin

definition
  "Inf A = (
    if A = {}  A = {} then 
    else if -∞  A  ¬ bdd_below (Fin -` A) then -∞
    else Fin (Inf (Fin -` A)))"

definition
  "Sup A = (
    if A = {}  A = {-∞} then -∞
    else if   A  ¬ bdd_above (Fin -` A) then 
    else Fin (Sup (Fin -` A)))"

instance
proof standard
  have [dest]: "Inf (Fin -` A)  x" if "Fin x  A" "bdd_below (Fin -` A)" for A and x :: 'a
    using that by (intro cInf_lower) auto
  have *: False if "¬ z  Inf (Fin -` A)" "x. x  A  Fin z  x" "Fin x  A" for A and x z :: 'a
    using cInf_greatest[of "Fin -` A" z] that vimage_eq by force
  show "Inf A  x" if "x  A" for x :: "'a extended" and A
    using that unfolding Inf_extended_def by (cases x) auto
  show "z  Inf A" if "x. x  A  z  x" for z :: "'a extended" and A
    using that
    unfolding Inf_extended_def
    apply (clarsimp; safe)
         apply force
        apply force
    subgoal
      by (cases z; force simp: bdd_below_def)
    subgoal
      by (cases z; force simp: bdd_below_def)
    subgoal for x y
      by (cases z; cases y) (auto elim: *)
    subgoal for x y
      by (cases z; cases y; simp; metis * less_eq_extended.elims(2))
    done
  have [dest]: "x  Sup (Fin -` A)" if "Fin x  A" "bdd_above (Fin -` A)" for A and x :: 'a
    using that by (intro cSup_upper) auto
  have *: False if "¬ Sup (Fin -` A)  z" "x. x  A  x  Fin z" "Fin x  A" for A and x z :: 'a
    using cSup_least[of "Fin -` A" z] that vimage_eq by force
  show "x  Sup A" if "x  A" for x :: "'a extended" and A
    using that unfolding Sup_extended_def by (cases x) auto
  show "Sup A  z" if "x. x  A  x  z" for z :: "'a extended" and A
    using that
    unfolding Sup_extended_def
    apply (clarsimp; safe)
         apply force
        apply force
    subgoal
      by (cases z; force)
    subgoal
      by (cases z; force)
    subgoal for x y
      by (cases z; cases y) (auto elim: *)
    subgoal for x y
      by (cases z; cases y; simp; metis * extended.exhaust)
    done
  show "Inf {} = (top::'a extended)"
    unfolding Inf_extended_def top_extended_def by simp
  show "Sup {} = (bot::'a extended)"
    unfolding Sup_extended_def bot_extended_def by simp
qed

end

instance "extended" :: ("{conditionally_complete_lattice,linorder}") complete_linorder ..


lemma Minf_eq_zero[simp]: "-∞ = 0  False" and Pinf_eq_zero[simp]: " = 0  False"
  unfolding zero_extended_def by auto

lemma Sup_int:
  fixes x :: int and X :: "int set"
  assumes "X  {}" "bdd_above X"
  shows "Sup X  X  (yX. y  Sup X)"
proof -
  from assms obtain x y where "X  {..y}" "x  X"
    by (auto simp: bdd_above_def)
  then have *: "finite (X  {x..y})" "X  {x..y}  {}" and "x  y"
    by (auto simp: subset_eq)
  have "∃!xX. (yX. y  x)"
  proof
    { fix z assume "z  X"
      have "z  Max (X  {x..y})"
      proof cases
        assume "x  z" with z  X X  {..y} *(1) show ?thesis
          by (auto intro!: Max_ge)
      next
        assume "¬ x  z"
        then have "z < x" by simp
        also have "x  Max (X  {x..y})"
          using x  X *(1) x  y by (intro Max_ge) auto
        finally show ?thesis by simp
      qed }
    note le = this
    with Max_in[OF *] show ex: "Max (X  {x..y})  X  (zX. z  Max (X  {x..y}))" by auto

    fix z assume *: "z  X  (yX. y  z)"
    with le have "z  Max (X  {x..y})"
      by auto
    moreover have "Max (X  {x..y})  z"
      using * ex by auto
    ultimately show "z = Max (X  {x..y})"
      by auto
  qed
  then show "Sup X  X  (yX. y  Sup X)"
    unfolding Sup_int_def by (rule theI')
qed

lemmas Sup_int_in = Sup_int[THEN conjunct1]

lemma Inf_int_in:
  fixes S :: "int set"
  assumes "S  {}" "bdd_below S"
  shows "Inf S  S"
  using assms unfolding Inf_int_def by (smt Sup_int_in bdd_above_uminus image_iff image_is_empty)


lemma finite_setcompr_eq_image: "finite {f x |x. P x}  finite (f ` {x. P x})"
  by (simp add: setcompr_eq_image)

lemma finite_lists_length_le1: "finite {xs. length xs  i  set xs  {0..(n::nat)}}" for i
  by (auto intro: finite_subset[OF _ finite_lists_length_le[OF finite_atLeastAtMost]])

lemma finite_lists_length_le2: "finite {xs. length xs + 1  i  set xs  {0..(n::nat)}}" for i
  by (auto intro: finite_subset[OF _ finite_lists_length_le1[of "i"]])

lemmas [simp] =
  finite_setcompr_eq_image finite_lists_length_le2[simplified] finite_lists_length_le1


lemma get_return:
  "run_state (State_Monad.bind State_Monad.get (λ m. State_Monad.return (f m))) m = (f m, m)"
  by (simp add: State_Monad.bind_def State_Monad.get_def)


lemma list_pidgeonhole:
  assumes "set xs  S" "card S < length xs" "finite S"
  obtains as a bs cs where "xs = as @ a # bs @ a # cs"
proof -
  from assms have "¬ distinct xs"
    by (metis card_mono distinct_card not_le)
  then show ?thesis
    by (metis append.assoc append_Cons not_distinct_conv_prefix split_list that)
qed

lemma path_eq_cycleE:
  assumes "v # ys @ [t] = as @ a # bs @ a # cs"
  obtains (Nil_Nil) "as = []" "cs = []" "v = a" "a = t" "ys = bs"
  | (Nil_Cons) cs' where "as = []" "v = a" "ys = bs @ a # cs'" "cs = cs' @ [t]"
  | (Cons_Nil) as' where "as = v # as'" "cs = []" "a = t" "ys = as' @ a # bs"
  | (Cons_Cons) as' cs' where "as = v # as'" "cs = cs' @ [t]" "ys = as' @ a # bs @ a # cs'"
  using assms by (auto simp: Cons_eq_append_conv append_eq_Cons_conv append_eq_append_conv2)

lemma le_add_same_cancel1:
  "a + b  a  b  0" if "a < " "-∞ < a" for a b :: "int extended"
  using that by (cases a; cases b) (auto simp add: zero_extended_def)

lemma add_gt_minfI:
  assumes "-∞ < a" "-∞ < b"
  shows "-∞ < a + b"
  using assms by (cases a; cases b) auto

lemma add_lt_infI:
  assumes "a < " "b < "
  shows "a + b < "
  using assms by (cases a; cases b) auto

lemma sum_list_not_infI:
  "sum_list xs < " if " x  set xs. x < " for xs :: "int extended list"
  using that
  apply (induction xs)
   apply (simp add: zero_extended_def)+
  by (smt less_extended_simps(2) plus_extended.elims)

lemma sum_list_not_minfI:
  "sum_list xs > -∞" if " x  set xs. x > -∞" for xs :: "int extended list"
  using that by (induction xs) (auto intro: add_gt_minfI simp: zero_extended_def)



subsubsection ‹Single-Sink Shortest Path Problem›

datatype bf_result = Path "nat list" int | No_Path | Computation_Error

context
  fixes n :: nat and W :: "nat  nat  int extended"
begin

context
  fixes t :: nat ― ‹Final node›
begin

text ‹
  The correctness proof closely follows Kleinberg &› Tardos: "Algorithm Design",
  chapter "Dynamic Programming" cite"Kleinberg-Tardos"

fun weight :: "nat list  int extended" where
  "weight [v] = 0"
| "weight (v # w # xs) = W v w + weight (w # xs)"

definition
  "OPT i v = (
    Min (
      {weight (v # xs @ [t]) | xs. length xs + 1  i  set xs  {0..n}} 
      {if t = v then 0 else }
    )
  )"

lemma weight_alt_def':
  "weight (s # xs) + w = snd (fold (λj (i, x). (j, W i j + x)) xs (s, w))"
  by (induction xs arbitrary: s w; simp; smt add.commute add.left_commute)

lemma weight_alt_def:
  "weight (s # xs) = snd (fold (λj (i, x). (j, W i j + x)) xs (s, 0))"
  by (rule weight_alt_def'[of s xs 0, simplified])

lemma weight_append:
  "weight (xs @ a # ys) = weight (xs @ [a]) + weight (a # ys)"
  by (induction xs rule: weight.induct; simp add: add.assoc)

lemma OPT_0:
  "OPT 0 v = (if t = v then 0 else )"
  unfolding OPT_def by simp


subsubsection ‹Functional Correctness›

lemma OPT_cases:
  obtains (path) xs where "OPT i v = weight (v # xs @ [t])" "length xs + 1  i" "set xs  {0..n}"
  | (sink) "v = t" "OPT i v = 0"
  | (unreachable) "v  t" "OPT i v = "
  unfolding OPT_def
  using Min_in[of "{weight (v # xs @ [t]) |xs. length xs + 1  i  set xs  {0..n}}
     {if t = v then 0 else }"]
  by (auto simp: finite_lists_length_le2[simplified] split: if_split_asm)

lemma OPT_Suc:
  "OPT (Suc i) v = min (OPT i v) (Min {OPT i w + W v w | w. w  n})" (is "?lhs = ?rhs")
  if "t  n"
proof -
  have "OPT i w + W v w  OPT (Suc i) v" if "w  n" for w
    using OPT_cases[of i w]
  proof cases
    case (path xs)
    with w  n show ?thesis
      by (subst OPT_def) (auto intro!: Min_le exI[where x = "w # xs"] simp: add.commute)
  next
    case sink
    then show ?thesis
      by (subst OPT_def) (auto intro!: Min_le exI[where x = "[]"])
  next
    case unreachable
    then show ?thesis
      by simp
  qed
  then have "Min {OPT i w + W v w |w. w  n}  OPT (Suc i) v"
    by (auto intro!: Min.boundedI)
  moreover have "OPT i v  OPT (Suc i) v"
    unfolding OPT_def by (rule Min_antimono) auto
  ultimately have "?lhs  ?rhs"
    by simp

  from OPT_cases[of "Suc i" v] have "?lhs  ?rhs"
  proof cases
    case (path xs)
    note [simp] = path(1)
    from path consider
      (zero) "i = 0" "length xs = 0" | (new) "i > 0" "length xs = i" | (old) "length xs < i"
      by (cases "length xs = i") auto
    then show ?thesis
    proof cases
      case zero
      with path have "OPT (Suc i) v = W v t"
        by simp
      also have "W v t = OPT i t + W v t"
        unfolding OPT_def using i = 0 by auto
      also have "  Min {OPT i w + W v w |w. w  n}"
        using t  n by (auto intro: Min_le)
      finally show ?thesis
        by (rule min.coboundedI2)
    next
      case new
      with _ = i obtain u ys where [simp]: "xs = u # ys"
        by (cases xs) auto
      from path have "OPT i u  weight (u # ys @ [t])"
        unfolding OPT_def by (intro Min_le) auto
      from path have "Min {OPT i w + W v w |w. w  n}  W v u + OPT i u"
        by (intro Min_le) (auto simp: add.commute)
      also from OPT i u  _ have "  OPT (Suc i) v"
        by (simp add: add_left_mono)
      finally show ?thesis
        by (rule min.coboundedI2)
    next
      case old
      with path have "OPT i v  OPT (Suc i) v"
        by (auto 4 3 intro: Min_le simp: OPT_def)
      then show ?thesis
        by (rule min.coboundedI1)
    qed
  next
    case unreachable
    then show ?thesis
      by simp
  next
    case sink
    then have "OPT i v  OPT (Suc i) v"
      unfolding OPT_def by simp
    then show ?thesis
      by (rule min.coboundedI1)
  qed

  with ?lhs  ?rhs show ?thesis
    by (rule order.antisym)
qed

fun bf :: "nat  nat  int extended" where
  "bf 0 v = (if t = v then 0 else )"
| "bf (Suc i) v = min_list
      (bf i v # [W v w + bf i w . w  [0 ..< Suc n]])"

lemmas [simp del] = bf.simps
lemmas bf_simps[simp] = bf.simps[unfolded min_list_fold]

lemma bf_correct:
  "OPT i j = bf i j" if t  n
proof (induction i arbitrary: j)
  case 0
  then show ?case
    by (simp add: OPT_0)
next
  case (Suc i)
  have *:
    "{bf i w + W j w |w. w  n} = set (map (λw. W j w + bf i w) [0..<Suc n])"
    by (fastforce simp: add.commute image_def)
  from Suc t  n show ?case
    by (simp add: OPT_Suc del: upt_Suc, subst Min.set_eq_fold[symmetric], auto simp: *)
qed


subsubsection ‹Functional Memoization›

memoize_fun bfm: bf with_memory dp_consistency_mapping monadifies (state) bf.simps

text ‹Generated Definitions›
context includes state_monad_syntax begin
thm bfm'.simps bfm_def
end

text ‹Correspondence Proof›
memoize_correct
  by memoize_prover
print_theorems
lemmas [code] = bfm.memoized_correct

interpretation iterator
  "λ (x, y). x  n  y  n"
  "λ (x, y). if y < n then (x, y + 1) else (x + 1, 0)"
  "λ (x, y). x * (n + 1) + y"
  by (rule table_iterator_up)

interpretation bottom_up: dp_consistency_iterator_empty
  "λ (_::(nat × nat, int extended) mapping). True"
  "λ (x, y). bf x y"
  "λ k. do {m  State_Monad.get; State_Monad.return (Mapping.lookup m k :: int extended option)}"
  "λ k v. do {m  State_Monad.get; State_Monad.set (Mapping.update k v m)}"
  "λ (x, y). x  n  y  n"
  "λ (x, y). if y < n then (x, y + 1) else (x + 1, 0)"
  "λ (x, y). x * (n + 1) + y"
  Mapping.empty ..

definition
  "iter_bf = iter_state (λ (x, y). bfm' x y)"

lemma iter_bf_unfold[code]:
  "iter_bf = (λ (i, j).
    (if i  n  j  n
     then do {
            bfm' i j;
            iter_bf (if j < n then (i, j + 1) else (i + 1, 0))
          }
     else State_Monad.return ()))"
  unfolding iter_bf_def by (rule ext) (safe, clarsimp simp: iter_state_unfold)

lemmas bf_memoized = bfm.memoized[OF bfm.crel]
lemmas bf_bottom_up = bottom_up.memoized[OF bfm.crel, folded iter_bf_def]

text ‹
This will be our final implementation, which includes detection of negative cycles.
See the corresponding section below for the correctness proof.
›
definition
  "bellman_ford 
    do {
      _   iter_bf (n, n);
      xs  State_Main.mapT' (λi. bfm' n i) [0..<n+1];
      ys  State_Main.mapT' (λi. bfm' (n + 1) i) [0..<n+1];
      State_Monad.return (if xs = ys then Some xs else None)
    }"

context
  includes state_monad_syntax
begin

lemma bellman_ford_alt_def:
  "bellman_ford 
    do {
      _   iter_bf (n, n);
      (λxs. λys. State_Monad.return (if xs = ys then Some xs else None)
      . (State_Main.mapT . λi. bfm' (n + 1) i . [0..<n+1]))
      . (State_Main.mapT . λi. bfm' n i       . [0..<n+1])
    }"
  unfolding
    State_Monad_Ext.fun_app_lifted_def bellman_ford_def State_Main.mapT_def bind_left_identity
  .

end



subsubsection ‹Imperative Memoization›

context
  fixes mem :: "nat ref × nat ref × int extended option array ref × int extended option array ref"
  assumes mem_is_init: "mem = result_of (init_state (n + 1) 1 0) Heap.empty"
begin

lemma [intro]:
  "dp_consistency_heap_array_pair' (n + 1) fst snd id 1 0 mem"
  by (standard; simp add: mem_is_init injective_def)

interpretation iterator
  "λ (x, y). x  n  y  n"
  "λ (x, y). if y < n then (x, y + 1) else (x + 1, 0)"
  "λ (x, y). x * (n + 1) + y"
  by (rule table_iterator_up)

lemma [intro]:
  "dp_consistency_heap_array_pair_iterator (n + 1) fst snd id 1 0 mem
  (λ (x, y). if y < n then (x, y + 1) else (x + 1, 0))
  (λ (x, y). x * (n + 1) + y)
  (λ (x, y). x  n  y  n)"
  by (standard; simp add: mem_is_init injective_def)

memoize_fun bfh: bf
  with_memory (default_proof) dp_consistency_heap_array_pair_iterator
  where size = "n + 1"
    and key1 = "fst :: nat × nat  nat" and key2 = "snd :: nat × nat  nat"
    and k1 = "1 :: nat" and k2 = "0 :: nat"
    and to_index = "id :: nat  nat"
    and mem = mem
    and cnt = "λ (x, y). x  n  y  n"
    and nxt = "λ (x :: nat, y). if y < n then (x, y + 1) else (x + 1, 0)"
    and sizef = "λ (x, y). x * (n + 1) + y"
monadifies (heap) bf.simps

memoize_correct
  by memoize_prover

lemmas memoized_empty = bfh.memoized_empty[OF bfh.consistent_DP_iter_and_compute[OF bfh.crel]]
lemmas iter_heap_unfold = iter_heap_unfold

end (* Fixed Memory *)


subsubsection ‹Detecting Negative Cycles›

definition
  "shortest v = (
    Inf (
      {weight (v # xs @ [t]) | xs. set xs  {0..n}} 
      {if t = v then 0 else }
    )
  )"

definition
  "is_path xs  weight (xs @ [t]) < "

definition
  "has_negative_cycle 
  xs a ys. set (a # xs @ ys)  {0..n}  weight (a # xs @ [a]) < 0  is_path (a # ys)"

definition
  "reaches a  xs. is_path (a # xs)  a  n  set xs  {0..n}"

lemma fold_sum_aux':
  assumes "u  set (a # xs). v  set (xs @ [b]). f v + W u v  f u"
  shows "sum_list (map f (a # xs))  sum_list (map f (xs @ [b])) + weight (a # xs @ [b])"
  using assms
  by (induction xs arbitrary: a; simp)
     (smt ab_semigroup_add_class.add_ac(1) add.left_commute add_mono)

lemma fold_sum_aux:
  assumes "u  set (a # xs). v  set (a # xs). f v + W u v  f u"
  shows "sum_list (map f (a # xs @ [a]))  sum_list (map f (a # xs @ [a])) + weight (a # xs @ [a])"
  using fold_sum_aux'[of a xs a f] assms
  by auto (metis (no_types, opaque_lifting) add.assoc add.commute add_left_mono)

context
begin

private definition "is_path2 xs  weight xs < "

private lemma is_path2_remove_cycle:
  assumes "is_path2 (as @ a # bs @ a # cs)"
  shows "is_path2 (as @ a # cs)"
proof -
  have "weight (as @ a # bs @ a # cs) =
    weight (as @ [a]) + weight (a # bs @ [a]) + weight (a # cs)"
    by (metis Bellman_Ford.weight_append append_Cons append_assoc)
  with assms have "weight (as @ [a]) < " "weight (a # cs) < "
    unfolding is_path2_def
    by (simp, metis Pinf_add_right antisym less_extended_simps(4) not_less add.commute)+
  then show ?thesis
    unfolding is_path2_def by (subst weight_append) (rule add_lt_infI)
qed

private lemma is_path_eq:
  "is_path xs  is_path2 (xs @ [t])"
  unfolding is_path_def is_path2_def ..

lemma is_path_remove_cycle:
  assumes "is_path (as @ a # bs @ a # cs)"
  shows "is_path (as @ a # cs)"
  using assms unfolding is_path_eq by (simp add: is_path2_remove_cycle)

lemma is_path_remove_cycle2:
  assumes "is_path (as @ t # cs)"
  shows "is_path as"
  using assms unfolding is_path_eq by (simp add: is_path2_remove_cycle)

end (* private lemmas *)

lemma is_path_shorten:
  assumes "is_path (i # xs)" "i  n" "set xs  {0..n}" "t  n" "t  i"
  obtains xs where "is_path (i # xs)" "i  n" "set xs  {0..n}" "length xs < n"
proof (cases "length xs < n")
  case True
  with assms show ?thesis
    by (auto intro: that)
next
  case False
  then have "length xs  n"
    by auto
  with assms(1,3) show ?thesis
  proof (induction "length xs" arbitrary: xs rule: less_induct)
    case less
    then have "length (i # xs @ [t]) > card ({0..n})"
      by auto
    moreover from less.prems i  n t  n have "set (i # xs @ [t])  {0..n}"
      by auto
    ultimately obtain a as bs cs where *: "i # xs @ [t] = as @ a # bs @ a # cs"
      by (elim list_pidgeonhole) auto
    obtain ys where ys: "is_path (i # ys)" "length ys < length xs" "set (i # ys)  {0..n}"
      apply atomize_elim
      using *
    proof (cases rule: path_eq_cycleE)
      case Nil_Nil
      with t  i show "ys. is_path (i # ys)  length ys < length xs  set (i # ys)  {0..n}"
        by auto
    next
      case (Nil_Cons cs')
      then show "ys. is_path (i # ys)  length ys < length xs  set (i # ys)  {0..n}"
        using set (i # xs @ [t])  {0..n} is_path (i # xs) is_path_remove_cycle[of "[]"]
        by - (rule exI[where x = cs'], simp)
    next
      case (Cons_Nil as')
      then show "ys. is_path (i # ys)  length ys < length xs  set (i # ys)  {0..n}"
        using set (i # xs @ [t])  {0..n} is_path (i # xs)
        by - (rule exI[where x = as'], auto intro: is_path_remove_cycle2)
    next
      case (Cons_Cons as' cs')
      then show "ys. is_path (i # ys)  length ys < length xs  set (i # ys)  {0..n}"
        using set (i # xs @ [t])  {0..n} is_path (i # xs) is_path_remove_cycle[of "i # as'"]
        by - (rule exI[where x = "as' @ a # cs'"], auto)
    qed
    then show ?thesis
      by (cases "n  length ys") (auto intro: that less)
  qed
qed

lemma reaches_non_inf_path:
  assumes "reaches i" "i  n" "t  n"
  shows "OPT n i < "
proof (cases "t = i")
  case True
  with i  n t  n have "OPT n i  0"
    unfolding OPT_def
    by (auto intro: Min_le simp: finite_lists_length_le2[simplified])
  then show ?thesis
    using less_linear by (fastforce simp: zero_extended_def)
next
  case False
  from assms(1) obtain xs where "is_path (i # xs)" "i  n" "set xs  {0..n}"
    unfolding reaches_def by safe
  then obtain xs where xs: "is_path (i # xs)" "i  n" "set xs  {0..n}" "length xs < n"
    using t  i t  n by (auto intro: is_path_shorten)
  then have "weight (i # xs @ [t]) < "
    unfolding is_path_def by auto
  with xs(2-) show ?thesis
    unfolding OPT_def
    by (elim order.strict_trans1[rotated])
       (auto simp: setcompr_eq_image finite_lists_length_le2[simplified])
qed

lemma OPT_sink_le_0:
  "OPT i t  0"
  unfolding OPT_def by (auto simp: finite_lists_length_le2[simplified])

lemma is_path_appendD:
  assumes "is_path (as @ a # bs)"
  shows "is_path (a # bs)"
  using assms weight_append[of as a "bs @ [t]"] unfolding is_path_def
  by simp (metis Pinf_add_right add.commute less_extended_simps(4) not_less_iff_gr_or_eq)

lemma has_negative_cycleI:
  assumes "set (a # xs @ ys)  {0..n}" "weight (a # xs @ [a]) < 0" "is_path (a # ys)"
  shows has_negative_cycle
  using assms unfolding has_negative_cycle_def by auto

lemma OPT_cases2:
  obtains (path) xs where
    "v  t" "OPT i v  " "OPT i v = weight (v # xs @ [t])" "length xs + 1  i" "set xs  {0..n}"
  | (unreachable) "v  t" "OPT i v = "
  | (sink) "v = t" "OPT i v  0"
  unfolding OPT_def
  using Min_in[of "{weight (v # xs @ [t]) |xs. length xs + 1  i  set xs  {0..n}}
     {if t = v then 0 else }"]
  by (cases "v = t"; force simp: finite_lists_length_le2[simplified] split: if_split_asm)

lemma shortest_le_OPT:
  assumes "v  n"
  shows "shortest v  OPT i v"
  unfolding OPT_def shortest_def
  apply (subst Min_Inf)
    apply (simp add: setcompr_eq_image finite_lists_length_le2[simplified]; fail)+
  apply (rule Inf_superset_mono)
  apply auto
  done


context
  assumes W_wellformed: "i  n. j  n. W i j > -∞"
  assumes "t  n"
begin

lemma weight_not_minfI:
  "-∞ < weight xs" if "set xs  {0..n}" "xs  []"
  using that using W_wellformed t  n
  by (induction xs rule: induct_list012) (auto intro: add_gt_minfI simp: zero_extended_def)

lemma OPT_not_minfI:
  "OPT n i > -∞" if "i  n"
proof -
  have "OPT n i 
    {weight (i # xs @ [t]) |xs. length xs + 1  n  set xs  {0..n}}  {if t = i then 0 else }"
    unfolding OPT_def
    by (rule Min_in) (auto simp: setcompr_eq_image finite_lists_length_le2[simplified])
  with that t  n show ?thesis
    by (auto 4 3 intro!: weight_not_minfI simp: zero_extended_def)
qed

theorem detects_cycle:
  assumes has_negative_cycle
  shows "i  n. OPT (n + 1) i < OPT n i"
proof -
  from assms t  n obtain xs a ys where cycle:
    "a  n" "set xs  {0..n}" "set ys  {0..n}"
    "weight (a # xs @ [a]) < 0" "is_path (a # ys)"
    unfolding has_negative_cycle_def by clarsimp
  then have "reaches a"
    unfolding reaches_def by auto
  have reaches: "reaches x" if "x  set xs" for x
  proof -
    from that obtain as bs where "xs = as @ x # bs"
      by atomize_elim (rule split_list)
    with cycle have "weight (x # bs @ [a]) < "
      using weight_append[of "a # as" x "bs @ [a]"]
      by simp (metis Pinf_add_right Pinf_le add.commute less_eq_extended.simps(2) not_less)

    moreover from reaches a obtain cs where "local.weight (a # cs @ [t]) < " "set cs  {0..n}"
      unfolding reaches_def is_path_def by auto
    ultimately show ?thesis
      unfolding reaches_def is_path_def
      using a  n weight_append[of "x # bs" a "cs @ [t]"] cycle(2) xs = _
      by - (rule exI[where x = "bs @ [a] @ cs"], auto intro: add_lt_infI)
  qed
  let ?S = "sum_list (map (OPT n) (a # xs @ [a]))"
  obtain u v where "u  n" "v  n" "OPT n v + W u v < OPT n u"
  proof (atomize_elim, rule ccontr)
    assume "u v. u  n  v  n  OPT n v + W u v < OPT n u"
    then have "?S  ?S + weight (a # xs @ [a])"
      using cycle(1-3) by (subst fold_sum_aux; fastforce simp: subset_eq)
    moreover have "?S > -∞"
      using cycle(1-4) by (intro sum_list_not_minfI, auto intro!: OPT_not_minfI)
    moreover have "?S < "
      using reaches t  n cycle(1,2)
      by (intro sum_list_not_infI) (auto intro: reaches_non_inf_path reaches a simp: subset_eq)
    ultimately have "weight (a # xs @ [a])  0"
      by (simp add: le_add_same_cancel1)
    with weight _ < 0 show False
      by simp
  qed
  then show ?thesis
    by -
       (rule exI[where x = u],
        auto 4 4 intro: Min.coboundedI min.strict_coboundedI2 elim: order.strict_trans1[rotated]
          simp: OPT_Suc[OF t  n])
qed

corollary bf_detects_cycle:
  assumes has_negative_cycle
  shows "i  n. bf (n + 1) i < bf n i"
  using detects_cycle[OF assms] unfolding bf_correct[OF t  n] .

lemma shortest_cases:
  assumes "v  n"
  obtains (path) xs where "shortest v = weight (v # xs @ [t])" "set xs  {0..n}"
  | (sink) "v = t" "shortest v = 0"
  | (unreachable) "v  t" "shortest v = "
  | (negative_cycle) "shortest v = -∞" "x. xs. set xs  {0..n}  weight (v # xs @ [t]) < Fin x"
proof -
  let ?S = "{weight (v # xs @ [t]) | xs. set xs  {0..n}}  {if t = v then 0 else }"
  have "?S  {}"
    by auto
  have Minf_lowest: False if  "-∞ < a" "-∞ = a" for a :: "int extended"
    using that by auto
  show ?thesis
  proof (cases "shortest v")
    case (Fin x)
    then have "-∞  ?S" "bdd_below (Fin -` ?S)" "?S  {}" "x = Inf (Fin -` ?S)"
      unfolding shortest_def Inf_extended_def by (auto split: if_split_asm)
    from this(1-3) have "x  Fin -` ?S"
      unfolding x = _
      by (intro Inf_int_in, auto simp: zero_extended_def)
        (smt empty_iff extended.exhaust insertI2 mem_Collect_eq vimage_eq)
    with shortest v = _ show ?thesis
      unfolding vimage_eq by (auto split: if_split_asm intro: that)
  next
    case Pinf
    with ?S  {} have "t  v"
      unfolding shortest_def Inf_extended_def by (auto split: if_split_asm)
    with _ =  show ?thesis
      by (auto intro: that)
  next
    case Minf
    then have "?S  {}" "?S  {}" "-∞  ?S  ¬ bdd_below (Fin -` ?S)"
      unfolding shortest_def Inf_extended_def by (auto split: if_split_asm)
    from this(3) have "x. xs. set xs  {0..n}  weight (v # xs @ [t]) < Fin x"
    proof
      assume "-∞  ?S"
      with weight_not_minfI have False
        using v  n t  n by (auto split: if_split_asm elim: Minf_lowest[rotated])
      then show ?thesis ..
    next
      assume "¬ bdd_below (Fin -` ?S)"
      show ?thesis
      proof
        fix x :: int
        let ?m = "min x (-1)"
        from ¬ bdd_below _ obtain m where "Fin m  ?S" "m < ?m"
          unfolding bdd_below_def by - (simp, drule spec[of _ "?m"], force)
        then show "xs. set xs  {0..n}  weight (v # xs @ [t]) < Fin x"
          by (auto split: if_split_asm simp: zero_extended_def) (metis less_extended_simps(1))+
      qed
    qed
    with shortest v = _ show ?thesis
      by (auto intro: that)
  qed
qed

lemma simple_paths:
  assumes "¬ has_negative_cycle" "weight (v # xs @ [t]) < " "set xs  {0..n}" "v  n"
  obtains ys where
    "weight (v # ys @ [t])  weight (v # xs @ [t])" "set ys  {0..n}" "length ys < n" | "v = t"
  using assms(2-)
proof (atomize_elim, induction "length xs" arbitrary: xs rule: less_induct)
  case (less ys)
  note ys = less.prems(1,2)
  note IH = less.hyps
  have path: "is_path (v # ys)"
    using is_path_def not_less_iff_gr_or_eq ys(1) by fastforce
  show ?case
  proof (cases "length ys  n")
    case True
    with ys v  n t  n obtain a as bs cs where "v # ys @ [t] = as @ a # bs @ a # cs"
      by - (rule list_pidgeonhole[of "v # ys @ [t]" "{0..n}"], auto)
    then show ?thesis
    proof (cases rule: path_eq_cycleE)
      case Nil_Nil
      then show ?thesis
        by simp
    next
      case (Nil_Cons cs')
      then have *: "weight (v # ys @ [t]) = weight (a # bs @ [a]) + weight (a # cs' @ [t])"
        by (simp add: weight_append[of "a # bs" a "cs' @ [t]", simplified])
      show ?thesis
      proof (cases "weight (a # bs @ [a]) < 0")
        case True
        with Nil_Cons set ys  _ path show ?thesis
          using assms(1) by (force intro: has_negative_cycleI[of a bs ys])
      next
        case False
        then have "weight (a # bs @ [a])  0"
          by auto
        with * ys have "weight (a # cs' @ [t])  weight (v # ys @ [t])"
          using add_mono not_le by fastforce
        with Nil_Cons length ys  n ys show ?thesis
          using IH[of cs'] by simp (meson le_less_trans order_trans)
      qed
    next
      case (Cons_Nil as')
      with ys have *: "weight (v # ys @ [t]) = weight (v # as' @ [t]) + weight (a # bs @ [a])"
        using weight_append[of "v # as'" t "bs @ [t]"] by simp
      show ?thesis
      proof (cases "weight (a # bs @ [a]) < 0")
        case True
        with Cons_Nil set ys  _ path assms(1) show ?thesis
          using is_path_appendD[of "v # as'"] by (force intro: has_negative_cycleI[of a bs bs])
      next
        case False
        then have "weight (a # bs @ [a])  0"
          by auto
        with * ys(1) have "weight (v # as' @ [t])  weight (v # ys @ [t])"
          using add_left_mono by fastforce
        with Cons_Nil length ys  n v  n ys show ?thesis
          using IH[of as'] by simp (meson le_less_trans order_trans)
      qed
    next
      case (Cons_Cons as' cs')
      with ys have *:
        "weight (v # ys @ [t]) = weight (v # as' @ a # cs' @ [t]) + weight (a # bs @ [a])"
        using
          weight_append[of "v # as'" a "bs @ a # cs' @ [t]"]
          weight_append[of "a # bs" a "cs' @ [t]"]
          weight_append[of "v # as'" a "cs' @ [t]"]
        by (simp add: algebra_simps)
      show ?thesis
      proof (cases "weight (a # bs @ [a]) < 0")
        case True
        with Cons_Cons set ys  _ path assms(1) show ?thesis
          using is_path_appendD[of "v # as'"]
          by (force intro: has_negative_cycleI[of a bs "bs @ a # cs'"])
      next
        case False
        then have "weight (a # bs @ [a])  0"
          by auto
        with * ys have "weight (v # as' @ a # cs' @ [t])  weight (v # ys @ [t])"
          using add_left_mono by fastforce
        with Cons_Cons v  n ys show ?thesis
          using is_path_remove_cycle2 IH[of "as' @ a # cs'"]
          by simp (meson le_less_trans order_trans)
      qed
    qed
  next
    case False
    with set ys  _ show ?thesis
      by auto
  qed
qed

theorem shorter_than_OPT_n_has_negative_cycle:
  assumes "shortest v < OPT n v" "v  n"
  shows has_negative_cycle
proof -
  from assms obtain ys where ys:
    "weight (v # ys @ [t]) < OPT n v" "set ys  {0..n}"
    apply (cases rule: OPT_cases2[of v n]; cases rule: shortest_cases[OF v  n]; simp)
      apply (metis uminus_extended.cases)
    using less_extended_simps(2) less_trans apply blast
    apply (metis less_eq_extended.elims(2) less_extended_def zero_extended_def)
    done
  show ?thesis
  proof (cases "v = t")
    case True
    with ys t  n show ?thesis
      using OPT_sink_le_0[of n] unfolding has_negative_cycle_def is_path_def
      using less_extended_def by force
  next
    case False
    show ?thesis
    proof (rule ccontr)
      assume "¬ has_negative_cycle"
      with False False ys v  n obtain xs where
        "weight (v # xs @ [t])  weight (v # ys @ [t])" "set xs  {0..n}" "length xs < n"
        using less_extended_def by (fastforce elim!: simple_paths[of v ys])
      then have "OPT n v  weight (v # xs @ [t])"
        unfolding OPT_def by (intro Min_le) auto
      with _  weight (v # ys @ [t]) weight (v # ys @ [t]) < OPT n v show False
        by simp
    qed
  qed
qed

corollary detects_cycle_has_negative_cycle:
  assumes "OPT (n + 1) v < OPT n v" "v  n"
  shows has_negative_cycle
  using assms shortest_le_OPT[of v "n + 1"] shorter_than_OPT_n_has_negative_cycle[of v] by auto

corollary bellman_ford_detects_cycle:
  "has_negative_cycle  (v  n. OPT (n + 1) v < OPT n v)"
  using detects_cycle_has_negative_cycle detects_cycle by blast

corollary bellman_ford_shortest_paths:
  assumes "¬ has_negative_cycle"
  shows "v  n. bf n v = shortest v"
proof -
  have "OPT n v  shortest v" if "v  n" for v
    using that assms shorter_than_OPT_n_has_negative_cycle[of v] by force
  then show ?thesis
    unfolding bf_correct[OF t  n, symmetric]
    by (safe, rule order.antisym) (auto elim: shortest_le_OPT)
qed

lemma OPT_mono:
  "OPT m v  OPT n v" if v  n n  m
  using that unfolding OPT_def by (intro Min_antimono) auto

corollary bf_fix:
  assumes "¬ has_negative_cycle" "m  n"
  shows "v  n. bf m v = bf n v"
proof (intro allI impI)
  fix v assume "v  n"
  from v  n n  m have "shortest v  OPT m v"
    by (simp add: shortest_le_OPT)
  moreover from v  n n  m have "OPT m v  OPT n v"
    by (rule OPT_mono)
  moreover from v  n assms have "OPT n v  shortest v"
    using shorter_than_OPT_n_has_negative_cycle[of v] by force
  ultimately show "bf m v = bf n v"
    unfolding bf_correct[OF t  n, symmetric] by simp
qed

lemma bellman_ford_correct':
  "bfm.crel_vs (=) (if has_negative_cycle then None else Some (map shortest [0..<n+1])) bellman_ford"
proof -
  include state_monad_syntax app_syntax
  let ?l = "if has_negative_cycle then None else Some (map shortest [0..<n + 1])"
  let ?r = "(λxs. (λys. (if xs = ys then Some xs else None))
    $ (map $ bf (n + 1) $ [0..<n + 1])) $ (map $ bf n $ [0..<n + 1])"
  note crel_bfm' = bfm.crel[unfolded bfm.consistentDP_def, THEN rel_funD,
      of "(m, x)" "(m, y)" for m x y, unfolded prod.case]
  have "?l = ?r"
    supply [simp del] = bf_simps
    supply [simp add] =
      bf_fix[rule_format, symmetric] bellman_ford_shortest_paths[rule_format, symmetric]
    unfolding Wrap_def App_def using bf_detects_cycle by (fastforce elim: nat_le_cases)
  ― ‹Slightly transform the goal, then apply parametric reasoning like usual.›
  show ?thesis
    ― ‹Roughly ›
    unfolding bellman_ford_alt_def ?l = ?r ― ‹Obtain parametric form.›
    apply (rule bfm.crel_vs_bind_ignore[rotated]) ― ‹Drop bind.›
     apply (rule bottom_up.consistent_crel_vs_iterate_state[OF bfm.crel, folded iter_bf_def])
    apply (subst Transfer.Rel_def[symmetric]) ― ‹Setup typical goal for automated reasoner.›
    ― ‹We need to reason manually because we are not in the context where bfm was defined.›
    ― ‹This is roughly what @{method "memoize_prover_match_step"}/Transform_Tactic.step_tac› does.›
    apply (tactic Transform_Tactic.solve_relator_tac context 1
          | rule HOL.refl
          | rule bfm.dp_match_rule
          | rule bfm.crel_vs_return_ext
          | (subst Rel_def, rule crel_bfm')
          | tactic Transform_Tactic.transfer_raw_tac context 1)+
    done
qed

theorem bellman_ford_correct:
  "fst (run_state bellman_ford Mapping.empty) =
  (if has_negative_cycle then None else Some (map shortest [0..<n+1]))"
  using bfm.cmem_empty bellman_ford_correct'[unfolded bfm.crel_vs_def, rule_format, of Mapping.empty]
  unfolding bfm.crel_vs_def by auto

end (* Wellformedness *)

end (* Final Node *)

end (* Bellman Ford *)



subsubsection ‹Extracting an Executable Constant for the Imperative Implementation›

ground_function (prove_termination) bfh'_impl: bfh'.simps

lemma bfh'_impl_def:
  fixes n :: nat
  fixes mem :: "nat ref × nat ref × int extended option array ref × int extended option array ref"
  assumes mem_is_init: "mem = result_of (init_state (n + 1) 1 0) Heap.empty"
  shows "bfh'_impl n w t mem = bfh' n w t mem"
proof -
  have "bfh'_impl n w t mem i j = bfh' n w t mem i j" for i j
    by (induction rule: bfh'.induct[OF mem_is_init];
        simp add: bfh'.simps[OF mem_is_init]; solve_cong simp
       )
  then show ?thesis
    by auto
qed

definition
  "iter_bf_heap n w t mem = iterator_defs.iter_heap
      (λ(x, y). x  n  y  n)
      (λ(x, y). if y < n then (x, y + 1) else (x + 1, 0))
      (λ(x, y). bfh'_impl n w t mem x y)"

lemma iter_bf_heap_unfold[code]:
  "iter_bf_heap n w t mem = (λ (i, j).
    (if i  n  j  n
     then do {
            bfh'_impl n w t mem i j;
            iter_bf_heap n w t mem (if j < n then (i, j + 1) else (i + 1, 0))
          }
     else Heap_Monad.return ()))"
  unfolding iter_bf_heap_def by (rule ext) (safe, simp add: iter_heap_unfold)

definition
  "bf_impl n w t i j = do {
    mem  (init_state (n + 1) (1::nat) (0::nat) ::
      (nat ref × nat ref × int extended option array ref × int extended option array ref) Heap);
    iter_bf_heap n w t mem (0, 0);
    bfh'_impl n w t mem i j
  }"

lemma bf_impl_correct:
  "bf n w t i j = result_of (bf_impl n w t i j) Heap.empty"
  using memoized_empty[OF HOL.refl, of n w t "(i, j)"]
  by (simp add:
        execute_bind_success[OF succes_init_state] bf_impl_def bfh'_impl_def iter_bf_heap_def
      )


subsubsection ‹Test Cases›

definition
  "G1_list = [[(1 :: nat,-6 :: int), (2,4), (3,5)], [(3,10)], [(3,2)], []]"

definition
  "G2_list = [[(1 :: nat,-6 :: int), (2,4), (3,5)], [(3,10)], [(3,2)], [(0, -5)]]"

definition
  "G3_list = [[(1 :: nat,-1 :: int), (2,2)], [(2,5), (3,4)], [(3,2), (4,3)], [(2,-2), (4,2)], []]"

definition
  "G4_list = [[(1 :: nat,-1 :: int), (2,2)], [(2,5), (3,4)], [(3,2), (4,3)], [(2,-3), (4,2)], []]"

definition
  "graph_of a i j = case_option  (Fin o snd) (List.find (λ p. fst p = j) (a !! i))"

definition "test_bf = bf_impl 3 (graph_of (IArray G1_list)) 3 3 0"

code_reflect Test functions test_bf

text ‹One can see a trace of the calls to the memory in the output›
ML Test.test_bf ()

lemma bottom_up_alt[code]:
  "bf n W t i j =
     fst (run_state
      (iter_bf n W t (0, 0)  (λ_. bfm' n W t i j))
      Mapping.empty)"
  using bf_bottom_up by auto

definition
  "bf_ia n W t i j = (let W' = graph_of (IArray W) in
    fst (run_state
      (iter_bf n W' t (i, j)  (λ_. bfm' n W' t i j))
      Mapping.empty)
  )"

― ‹Component tests.›
lemma
  "fst (run_state (bfm' 3 (graph_of (IArray G1_list)) 3 3 0) Mapping.empty) = 4"
  "bf 3 (graph_of (IArray G1_list)) 3 3 0 = 4"
  by eval+

― ‹Regular test cases.›
lemma
  "fst (run_state (bellman_ford 3 (graph_of (IArray G1_list)) 3) Mapping.empty) = Some [4, 10, 2, 0]"
  "fst (run_state (bellman_ford 4 (graph_of (IArray G3_list)) 4) Mapping.empty) = Some [4, 5, 3, 1, 0]"
  by eval+

― ‹Test detection of negative cycles.›
lemma
  "fst (run_state (bellman_ford 3 (graph_of (IArray G2_list)) 3) Mapping.empty) = None"
  "fst (run_state (bellman_ford 4 (graph_of (IArray G4_list)) 4) Mapping.empty) = None"
  by eval+

end (* Theory *)