Theory Challenge3
section ‹Challenge 3›
theory Challenge3
imports Parallel_Multiset_Fold Refine_Imperative_HOL.IICF
begin
text ‹Problem definition:
🌐‹https://ethz.ch/content/dam/ethz/special-interest/infk/chair-program-method/pm/documents/Verify%20This/Challenges%202019/sparse_matrix_multiplication.pdf››
subsection ‹Single-Threaded Implementation›
text ‹We define type synonyms for values (which we fix to integers here) and
triplets, which are a pair of coordinates and a value.
›
type_synonym val = int
type_synonym triplet = "(nat × nat) × val"
text ‹We fix a size ‹n› for the vector.›
context
fixes n :: nat
begin
text ‹An algorithm finishing triples in any order.
›
definition
"alg (ts :: triplet list) x = fold_mset (λ((r,c),v) y. y(c:=y c + x r * v)) (λ_. 0 :: int) (mset ts)"
text ‹
We show that the folding function is commutative, i.e., the order of the folding does not matter.
We will use this below to show that the computation can be parallelized.
›
interpretation comp_fun_commute "(λ((r, c), v) y. y(c := (y c :: val) + x r * v))"
apply unfold_locales
apply (auto intro!: ext)
done
subsection ‹Specification›
text ‹Abstraction function, mapping a sparse matrix to a function from coordinates to values.›
definition α :: "triplet list ⇒ (nat × nat) ⇒ val" where
"α = the_default 0 oo map_of"
text ‹Abstract product.›
definition "pr m x i ≡ ∑k=0..<n. x k * m (k, i)"
subsection ‹Correctness›
lemma aux:
"
distinct (map fst (ts1@ts2)) ⟹
the_default (0::val) (case map_of ts1 (k, i) of None ⇒ map_of ts2 (k, i) | Some x ⇒ Some x)
= the_default 0 (map_of ts1 (k, i)) + the_default 0 (map_of ts2 (k, i))
"
apply (auto split: option.splits)
by (metis disjoint_iff_not_equal img_fst map_of_eq_None_iff the_default.simps(2))
lemma 1[simp]: "distinct (map fst (ts1@ts2)) ⟹
pr (α (ts1@ts2)) x i = pr (α ts1) x i + pr (α ts2) x i"
apply (auto simp: pr_def α_def map_add_def aux split: option.splits)
apply (auto simp: algebra_simps)
by (simp add: sum.distrib)
lemmas 2 = 1[of "[((r,c),v)]" "ts", simplified] for r c v ts
lemma [simp]: "α [] = (λ_. 0)" by (auto simp: α_def)
lemma [simp]: "pr (λ_. 0::val) x = (λ_. 0)"
by (auto simp: pr_def[abs_def])
lemma aux3: "the_default 0 (if b then Some x else None) = (if b then x else 0)"
by auto
lemma correct_aux: "⟦distinct (map fst ts); ∀((r,c),_)∈set ts. r<n⟧
⟹ ∀i. fold (λ((r,c),v) y. y(c:=y c + x r * v)) ts m i = m i + pr (α ts) x i"
apply (induction ts arbitrary: m)
apply auto
subgoal
apply (subst 2)
apply auto
unfolding pr_def α_def
apply (auto split: if_splits cong: sum.cong simp: aux3)
apply (auto simp: if_distrib[where f="λx. _*x"] cong: sum.cong if_cong)
done
subgoal
apply (subst 2)
apply auto
unfolding pr_def α_def
apply (auto split: if_splits cong: sum.cong simp: aux3)
done
done
lemma correct_fold:
assumes "distinct (map fst ts)"
assumes "∀((r,c),_)∈set ts. r<n"
shows "fold (λ((r,c),v) y. y(c:=y c + x r * v)) ts (λ_. 0) = pr (α ts) x"
apply (rule ext)
using correct_aux[OF assms, rule_format, where m = "λ_. 0", simplified]
by simp
lemma alg_by_fold: "alg ts x = fold (λ((r,c),v) y. y(c:=y c + x r * v)) ts (λ_. 0)"
unfolding alg_def by (simp add: fold_mset_rewr)
theorem correct:
assumes "distinct (map fst ts)"
assumes "∀((r,c),_)∈set ts. r<n"
shows "alg ts x = pr (α ts) x"
using alg_by_fold correct_fold[OF assms] by simp
subsection ‹Multi-Threaded Implementation›
text ‹Correctness of the parallel implementation:›
theorem parallel_correct:
assumes "distinct (map fst ts)" "∀((r,c),_)∈set ts. r<n"
and "0 < n"
and "reachable x n ts (λ_. 0) (ts', ms, r)" "final n (ts', ms, r)"
shows "r = pr (α ts) x"
unfolding final_state_correct[OF assms(3-)] correct[OF assms(1,2)] alg_by_fold[symmetric] ..
text ‹We also know that the computation will always terminate.›
theorem parallel_termination:
assumes "0 < n"
and "reachable x n ts (λ_. 0) s"
shows "∃s'. final n s' ∧ (step x n)⇧*⇧* s s'"
using assms by (rule "termination")
end
end