Theory Efficient-Mergesort.Efficient_Sort

(*  Author:      Christian Sternagel <c.sternagel@gmail.com>
    Maintainer:  Christian Sternagel <c.sternagel@gmail.com>
*)
theory Efficient_Sort
  imports "HOL-Library.Multiset"
begin


section ‹GHC Version of Mergesort›

text ‹
  In the following we show that the mergesort implementation
  used in GHC (see 🌐‹http://hackage.haskell.org/package/base-4.11.1.0/docs/src/Data.OldList.html#sort›)
  is a correct and stable sorting algorithm. Furthermore, experimental
  data suggests that generated code for this implementation is much more
  efficient than for the implementation provided by theoryHOL-Library.Multiset.

  A high-level overview of an older version of this formalization as well as some experimental data
  is to be found in cite"Sternagel2012".
›
subsection ‹Definition of Natural Mergesort›

context
  fixes key :: "'a  'k::linorder"
begin

text ‹
  Split a list into chunks of ascending and descending parts, where
  descending parts are reversed on the fly.
  Thus, the result is a list of sorted lists.
›
fun sequences :: "'a list  'a list list"
  and asc :: "'a  ('a list  'a list)  'a list  'a list list"
  and desc :: "'a  'a list  'a list  'a list list"
  where
    "sequences (a # b # xs) =
      (if key a > key b then desc b [a] xs else asc b ((#) a) xs)"
  | "sequences [x] = [[x]]"
  | "sequences [] = []"
  | "asc a as (b # bs) =
      (if key a  key b then asc b (λys. as (a # ys)) bs
      else as [a] # sequences (b # bs))"
  | "asc a as [] = [as [a]]"
  | "desc a as (b # bs) =
      (if key a > key b then desc b (a # as) bs
      else (a # as) # sequences (b # bs))"
  | "desc a as [] = [a # as]"

fun merge :: "'a list  'a list  'a list"
  where
    "merge (a # as) (b # bs) =
      (if key a > key b then b # merge (a # as) bs else a # merge as (b # bs))"
  | "merge [] bs = bs"
  | "merge as [] = as"

fun merge_pairs :: "'a list list  'a list list"
  where
    "merge_pairs (a # b # xs) = merge a b # merge_pairs xs"
  | "merge_pairs xs = xs"

lemma length_merge [simp]:
  "length (merge xs ys) = length xs + length ys"
  by (induct xs ys rule: merge.induct) simp_all

lemma length_merge_pairs [simp]:
  "length (merge_pairs xs) = (1 + length xs) div 2"
  by (induct xs rule: merge_pairs.induct) simp_all

fun merge_all :: "'a list list  'a list"
  where
    "merge_all [] = []"
  | "merge_all [x] = x"
  | "merge_all xs = merge_all (merge_pairs xs)"

fun msort_key :: "'a list  'a list"
  where
    "msort_key xs = merge_all (sequences xs)"


subsection ‹The Functional Argument of @{const asc}

text f› is a function that only adds some prefix to a given list.›
definition "ascP f = (xs. f xs = f [] @ xs)"

lemma ascP_Cons [simp]: "ascP ((#) x)" by (simp add: ascP_def)

lemma ascP_comp_append_Cons [simp]:
  "ascP (λxs. f [] @ x # xs)"
  by (auto simp: ascP_def)

lemma ascP_f_Cons:
  assumes "ascP f"
  shows "f (x # xs) = f [] @ x # xs"
  using ascP f [unfolded ascP_def, THEN spec, of "x # xs"] .

lemma ascP_comp_Cons [simp]:
  assumes "ascP f"
  shows "ascP (λys. f (x # ys))"
proof (unfold ascP_def, intro allI)
  fix xs show "f (x # xs) = f [x] @ xs"
    using assms by (simp add: ascP_f_Cons)
qed

lemma ascP_f_singleton:
  assumes "ascP f"
  shows "f [x] = f [] @ [x]"
  by (rule ascP_f_Cons [OF assms])


subsection ‹Facts about Lengths›

lemma
  shows length_sequences: "length (sequences xs)  length xs"
    and length_asc: "ascP f  length (asc a f ys)  1 + length ys"
    and length_desc: "length (desc a xs ys)  1 + length ys"
  by (induct xs and a f ys and a xs ys rule: sequences_asc_desc.induct) auto

lemma length_concat_merge_pairs [simp]:
  "length (concat (merge_pairs xss)) = length (concat xss)"
  by (induct xss rule: merge_pairs.induct) simp_all


subsection ‹Functional Correctness›

lemma mset_merge [simp]:
  "mset (merge xs ys) = mset xs + mset ys"
  by (induct xs ys rule: merge.induct) simp_all

lemma set_merge [simp]:
  "set (merge xs ys) = set xs  set ys"
  by (simp flip: set_mset_mset)

lemma mset_concat_merge_pairs [simp]:
  "mset (concat (merge_pairs xs)) = mset (concat xs)"
  by (induct xs rule: merge_pairs.induct) auto

lemma set_concat_merge_pairs [simp]:
  "set (concat (merge_pairs xs)) = set (concat xs)"
  by (simp flip: set_mset_mset)

lemma mset_merge_all [simp]:
  "mset (merge_all xs) = mset (concat xs)"
  by (induct xs rule: merge_all.induct) simp_all

lemma set_merge_all [simp]:
  "set (merge_all xs) = set (concat xs)"
  by (simp flip: set_mset_mset)

lemma
  shows mset_seqeuences [simp]: "mset (concat (sequences xs)) = mset xs"
    and mset_asc: "ascP f  mset (concat (asc x f ys)) = {#x#} + mset (f []) + mset ys"
    and mset_desc: "mset (concat (desc x xs ys)) = {#x#} + mset xs + mset ys"
  by (induct xs and x f ys and x xs ys rule: sequences_asc_desc.induct)
    (auto simp: ascP_f_singleton)

lemma mset_msort_key:
  "mset (msort_key xs) = mset xs"
  by (auto)

lemma sorted_merge [simp]:
  assumes "sorted (map key xs)" and "sorted (map key ys)"
  shows "sorted (map key (merge xs ys))"
  using assms by (induct xs ys rule: merge.induct) (auto)

lemma sorted_merge_pairs [simp]:
  assumes "xset xs. sorted (map key x)"
  shows "xset (merge_pairs xs). sorted (map key x)"
  using assms by (induct xs rule: merge_pairs.induct) simp_all

lemma sorted_merge_all:
  assumes "xset xs. sorted (map key x)"
  shows "sorted (map key (merge_all xs))"
  using assms by (induct xs rule: merge_all.induct) simp_all

lemma
  shows sorted_sequences: "x  set (sequences xs). sorted (map key x)"
    and sorted_asc: "ascP f  sorted (map key (f []))  xset (f []). key x  key a  xset (asc a f ys). sorted (map key x)"
    and sorted_desc: "sorted (map key xs)  xset xs. key a  key x  xset (desc a xs ys). sorted (map key x)"
  by (induct xs and a f ys and a xs ys rule: sequences_asc_desc.induct)
    (auto simp: ascP_f_singleton sorted_append not_less dest: order_trans, fastforce)

lemma sorted_msort_key:
  "sorted (map key (msort_key xs))"
  by (unfold msort_key.simps) (intro sorted_merge_all sorted_sequences)


subsection ‹Stability›

lemma
  shows filter_by_key_sequences [simp]: "[yconcat (sequences xs). key y = k] = [yxs. key y = k]"
    and filter_by_key_asc: "ascP f  [yconcat (asc a f ys). key y = k] = [yf [a] @ ys. key y = k]"
    and filter_by_key_desc: "sorted (map key xs)  xset xs. key a  key x  [yconcat (desc a xs ys). key y = k] = [ya # xs @ ys. key y = k]"
proof (induct xs and a f ys and a xs ys rule: sequences_asc_desc.induct)
  case (4 a f b bs)
  then show ?case
    by (auto simp: o_def ascP_f_Cons [where f = f])
next
  case (6 a as b bs)
  then show ?case
  proof (cases "key b < key a")
    case True
    with 6 have "[yconcat (desc b (a # as) bs). key y = k] = [yb # (a # as) @ bs. key y = k]"
      by (auto simp: less_le order_trans)
    then show ?thesis
      using True and 6
      by (cases "key a = k", cases "key b = k")
        (auto simp: Cons_eq_append_conv intro!: filter_False)
  qed auto
qed auto
 
lemma filter_by_key_merge_is_append [simp]:
  assumes "sorted (map key xs)"
  shows "[ymerge xs ys. key y = k] = [yxs. key y = k] @ [yys. key y = k]"
  using assms
  by (induct xs ys rule: merge.induct) (auto simp: Cons_eq_append_conv leD intro!: filter_False)

lemma filter_by_key_merge_pairs [simp]:
  assumes "xsset xss. sorted (map key xs)"
  shows "[yconcat (merge_pairs xss). key y = k] = [yconcat xss. key y = k]"
  using assms by (induct xss rule: merge_pairs.induct) simp_all

lemma filter_by_key_merge_all [simp]:
  assumes "xsset xss. sorted (map key xs)"
  shows "[ymerge_all xss. key y = k] = [yconcat xss. key y = k]"
  using assms by (induct xss rule: merge_all.induct) simp_all

lemma filter_by_key_merge_all_sequences [simp]:
  "[xmerge_all (sequences xs) . key x = k] = [xxs . key x = k]"
  using sorted_sequences [of xs] by simp

lemma msort_key_stable:
  "[xmsort_key xs. key x = k] = [xxs. key x = k]"
  by auto

lemma sort_key_msort_key_conv:
  "sort_key key = msort_key"
  using msort_key_stable [of "key x" for x]
  by (intro ext properties_for_sort_key mset_msort_key sorted_msort_key)
    (metis (mono_tags, lifting) filter_cong)

end

text ‹
  Replace existing code equations for constsort_key by constmsort_key.
›
declare sort_key_by_quicksort_code [code del]
declare sort_key_msort_key_conv [code]

end