Theory Challenge3

```section ‹Challenge 3›
theory Challenge3
imports Parallel_Multiset_Fold Refine_Imperative_HOL.IICF
begin

text ‹Problem definition:
🌐‹https://ethz.ch/content/dam/ethz/special-interest/infk/chair-program-method/pm/documents/Verify%20This/Challenges%202019/sparse_matrix_multiplication.pdf››

text ‹We define type synonyms for values (which we fix to integers here) and
triplets, which are a pair of coordinates and a value.
›
type_synonym val = int
type_synonym triplet = "(nat × nat) × val"

text ‹We fix a size ‹n› for the vector.›
context
fixes n :: nat
begin

text ‹An algorithm finishing triples in any order.
›
definition
"alg (ts :: triplet list) x = fold_mset (λ((r,c),v) y. y(c:=y c + x r * v)) (λ_. 0 :: int) (mset ts)"

text ‹
We show that the folding function is commutative, i.e., the order of the folding does not matter.
We will use this below to show that the computation can be parallelized.
›
interpretation comp_fun_commute "(λ((r, c), v) y. y(c := (y c :: val) + x r * v))"
apply unfold_locales
apply (auto intro!: ext)
done

subsection ‹Specification›
text ‹Abstraction function, mapping a sparse matrix to a function from coordinates to values.›
definition α :: "triplet list ⇒ (nat × nat) ⇒ val" where
"α = the_default 0 oo map_of"

text ‹Abstract product.›
definition "pr m x i ≡ ∑k=0..<n. x k * m (k, i)"

subsection ‹Correctness›

lemma aux:
"
distinct (map fst (ts1@ts2)) ⟹
the_default (0::val) (case map_of ts1 (k, i) of None ⇒ map_of ts2 (k, i) | Some x ⇒ Some x)

= the_default 0 (map_of ts1 (k, i)) + the_default 0 (map_of ts2 (k, i))
"
apply (auto split: option.splits)
by (metis disjoint_iff_not_equal img_fst map_of_eq_None_iff the_default.simps(2))

lemma 1[simp]: "distinct (map fst (ts1@ts2)) ⟹
pr (α (ts1@ts2)) x i = pr (α ts1) x i + pr (α ts2) x i"
apply (auto simp: pr_def α_def map_add_def aux split: option.splits)
apply (auto simp: algebra_simps)
by (simp add: sum.distrib)

lemmas 2 = 1[of "[((r,c),v)]" "ts", simplified] for r c v ts

lemma [simp]: "α [] = (λ_. 0)" by (auto simp: α_def)

lemma [simp]: "pr (λ_. 0::val) x = (λ_. 0)"
by (auto simp: pr_def[abs_def])

lemma aux3: "the_default 0 (if b then Some x else None) = (if b then x else 0)"
by auto

lemma correct_aux: "⟦distinct (map fst ts); ∀((r,c),_)∈set ts. r<n⟧
⟹ ∀i. fold (λ((r,c),v) y. y(c:=y c + x r * v)) ts m i = m i + pr (α ts) x i"
apply (induction ts arbitrary: m)
apply auto
subgoal
apply (subst 2)
apply auto
unfolding pr_def α_def
apply (auto split: if_splits cong: sum.cong simp: aux3)
apply (auto simp: if_distrib[where f="λx. _*x"] cong: sum.cong if_cong)
done

subgoal
apply (subst 2)
apply auto
unfolding pr_def α_def
apply (auto split: if_splits cong: sum.cong simp: aux3)
done
done

lemma correct_fold:
assumes "distinct (map fst ts)"
assumes "∀((r,c),_)∈set ts. r<n"
shows "fold (λ((r,c),v) y. y(c:=y c + x r * v)) ts (λ_. 0) = pr (α ts) x"
apply (rule ext)
using correct_aux[OF assms, rule_format, where m = "λ_. 0", simplified]
by simp

lemma alg_by_fold: "alg ts x = fold (λ((r,c),v) y. y(c:=y c + x r * v)) ts (λ_. 0)"
unfolding alg_def by (simp add: fold_mset_rewr)

theorem correct:
assumes "distinct (map fst ts)"
assumes "∀((r,c),_)∈set ts. r<n"
shows "alg ts x = pr (α ts) x"
using alg_by_fold correct_fold[OF assms] by simp

text ‹Correctness of the parallel implementation:›
theorem parallel_correct:
assumes "distinct (map fst ts)" "∀((r,c),_)∈set ts. r<n"
and "0 < n" ― ‹At least on thread›
―‹We have reached a final state.›
and "reachable x n ts (λ_. 0) (ts', ms, r)" "final n (ts', ms, r)"
shows "r = pr (α ts) x"
unfolding final_state_correct[OF assms(3-)] correct[OF assms(1,2)] alg_by_fold[symmetric] ..

text ‹We also know that the computation will always terminate.›
theorem parallel_termination:
assumes "0 < n"
and "reachable x n ts (λ_. 0) s"
shows "∃s'. final n s' ∧ (step x n)⇧*⇧* s s'"
using assms by (rule "termination")

end ― ‹Context for fixed ‹n›.›

end```