Theory Parallel_Multiset_Fold

section ‹Iterating a Commutative Computation Concurrently›

theory Parallel_Multiset_Fold
  imports "HOL-Library.Multiset"
begin

text ‹
This theory formalizes a deep embedding of a simple parallel computation model.
In this model, we formalize a computation scheme to execute a fold-function over a
commutative operation concurrently, and prove it correct.
›

subsection ‹Misc›

(* TODO: Move *)
lemma (in comp_fun_commute) fold_mset_rewr: "fold_mset f a (mset l) = fold f l a" 
  by (induction l arbitrary: a; clarsimp; metis fold_mset_fun_left_comm)

lemma finite_set_of_finite_maps:
  fixes A :: "'a set"
    and B :: "'b set"
  assumes "finite A"
    and "finite B"
  shows "finite {m. dom m  A  ran m  B}"
proof -
  have "{m. dom m  A  ran m  B}  ( S  {S. S  A}. {m. dom m = S  ran m  B})"
    by auto
  moreover have "finite "
    using assms by (auto intro!: finite_set_of_finite_maps intro: finite_subset)
  ultimately show ?thesis
    by (rule finite_subset)
qed

lemma wf_rtranclp_ev_induct[consumes 1, case_names step]:
  assumes "wf {(x, y). R y x}" and step: " x. R** a x  P x  ( y. R x y)"
  shows "x. P x  R** a x"
proof -
  have "y. P y  R** x y" if "R** a x" for x
    using assms(1) that
  proof induction
    case (less x)
    from step[OF R** a x] have "P x  (y. R x y)" .
    then show ?case
    proof
      assume "P x"
      then show ?case
        by auto
    next
      assume "y. R x y"
      then obtain y where "R x y" ..
      with less(1)[of y] less(2) show ?thesis
        by simp (meson converse_rtranclp_into_rtranclp rtranclp.rtrancl_into_rtrancl)
    qed
  qed
  then show ?thesis
    by blast
qed

subsection ‹The Concurrent System›
text ‹
  A state of our concurrent systems consists of a list of tasks,
  a partial map from threads to the task they are currently working on,
  and the current computation result.›
type_synonym ('a, 's) state = "'a list × (nat  'a) × 's"

context comp_fun_commute
begin

context
  fixes n :: nat ― ‹The number of threads.›
  assumes n_gt_0[simp, intro]: "n > 0"
begin

text ‹
  A state is ‹final› if there are no remaining tasks and if all workers have finished their work.›
definition
  "final  λ(ts, ws, r). ts = []  dom ws  {0..<n} = {}"

text ‹At any point a thread can:
   pick a new task from the queue if it is currently not busy
   or execute its current task.›
inductive step :: "('a, 'b) state  ('a, 'b) state  bool" where
  pick: "step (t # ts, ws, s) (ts, ws(i := Some t), s)"   if "ws i = None"   and "i < n"
| exec: "step (ts, ws, s)     (ts, ws(i := None), f a s)" if "ws i = Some a" and "i < n"

lemma no_deadlock:
  assumes "¬ final cfg"
  shows "cfg'. step cfg cfg'"
  using assms
  apply (cases cfg)
  apply safe
  subgoal for ts ws s
    by (cases ts; cases "ws 0") (auto 4 5 simp: final_def intro: step.intros)
  done

lemma wf_step:
  "wf {((ts', ws', r'), (ts, ws, r)).
    step (ts, ws, r) (ts', ws', r')  set ts'  S  dom ws  {0..<n}  ran ws  S}"
  if "finite S"
proof -
  let ?R1 = "{(x, y). dom x  dom y  ran x  S  dom y  {0..<n}  ran y  S}"
  have "?R1  {y. dom y  {0..<n}  ran y  S} × {y. dom y  {0..<n}  ran y  S}"
    by auto
  then have "finite ?R1"
    using finite S by - (erule finite_subset, auto intro: finite_set_of_finite_maps)
  then have [intro]: "wf ?R1"
    apply (rule finite_acyclic_wf)
    apply (rule preorder_class.acyclicI_order[where f = "λx. n - card (dom x)"])
    apply clarsimp
    by (metis (full_types) 
        cancel_ab_semigroup_add_class.diff_right_commute diff_diff_cancel domD domI
        psubsetI psubset_card_mono subset_eq_atLeast0_lessThan_card
        subset_eq_atLeast0_lessThan_finite zero_less_diff)
  let ?R = "measure length <*lex*> ?R1 <*lex*> {}"
  have "wf ?R"
    by auto
  then show ?thesis
    apply (rule wf_subset)
    apply clarsimp
    apply (erule step.cases; clarsimp)
    by (smt
        Diff_iff domIff fun_upd_apply mem_Collect_eq option.simps(3) psubsetI ran_def
        singletonI subset_iff)
qed

context
  fixes ts :: "'a list" and start :: "'b"
begin

definition
  "s0 = (ts, λ_. None, start)"

definition "reachable  (step**) s0"

lemma reachable0[simp]: "reachable s0"
  unfolding reachable_def by auto

definition "is_invar I  I s0  (s s'. reachable s  I s  step s s'  I s')"

lemma is_invarI[intro?]: 
  " I s0; s s'.  reachable s; I s; step s s'  I s'   is_invar I"
  by (auto simp: is_invar_def)

lemma invar_reachable: "is_invar I  reachable s  I s"  
  unfolding reachable_def
  by rotate_tac (induction rule: rtranclp_induct, auto simp: is_invar_def reachable_def)

definition
  "invar  λ(ts2, ws, r).
    (ts1.
      mset ts = ts1 + {# the (ws i). i ∈# mset_set (dom ws  {0..<n}) #} + mset ts2
     r = fold_mset f start ts1
     set ts2  set ts  ran ws  set ts  dom ws  {0..<n})"

lemma invariant:
  "is_invar invar"
  apply rule
  subgoal
    unfolding s0_def unfolding invar_def by simp
  subgoal
    unfolding invar_def
    apply (elim step.cases)
     apply (clarsimp split: option.split_asm)
    subgoal for ws i t ts ts1
      apply (rule exI[where x = ts1])
       apply (subst mset_set.insert)
         apply (auto intro!: multiset.map_cong0)
      done
    apply (clarsimp split!: prod.splits)
    subgoal for ws i a ts ts1
      apply (rule exI[where x = "add_mset a ts1"])
         apply (subst Diff_Int_distrib2)
         apply (subst mset_set.remove)
           apply (auto intro!: multiset.map_cong0 split: if_split_asm simp: ran_def)
      done
    done
  done

lemma final_state_correct1:
  assumes "invar (ts', ms, r)" "final (ts', ms, r)"
  shows "r = fold_mset f start (mset ts)"
  using assms unfolding invar_def final_def by auto

lemma final_state_correct2:
  assumes "reachable (ts', ms, r)" "final (ts', ms, r)"
  shows "r = fold_mset f start (mset ts)"
  using assms by - (rule final_state_correct1, rule invar_reachable[OF invariant])

text ‹Soundness: whenever we reach a final state, the computation result is correct.›
theorem final_state_correct:
  assumes "reachable (ts', ms, r)" "final (ts', ms, r)"
  shows "r = fold f ts start"
  using final_state_correct2[OF assms] by (simp add: fold_mset_rewr)

text ‹Termination: at any point during the program execution, we can continue to a final state.
That is, the computation always terminates.
›
theorem "termination":
  assumes "reachable s"
  shows "s'. final s'  step** s s'"
proof -
  have "{(s', s). step s s'  reachable s}  {(s', s). step s s'  reachable s  reachable s'}"
    unfolding reachable_def by auto
  also have "  {((ts', ws', r'), (ts1, ws, r)).
    step (ts1, ws, r) (ts', ws', r')  set ts'  set ts  dom ws  {0..<n}  ran ws  set ts}"
    by (force dest!: invar_reachable[OF invariant] simp: invar_def)
  finally have "wf {(s', s). step s s'  reachable s}"
    by (elim wf_subset[OF wf_step, rotated]) simp
  then have "s'. final s'  (λs s'. step s s'  reachable s)** s s'"
  proof (induction rule: wf_rtranclp_ev_induct)
    case (step x)
    then have "(λs s'. step s s')** s x"
      by (elim mono_rtranclp[rule_format, rotated] conjE)
    with reachable s have "reachable x"
      unfolding reachable_def by auto
    then show ?case
      using no_deadlock[of x] by auto
  qed
  then show ?thesis
    apply clarsimp
    apply (intro exI conjI, assumption)
    apply (rule mono_rtranclp[rule_format])
     apply auto
    done
qed

end (* Fixed task list *)

end (* Fixed number of workers *)

end (* Commutative function *)

text ‹The main theorems outside the locale:›
thm comp_fun_commute.final_state_correct comp_fun_commute.termination

end (* End of theory *)