Theory Challenge3

section ‹Challenge 3›
theory Challenge3
  imports Parallel_Multiset_Fold Refine_Imperative_HOL.IICF
begin

text ‹Problem definition:
🌐‹https://ethz.ch/content/dam/ethz/special-interest/infk/chair-program-method/pm/documents/Verify%20This/Challenges%202019/sparse_matrix_multiplication.pdf›

subsection ‹Single-Threaded Implementation›
text ‹We define type synonyms for values (which we fix to integers here) and 
  triplets, which are a pair of coordinates and a value.
›
type_synonym val = int
type_synonym triplet = "(nat × nat) × val"

text ‹We fix a size n› for the vector.›
context 
  fixes n :: nat 
begin

  text ‹An algorithm finishing triples in any order.
  ›
  definition
    "alg (ts :: triplet list) x = fold_mset (λ((r,c),v) y. y(c:=y c + x r * v)) (λ_. 0 :: int) (mset ts)"

  text ‹
    We show that the folding function is commutative, i.e., the order of the folding does not matter.
    We will use this below to show that the computation can be parallelized.
  ›  
  interpretation comp_fun_commute "(λ((r, c), v) y. y(c := (y c :: val) + x r * v))"
    apply unfold_locales
    apply (auto intro!: ext)
    done

  subsection ‹Specification›
  text ‹Abstraction function, mapping a sparse matrix to a function from coordinates to values.›
  definition α :: "triplet list  (nat × nat)  val" where 
    "α = the_default 0 oo map_of"

  text ‹Abstract product.›
  definition "pr m x i  k=0..<n. x k * m (k, i)"    

  subsection ‹Correctness›

  lemma aux: 
    "
    distinct (map fst (ts1@ts2)) 
    the_default (0::val) (case map_of ts1 (k, i) of None  map_of ts2 (k, i) | Some x  Some x)
    
    = the_default 0 (map_of ts1 (k, i)) + the_default 0 (map_of ts2 (k, i))
    "    
    apply (auto split: option.splits)
    by (metis disjoint_iff_not_equal img_fst map_of_eq_None_iff the_default.simps(2))
    
  lemma 1[simp]: "distinct (map fst (ts1@ts2))  
    pr (α (ts1@ts2)) x i = pr (α ts1) x i + pr (α ts2) x i"
    apply (auto simp: pr_def α_def map_add_def aux split: option.splits)
    apply (auto simp: algebra_simps)
    by (simp add: sum.distrib)

  lemmas 2 = 1[of "[((r,c),v)]" "ts", simplified] for r c v ts 
    
  lemma [simp]: "α [] = (λ_. 0)" by (auto simp: α_def)  
    
  lemma [simp]: "pr (λ_. 0::val) x = (λ_. 0)" 
    by (auto simp: pr_def[abs_def])
  
  lemma aux3: "the_default 0 (if b then Some x else None) = (if b then x else 0)"
    by auto
  
  lemma correct_aux: "distinct (map fst ts); ((r,c),_)set ts. r<n 
     i. fold (λ((r,c),v) y. y(c:=y c + x r * v)) ts m i = m i + pr (α ts) x i"  
    apply (induction ts arbitrary: m)
    apply auto
    subgoal
      apply (subst 2)
      apply auto 
      unfolding pr_def α_def
      apply (auto split: if_splits cong: sum.cong simp: aux3)
      apply (auto simp: if_distrib[where f="λx. _*x"] cong: sum.cong if_cong)
      done
        
    subgoal
      apply (subst 2)
      apply auto 
      unfolding pr_def α_def
      apply (auto split: if_splits cong: sum.cong simp: aux3)
      done
    done

    
  lemma correct_fold: 
    assumes "distinct (map fst ts)"
    assumes "((r,c),_)set ts. r<n"
    shows "fold (λ((r,c),v) y. y(c:=y c + x r * v)) ts (λ_. 0) = pr (α ts) x"
    apply (rule ext)
    using correct_aux[OF assms, rule_format, where m = "λ_. 0", simplified]
    by simp

  lemma alg_by_fold: "alg ts x = fold (λ((r,c),v) y. y(c:=y c + x r * v)) ts (λ_. 0)"    
    unfolding alg_def by (simp add: fold_mset_rewr)
          
  theorem correct: 
    assumes "distinct (map fst ts)"
    assumes "((r,c),_)set ts. r<n"
    shows "alg ts x = pr (α ts) x"
    using alg_by_fold correct_fold[OF assms] by simp 

  subsection ‹Multi-Threaded Implementation›
  text ‹Correctness of the parallel implementation:›
  theorem parallel_correct:
    assumes "distinct (map fst ts)" "((r,c),_)set ts. r<n"
        and "0 < n" ― ‹At least on thread›
        ―‹We have reached a final state.›
        and "reachable x n ts (λ_. 0) (ts', ms, r)" "final n (ts', ms, r)"
      shows "r = pr (α ts) x"
    unfolding final_state_correct[OF assms(3-)] correct[OF assms(1,2)] alg_by_fold[symmetric] ..

  text ‹We also know that the computation will always terminate.›
  theorem parallel_termination:
    assumes "0 < n"
      and "reachable x n ts (λ_. 0) s"
    shows "s'. final n s'  (step x n)** s s'"
    using assms by (rule "termination")

end ― ‹Context for fixed n›.›

end