Theory Sorting_Algorithms

(*  Title:      HOL/Library/Sorting_Algorithms.thy
    Author:     Florian Haftmann, TU Muenchen
*)

theory Sorting_Algorithms
  imports Main Multiset Comparator
begin

section ‹Stably sorted lists›

abbreviation (input) stable_segment :: 'a comparator  'a  'a list  'a list
  where stable_segment cmp x  filter (λy. compare cmp x y = Equiv)

fun sorted :: 'a comparator  'a list  bool
  where sorted_Nil: sorted cmp []  True
  | sorted_single: sorted cmp [x]  True
  | sorted_rec: sorted cmp (y # x # xs)  compare cmp y x  Greater  sorted cmp (x # xs)

lemma sorted_ConsI:
  sorted cmp (x # xs) if sorted cmp xs
    and y ys. xs = y # ys  compare cmp x y  Greater
  using that by (cases xs) simp_all

lemma sorted_Cons_imp_sorted:
  sorted cmp xs if sorted cmp (x # xs)
  using that by (cases xs) simp_all

lemma sorted_Cons_imp_not_less:
  compare cmp y x  Greater if sorted cmp (y # xs)
    and x  set xs
  using that by (induction xs arbitrary: y) (auto dest: compare.trans_not_greater)

lemma sorted_induct [consumes 1, case_names Nil Cons, induct pred: sorted]:
  P xs if sorted cmp xs and P []
    and *: x xs. sorted cmp xs  P xs
       (y. y  set xs  compare cmp x y  Greater)  P (x # xs)
using sorted cmp xs proof (induction xs)
  case Nil
  show ?case
    by (rule P [])
next
  case (Cons x xs)
  from sorted cmp (x # xs) have sorted cmp xs
    by (cases xs) simp_all
  moreover have P xs using sorted cmp xs
    by (rule Cons.IH)
  moreover have compare cmp x y  Greater if y  set xs for y
  using that sorted cmp (x # xs) proof (induction xs)
    case Nil
    then show ?case
      by simp
  next
    case (Cons z zs)
    then show ?case
    proof (cases zs)
      case Nil
      with Cons.prems show ?thesis
        by simp
    next
      case (Cons w ws)
      with Cons.prems have compare cmp z w  Greater compare cmp x z  Greater
        by auto
      then have compare cmp x w  Greater
        by (auto dest: compare.trans_not_greater)
      with Cons show ?thesis
        using Cons.prems Cons.IH by auto
    qed
  qed
  ultimately show ?case
    by (rule *)
qed

lemma sorted_induct_remove1 [consumes 1, case_names Nil minimum]:
  P xs if sorted cmp xs and P []
    and *: x xs. sorted cmp xs  P (remove1 x xs)
       x  set xs  hd (stable_segment cmp x xs) = x  (y. y  set xs  compare cmp x y  Greater)
     P xs
using sorted cmp xs proof (induction xs)
  case Nil
  show ?case
    by (rule P [])
next
  case (Cons x xs)
  then have sorted cmp (x # xs)
    by (simp add: sorted_ConsI)
  moreover note Cons.IH
  moreover have y. compare cmp x y = Greater  y  set xs  False
    using Cons.hyps by simp
  ultimately show ?case
    by (auto intro!: * [of x # xs x]) blast
qed

lemma sorted_remove1:
  sorted cmp (remove1 x xs) if sorted cmp xs
proof (cases x  set xs)
  case False
  with that show ?thesis
    by (simp add: remove1_idem)
next
  case True
  with that show ?thesis proof (induction xs)
    case Nil
    then show ?case
      by simp
  next
    case (Cons y ys)
    show ?case proof (cases x = y)
      case True
      with Cons.hyps show ?thesis
        by simp
    next
      case False
      then have sorted cmp (remove1 x ys)
        using Cons.IH Cons.prems by auto
      then have sorted cmp (y # remove1 x ys)
      proof (rule sorted_ConsI)
        fix z zs
        assume remove1 x ys = z # zs
        with x  y have z  set ys
          using notin_set_remove1 [of z ys x] by auto
        then show compare cmp y z  Greater
          by (rule Cons.hyps(2))
      qed
      with False show ?thesis
        by simp
    qed
  qed
qed

lemma sorted_stable_segment:
  sorted cmp (stable_segment cmp x xs)
proof (induction xs)
  case Nil
  show ?case
    by simp
next
  case (Cons y ys)
  then show ?case
    by (auto intro!: sorted_ConsI simp add: filter_eq_Cons_iff compare.sym)
      (auto dest: compare.trans_equiv simp add: compare.sym compare.greater_iff_sym_less)

qed

primrec insort :: 'a comparator  'a  'a list  'a list
  where insort cmp y [] = [y]
  | insort cmp y (x # xs) = (if compare cmp y x  Greater
       then y # x # xs
       else x # insort cmp y xs)

lemma mset_insort [simp]:
  mset (insort cmp x xs) = add_mset x (mset xs)
  by (induction xs) simp_all

lemma length_insort [simp]:
  length (insort cmp x xs) = Suc (length xs)
  by (induction xs) simp_all

lemma sorted_insort:
  sorted cmp (insort cmp x xs) if sorted cmp xs
using that proof (induction xs)
  case Nil
  then show ?case
    by simp
next
  case (Cons y ys)
  then show ?case by (cases ys)
    (auto, simp_all add: compare.greater_iff_sym_less)
qed

lemma stable_insort_equiv:
  stable_segment cmp y (insort cmp x xs) = x # stable_segment cmp y xs
    if compare cmp y x = Equiv
proof (induction xs)
  case Nil
  from that show ?case
    by simp
next
  case (Cons z xs)
  moreover from that have compare cmp y z = Equiv  compare cmp z x = Equiv
    by (auto intro: compare.trans_equiv simp add: compare.sym)
  ultimately show ?case
    using that by (auto simp add: compare.greater_iff_sym_less)
qed

lemma stable_insort_not_equiv:
  stable_segment cmp y (insort cmp x xs) = stable_segment cmp y xs
    if compare cmp y x  Equiv
  using that by (induction xs) simp_all

lemma remove1_insort_same_eq [simp]:
  remove1 x (insort cmp x xs) = xs
  by (induction xs) simp_all

lemma insort_eq_ConsI:
  insort cmp x xs = x # xs
    if sorted cmp xs y. y  set xs  compare cmp x y  Greater
  using that by (induction xs) (simp_all add: compare.greater_iff_sym_less)

lemma remove1_insort_not_same_eq [simp]:
  remove1 y (insort cmp x xs) = insort cmp x (remove1 y xs)
    if sorted cmp xs x  y
using that proof (induction xs)
  case Nil
  then show ?case
    by simp
next
  case (Cons z zs)
  show ?case
  proof (cases compare cmp x z = Greater)
    case True
    with Cons show ?thesis
      by simp
  next
    case False
    then have compare cmp x y  Greater if y  set zs for y
      using that Cons.hyps
      by (auto dest: compare.trans_not_greater)
    with Cons show ?thesis
      by (simp add: insort_eq_ConsI)
  qed
qed

lemma insort_remove1_same_eq:
  insort cmp x (remove1 x xs) = xs
    if sorted cmp xs and x  set xs and hd (stable_segment cmp x xs) = x
using that proof (induction xs)
  case Nil
  then show ?case
    by simp
next
  case (Cons y ys)
  then have compare cmp x y  Less
    by (auto simp add: compare.greater_iff_sym_less)
  then consider compare cmp x y = Greater | compare cmp x y = Equiv
    by (cases compare cmp x y) auto
  then show ?case proof cases
    case 1
    with Cons.prems Cons.IH show ?thesis
      by auto
  next
    case 2
    with Cons.prems have x = y
      by simp
    with Cons.hyps show ?thesis
      by (simp add: insort_eq_ConsI)
  qed
qed

lemma sorted_append_iff:
  sorted cmp (xs @ ys)  sorted cmp xs  sorted cmp ys
      (x  set xs. y  set ys. compare cmp x y  Greater) (is ?P  ?R  ?S  ?Q)
proof
  assume ?P
  have ?R
    using ?P by (induction xs)
      (auto simp add: sorted_Cons_imp_not_less,
        auto simp add: sorted_Cons_imp_sorted intro: sorted_ConsI)
  moreover have ?S
    using ?P by (induction xs) (auto dest: sorted_Cons_imp_sorted)
  moreover have ?Q
    using ?P by (induction xs) (auto simp add: sorted_Cons_imp_not_less,
      simp add: sorted_Cons_imp_sorted)
  ultimately show ?R  ?S  ?Q
    by simp
next
  assume ?R  ?S  ?Q
  then have ?R ?S ?Q
    by simp_all
  then show ?P
    by (induction xs)
      (auto simp add: append_eq_Cons_conv intro!: sorted_ConsI)
qed

definition sort :: 'a comparator  'a list  'a list
  where sort cmp xs = foldr (insort cmp) xs []

lemma sort_simps [simp]:
  sort cmp [] = []
  sort cmp (x # xs) = insort cmp x (sort cmp xs)
  by (simp_all add: sort_def)

lemma mset_sort [simp]:
  mset (sort cmp xs) = mset xs
  by (induction xs) simp_all

lemma length_sort [simp]:
  length (sort cmp xs) = length xs
  by (induction xs) simp_all

lemma sorted_sort [simp]:
  sorted cmp (sort cmp xs)
  by (induction xs) (simp_all add: sorted_insort)

lemma stable_sort:
  stable_segment cmp x (sort cmp xs) = stable_segment cmp x xs
  by (induction xs) (simp_all add: stable_insort_equiv stable_insort_not_equiv)

lemma sort_remove1_eq [simp]:
  sort cmp (remove1 x xs) = remove1 x (sort cmp xs)
  by (induction xs) simp_all

lemma set_insort [simp]:
  set (insort cmp x xs) = insert x (set xs)
  by (induction xs) auto

lemma set_sort [simp]:
  set (sort cmp xs) = set xs
  by (induction xs) auto

lemma sort_eqI:
  sort cmp ys = xs
    if permutation: mset ys = mset xs
    and sorted: sorted cmp xs
    and stable: y. y  set ys 
      stable_segment cmp y ys = stable_segment cmp y xs
proof -
  have stable': stable_segment cmp y ys =
    stable_segment cmp y xs for y
  proof (cases xset ys. compare cmp y x = Equiv)
    case True
    then obtain z where z  set ys and compare cmp y z = Equiv
      by auto
    then have compare cmp y x = Equiv  compare cmp z x = Equiv for x
      by (meson compare.sym compare.trans_equiv)
    moreover have stable_segment cmp z ys =
      stable_segment cmp z xs
      using z  set ys by (rule stable)
    ultimately show ?thesis
      by simp
  next
    case False
    moreover from permutation have set ys = set xs
      by (rule mset_eq_setD)
    ultimately show ?thesis
      by simp
  qed
  show ?thesis
  using sorted permutation stable' proof (induction xs arbitrary: ys rule: sorted_induct_remove1)
    case Nil
    then show ?case
      by simp
  next
    case (minimum x xs)
    from mset ys = mset xs have ys: set ys = set xs
      by (rule mset_eq_setD)
    then have compare cmp x y  Greater if y  set ys for y
      using that minimum.hyps by simp
    from minimum.prems have stable: stable_segment cmp x ys = stable_segment cmp x xs
      by simp
    have sort cmp (remove1 x ys) = remove1 x xs
      by (rule minimum.IH) (simp_all add: minimum.prems filter_remove1)
    then have remove1 x (sort cmp ys) = remove1 x xs
      by simp
    then have insort cmp x (remove1 x (sort cmp ys)) =
      insort cmp x (remove1 x xs)
      by simp
    also from minimum.hyps ys stable have insort cmp x (remove1 x (sort cmp ys)) = sort cmp ys
      by (simp add: stable_sort insort_remove1_same_eq)
    also from minimum.hyps have insort cmp x (remove1 x xs) = xs
      by (simp add: insort_remove1_same_eq)
    finally show ?case .
  qed
qed

lemma filter_insort:
  filter P (insort cmp x xs) = insort cmp x (filter P xs)
    if sorted cmp xs and P x
  using that by (induction xs)
    (auto simp add: compare.trans_not_greater insort_eq_ConsI)

lemma filter_insort_triv:
  filter P (insort cmp x xs) = filter P xs
    if ¬ P x
  using that by (induction xs) simp_all

lemma filter_sort:
  filter P (sort cmp xs) = sort cmp (filter P xs)
  by (induction xs) (auto simp add: filter_insort filter_insort_triv)


section ‹Alternative sorting algorithms›

subsection ‹Quicksort›

definition quicksort :: 'a comparator  'a list  'a list
  where quicksort_is_sort [simp]: quicksort = sort

lemma sort_by_quicksort:
  sort = quicksort
  by simp

lemma sort_by_quicksort_rec:
  sort cmp xs = sort cmp [xxs. compare cmp x (xs ! (length xs div 2)) = Less]
    @ stable_segment cmp (xs ! (length xs div 2)) xs
    @ sort cmp [xxs. compare cmp x (xs ! (length xs div 2)) = Greater] (is _ = ?rhs)
proof (rule sort_eqI)
  show mset xs = mset ?rhs
    by (rule multiset_eqI) (auto simp add: compare.sym intro: comp.exhaust)
next
  show sorted cmp ?rhs
    by (auto simp add: sorted_append_iff sorted_stable_segment compare.equiv_subst_right dest: compare.trans_greater)
next
  let ?pivot = xs ! (length xs div 2)
  fix l
  have compare cmp x ?pivot = comp  compare cmp l x = Equiv
     compare cmp l ?pivot = comp  compare cmp l x = Equiv for x comp
  proof -
    have compare cmp x ?pivot = comp  compare cmp l ?pivot = comp
      if compare cmp l x = Equiv
      using that by (simp add: compare.equiv_subst_left compare.sym)
    then show ?thesis by blast
  qed
  then show stable_segment cmp l xs = stable_segment cmp l ?rhs
    by (simp add: stable_sort compare.sym [of _ ?pivot])
      (cases compare cmp l ?pivot, simp_all)
qed

context
begin

qualified definition partition :: 'a comparator  'a  'a list  'a list × 'a list × 'a list
  where partition cmp pivot xs =
    ([x  xs. compare cmp x pivot = Less], stable_segment cmp pivot xs, [x  xs. compare cmp x pivot = Greater])

qualified lemma partition_code [code]:
  partition cmp pivot [] = ([], [], [])
  partition cmp pivot (x # xs) =
    (let (lts, eqs, gts) = partition cmp pivot xs
     in case compare cmp x pivot of
       Less  (x # lts, eqs, gts)
     | Equiv  (lts, x # eqs, gts)
     | Greater  (lts, eqs, x # gts))
  using comp.exhaust by (auto simp add: partition_def Let_def compare.sym [of _ pivot])

lemma quicksort_code [code]:
  quicksort cmp xs =
    (case xs of
      []  []
    | [x]  xs
    | [x, y]  (if compare cmp x y  Greater then xs else [y, x])
    | _ 
        let (lts, eqs, gts) = partition cmp (xs ! (length xs div 2)) xs
        in quicksort cmp lts @ eqs @ quicksort cmp gts)
proof (cases length xs  3)
  case False
  then have length xs  {0, 1, 2}
    by (auto simp add: not_le le_less less_antisym)
  then consider xs = [] | x where xs = [x] | x y where xs = [x, y]
    by (auto simp add: length_Suc_conv numeral_2_eq_2)
  then show ?thesis
    by cases simp_all
next
  case True
  then obtain x y z zs where xs = x # y # z # zs
    by (metis le_0_eq length_0_conv length_Cons list.exhaust not_less_eq_eq numeral_3_eq_3)
  moreover have quicksort cmp xs =
    (let (lts, eqs, gts) = partition cmp (xs ! (length xs div 2)) xs
    in quicksort cmp lts @ eqs @ quicksort cmp gts)
    using sort_by_quicksort_rec [of cmp xs] by (simp add: partition_def)
  ultimately show ?thesis
    by simp
qed

end


subsection ‹Mergesort›

definition mergesort :: 'a comparator  'a list  'a list
  where mergesort_is_sort [simp]: mergesort = sort

lemma sort_by_mergesort:
  sort = mergesort
  by simp

context
  fixes cmp :: 'a comparator
begin

qualified function merge :: 'a list  'a list  'a list
  where merge [] ys = ys
  | merge xs [] = xs
  | merge (x # xs) (y # ys) = (if compare cmp x y = Greater
      then y # merge (x # xs) ys else x # merge xs (y # ys))
  by pat_completeness auto

qualified termination by lexicographic_order

lemma mset_merge:
  mset (merge xs ys) = mset xs + mset ys
  by (induction xs ys rule: merge.induct) simp_all

lemma merge_eq_Cons_imp:
  xs  []  z = hd xs  ys  []  z = hd ys
    if merge xs ys = z # zs
  using that by (induction xs ys rule: merge.induct) (auto split: if_splits)

lemma filter_merge:
  filter P (merge xs ys) = merge (filter P xs) (filter P ys)
    if sorted cmp xs and sorted cmp ys
using that proof (induction xs ys rule: merge.induct)
  case (1 ys)
  then show ?case
    by simp
next
  case (2 xs)
  then show ?case
    by simp
next
  case (3 x xs y ys)
  show ?case
  proof (cases compare cmp x y = Greater)
    case True
    with 3 have hyp: filter P (merge (x # xs) ys) =
      merge (filter P (x # xs)) (filter P ys)
      by (simp add: sorted_Cons_imp_sorted)
    show ?thesis
    proof (cases ¬ P x  P y)
      case False
      with compare cmp x y = Greater show ?thesis
        by (auto simp add: hyp)
    next
      case True
      from compare cmp x y = Greater "3.prems"
      have *: compare cmp z y = Greater if z  set (filter P xs) for z
        using that by (auto dest: compare.trans_not_greater sorted_Cons_imp_not_less)
      from compare cmp x y = Greater show ?thesis
        by (cases filter P xs) (simp_all add: hyp *)
    qed
  next
    case False
    with 3 have hyp: filter P (merge xs (y # ys)) =
      merge (filter P xs) (filter P (y # ys))
      by (simp add: sorted_Cons_imp_sorted)
    show ?thesis
    proof (cases P x  ¬ P y)
      case False
      with compare cmp x y  Greater show ?thesis
        by (auto simp add: hyp)
    next
      case True
      from compare cmp x y  Greater "3.prems"
      have *: compare cmp x z  Greater if z  set (filter P ys) for z
        using that by (auto dest: compare.trans_not_greater sorted_Cons_imp_not_less)
      from compare cmp x y  Greater show ?thesis
        by (cases filter P ys) (simp_all add: hyp *)
    qed
  qed
qed

lemma sorted_merge:
  sorted cmp (merge xs ys) if sorted cmp xs and sorted cmp ys
using that proof (induction xs ys rule: merge.induct)
  case (1 ys)
  then show ?case
    by simp
next
  case (2 xs)
  then show ?case
    by simp
next
  case (3 x xs y ys)
  show ?case
  proof (cases compare cmp x y = Greater)
    case True
    with 3 have sorted cmp (merge (x # xs) ys)
      by (simp add: sorted_Cons_imp_sorted)
    then have sorted cmp (y # merge (x # xs) ys)
    proof (rule sorted_ConsI)
      fix z zs
      assume merge (x # xs) ys = z # zs
      with 3(4) True show compare cmp y z  Greater
        by (clarsimp simp add: sorted_Cons_imp_sorted dest!: merge_eq_Cons_imp)
          (auto simp add: compare.asym_greater sorted_Cons_imp_not_less)
    qed
    with True show ?thesis
      by simp
  next
    case False
    with 3 have sorted cmp (merge xs (y # ys))
      by (simp add: sorted_Cons_imp_sorted)
    then have sorted cmp (x # merge xs (y # ys))
    proof (rule sorted_ConsI)
      fix z zs
      assume merge xs (y # ys) = z # zs
      with 3(3) False show compare cmp x z  Greater
        by (clarsimp simp add: sorted_Cons_imp_sorted dest!: merge_eq_Cons_imp)
          (auto simp add: compare.asym_greater sorted_Cons_imp_not_less)
    qed
    with False show ?thesis
      by simp
  qed
qed

lemma merge_eq_appendI:
  merge xs ys = xs @ ys
    if x y. x  set xs  y  set ys  compare cmp x y  Greater
  using that by (induction xs ys rule: merge.induct) simp_all

lemma merge_stable_segments:
  merge (stable_segment cmp l xs) (stable_segment cmp l ys) =
     stable_segment cmp l xs @ stable_segment cmp l ys
  by (rule merge_eq_appendI) (auto dest: compare.trans_equiv_greater)

lemma sort_by_mergesort_rec:
  sort cmp xs =
    merge (sort cmp (take (length xs div 2) xs))
      (sort cmp (drop (length xs div 2) xs)) (is _ = ?rhs)
proof (rule sort_eqI)
  have mset (take (length xs div 2) xs) + mset (drop (length xs div 2) xs) =
    mset (take (length xs div 2) xs @ drop (length xs div 2) xs)
    by (simp only: mset_append)
  then show mset xs = mset ?rhs
    by (simp add: mset_merge)
next
  show sorted cmp ?rhs
    by (simp add: sorted_merge)
next
  fix l
  have stable_segment cmp l (take (length xs div 2) xs) @ stable_segment cmp l (drop (length xs div 2) xs)
    = stable_segment cmp l xs
    by (simp only: filter_append [symmetric] append_take_drop_id)
  have merge (stable_segment cmp l (take (length xs div 2) xs))
    (stable_segment cmp l (drop (length xs div 2) xs)) =
    stable_segment cmp l (take (length xs div 2) xs) @ stable_segment cmp l (drop (length xs div 2) xs)
    by (rule merge_eq_appendI) (auto simp add: compare.trans_equiv_greater)
  also have  = stable_segment cmp l xs
    by (simp only: filter_append [symmetric] append_take_drop_id)
  finally show stable_segment cmp l xs = stable_segment cmp l ?rhs
    by (simp add: stable_sort filter_merge)
qed

lemma mergesort_code [code]:
  mergesort cmp xs =
    (case xs of
      []  []
    | [x]  xs
    | [x, y]  (if compare cmp x y  Greater then xs else [y, x])
    | _ 
        let
          half = length xs div 2;
          ys = take half xs;
          zs = drop half xs
        in merge (mergesort cmp ys) (mergesort cmp zs))
proof (cases length xs  3)
  case False
  then have length xs  {0, 1, 2}
    by (auto simp add: not_le le_less less_antisym)
  then consider xs = [] | x where xs = [x] | x y where xs = [x, y]
    by (auto simp add: length_Suc_conv numeral_2_eq_2)
  then show ?thesis
    by cases simp_all
next
  case True
  then obtain x y z zs where xs = x # y # z # zs
    by (metis le_0_eq length_0_conv length_Cons list.exhaust not_less_eq_eq numeral_3_eq_3)
  moreover have mergesort cmp xs =
    (let
       half = length xs div 2;
       ys = take half xs;
       zs = drop half xs
     in merge (mergesort cmp ys) (mergesort cmp zs))
    using sort_by_mergesort_rec [of xs] by (simp add: Let_def)
  ultimately show ?thesis
    by simp
qed

end


subsection ‹Lexicographic products›

lemma sorted_prod_lex_imp_sorted_fst:
  sorted (key fst cmp1) ps if sorted (prod_lex cmp1 cmp2) ps
using that proof (induction rule: sorted_induct)
  case Nil
  then show ?case
    by simp
next
  case (Cons p ps)
  have compare (key fst cmp1) p q  Greater if ps = q # qs for q qs
    using that Cons.hyps(2) [of q] by (simp add: compare_prod_lex_apply split: comp.splits)
  with Cons.IH show ?case
    by (rule sorted_ConsI) simp
qed

lemma sorted_prod_lex_imp_sorted_snd:
  sorted (key snd cmp2) ps if sorted (prod_lex cmp1 cmp2) ps a' b'. (a', b')  set ps  compare cmp1 a a' = Equiv
using that proof (induction rule: sorted_induct)
  case Nil
  then show ?case
    by simp
next
  case (Cons p ps)
  then show ?case 
    apply (cases p)
    apply (rule sorted_ConsI)
     apply (simp_all add: compare_prod_lex_apply)
     apply (auto cong del: comp.case_cong_weak)
    apply (metis comp.simps(8) compare.equiv_subst_left)
    done
qed

lemma sort_comp_fst_snd_eq_sort_prod_lex:
  sort (key fst cmp1)  sort (key snd cmp2) = sort (prod_lex cmp1 cmp2)  (is sort ?cmp1  sort ?cmp2 = sort ?cmp)
proof
  fix ps :: ('a × 'b) list
  have sort ?cmp1 (sort ?cmp2 ps) = sort ?cmp ps
  proof (rule sort_eqI)
    show mset (sort ?cmp2 ps) = mset (sort ?cmp ps)
      by simp
    show sorted ?cmp1 (sort ?cmp ps)
      by (rule sorted_prod_lex_imp_sorted_fst [of _ cmp2]) simp
  next
    fix p :: 'a × 'b
    define a b where ab: a = fst p b = snd p
    moreover assume p  set (sort ?cmp2 ps)
    ultimately have (a, b)  set (sort ?cmp2 ps)
      by simp
    let ?qs = filter (λ(a', _). compare cmp1 a a' = Equiv) ps
    have sort ?cmp2 ?qs = sort ?cmp ?qs
    proof (rule sort_eqI)
      show mset ?qs = mset (sort ?cmp ?qs)
        by simp
      show sorted ?cmp2 (sort ?cmp ?qs)
        by (rule sorted_prod_lex_imp_sorted_snd) auto
    next
      fix q :: 'a × 'b
      define c d where c = fst q d = snd q
      moreover assume q  set ?qs
      ultimately have (c, d)  set ?qs
        by simp
      from sorted_stable_segment [of ?cmp (a, d) ps]
      have sorted ?cmp (filter (λ(c, b). compare (prod_lex cmp1 cmp2) (a, d) (c, b) = Equiv) ps)
        by (simp only: case_prod_unfold prod.collapse)
      also have (λ(c, b). compare (prod_lex cmp1 cmp2) (a, d) (c, b) = Equiv) =
        (λ(c, b). compare cmp1 a c = Equiv  compare cmp2 d b = Equiv)
        by (simp add: fun_eq_iff compare_prod_lex_apply split: comp.split)
      finally have *: sorted ?cmp (filter (λ(c, b). compare cmp1 a c = Equiv  compare cmp2 d b = Equiv) ps) .
      let ?rs = filter (λ(_, d'). compare cmp2 d d' = Equiv) ?qs
      have sort ?cmp ?rs = ?rs
        by (rule sort_eqI) (use * in simp_all add: case_prod_unfold)
      then show filter (λr. compare ?cmp2 q r = Equiv) ?qs =
        filter (λr. compare ?cmp2 q r = Equiv) (sort ?cmp ?qs)
        by (simp add: filter_sort case_prod_unfold flip: d = snd q)
    qed      
    then show filter (λq. compare ?cmp1 p q = Equiv) (sort ?cmp2 ps) =
      filter (λq. compare ?cmp1 p q = Equiv) (sort ?cmp ps)
      by (simp add: filter_sort case_prod_unfold flip: ab)
  qed
  then show (sort (key fst cmp1)  sort (key snd cmp2)) ps = sort (prod_lex cmp1 cmp2) ps
    by simp
qed

end