Theory Challenge3
section ‹Challenge 3›
theory Challenge3
imports Parallel_Multiset_Fold Refine_Imperative_HOL.IICF
text ‹Problem definition:
subsection ‹Single-Threaded Implementation›
text ‹We define type synonyms for values (which we fix to integers here) and
triplets, which are a pair of coordinates and a value.
type_synonym val = int
type_synonym triplet = "(nat × nat) × val"
text ‹We fix a size ‹n› for the vector.›
fixes n :: nat
text ‹An algorithm finishing triples in any order.
"alg (ts :: triplet list) x = fold_mset (λ((r,c),v) y. y(c:=y c + x r * v)) (λ_. 0 :: int) (mset ts)"
text ‹
We show that the folding function is commutative, i.e., the order of the folding does not matter.
We will use this below to show that the computation can be parallelized.
interpretation comp_fun_commute "(λ((r, c), v) y. y(c := (y c :: val) + x r * v))"
apply unfold_locales
apply (auto intro!: ext)
subsection ‹Specification›
text ‹Abstraction function, mapping a sparse matrix to a function from coordinates to values.›
definition α :: "triplet list ⇒ (nat × nat) ⇒ val" where
"α = the_default 0 oo map_of"
text ‹Abstract product.›
definition "pr m x i ≡ ∑k=0..<n. x k * m (k, i)"
subsection ‹Correctness›
lemma aux:
distinct (map fst (ts1@ts2)) ⟹
the_default (0::val) (case map_of ts1 (k, i) of None ⇒ map_of ts2 (k, i) | Some x ⇒ Some x)
= the_default 0 (map_of ts1 (k, i)) + the_default 0 (map_of ts2 (k, i))
apply (auto split: option.splits)
by (metis disjoint_iff_not_equal img_fst map_of_eq_None_iff the_default.simps(2))
lemma 1[simp]: "distinct (map fst (ts1@ts2)) ⟹
pr (α (ts1@ts2)) x i = pr (α ts1) x i + pr (α ts2) x i"
apply (auto simp: pr_def α_def map_add_def aux split: option.splits)
apply (auto simp: algebra_simps)
by (simp add: sum.distrib)
lemmas 2 = 1[of "[((r,c),v)]" "ts", simplified] for r c v ts
lemma [simp]: "α [] = (λ_. 0)" by (auto simp: α_def)
lemma [simp]: "pr (λ_. 0::val) x = (λ_. 0)"
by (auto simp: pr_def[abs_def])
lemma aux3: "the_default 0 (if b then Some x else None) = (if b then x else 0)"
by auto
lemma correct_aux: "⟦distinct (map fst ts); ∀((r,c),_)∈set ts. r<n⟧
⟹ ∀i. fold (λ((r,c),v) y. y(c:=y c + x r * v)) ts m i = m i + pr (α ts) x i"
apply (induction ts arbitrary: m)
apply auto
apply (subst 2)
apply auto
unfolding pr_def α_def
apply (auto split: if_splits cong: sum.cong simp: aux3)
apply (auto simp: if_distrib[where f="λx. _*x"] cong: sum.cong if_cong)
apply (subst 2)
apply auto
unfolding pr_def α_def
apply (auto split: if_splits cong: sum.cong simp: aux3)
lemma correct_fold:
assumes "distinct (map fst ts)"
assumes "∀((r,c),_)∈set ts. r<n"
shows "fold (λ((r,c),v) y. y(c:=y c + x r * v)) ts (λ_. 0) = pr (α ts) x"
apply (rule ext)
using correct_aux[OF assms, rule_format, where m = "λ_. 0", simplified]
by simp
lemma alg_by_fold: "alg ts x = fold (λ((r,c),v) y. y(c:=y c + x r * v)) ts (λ_. 0)"
unfolding alg_def by (simp add: fold_mset_rewr)
theorem correct:
assumes "distinct (map fst ts)"
assumes "∀((r,c),_)∈set ts. r<n"
shows "alg ts x = pr (α ts) x"
using alg_by_fold correct_fold[OF assms] by simp
subsection ‹Multi-Threaded Implementation›
text ‹Correctness of the parallel implementation:›
theorem parallel_correct:
assumes "distinct (map fst ts)" "∀((r,c),_)∈set ts. r<n"
and "0 < n"
and "reachable x n ts (λ_. 0) (ts', ms, r)" "final n (ts', ms, r)"
shows "r = pr (α ts) x"
unfolding final_state_correct[OF assms(3-)] correct[OF assms(1,2)] alg_by_fold[symmetric] ..
text ‹We also know that the computation will always terminate.›
theorem parallel_termination:
assumes "0 < n"
and "reachable x n ts (λ_. 0) s"
shows "∃s'. final n s' ∧ (step x n)⇧*⇧* s s'"
using assms by (rule "termination")