Theory Challenge3

section ‹Challenge 3›
theory Challenge3
  imports Parallel_Multiset_Fold Refine_Imperative_HOL.IICF

text ‹Problem definition:

subsection ‹Single-Threaded Implementation›
text ‹We define type synonyms for values (which we fix to integers here) and 
  triplets, which are a pair of coordinates and a value.
type_synonym val = int
type_synonym triplet = "(nat × nat) × val"

text ‹We fix a size n› for the vector.›
  fixes n :: nat 

  text ‹An algorithm finishing triples in any order.
    "alg (ts :: triplet list) x = fold_mset (λ((r,c),v) y. y(c:=y c + x r * v)) (λ_. 0 :: int) (mset ts)"

  text ‹
    We show that the folding function is commutative, i.e., the order of the folding does not matter.
    We will use this below to show that the computation can be parallelized.
  interpretation comp_fun_commute "(λ((r, c), v) y. y(c := (y c :: val) + x r * v))"
    apply unfold_locales
    apply (auto intro!: ext)

  subsection ‹Specification›
  text ‹Abstraction function, mapping a sparse matrix to a function from coordinates to values.›
  definition α :: "triplet list  (nat × nat)  val" where 
    "α = the_default 0 oo map_of"

  text ‹Abstract product.›
  definition "pr m x i  k=0..<n. x k * m (k, i)"    

  subsection ‹Correctness›

  lemma aux: 
    distinct (map fst (ts1@ts2)) 
    the_default (0::val) (case map_of ts1 (k, i) of None  map_of ts2 (k, i) | Some x  Some x)
    = the_default 0 (map_of ts1 (k, i)) + the_default 0 (map_of ts2 (k, i))
    apply (auto split: option.splits)
    by (metis disjoint_iff_not_equal img_fst map_of_eq_None_iff the_default.simps(2))
  lemma 1[simp]: "distinct (map fst (ts1@ts2))  
    pr (α (ts1@ts2)) x i = pr (α ts1) x i + pr (α ts2) x i"
    apply (auto simp: pr_def α_def map_add_def aux split: option.splits)
    apply (auto simp: algebra_simps)
    by (simp add: sum.distrib)

  lemmas 2 = 1[of "[((r,c),v)]" "ts", simplified] for r c v ts 
  lemma [simp]: "α [] = (λ_. 0)" by (auto simp: α_def)  
  lemma [simp]: "pr (λ_. 0::val) x = (λ_. 0)" 
    by (auto simp: pr_def[abs_def])
  lemma aux3: "the_default 0 (if b then Some x else None) = (if b then x else 0)"
    by auto
  lemma correct_aux: "distinct (map fst ts); ((r,c),_)set ts. r<n 
     i. fold (λ((r,c),v) y. y(c:=y c + x r * v)) ts m i = m i + pr (α ts) x i"  
    apply (induction ts arbitrary: m)
    apply auto
      apply (subst 2)
      apply auto 
      unfolding pr_def α_def
      apply (auto split: if_splits cong: sum.cong simp: aux3)
      apply (auto simp: if_distrib[where f="λx. _*x"] cong: sum.cong if_cong)
      apply (subst 2)
      apply auto 
      unfolding pr_def α_def
      apply (auto split: if_splits cong: sum.cong simp: aux3)

  lemma correct_fold: 
    assumes "distinct (map fst ts)"
    assumes "((r,c),_)set ts. r<n"
    shows "fold (λ((r,c),v) y. y(c:=y c + x r * v)) ts (λ_. 0) = pr (α ts) x"
    apply (rule ext)
    using correct_aux[OF assms, rule_format, where m = "λ_. 0", simplified]
    by simp

  lemma alg_by_fold: "alg ts x = fold (λ((r,c),v) y. y(c:=y c + x r * v)) ts (λ_. 0)"    
    unfolding alg_def by (simp add: fold_mset_rewr)
  theorem correct: 
    assumes "distinct (map fst ts)"
    assumes "((r,c),_)set ts. r<n"
    shows "alg ts x = pr (α ts) x"
    using alg_by_fold correct_fold[OF assms] by simp 

  subsection ‹Multi-Threaded Implementation›
  text ‹Correctness of the parallel implementation:›
  theorem parallel_correct:
    assumes "distinct (map fst ts)" "((r,c),_)set ts. r<n"
        and "0 < n" ― ‹At least on thread›
        ―‹We have reached a final state.›
        and "reachable x n ts (λ_. 0) (ts', ms, r)" "final n (ts', ms, r)"
      shows "r = pr (α ts) x"
    unfolding final_state_correct[OF assms(3-)] correct[OF assms(1,2)] alg_by_fold[symmetric] ..

  text ‹We also know that the computation will always terminate.›
  theorem parallel_termination:
    assumes "0 < n"
      and "reachable x n ts (λ_. 0) s"
    shows "s'. final n s'  (step x n)** s s'"
    using assms by (rule "termination")

end ― ‹Context for fixed n›.›