# Theory Bellman_Ford

theory Bellman_Ford
imports Extended IArray Product_Lexorder RBT_Mapping State_Main Example_Misc Tracing Ground_Function
```subsection ‹The Bellman-Ford Algorithm›

theory Bellman_Ford
imports
"HOL-Library.Extended"
"HOL-Library.IArray"
"HOL-Library.Code_Target_Numeral"
"HOL-Library.Product_Lexorder"
"HOL-Library.RBT_Mapping"
Example_Misc
"../util/Tracing"
"../util/Ground_Function"
begin

subsubsection ‹Misc›

instance extended :: (countable) countable
proof standard
obtain to_nat :: "'a ⇒ nat" where "inj to_nat"
by auto
let ?f = "λ x. case x of Fin n ⇒ to_nat n + 2 | Pinf ⇒ 0 | Minf ⇒ 1"
from ‹inj _ › have "inj ?f"
by (auto simp: inj_def split: extended.split)
then show "∃to_nat :: 'a extended ⇒ nat. inj to_nat"
by auto
qed

instance extended :: (heap) heap ..

lemma fold_acc_preserv:
assumes "⋀ x acc. P acc ⟹ P (f x acc)" "P acc"
shows "P (fold f xs acc)"
using assms(2) by (induction xs arbitrary: acc) (auto intro: assms(1))

lemma get_return:

subsubsection ‹Single-Sink Shortest Path Problem›

datatype bf_result = Path "nat list" int | No_Path | Computation_Error

context
fixes n :: nat and W :: "nat ⇒ nat ⇒ int extended"
begin

context
fixes t :: nat ― ‹Final node›
begin

text ‹
The correctness proof closely follows Kleinberg ‹&› Tardos: "Algorithm Design",
chapter "Dynamic Programming" @{cite "Kleinberg-Tardos"}
›

definition weight :: "nat list ⇒ int extended" where
"weight xs = snd (fold (λ i (j, x). (i, W i j + x)) (rev xs) (t, 0))"

definition
"OPT i v = (
Min (
{weight (v # xs) | xs. length xs + 1 ≤ i ∧ set xs ⊆ {0..n}} ∪
{if t = v then 0 else ∞}
)
)"

lemma weight_Cons [simp]:
"weight (v # w # xs) = W v w + weight (w # xs)"

lemma weight_single [simp]:
"weight [v] = W v t"

(* XXX Generalize to the right type class *)
"Min S + (x :: int extended) = Min ((λ y. y + x) ` S)" (is "?A = ?B") if "finite S" "S ≠ {}"
proof -
have "?A ≤ ?B"
using that by (force intro: Min.boundedI add_right_mono)
moreover have "?B ≤ ?A"
using that by (force intro: Min.boundedI)
ultimately show ?thesis
by simp
qed

lemma OPT_0:
"OPT 0 v = (if t = v then 0 else ∞)"
unfolding OPT_def by simp

(* TODO: Move to distribution! *)
"∞ + x = ∞"
by (cases x; simp)

subsubsection ‹Functional Correctness›

lemma OPT_Suc:
"OPT (Suc i) v = min (OPT i v) (Min {OPT i w + W v w | w. w ≤ n})" (is "?lhs = ?rhs")
if "t ≤ n"
proof -
have fin': "finite {xs. length xs ≤ i ∧ set xs ⊆ {0..n}}" for i
by (auto intro: finite_subset[OF _ finite_lists_length_le[OF finite_atLeastAtMost]])
have fin: "finite {weight (v # xs) |xs. length xs ≤ i ∧ set xs ⊆ {0..n}}"
for v i using [[simproc add: finite_Collect]] by (auto intro: finite_subset[OF _ fin'])
have OPT_in: "OPT i v ∈
{weight (v # xs) | xs. length xs + 1 ≤ i ∧ set xs ⊆ {0..n}} ∪
{if t = v then 0 else ∞}"
if "i > 0" for i v
using that unfolding OPT_def
by - (rule Min_in, auto 4 3 intro: finite_subset[OF _ fin, of _ v "Suc i"])

have "OPT i v ≥ OPT (Suc i) v"
unfolding OPT_def using fin by (auto 4 3 intro: Min_antimono)
have subs:
"(λy. y + W v w) ` {weight (w # xs) |xs. length xs + 1 ≤ i ∧ set xs ⊆ {0..n}}
⊆ {weight (v # xs) |xs. length xs + 1 ≤ Suc i ∧ set xs ⊆ {0..n}}" if ‹w ≤ n› for v w
using ‹w ≤ n› apply clarify
subgoal for _ _ xs
by (rule exI[where x = "w # xs"]) (auto simp: algebra_simps)
done
have "OPT i t + W v t ≥ OPT (Suc i) v"
unfolding OPT_def using subs[OF ‹t ≤ n›, of v] that
(auto 4 3
simp: Bellman_Ford.weight_single
intro: exI[where x = "[]"] finite_subset[OF _ fin[of _ "Suc i"]] intro!: Min_antimono
)
moreover have "OPT i w + W v w ≥ OPT (Suc i) v" if "w ≤ n" ‹w ≠ t› ‹t ≠ v› for w
unfolding OPT_def using subs[OF ‹w ≤ n›, of v] that
(auto 4 3 intro: finite_subset[OF _ fin[of _ "Suc i"]] intro!: Min_antimono)
moreover have "OPT i w + W t w ≥ OPT (Suc i) t" if "w ≤ n" ‹w ≠ t› for w
unfolding OPT_def
prefer 3
using ‹w ≠ t›
apply simp
apply (cases "i = 0")
apply (simp; fail)
using subs[OF ‹w ≤ n›, of t]
by (subst (2) Min_insert)
(auto 4 4
intro: finite_subset[OF _ fin[of _ "Suc i"]] exI[where x = "[]"] intro!: Min_antimono
)
ultimately have "Min {local.OPT i w + W v w |w. w ≤ n} ≥ OPT (Suc i) v"
by (auto intro!: Min.boundedI)
with ‹OPT i v ≥ _› have "?lhs ≤ ?rhs"
by simp

from OPT_in[of "Suc i" v] consider
"OPT (Suc i) v = ∞" | "t = v" "OPT (Suc i) v = 0" |
xs where "OPT (Suc i) v = weight (v # xs)" "length xs ≤ i" "set xs ⊆ {0..n}"
by (auto split: if_split_asm)
then have "?lhs ≥ ?rhs"
proof cases
case 1
then show ?thesis
by simp
next
case 2
then have "OPT i v ≤ OPT (Suc i) v"
unfolding OPT_def using [[simproc add: finite_Collect]]
by (auto 4 4 intro: finite_subset[OF _ fin', of _ "Suc i"] intro!: Min_le)
then show ?thesis
by (rule min.coboundedI1)
next
case xs: 3
note [simp] = xs(1)
show ?thesis
proof (cases "length xs = i")
case True
show ?thesis
proof (cases "i = 0")
case True
with xs have "OPT (Suc i) v = W v t"
by simp
also have "W v t = OPT i t + W v t"
unfolding OPT_def using ‹i = 0› by auto
also have "… ≥ Min {OPT i w + W v w |w. w ≤ n}"
using ‹t ≤ n› by (auto intro: Min_le)
finally show ?thesis
by (rule min.coboundedI2)
next
case False
with ‹_ = i› have "xs ≠ []"
by auto
with xs have "weight xs ≥ OPT i (hd xs)"
unfolding OPT_def
by (intro Min_le[rotated] UnI1 CollectI exI[where x = "tl xs"])
(auto 4 3 intro: finite_subset[OF _ fin, of _ "hd xs" "Suc i"] dest: list.set_sel(2))
have "Min {OPT i w + W v w |w. w ≤ n} ≤ W v (hd xs) + OPT i (hd xs)"
using ‹set xs ⊆ _› ‹xs ≠ []› by (force simp: add.commute intro: Min_le)
also have "… ≤ W v (hd xs) + weight xs"
using ‹_ ≥ OPT i (hd xs)› by (metis add_left_mono)
also from ‹xs ≠ []› have "… = OPT (Suc i) v"
by (cases xs) auto
finally show ?thesis
by (rule min.coboundedI2)
qed
next
case False
with xs have "OPT i v ≤ OPT (Suc i) v"
by (auto 4 4 intro: Min_le finite_subset[OF _ fin, of _ v "Suc i"] simp: OPT_def)
then show ?thesis
by (rule min.coboundedI1)
qed
qed
with ‹?lhs ≤ ?rhs› show ?thesis
by (rule order.antisym)
qed

fun bf :: "nat ⇒ nat ⇒ int extended" where
"bf 0 j = (if t = j then 0 else ∞)"
| "bf (Suc k) j = min_list
(bf k j # [W j i + bf k i . i ← [0 ..< Suc n]])"

lemmas [simp del] = bf.simps
lemmas [simp] = bf.simps[unfolded min_list_fold]
thm bf.simps
thm bf.induct

lemma bf_correct:
"OPT i j = bf i j" if ‹t ≤ n›
proof (induction i arbitrary: j)
case 0
then show ?case
next
case (Suc i)
have *:
"{bf i w + W j w |w. w ≤ n} = set (map (λw. W j w + bf i w) [0..<Suc n])"
from Suc ‹t ≤ n› show ?case
by (simp add: OPT_Suc del: upt_Suc, subst Min.set_eq_fold[symmetric], auto simp: *)
qed

subsubsection ‹Functional Memoization›

memoize_fun bf⇩m: bf with_memory dp_consistency_mapping monadifies (state) bf.simps

text ‹Generated Definitions›
thm bf⇩m'.simps bf⇩m_def
end

text ‹Correspondence Proof›
memoize_correct
by memoize_prover
print_theorems
lemmas [code] = bf⇩m.memoized_correct

interpretation iterator
"λ (x, y). x ≤ n ∧ y ≤ n"
"λ (x, y). if y < n then (x, y + 1) else (x + 1, 0)"
"λ (x, y). x * (n + 1) + y"
by (rule table_iterator_up)

interpretation bottom_up: dp_consistency_iterator_empty
"λ (_::(nat × nat, int extended) mapping). True"
"λ (x, y). bf x y"
"λ k. do {m ← State_Monad.get; State_Monad.return (Mapping.lookup m k :: int extended option)}"
"λ k v. do {m ← State_Monad.get; State_Monad.set (Mapping.update k v m)}"
"λ (x, y). x ≤ n ∧ y ≤ n"
"λ (x, y). if y < n then (x, y + 1) else (x + 1, 0)"
"λ (x, y). x * (n + 1) + y"
Mapping.empty ..

definition
"iter_bf = iter_state (λ (x, y). bf⇩m' x y)"

lemma iter_bf_unfold[code]:
"iter_bf = (λ (i, j).
(if i ≤ n ∧ j ≤ n
then do {
bf⇩m' i j;
iter_bf (if j < n then (i, j + 1) else (i + 1, 0))
}
unfolding iter_bf_def by (rule ext) (safe, clarsimp simp: iter_state_unfold)

lemmas bf_memoized = bf⇩m.memoized[OF bf⇩m.crel]
lemmas bf_bottom_up = bottom_up.memoized[OF bf⇩m.crel, folded iter_bf_def]

thm bf⇩m'.simps bf_memoized

subsubsection ‹Imperative Memoization›

context
fixes mem :: "nat ref × nat ref × int extended option array ref × int extended option array ref"
assumes mem_is_init: "mem = result_of (init_state (n + 1) 1 0) Heap.empty"
begin

lemma [intro]:
"dp_consistency_heap_array_pair' (n + 1) fst snd id 1 0 mem"
by (standard; simp add: mem_is_init injective_def)

interpretation iterator
"λ (x, y). x ≤ n ∧ y ≤ n"
"λ (x, y). if y < n then (x, y + 1) else (x + 1, 0)"
"λ (x, y). x * (n + 1) + y"
by (rule table_iterator_up)

lemma [intro]:
"dp_consistency_heap_array_pair_iterator (n + 1) fst snd id 1 0 mem
(λ (x, y). if y < n then (x, y + 1) else (x + 1, 0))
(λ (x, y). x * (n + 1) + y)
(λ (x, y). x ≤ n ∧ y ≤ n)"
by (standard; simp add: mem_is_init injective_def)

memoize_fun bf⇩h: bf
with_memory (default_proof) dp_consistency_heap_array_pair_iterator
where size = "n + 1"
and key1="fst :: nat × nat ⇒ nat" and key2="snd :: nat × nat ⇒ nat"
and k1="1 :: nat" and k2="0 :: nat"
and to_index = "id :: nat ⇒ nat"
and mem = mem
and cnt = "λ (x, y). x ≤ n ∧ y ≤ n"
and nxt = "λ (x :: nat, y). if y < n then (x, y + 1) else (x + 1, 0)"
and sizef = "λ (x, y). x * (n + 1) + y"

memoize_correct
by memoize_prover

lemmas memoized_empty = bf⇩h.memoized_empty[OF bf⇩h.consistent_DP_iter_and_compute[OF bf⇩h.crel]]
lemmas iter_heap_unfold = iter_heap_unfold

end (* Fixed Memory *)

end (* Final Node *)

end (* Bellman Ford *)

subsubsection ‹Extracting an Executable Constant for the Imperative Implementation›

ground_function (prove_termination) bf⇩h'_impl: bf⇩h'.simps

lemma bf⇩h'_impl_def:
fixes n :: nat
fixes mem :: "nat ref × nat ref × int extended option array ref × int extended option array ref"
assumes mem_is_init: "mem = result_of (init_state (n + 1) 1 0) Heap.empty"
shows "bf⇩h'_impl n w t mem = bf⇩h' n w t mem"
proof -
have "bf⇩h'_impl n w t mem i j = bf⇩h' n w t mem i j" for i j
by (induction rule: bf⇩h'.induct[OF mem_is_init];
simp add: bf⇩h'.simps[OF mem_is_init]; solve_cong simp
)
then show ?thesis
by auto
qed

definition
"iter_bf_heap n w t mem = iterator_defs.iter_heap
(λ(x, y). x ≤ n ∧ y ≤ n)
(λ(x, y). if y < n then (x, y + 1) else (x + 1, 0))
(λ(x, y). bf⇩h'_impl n w t mem x y)"

lemma iter_bf_heap_unfold[code]:
"iter_bf_heap n w t mem = (λ (i, j).
(if i ≤ n ∧ j ≤ n
then do {
bf⇩h'_impl n w t mem i j;
iter_bf_heap n w t mem (if j < n then (i, j + 1) else (i + 1, 0))
}
unfolding iter_bf_heap_def by (rule ext) (safe, simp add: iter_heap_unfold)

definition
"bf_impl n w t i j = do {
mem ← (init_state (n + 1) (1::nat) (0::nat) ::
(nat ref × nat ref × int extended option array ref × int extended option array ref) Heap);
iter_bf_heap n w t mem (0, 0);
bf⇩h'_impl n w t mem i j
}"

lemma bf_impl_correct:
"bf n w t i j = result_of (bf_impl n w t i j) Heap.empty"
using memoized_empty[OF HOL.refl, of n w t "(i, j)"]
execute_bind_success[OF succes_init_state] bf_impl_def bf⇩h'_impl_def iter_bf_heap_def
)

subsubsection ‹Test Cases›

definition
"G⇩1_list = [[(1 :: nat,-6 :: int), (2,4), (3,5)], [(3,10)], [(3,2)], []]"

definition
"graph_of a i j = case_option ∞ (Fin o snd) (List.find (λ p. fst p = j) (a !! i))"

definition "test_bf = bf_impl 3 (graph_of (IArray G⇩1_list)) 3 3 0"

code_reflect Test functions test_bf

text ‹One can see a trace of the calls to the memory in the output›
ML ‹Test.test_bf ()›

lemma bottom_up_alt[code]:
"bf n W t i j =
fst (run_state
(iter_bf n W t (0, 0) ⤜ (λ_. bf⇩m' n W t i j))
Mapping.empty)"
using bf_bottom_up by auto

definition
"bf_ia n W t i j = (let W' = graph_of (IArray W) in
fst (run_state
(iter_bf n W' t (i, j) ⤜ (λ_. bf⇩m' n W' t i j))
Mapping.empty)
)"

value "fst (run_state (bf⇩m' 3 (graph_of (IArray G⇩1_list)) 3 3 0) Mapping.empty)"

value "bf 3 (graph_of (IArray G⇩1_list)) 3 3 0"

end (* Theory *)
```