Theory Sepref_Dijkstra

section ‹Imperative Implementation of Dijkstra's Shortest Paths Algorithm›
theory Sepref_Dijkstra
imports 
  "../IICF/IICF"
  "../Sepref_ICF_Bindings"
  Dijkstra_Shortest_Path.Dijkstra
  Dijkstra_Shortest_Path.Test
  "HOL-Library.Code_Target_Numeral"
  (*"../../../DFS_Framework/Misc/DFS_Framework_Refine_Aux"*)
  Sepref_WGraph
begin


(* Setup for Infty *)

instantiation infty :: (heap) heap
begin
  instance 
    apply standard
    apply (rule_tac x="λInfty  0 | Num a  to_nat a + 1" in exI)
    apply (rule injI)
    apply (auto split: infty.splits)
    done
end

fun infty_assn where
  "infty_assn A (Num x) (Num y) = A x y"
| "infty_assn A Infty Infty = emp"
| "infty_assn _ _ _ = false"

text ‹Connection with infty_rel›
lemma infty_assn_pure_conv: "infty_assn (pure A) = pure (Ainfty_rel)"
  apply (intro ext)
  subgoal for x y by (cases x; cases y; simp add: pure_def)
  done

lemmas [sepref_import_rewrite, fcomp_norm_unfold, sepref_frame_normrel_eqs] =
  infty_assn_pure_conv[symmetric]
lemmas [constraint_simps] = infty_assn_pure_conv

lemma infty_assn_pure[safe_constraint_rules]: "is_pure A  is_pure (infty_assn A)"
  by (auto simp: is_pure_conv infty_assn_pure_conv)

lemma infty_assn_id[simp]: "infty_assn id_assn = id_assn"
  by (simp add: infty_assn_pure_conv)

lemma [safe_constraint_rules]: "IS_BELOW_ID R  IS_BELOW_ID (Rinfty_rel)"  
  by (auto simp: infty_rel_def IS_BELOW_ID_def)

sepref_register Num Infty

lemma Num_hnr[sepref_fr_rules]: "(return o Num,RETURN o Num)Ad a infty_assn A"
  by sepref_to_hoare sep_auto

lemma Infty_hnr[sepref_fr_rules]: "(uncurry0 (return Infty),uncurry0 (RETURN Infty))unit_assnk a infty_assn A"
  by sepref_to_hoare sep_auto

sepref_register case_infty
lemma [sepref_monadify_arity]: "case_infty  λ2f1 f2 x. SP case_infty$f1$(λ2x. f2$x)$x"
  by simp
lemma [sepref_monadify_comb]: "case_infty$f1$f2$x  (⤜)$(EVAL$x)$(λ2x. SP case_infty$f1$f2$x)" by simp
lemma [sepref_monadify_comb]: "EVAL$(case_infty$f1$(λ2x. f2 x)$x) 
   (⤜)$(EVAL$x)$(λ2x. SP case_infty$(EVAL $ f1)$(λ2x. EVAL $ f2 x)$x)"
  apply (rule eq_reflection)
  by (simp split: infty.splits)

lemma infty_assn_ctxt: "infty_assn A x y = z  hn_ctxt (infty_assn A) x y = z"
  by (simp add: hn_ctxt_def)

lemma infty_cases_hnr[sepref_prep_comb_rule, sepref_comb_rules]:
  fixes A e e'
  defines [simp]: "INVe  hn_invalid (infty_assn A) e e'"
  assumes FR: "Γ t hn_ctxt (infty_assn A) e e' * F"
  assumes Infty: "e = Infty; e' = Infty  hn_refine (hn_ctxt (infty_assn A) e e' * F) f1' (hn_ctxt XX1 e e' * Γ1') R f1"
  assumes Num: "x1 x1a. e = Num x1; e' = Num x1a  hn_refine (hn_ctxt A x1 x1a * INVe * F) (f2' x1a) (hn_ctxt A' x1 x1a * hn_ctxt XX2 e e' * Γ2') R (f2 x1)"
  assumes MERGE2[unfolded hn_ctxt_def]: "Γ1' A Γ2' t Γ'"
  shows "hn_refine Γ (case_infty f1' f2' e') (hn_ctxt (infty_assn A') e e' * Γ') R (case_infty$f1$(λ2x. f2 x)$e)"
  apply (rule hn_refine_cons_pre[OF FR])
  apply1 extract_hnr_invalids
  apply (cases e; cases e'; simp add: infty_assn.simps[THEN infty_assn_ctxt])
  subgoal 
    apply (rule hn_refine_cons[OF _ Infty _ entt_refl]; assumption?)
    applyS (simp add: hn_ctxt_def)
    apply (subst mult.commute, rule entt_fr_drop)
    apply (rule entt_trans[OF _ MERGE2])
    apply (simp add:)
  done  
  subgoal 
    apply (rule hn_refine_cons[OF _ Num _ entt_refl]; assumption?)
    applyS (simp add: hn_ctxt_def)
    apply (rule entt_star_mono)
    apply1 (rule entt_fr_drop)
    applyS (simp add: hn_ctxt_def)
    apply1 (rule entt_trans[OF _ MERGE2])
    applyS (simp add:)
  done    
  done
  
lemma hnr_val[sepref_fr_rules]: "(return o Weight.val,RETURN o Weight.val)  [λx. xInfty]a (infty_assn A)d  A"
  apply sepref_to_hoare
  subgoal for x y by (cases x; cases y; sep_auto)
  done

context
  fixes A :: "'a::weight  'b  assn"
  fixes plusi
  assumes GA[unfolded GEN_ALGO_def, sepref_fr_rules]: "GEN_ALGO plusi (λf. (uncurry f,uncurry (RETURN oo (+)))Ak*aAk a A)"
begin
  sepref_thm infty_plus_impl is "uncurry (RETURN oo (+))" :: "((infty_assn A)k *a (infty_assn A)k a infty_assn A)"
    unfolding infty_plus_eq_plus[symmetric] infty_plus_def[abs_def]
    by sepref
end
concrete_definition infty_plus_impl uses infty_plus_impl.refine_raw is "(uncurry ?impl,_)_"
lemmas [sepref_fr_rules] = infty_plus_impl.refine

definition infty_less where
  "infty_less lt a b  case (a,b) of (Num a, Num b)  lt a b | (Num _, Infty)  True | _  False"

lemma infty_less_param[param]:
  "(infty_less,infty_less)  (RRbool_rel)  Rinfty_rel  Rinfty_rel  bool_rel"
  unfolding infty_less_def[abs_def]
  by parametricity

lemma infty_less_eq_less: "infty_less (<) = (<)"
  unfolding infty_less_def[abs_def] 
  apply (clarsimp intro!: ext)
  subgoal for a b by (cases a; cases b; auto)
  done

context
  fixes A :: "'a::weight  'b  assn"
  fixes lessi
  assumes GA[unfolded GEN_ALGO_def, sepref_fr_rules]: "GEN_ALGO lessi (λf. (uncurry f,uncurry (RETURN oo (<)))Ak*aAk a bool_assn)"
begin
  sepref_thm infty_less_impl is "uncurry (RETURN oo (<))" :: "((infty_assn A)k *a (infty_assn A)k a bool_assn)"
    unfolding infty_less_eq_less[symmetric] infty_less_def[abs_def]
    by sepref
end
concrete_definition infty_less_impl uses infty_less_impl.refine_raw is "(uncurry ?impl,_)_"
lemmas [sepref_fr_rules] = infty_less_impl.refine

lemma param_mpath': "(mpath',mpath')
   A×r B ×r Alist_rel ×r Boption_rel  A×r B ×r Alist_reloption_rel"
proof -
  have 1: "mpath' = map_option fst"
    apply (intro ext, rename_tac x)
    apply (case_tac x)
    apply simp
    apply (rename_tac a)
    apply (case_tac a)
    apply simp
    done
  show ?thesis  
    unfolding 1
    by parametricity
qed
lemmas (in -) [sepref_import_param] = param_mpath'

lemma param_mpath_weight': 
  "(mpath_weight', mpath_weight')  A×rB×rAlist_rel ×r Boption_rel  Binfty_rel"
  by (auto elim!: option_relE simp: infty_rel_def top_infty_def)

lemmas [sepref_import_param] = param_mpath_weight'

context Dijkstra begin  
  lemmas impl_aux = mdijkstra_def[unfolded mdinit_def mpop_min_def mupdate_def]

  lemma mdijkstra_correct:  
    "(mdijkstra, SPEC (is_shortest_path_map v0))  br αr res_invarmnres_rel"
  proof -
    note mdijkstra_refines
    also note dijkstra'_refines
    also note dijkstra_correct
    finally show ?thesis
      by (rule nres_relI)
  qed

end

locale Dijkstra_Impl = fixes w_dummy :: "'W::{weight,heap}"
begin
  text ‹Weights›
  sepref_register "0::'W"  
  lemmas [sepref_import_param] = 
    IdI[of "0::'W"]

  abbreviation "weight_assn  id_assn :: 'W  _"

  lemma w_plus_param: "((+), (+)::'W_)  Id  Id  Id" by simp
  lemma w_less_param: "((<), (<)::'W_)  Id  Id  Id" by simp
  lemmas [sepref_import_param] = w_plus_param w_less_param
  lemma [sepref_gen_algo_rules]: 
    "GEN_ALGO (return oo (+)) (λf. (uncurry f, uncurry (RETURN ∘∘ (+)))  id_assnk *a id_assnk a id_assn)"
    "GEN_ALGO (return oo (<)) (λf. (uncurry f, uncurry (RETURN ∘∘ (<)))  id_assnk *a id_assnk a id_assn)"
    by (sep_auto simp: GEN_ALGO_def pure_def intro!: hfrefI hn_refineI)+

  lemma conv_prio_pop_min: "prio_pop_min m = do {
      ASSERT (dom m  {}); 
      ((k,v),m)  mop_pm_pop_min id m;
      RETURN (k,v,m)
    }"
    unfolding prio_pop_min_def mop_pm_pop_min_def
    by (auto simp: pw_eq_iff refine_pw_simps ran_def)
end

context fixes N :: nat and w_dummy::"'W::{heap,weight}" begin  

  interpretation Dijkstra_Impl w_dummy .

  definition "drmap_assn2  IICF_Sepl_Binding.iam.assn 
    (pure (node_rel N))  
    (prod_assn
      (list_assn (prod_assn (pure (node_rel N)) (prod_assn weight_assn (pure (node_rel N)))))
      weight_assn)
    "
    

  concrete_definition mdijkstra' uses Dijkstra.impl_aux

  sepref_definition dijkstra_imp is "uncurry mdijkstra'" 
    :: "(is_graph N (Id::('W×'W) set))k *a (pure (node_rel N))k a drmap_assn2"
    unfolding mdijkstra'_def
    apply (subst conv_prio_pop_min)
    apply (rewrite in "RETURN (_,)" iam.fold_custom_empty)
    apply (rewrite hm_fold_custom_empty_sz[of N])
    apply (rewrite in "_(_  (,0))" HOL_list.fold_custom_empty)
    unfolding drmap_assn2_def
    using [[id_debug, goals_limit = 1]]
    by sepref
  export_code dijkstra_imp checking SML_imp
end


text ‹The main correctness theorem›

thm Dijkstra.mdijkstra_correct

lemma mdijkstra'_aref: "(uncurry mdijkstra',uncurry (SPEC oo weighted_graph.is_shortest_path_map))
   [λ(G,v0). Dijkstra G v0]f Id×rId  br Dijkstra.αr Dijkstra.res_invarmnres_rel"
  using Dijkstra.mdijkstra_correct
  by (fastforce intro!: frefI simp: mdijkstra'.refine[symmetric])

definition "drmap_assn N  hr_comp (drmap_assn2 N) (br Dijkstra.αr Dijkstra.res_invarm)"

context notes [fcomp_norm_unfold] = drmap_assn_def[symmetric] begin

theorem dijkstra_imp_correct: "(uncurry (dijkstra_imp N), uncurry (SPEC ∘∘ weighted_graph.is_shortest_path_map))
   [λ(G, v0). v0  nodes G  ((v, w, v')  edges G. 0  w)]a (is_graph N Id)k *a (node_assn N)k  drmap_assn N"
  apply (rule hfref_weaken_pre'[OF _ dijkstra_imp.refine[FCOMP mdijkstra'_aref]])
proof clarsimp
  fix G :: "(nat,'w::{weight,heap}) graph" and v0
  assume v0_is_node: "v0  nodes G"
    and nonneg_weights: "(v, w, v')  edges G. 0  w"
    and "v0<N" 
    and RDOM: "rdomp (is_graph N Id) G"

  from RDOM interpret valid_graph G unfolding is_graph_def rdomp_def by auto

  from RDOM have [simp]: "finite V" unfolding is_graph_def rdomp_def by auto

  from RDOM have "vV. {(w, v'). (v, w, v')  E}  
    Range (Id ×r node_rel Nlist_set_rel)"
    by (auto simp: succ_def is_graph_def rdomp_def)
  hence "vV. finite {(w, v'). (v, w, v')  E}"
    unfolding list_set_rel_range by simp
  hence "finite (Sigma V (λv. {(w, v'). (v, w, v')  E}))"
    by auto
  also have "E  (Sigma V (λv. {(w, v'). (v, w, v')  E}))"  
    using E_valid
    by auto
  finally (finite_subset[rotated]) have [simp]: "finite E" .
    
  show "Dijkstra G v0"
    apply (unfold_locales)
    unfolding is_graph_def using v0_is_node nonneg_weights
    by auto
qed    

end
  
corollary dijkstra_imp_rule: "
  <is_graph n Id G Gi * (v0  nodes G  ((v, w, v')  edges G. 0  w))> 
    dijkstra_imp n Gi v0 
  <λmi. (is_graph n Id) G Gi 
      * (Am. drmap_assn n m mi * (weighted_graph.is_shortest_path_map G v0 m)) >t"
  using dijkstra_imp_correct[to_hnr, of v0 G n v0 Gi]
  unfolding hn_refine_def
  apply (clarsimp)
  apply (erule cons_rule[rotated -1])
  apply (sep_auto simp: hn_ctxt_def pure_def is_graph_def)
  apply (sep_auto simp: hn_ctxt_def)
  done


end