Theory Floyd_Warshall

(*<*)
theory Floyd_Warshall
imports
  "../CImperativeHOL"
begin

(*>*)
section‹ Floyd-Warshall all-pairs shortest paths \label{sec:floyd_warshall} ›

text‹

The Floyd-Warshall algorithm computes the lengths of the shortest
paths between all pairs of nodes by updating an adjacency (square)
matrix that represents the edge weights. Our goal here is to present
it at a very abstract level to exhibit the data dependencies.

Source materials:
  🌐‹https://en.wikipedia.org/wiki/Floyd%E2%80%93Warshall_algorithm›
  ‹$AFP/Floyd_Warshall/Floyd_Warshall.thy›
   a proof by refinement yielding a thorough correctness result including negative weights but not the absence of edges
  citet‹\S6.2› in "Dingel:2002"
   Overly parallelised, which is not practically useful but does reveal the data dependencies
   the refinement is pretty much the same as the direct partial correctness proof here
   the equivalent to fw_update› is a single expression

We are not very ambitious here. This theory:
  does not track the actual shortest paths here but it is easy to add another array to do so
  ignores numeric concerns
  assumes the graph is complete

A further step would be to refine the parallel program to the classic three-loop presentation.

›

―‹ body of inner loop: update the weight of path ‹a[i, j]› considering the path ‹a[i, k] → a[k, j]› ›
definition fw_update :: "('i::Ix × 'i, nat) array  'i × 'i  'i  unit imp" where
  "fw_update = (λa (i, j) k. do {
     ij  prog.Array.nth a (i, j);
     ik  prog.Array.nth a (i, k);
     kj  prog.Array.nth a (k, j);
     prog.whenM (ik + kj < ij) (prog.Array.upd a (i, j) (ik + kj))
   })"

―‹ top-level specification: we can process the nodes in an arbitrary order ›
definition fw_chaotic :: "('i::Ix × 'i, nat) array  unit imp" where
  "fw_chaotic a =
    (let b = array.bounds a in
      prog.Array.fst_app_chaotic b (λk. (i, j)set (Ix.interval b). fw_update a (i, j) k))"

―‹ executable version ›
definition fw :: "('i::Ix × 'i, nat) array  unit imp" where
  "fw a =
    (let b = array.bounds a in
      prog.Array.fst_app b (λk. (i, j)set (Ix.interval b). fw_update a (i, j) k))"

lemma fw_fw_chaotic_le: ―‹ the executable program refines the specification ›
  shows "fw a  fw_chaotic a"
unfolding fw_chaotic_def fw_def
by (strengthen ord_to_strengthen(1)[OF prog.Array.fst_app_fst_app_chaotic_le]) simp

paragraph‹ Safety proof ›

type_synonym 'i matrix = "'i × 'i  nat"

―‹ The weight of the given path ›
fun path_weight :: "'i matrix  'i × 'i  'i list  nat" where
  "path_weight m ij [] = m ij"
| "path_weight m ij (k # xs) = m (fst ij, k) + path_weight m (k, snd ij) xs"

―‹ The set of acyclic paths from i› to j› using the nodes ks›
definition paths :: "'i × 'i  'i set  'i list set" where
  "paths ij ks = {p. set p  ks  fst ij  set p  snd ij  set p  distinct p}"

―‹ The minimum weight of a path from i› to j› using the nodes ks›.
    See ‹$AFP/Floyd_Warshall/Floyd_Warshall.thy› for proof that these are minimal amongst all paths. ›
definition min_path_weight :: "'i matrix  'i × 'i  'i set  nat" where
  "min_path_weight m ij ks = Min (path_weight m ij ` paths ij ks)"

context
  fixes a :: "('i::Ix × 'i, nat) array"
  fixes m :: "'i matrix"
begin

definition fw_p_inv :: "'i × 'i  'i set  heap.t pred" where ―‹ process invariant ›
  "fw_p_inv ij ks = (heap.rep_inv a  Array.get a ij = min_path_weight m ij ks)"

definition fw_inv :: "'i set  heap.t pred" where ―‹ loop invariant ›
  "fw_inv ks = (ij. ijset (Array.interval a)  fw_p_inv ij ks)"

definition fw_pre :: "heap.t pred" where ―‹ overall precondition ›
  "fw_pre = (Array.square a  heap.rep_inv a
           (ij. ijset (Array.interval a)  Array.get a ij = m ij))"

definition fw_post :: "unit  heap.t pred" where ―‹ overall postcondition ›
  "fw_post _ = fw_inv (set (Ix.interval (fst_bounds (array.bounds a))))"

end

setup Sign.mandatory_path "paths"

lemma I:
  assumes "set p  ks"
  assumes "i  set p"
  assumes "j  set p"
  assumes "distinct p"
  shows "p  paths (i, j) ks"
using assms by (simp add: paths_def)

lemma Nil:
  shows "[]  paths ij ks"
by (simp add: paths_def)

lemma empty:
  shows "paths ij {} = {[]}"
by (fastforce simp: paths_def)

lemma not_empty:
  shows "paths ij ks  {}"
by (metis empty_iff paths.Nil)

lemma monotone:
  shows "mono (paths ij)"
by (rule monoI) (auto simp add: paths_def)

lemmas mono = monoD[OF paths.monotone]
lemmas strengthen[strg] = st_monotone[OF paths.monotone]

lemma finite:
  assumes "finite ks"
  shows "finite (paths ij ks)"
unfolding paths_def by (rule finite_subset[OF _ iffD1[OF finite_distinct_conv assms]]) auto

lemma unused:
  assumes "p  paths ij (insert k ks)"
  assumes "k  set p"
  shows "p  paths ij ks"
using assms unfolding paths_def by blast

lemma decompE:
  assumes "p  paths (i, j) (insert k ks)"
  assumes "k  set p"
  obtains r s
    where "p = r @ k # s"
      and "r  paths (i, k) ks" and "s  paths (k, j) ks"
      and "distinct (r @ s)" and "i  set (r @ k # s)" and "j  set (r @ k # s)"
using assms by (fastforce simp: paths_def dest: split_list)

setup Sign.parent_path

setup Sign.mandatory_path "path_weight"

lemma append:
  shows "path_weight m ij (xs @ y # ys) = path_weight m (fst ij, y) xs + path_weight m (y, snd ij) ys"
by (induct xs arbitrary: ij) simp_all

setup Sign.parent_path

lemmas min_path_weightI = trans[OF min_path_weight_def Min_eqI]

setup Sign.mandatory_path "min_path_weight"

lemma fw_update:
  assumes m: "min_path_weight m (i, k) ks + min_path_weight m (k, j) ks < min_path_weight m (i, j) ks"
  assumes "finite ks"
  shows "min_path_weight m (i, j) (insert k ks)
       = min_path_weight m (i, k) ks + min_path_weight m (k, j) ks" (is "?lhs = ?rhs")
proof(rule min_path_weightI)
  from finite ks show "finite (path_weight m (i, j) ` paths (i, j) (insert k ks))"
    by (simp add: paths.finite)
next
  fix w
  assume w: "w  path_weight m (i, j) ` paths (i, j) (insert k ks)"
  then obtain p where p: "w = path_weight m (i, j) p" "p  paths (i, j) (insert k ks)" ..
  show "?rhs  w"
  proof(cases "k  set p")
    case True with m finite ks w p show ?thesis
      by (clarsimp simp: min_path_weight_def path_weight.append elim!: paths.decompE)
         (auto simp: Min_plus paths.finite paths.not_empty finite_image_set2 intro!: Min_le)
  next
    case False with m finite ks w p show ?thesis
      unfolding min_path_weight_def
      by (fastforce simp: paths.finite paths.not_empty dest: paths.unused)
  qed
next
  from finite ks obtain pik
    where pik: "pik  paths (i, k) ks"
      and mpik: "Min (path_weight m (i, k) ` paths (i, k) ks) = path_weight m (i, k) pik"
    by (meson finite_set Min_in finite_imageI paths.finite image_iff image_is_empty paths.not_empty)
  from finite ks obtain pkj
    where pkj: "pkj  paths (k, j) ks"
      and mpkj: "Min (path_weight m (k, j) ` paths (k, j) ks) = path_weight m (k, j) pkj"
    by (meson finite_set Min_in finite_imageI paths.finite image_iff image_is_empty paths.not_empty)
  let ?p = "pik @ k # pkj"
  have "?p  paths (i, j) (insert k ks)"
  proof(rule paths.I)
    from pik pkj
    show "set ?p  insert k ks" by (auto simp: paths_def)
    show "i  set ?p"
    proof(rule notI)
      assume "i  set ?p"
      with m pik have "i  set pkj" by (fastforce simp: paths_def) (* m implies i ≠ k *)
      then obtain p' zs where *: "pkj = zs @ i # p'" by (meson split_list)
      moreover from pkj * have "p'  paths (i, j) ks" by (simp add: paths_def)
      moreover note m finite ks mpkj
      ultimately show False by (simp add: paths.finite leD min_path_weight_def path_weight.append trans_le_add2)
    qed
    show "j  set ?p"
    proof(rule notI)
      assume "j  set ?p"
      with m pkj have "j  set pik" by (fastforce simp: paths_def) (* m implies j ≠ k *)
      then obtain p' zs where *: "pik = p' @ j # zs" by (meson split_list)
      moreover from pik * have "p'  paths (i, j) ks" by (simp add: paths_def)
      moreover note m finite ks mpik
      ultimately show False
        by (fastforce simp: min_path_weight_def path_weight.append paths.finite paths.not_empty)
    qed
    show "distinct ?p"
    proof(rule ccontr)
      let ?p1 = "takeWhile (λx. x  set pkj) pik"
      let ?l = "hd (drop (length ?p1) pik)"
      let ?p2 = "tl (dropWhile (λx. x  ?l) pkj)"
      let ?p' = "?p1 @ ?l # ?p2"
      assume "¬distinct (pik @ k # pkj)"
      from pik pkj ¬distinct (pik @ k # pkj) have "strict_prefix ?p1 pik"
        by (auto simp: paths_def strict_prefix_def takeWhile_is_prefix)
      from pik pkj ¬distinct (pik @ k # pkj) strict_prefix ?p1 pik have "strict_suffix ?p2 pkj"
        by (fastforce simp: dropWhile_eq_drop tl_drop
                     intro: drop_strict_suffix[OF strict_suffix_tl]
                      dest: prefix_length_less nth_length_takeWhile)
      from strict_prefix ?p1 pik have "?l  set pkj"
        by (fastforce simp: hd_drop_conv_nth dest: prefix_length_less nth_length_takeWhile)
      have "?p'  paths (i, j) ks"
      proof(rule paths.I)
        from pik pkj strict_prefix ?p1 pik strict_suffix ?p2 pkj ?l  set pkj show "set ?p'  ks"
          by (force dest: set_takeWhileD strict_suffix_set_subset simp: paths_def)
        from i  set ?p strict_suffix ?p2 pkj ?l  set pkj show "i  set ?p'"
          by (auto dest: set_takeWhileD strict_suffix_set_subset)
        from j  set ?p strict_suffix ?p2 pkj ?l  set pkj show "j  set ?p'"
          by (auto dest: set_takeWhileD strict_suffix_set_subset)
        from pik pkj strict_suffix ?p2 pkj ?l  set pkj show "distinct ?p'"
          by (auto simp: paths_def distinct_tl dest!: set_takeWhileD strict_suffix_set_subset
              simp flip: arg_cong[where f=set, OF takeWhile_neq_rev, simplified])
      qed
      have "path_weight m (i, j) ?p'  path_weight m (i, k) pik + path_weight m (k, j) pkj"
        unfolding path_weight.append
      proof(induct rule: add_le_mono[case_names l r])
        case l from strict_prefix ?p1 pik show ?case
          by (metis append.right_neutral append_take_drop_id fst_conv linorder_le_less_linear
                    list.collapse not_add_less1 path_weight.append prefix_order.less_le takeWhile_eq_take)
      next
        case r from ?l  set pkj show ?case
          by (smt (verit) append.right_neutral hd_dropWhile le_add2 list.collapse path_weight.append
                          set_takeWhileD snd_conv takeWhile_dropWhile_id)
      qed
      with m finite ks mpik mpkj ?p'  paths (i, j) ks show False
        by (fastforce simp: min_path_weight_def paths.finite paths.not_empty)
    qed
  qed
  with m mpik mpkj
  show "?rhs  path_weight m (i, j) ` paths (i, j) (insert k ks)"
    by (force simp: min_path_weight_def path_weight.append)
qed

lemma return:
  assumes m: "¬(min_path_weight m (i, k) ks + min_path_weight m (k, j) ks < min_path_weight m (i, j) ks)"
  assumes "finite ks"
  shows "min_path_weight m (i, j) (insert k ks) = min_path_weight m (i, j) ks"
unfolding min_path_weight_def
proof(rule Min_eqI)
  from finite ks show "finite (path_weight m (i, j) ` paths (i, j) (insert k ks))"
    by (simp add: paths.finite)
next
  fix w
  assume w: "w  path_weight m (i, j) ` paths (i, j) (insert k ks)"
  then obtain p where p: "w = path_weight m (i, j) p" "p  paths (i, j) (insert k ks)" ..
  with m finite ks show "Min (path_weight m (i, j) ` paths (i, j) ks)  w"
  proof(cases "k  set p")
    case True with m finite ks w p show ?thesis
      by (auto simp: not_less min_path_weight_def path_weight.append paths.finite
              intro: order.trans[OF add_mono[OF Min_le Min_le]]
              elim!: order.trans paths.decompE)
  next
    case False with m finite ks w p show ?thesis
      by (meson Min_le finite_imageI paths.finite image_eqI paths.unused)
  qed
next
  from finite ks
  show "Min (path_weight m (i, j) ` paths (i, j) ks)  path_weight m (i, j) ` paths (i, j) (insert k ks)"
    by (fastforce simp: paths.finite paths.not_empty intro: subsetD[OF _ Min_in] subsetD[OF paths.mono])
qed

setup Sign.parent_path

setup Sign.mandatory_path "stable"

lemma Id_on_fw_inv:
  shows "stable heap.Id⇘{a}(fw_inv a m ys)"
by (auto simp: fw_inv_def fw_p_inv_def intro!: stable.intro stable.impliesI)

lemma Id_on_fw_p_inv:
  shows "stable heap.Id⇘{a}(fw_p_inv a m ij ks)"
by (auto simp: fw_p_inv_def intro: stable.intro)

lemma modifies_fw_p_inv:
  assumes "ij  set (Array.interval a) - is"
  shows "stable Array.modifies⇘a, is(fw_p_inv a m ij ks)"
using assms by (auto simp: fw_p_inv_def intro: stable.intro)

setup Sign.parent_path

lemma fw_p_inv_cong:
  assumes "a = a'"
  assumes "m = m'"
  assumes "ij = ij'"
  assumes "ks = ks'"
  assumes "s (heap.addr_of a) = s' (heap.addr_of a')"
  shows "fw_p_inv a m ij ks s = fw_p_inv a' m' ij' ks' s'"
using assms by (simp add: fw_p_inv_def cong: heap.obj_at.cong Array.get.weak_cong)

lemma fw_p_invD:
  assumes "fw_p_inv a m ij ks s"
  shows "heap.rep_inv a s"
    and "Array.get a ij s = min_path_weight m ij ks"
using assms unfolding fw_p_inv_def by blast+

lemma fw_p_inv_fw_update:
  assumes "finite ks"
  assumes "ij  set (Array.interval a)"
  assumes "fw_p_inv a m ij ks s"
  assumes "min_path_weight m (fst ij, k) ks + min_path_weight m (k, snd ij) ks < min_path_weight m ij ks"
  shows "fw_p_inv a m ij (insert k ks) (Array.set a ij (min_path_weight m (fst ij, k) ks + min_path_weight m (k, snd ij) ks) s)"
using assms by (cases ij) (simp add: fw_p_inv_def Array.simps' min_path_weight.fw_update)

lemma fw_p_inv_return:
  assumes "finite ks"
  assumes "fw_p_inv a m ij ks s"
  assumes "¬(min_path_weight m (fst ij, k) ks + min_path_weight m (k, snd ij) ks < min_path_weight m ij ks)"
  shows "fw_p_inv a m ij (insert k ks) s"
using assms by (cases ij) (simp add: fw_p_inv_def min_path_weight.return)

setup Sign.mandatory_path "ag"

textcitet‹p109› in "Dingel:2000" key intuition: when processing index k›, neither a[i, k]› and a[k, j]› change.
  his argument is bogus: it is enough to observe that shortest paths never get shorter by adding edges
  he unnecessarily assumes that δ(i, i) = 0› for all i›

lemma fw_update:
  assumes "insert k ks  set (Ix.interval (fst_bounds (array.bounds a)))"
  assumes "Array.square a"
  assumes ij: "ij  set (Array.interval a)"
  defines "ij. G ij  Array.modifies⇘a, {ij |_::unit. k  {fst ij, snd ij}}⇙"
  defines "A  heap.Id⇘{a}  (G ` (set (Array.interval a) - {ij}))"
  shows "prog.p2s (fw_update a ij k)
           fw_p_inv a m ij ks  fw_p_inv a m (fst ij, k) ks  fw_p_inv a m (k, snd ij) ks⦄, A
            G ij, λ_. fw_p_inv a m ij (insert k ks)"
proof -
  from assms(1) have "finite ks"
    using finite_subset by auto
  from assms(1-3) have ijk: "(fst ij, k)  set (Array.interval a)" "(k, snd ij)  set (Array.interval a)"
    by (auto simp: Ix.square_def interval_prod_def)
  show ?thesis
apply (simp add: fw_update_def split_def)
apply (rule ag.pre_pre)
 apply (rule ag.prog.bind)+
    apply (rule ag.prog.if)
    apply (rename_tac vij vik vkj)
    apply (subst prog.Array.upd_def)
    apply (rule_tac P="λs. fw_p_inv a m ij ks s  fw_p_inv a m (fst ij, k) ks s  fw_p_inv a m (k, snd ij) ks s
                          vij = Array.get a ij s  vik = Array.get a (fst ij, k) s  vkj = Array.get a (k, snd ij) s"
                in ag.prog.action)
        apply (clarsimp simp: finite ks fw_p_invD(2) fw_p_inv_fw_update ij; fail)
       using ij apply (fastforce simp: G_def intro: Array.modifies.Array_set dest: fw_p_invD(1))
      using ij ijk apply (fastforce simp: A_def G_def intro: stable.intro stable.Id_on_fw_p_inv stable.modifies_fw_p_inv)
     using ij ijk apply (fastforce simp: A_def G_def intro: stable.intro stable.Id_on_fw_p_inv stable.modifies_fw_p_inv)
    apply (rename_tac vij vik vkj)
    apply (rule_tac Q="λ_ s. vij = Array.get a ij s  vik = Array.get a (fst ij, k) s  vkj = Array.get a (k, snd ij) s"
                 in ag.augment_post)
    apply (rule ag.prog.return)
    using ij ijk apply (fastforce simp: A_def G_def intro: stable.intro stable.Id_on_fw_p_inv stable.modifies_fw_p_inv)
   apply (rename_tac vij vik)
   apply (rule_tac Q="λv s. fw_p_inv a m ij ks s  fw_p_inv a m (fst ij, k) ks s  fw_p_inv a m (k, snd ij) ks s
                           vij = Array.get a ij s  vik = Array.get a (fst ij, k) s  v = Array.get a (k, snd ij) s"
                in ag.post_imp)
    apply (force simp: finite ks fw_p_invD(2) fw_p_inv_return)
   apply (subst prog.Array.nth_def)
   apply (rule ag.prog.action)
      apply (clarsimp split del: if_split; assumption)
     apply fast
    using ij ijk apply (fastforce simp: A_def G_def intro: stable.intro stable.Id_on_fw_p_inv stable.modifies_fw_p_inv)
   using ij ijk apply (fastforce simp: A_def G_def intro: stable.intro stable.Id_on_fw_p_inv stable.modifies_fw_p_inv)
  apply (subst prog.Array.nth_def)
  apply (rule ag.prog.action)
     apply (clarsimp; assumption)
    apply fast
   using ij ijk apply (fastforce simp: A_def G_def intro: stable.intro stable.Id_on_fw_p_inv stable.modifies_fw_p_inv)
  using ij ijk apply (fastforce simp: A_def G_def intro: stable.intro stable.Id_on_fw_p_inv stable.modifies_fw_p_inv)
 apply (subst prog.Array.nth_def)
 apply (rule ag.prog.action)
    apply (clarsimp; assumption)
   apply fast
  using ij ijk apply (fastforce simp: A_def G_def intro: stable.intro stable.Id_on_fw_p_inv stable.modifies_fw_p_inv)
 using ij ijk apply (fastforce simp: A_def G_def intro: stable.intro stable.Id_on_fw_p_inv stable.modifies_fw_p_inv)
apply blast
done
qed

lemma fw_chaotic:
  fixes a :: "('i::Ix × 'i, nat) array"
  fixes m :: "'i matrix"
  shows "prog.p2s (fw_chaotic a)  fw_pre a m⦄, heap.Id⇘{a} heap.modifies⇘{a}, fw_post a m"
unfolding fw_chaotic_def fw_pre_def
apply (simp add: prog.p2s.simps Let_def split_def)
apply (rule ag.gen_asm)
apply (rule ag.pre_pre_post)
  apply (rule ag.prog.fst_app_chaotic[where P="fw_inv a m"])
   apply (rule ag.pre)
       apply (rule ag.prog.Parallel)
       apply (rule ag.fw_update[where m=m])
         apply (simp; fail)
        apply (simp; fail)
       apply (simp; fail)
      apply (fastforce simp: fw_inv_def split_def Ix.prod.interval_conv Ix.square.conv)
     apply blast
    using Array.modifies.heap_modifies_le apply blast
   apply (simp add: fw_inv_def; fail)
  apply (simp add: stable.Id_on_fw_inv; fail)
 apply (fastforce simp: fw_pre_def fw_inv_def fw_p_inv_def min_path_weight_def paths.empty)
apply (fastforce simp: fw_post_def split_def stable.Id_on_fw_inv)
done

setup Sign.parent_path
(*<*)

end
(*>*)