Theory Kruskal_Impl

section "Kruskal Implementation"

theory Kruskal_Impl
imports Kruskal_Refine Refine_Imperative_HOL.IICF
begin

subsection ‹Refinement III: concrete edges›

text ‹Given a concrete representation of edges and their endpoints as a pair, we refine
  Kruskal's algorithm to work on these concrete edges.›

locale Kruskal_concrete = Kruskal_interface E V vertices  joins  forest connected  weight
  for E V vertices joins forest connected  and weight :: "'edge  int"  +
  fixes
    α :: "'cedge  'edge"
    and endpoints :: "'cedge  ('a*'a) nres"
  assumes  
    endpoints_refine: "α xi = x  endpoints xi   Id (a_endpoints x)"
begin

definition wsorted' where "wsorted' == sorted_wrt (λx y. weight (α x)  weight (α y))"

lemma wsorted_mapα[simp]: "wsorted' s  wsorted (map α s)"
  by(auto simp: wsorted'_def sorted_wrt_map) 

definition "obtain_sorted_carrier' == SPEC (λL. wsorted' L  α ` set L = E)"

abbreviation concrete_edge_rel :: "('cedge  × 'edge) set" where
  "concrete_edge_rel  br α (λ_. True)"

lemma obtain_sorted_carrier'_refine:
  "(obtain_sorted_carrier', obtain_sorted_carrier)  concrete_edge_rellist_relnres_rel"
  unfolding obtain_sorted_carrier'_def obtain_sorted_carrier_def 
  apply refine_vcg
  apply (auto intro!: RES_refine simp:     )
  subgoal for s apply(rule exI[where x="map α s"])
    by(auto simp: map_in_list_rel_conv in_br_conv)  
  done

definition kruskal2
  where "kruskal2  do {
    l  obtain_sorted_carrier';
    let initial_union_find = per_init V;
    (per, spanning_forest)  nfoldli l (λ_. True)
        (λce (uf, T). do { 
            ASSERT (α ce  E);
            (a,b)  endpoints ce;
            ASSERT (aV  bV  a  Domain uf  b  Domain uf );
            if ¬ per_compare uf a b then
              do { 
                let uf = per_union uf a b;
                ASSERT (ce  set T);
                RETURN (uf, T@[ce])
              }
            else 
              RETURN (uf,T)
        }) (initial_union_find, []);
        RETURN spanning_forest
      }"

lemma lst_graph_rel_empty[simp]: "([], {})  concrete_edge_rellist_set_rel"
  unfolding list_set_rel_def apply(rule relcompI[where b="[]"])
  by (auto simp add: in_br_conv)

lemma loop_initial_rel:
  "((per_init V, []), per_init V, {})  Id ×r concrete_edge_rellist_set_rel"
  by simp

lemma concrete_edge_rel_list_set_rel:
  "(a, b)  concrete_edge_rellist_set_rel  α ` (set a) = b"  
  by (auto simp: in_br_conv list_set_rel_def dest: list_relD2)

theorem kruskal2_refine: "(kruskal2, kruskal1)concrete_edge_rellist_set_relnres_rel"
  unfolding kruskal1_def kruskal2_def Let_def 
  apply (refine_rcg obtain_sorted_carrier'_refine[THEN nres_relD]
                    endpoints_refine loop_initial_rel)     
  by (auto intro!: list_set_rel_append
            dest: concrete_edge_rel_list_set_rel
            simp: in_br_conv)

end

subsection ‹Refinement to Imperative/HOL with Sepref-Tool›

text ‹Given implementations for the operations of getting a list of concrete edges 
  and getting the endpoints of a concrete edge we synthesize Kruskal in Imperative/HOL.›

locale Kruskal_Impl = Kruskal_concrete E V vertices joins forest connected weight α endpoints                       
  for E V vertices joins forest connected and weight :: "'edge  int"
    and α and endpoints :: "nat × int × nat  (nat × nat) nres"
    +
  fixes getEdges  :: "(nat × int × nat) list nres"
    and getEdges_impl :: "(nat × int × nat) list Heap"         
    and superE :: "(nat × int × nat) set"         
    and endpoints_impl :: "(nat × int × nat)  (nat × nat) Heap"                    
  assumes 
    getEdges_refine: "getEdges  SPEC (λL. α ` set L = E 
                             ((a,wv,b)set L.  weight (α (a,wv,b)) = wv)  set L  superE)"
    and
    getEdges_impl: "(uncurry0 getEdges_impl, uncurry0 getEdges)
                      unit_assnk a list_assn (nat_assn ×a int_assn ×a nat_assn)"
    and 
    max_node_is_Max_V: "E = α ` set la  max_node la = Max (insert 0 V)"
    and
    endpoints_impl: "( endpoints_impl,  endpoints) 
                       (nat_assn ×a int_assn ×a nat_assn)k a (nat_assn ×a nat_assn)"
begin
  
  lemma this_loc: "Kruskal_Impl E V vertices  joins  forest connected  weight
            α endpoints getEdges getEdges_impl superE    endpoints_impl" by unfold_locales
  
  
  subsubsection ‹Refinement IV: given an edge set›

  text ‹We now assume to have an implementation of the operation to obtain a list of the edges of
    a graph. By sorting this list we refine @{term obtain_sorted_carrier'}.›

  definition "obtain_sorted_carrier'' = do {
      l  SPEC (λL.  α ` set L = E 
                               ((a,wv,b)set L.  weight (α (a,wv,b)) = wv)  set L  superE);
      SPEC (λL. sorted_wrt edges_less_eq L  set L = set l)
  }"
  
  lemma wsorted'_sorted_wrt_edges_less_eq:
    assumes "(a,wv,b)set s.  weight (α (a,wv,b)) = wv"
        "sorted_wrt edges_less_eq s"
    shows "wsorted' s"
    using assms apply -
    unfolding wsorted'_def   unfolding edges_less_eq_def
    apply(rule sorted_wrt_mono_rel )    
    by (auto simp: case_prod_beta)
  
  lemma obtain_sorted_carrier''_refine:
    "(obtain_sorted_carrier'', obtain_sorted_carrier')  Idnres_rel"
    unfolding obtain_sorted_carrier''_def obtain_sorted_carrier'_def
    apply refine_vcg
     apply(auto simp: in_br_conv   wsorted'_sorted_wrt_edges_less_eq
        distinct_map map_in_list_rel_conv)  
    done

  definition "obtain_sorted_carrier''' =
        do {
      l  getEdges; 
      RETURN (quicksort_by_rel edges_less_eq [] l, max_node l)
  }" 
    
  definition "add_size_rel   = br fst (λ(l,n). n= Max (insert 0 V))"    
  
  lemma obtain_sorted_carrier'''_refine:
    "(obtain_sorted_carrier''', obtain_sorted_carrier'')  add_size_relnres_rel"
    unfolding obtain_sorted_carrier'''_def obtain_sorted_carrier''_def
    apply (refine_rcg getEdges_refine) 
    by (auto intro!: RETURN_SPEC_refine simp: quicksort_by_rel_distinct sort_edges_correct
        add_size_rel_def in_br_conv  max_node_is_Max_V
        dest!: distinct_mapI)

  lemmas osc_refine =  obtain_sorted_carrier'''_refine[FCOMP obtain_sorted_carrier''_refine,
                                                        to_foparam, simplified]

  definition kruskal3 :: "(nat × int × nat) list nres"
    where "kruskal3   do {
      (sl,mn)  obtain_sorted_carrier''';
      let initial_union_find = per_init' (mn + 1);
      (per, spanning_forest)  nfoldli sl (λ_. True)
          (λce (uf, T). do { 
              ASSERT (α ce  E);
              (a,b)  endpoints ce;
              ASSERT (a  Domain uf  b  Domain uf);
              if ¬ per_compare uf a b then
                do { 
                  let uf = per_union uf a b;
                  ASSERT (ceset T);
                  RETURN (uf, T@[ce])
                }
              else 
                RETURN (uf,T)
          }) (initial_union_find, []);
          RETURN spanning_forest
        }"
 
  lemma endpoints_spec: "endpoints ce  SPEC (λ_. True)"
    by(rule order.trans[OF endpoints_refine], auto)    

  lemma  kruskal3_subset:
    shows "kruskal3 n SPEC (λT. distinct T  set T  superE )" 
    unfolding kruskal3_def obtain_sorted_carrier'''_def
    apply (refine_vcg getEdges_refine[THEN leof_lift] endpoints_spec[THEN leof_lift]
        nfoldli_leof_rule[where I="λ_ _ (_, T). distinct T   set T  superE "])
             apply auto 
    subgoal   
      by (metis append_self_conv in_set_conv_decomp set_quicksort_by_rel subset_iff)  
    done


  definition per_supset_rel :: "('a per × 'a per) set" where
    "per_supset_rel
       {(p1,p2). p1  Domain p2 × Domain p2 = p2  p1 - (Domain p2 × Domain p2)  Id}"

  lemma per_supset_rel_dom: "(p1, p2)  per_supset_rel  Domain p1  Domain p2"
    by (auto simp: per_supset_rel_def)
  
  lemma per_supset_compare:
    "(p1, p2)  per_supset_rel  x1Domain p2  x2Domain p2
        per_compare p1 x1 x2  per_compare p2 x1 x2"
    by (auto simp: per_supset_rel_def)
  
  lemma per_supset_union: "(p1, p2)  per_supset_rel  x1Domain p2  x2Domain p2 
    (per_union p1 x1 x2, per_union p2 x1 x2)  per_supset_rel"
    apply (clarsimp simp: per_supset_rel_def per_union_def Domain_unfold )
    apply (intro subsetI conjI)
     apply blast
    apply force
    done

  lemma per_initN_refine: "(per_init' (Max (insert 0 V) + 1), per_init V)  per_supset_rel"
    unfolding per_supset_rel_def per_init'_def per_init_def max_node_def
    by (auto simp: less_Suc_eq_le  ) 

  theorem kruskal3_refine: "(kruskal3, kruskal2)Idnres_rel"
    unfolding kruskal2_def kruskal3_def Let_def
    apply (refine_rcg osc_refine[THEN nres_relD]   )
               supply RELATESI[where R="per_supset_rel::(nat per × _) set", refine_dref_RELATES]
               apply refine_dref_type 
    subgoal by (simp add: add_size_rel_def in_br_conv)
    subgoal using per_initN_refine by (simp add: add_size_rel_def in_br_conv)
    by (auto simp add: add_size_rel_def in_br_conv per_supset_compare per_supset_union
        dest: per_supset_rel_dom
        simp del: per_compare_def ) 


  subsubsection ‹Synthesis of Kruskal by SepRef›


  lemma [sepref_import_param]: "(sort_edges,sort_edges)Id×rId×rIdlist_rel Id×rId×rIdlist_rel"
    by simp
  lemma [sepref_import_param]: "(max_node, max_node)  Id×rId×rIdlist_rel  nat_rel" by simp

  sepref_register "getEdges" :: "(nat × int × nat) list nres"
  sepref_register "endpoints" :: "(nat × int × nat)  (nat*nat) nres"
 
 
  declare getEdges_impl [sepref_fr_rules]
  declare endpoints_impl [sepref_fr_rules]
  
  schematic_goal kruskal_impl:
    "(uncurry0 ?c, uncurry0 kruskal3 )  (unit_assn)k a list_assn (nat_assn ×a int_assn ×a nat_assn)"
    unfolding kruskal3_def obtain_sorted_carrier'''_def 
    unfolding sort_edges_def[symmetric]
    apply (rewrite at "nfoldli _ _ _ (_,rewrite_HOLE)" HOL_list.fold_custom_empty)
    by sepref 

  concrete_definition (in -) kruskal uses Kruskal_Impl.kruskal_impl
  prepare_code_thms (in -) kruskal_def
  lemmas kruskal_refine = kruskal.refine[OF this_loc]

  
  
  abbreviation "MSF == minBasis"   
  abbreviation "SpanningForest == basis"
  lemmas SpanningForest_def = basis_def
  lemmas MSF_def = minBasis_def 
  
  lemmas kruskal3_ref_spec_ = kruskal3_refine[FCOMP kruskal2_refine, FCOMP kruskal1_refine,
      FCOMP kruskal0_refine,
      FCOMP minWeightBasis_refine] 
  
  lemma kruskal3_ref_spec':
    "(uncurry0 kruskal3, uncurry0 (SPEC (λr. MSF (α ` set r))))  unit_rel f Idnres_rel" 
    unfolding fref_def 
    apply auto
    apply(rule nres_relI) 
    apply(rule order.trans[OF  kruskal3_ref_spec_[unfolded fref_def, simplified,  THEN nres_relD]])
    by (auto simp: conc_fun_def list_set_rel_def in_br_conv dest!: list_relD2) 
  
  lemma kruskal3_ref_spec:
   "(uncurry0 kruskal3,
      uncurry0 (SPEC (λr. distinct r  set r  superE   MSF (α ` set r))))
       unit_rel f Idnres_rel"
    unfolding fref_def 
    apply auto
    apply(rule nres_relI) 
    apply simp
    using SPEC_rule_conj_leofI2[OF kruskal3_subset kruskal3_ref_spec'
              [unfolded fref_def, simplified,  THEN nres_relD, simplified]]
    by simp
  
  lemma [fcomp_norm_simps]: "list_assn (nat_assn ×a int_assn ×a nat_assn) = id_assn"
    by (auto simp: list_assn_pure_conv)
  
  lemmas kruskal_ref_spec = kruskal_refine[FCOMP kruskal3_ref_spec]
  

  text ‹The final correctness lemma for Kruskal's algorithm. ›

  lemma kruskal_correct_forest:
    shows "<emp> kruskal getEdges_impl endpoints_impl ()
             <λr. ( distinct r  set r  superE  MSF (set (map α r)))>t"
  proof -
    show ?thesis
      using kruskal_ref_spec[to_hnr]
      unfolding hn_refine_def  
      apply clarsimp
      apply (erule cons_post_rule)
      by (sep_auto simp: hn_ctxt_def pure_def list_set_rel_def in_br_conv dest: list_relD)     
  qed                            

end ― ‹locale @{text Kruskal_Impl}

end