Theory Closest_Pair

section "Closest Pair Algorithm"

theory Closest_Pair
  imports Common
begin

text‹
  Formalization of a slightly optimized divide-and-conquer algorithm solving the Closest Pair Problem
  based on the presentation of Cormen \emph{et al.} cite"Introduction-to-Algorithms:2009".
›

subsection "Functional Correctness Proof"

subsubsection "Combine Step"

fun find_closest_tm :: "point  real  point list  point tm" where
  "find_closest_tm _ _ [] =1 return undefined"
| "find_closest_tm _ _ [p] =1 return p"
| "find_closest_tm p δ (p0 # ps) =1 (
    if δ  snd p0 - snd p then
      return p0
    else
      do {
        p1 <- find_closest_tm p (min δ (dist p p0)) ps;
        if dist p p0  dist p p1 then
          return p0
        else
          return p1
      }
  )"

fun find_closest :: "point  real  point list  point" where
  "find_closest _ _ [] = undefined"
| "find_closest _ _ [p] = p"
| "find_closest p δ (p0 # ps) = (
    if δ  snd p0 - snd p then
      p0
    else
      let p1 = find_closest p (min δ (dist p p0)) ps in
      if dist p p0  dist p p1 then
        p0
      else
        p1
  )"

lemma find_closest_eq_val_find_closest_tm:
  "val (find_closest_tm p δ ps) = find_closest p δ ps"
  by (induction p δ ps rule: find_closest.induct) (auto simp: Let_def)

lemma find_closest_set:
  "0 < length ps  find_closest p δ ps  set ps"
  by (induction p δ ps rule: find_closest.induct)
     (auto simp: Let_def)

lemma find_closest_dist:
  assumes "sorted_snd (p # ps)" "q  set ps. dist p q < δ"
  shows "q  set ps. dist p (find_closest p δ ps)  dist p q"
  using assms
proof (induction p δ ps rule: find_closest.induct)
  case (3 p δ p0 p2 ps)
  let ?ps = "p0 # p2 # ps"
  define p1 where p1_def: "p1 = find_closest p (min δ (dist p p0)) (p2 # ps)"
  have A: "¬ δ  snd p0 - snd p"
  proof (rule ccontr)
    assume B: "¬ ¬ δ  snd p0 - snd p"
    have "q  set ?ps. snd p  snd q"
      using sorted_snd_def "3.prems"(1) by simp
    moreover have "q  set ?ps. δ  snd q - snd p"
      using sorted_snd_def "3.prems"(1) B by auto
    ultimately have "q  set ?ps. δ  dist (snd p) (snd q)"
      using dist_real_def by simp
    hence "q  set ?ps. δ  dist p q"
      using dist_snd_le order_trans
      apply (auto split: prod.splits) by fastforce+
    thus False
      using "3.prems"(2) by fastforce
  qed
  show ?case
  proof cases
    assume "q  set (p2 # ps). dist p q < min δ (dist p p0)"
    hence "q  set (p2 # ps). dist p p1  dist p q"
      using "3.IH" "3.prems"(1) A p1_def sorted_snd_def by simp
    thus ?thesis
      using p1_def A by (auto split: prod.splits)
  next
    assume B: "¬ (q  set (p2 # ps). dist p q < min δ (dist p p0))"
    hence "dist p p0 < δ"
      using "3.prems"(2) p1_def by auto
    hence C: "q  set ?ps. dist p p0  dist p q"
      using p1_def B by auto
    have "p1  set (p2 # ps)"
      using p1_def find_closest_set by blast
    hence "dist p p0  dist p p1"
      using p1_def C by auto
    thus ?thesis
      using p1_def A C by (auto split: prod.splits)
  qed
qed auto

declare find_closest.simps [simp del]

fun find_closest_pair_tm :: "(point * point)  point list  (point × point) tm" where
  "find_closest_pair_tm (c0, c1) [] =1 return (c0, c1)"
| "find_closest_pair_tm (c0, c1) [_] =1 return (c0, c1)"
| "find_closest_pair_tm (c0, c1) (p0 # ps) =1 (
    do {
      p1 <- find_closest_tm p0 (dist c0 c1) ps;
      if dist c0 c1  dist p0 p1 then
        find_closest_pair_tm (c0, c1) ps
      else
        find_closest_pair_tm (p0, p1) ps
    }
  )"

fun find_closest_pair :: "(point * point)  point list  (point × point)" where
  "find_closest_pair (c0, c1) [] = (c0, c1)"
| "find_closest_pair (c0, c1) [_] = (c0, c1)"
| "find_closest_pair (c0, c1) (p0 # ps) = (
    let p1 = find_closest p0 (dist c0 c1) ps in
    if dist c0 c1  dist p0 p1 then
      find_closest_pair (c0, c1) ps
    else
      find_closest_pair (p0, p1) ps
  )"

lemma find_closest_pair_eq_val_find_closest_pair_tm:
  "val (find_closest_pair_tm (c0, c1) ps) = find_closest_pair (c0, c1) ps"
  by (induction "(c0, c1)" ps arbitrary: c0 c1 rule: find_closest_pair.induct)
     (auto simp: Let_def find_closest_eq_val_find_closest_tm)

lemma find_closest_pair_set:
  assumes "(C0, C1) = find_closest_pair (c0, c1) ps"
  shows "(C0  set ps  C1  set ps)  (C0 = c0  C1 = c1)"
  using assms
proof (induction "(c0, c1)" ps arbitrary: c0 c1 C0 C1 rule: find_closest_pair.induct)
  case (3 c0 c1 p0 p2 ps)
  define p1 where p1_def: "p1 = find_closest p0 (dist c0 c1) (p2 # ps)"
  hence A: "p1  set (p2 # ps)"
    using find_closest_set by blast
  show ?case
  proof (cases "dist c0 c1  dist p0 p1")
    case True
    obtain C0' C1' where C'_def: "(C0', C1') = find_closest_pair (c0, c1) (p2 # ps)"
      using prod.collapse by blast
    note defs = p1_def C'_def
    hence "(C0'  set (p2 # ps)  C1'  set (p2 # ps))  (C0' = c0  C1' = c1)"
      using "3.hyps"(1) True p1_def by blast
    moreover have "C0 = C0'" "C1 = C1'"
      using defs True "3.prems" by (auto split: prod.splits, metis Pair_inject)+
    ultimately show ?thesis
      by auto
  next
    case False
    obtain C0' C1' where C'_def: "(C0', C1') = find_closest_pair (p0, p1) (p2 # ps)"
      using prod.collapse by blast
    note defs = p1_def C'_def
    hence "(C0'  set (p2 # ps)  C1'  set (p2 # ps))  (C0' = p0  C1' = p1)"
      using "3.hyps"(2) p1_def False by blast
    moreover have "C0 = C0'" "C1 = C1'"
      using defs False "3.prems" by (auto split: prod.splits, metis Pair_inject)+
    ultimately show ?thesis
      using A by auto
  qed
qed auto

lemma find_closest_pair_c0_ne_c1:
  "c0  c1  distinct ps  (C0, C1) = find_closest_pair (c0, c1) ps  C0  C1"
proof (induction "(c0, c1)" ps arbitrary: c0 c1 C0 C1 rule: find_closest_pair.induct)
  case (3 c0 c1 p0 p2 ps)
  define p1 where p1_def: "p1 = find_closest p0 (dist c0 c1) (p2 # ps)"
  hence A: "p0  p1"
    using "3.prems"(1,2)
    by (metis distinct.simps(2) find_closest_set length_pos_if_in_set list.set_intros(1))
  show ?case
  proof (cases "dist c0 c1  dist p0 p1")
    case True
    obtain C0' C1' where C'_def: "(C0', C1') = find_closest_pair (c0, c1) (p2 # ps)"
      using prod.collapse by blast
    note defs = p1_def C'_def
    hence "C0'  C1'"
      using "3.hyps"(1) "3.prems"(1,2) True p1_def by simp
    moreover have "C0 = C0'" "C1 = C1'"
      using defs True "3.prems"(3) by (auto split: prod.splits, metis Pair_inject)+
    ultimately show ?thesis
      by simp
  next
    case False
    obtain C0' C1' where C'_def: "(C0', C1') = find_closest_pair (p0, p1) (p2 # ps)"
      using prod.collapse by blast
    note defs = p1_def C'_def
    hence "C0'  C1'"
      using "3.hyps"(2) "3.prems"(2) A False p1_def by simp
    moreover have "C0 = C0'" "C1 = C1'"
      using defs False "3.prems"(3) by (auto split: prod.splits, metis Pair_inject)+
    ultimately show ?thesis
      by simp
  qed
qed auto

lemma find_closest_pair_dist_mono:
  assumes "(C0, C1) = find_closest_pair (c0, c1) ps"
  shows "dist C0 C1  dist c0 c1"
  using assms
proof (induction "(c0, c1)" ps arbitrary: c0 c1 C0 C1 rule: find_closest_pair.induct)
  case (3 c0 c1 p0 p2 ps)
  define p1 where p1_def: "p1 = find_closest p0 (dist c0 c1) (p2 # ps)"
  show ?case
  proof (cases "dist c0 c1  dist p0 p1")
    case True
    obtain C0' C1' where C'_def: "(C0', C1') = find_closest_pair (c0, c1) (p2 # ps)"
      using prod.collapse by blast
    note defs = p1_def C'_def
    hence "dist C0' C1'  dist c0 c1"
      using "3.hyps"(1) True p1_def by simp
    moreover have "C0 = C0'" "C1 = C1'"
      using defs True "3.prems" by (auto split: prod.splits, metis Pair_inject)+
    ultimately show ?thesis
      by simp
  next
    case False
    obtain C0' C1' where C'_def: "(C0', C1') = find_closest_pair (p0, p1) (p2 # ps)"
      using prod.collapse by blast
    note defs = p1_def C'_def
    hence "dist C0' C1'  dist p0 p1"
      using "3.hyps"(2) False p1_def by blast
    moreover have "C0 = C0'" "C1 = C1'"
      using defs False "3.prems"(1) by (auto split: prod.splits, metis Pair_inject)+
    ultimately show ?thesis
      using False by simp
  qed
qed auto

lemma find_closest_pair_dist:
  assumes "sorted_snd ps" "(C0, C1) = find_closest_pair (c0, c1) ps"
  shows "sparse (dist C0 C1) (set ps)"
  using assms
proof (induction "(c0, c1)" ps arbitrary: c0 c1 C0 C1 rule: find_closest_pair.induct)
  case (3 c0 c1 p0 p2 ps)
  define p1 where p1_def: "p1 = find_closest p0 (dist c0 c1) (p2 # ps)"
  show ?case
  proof cases
    assume "p  set (p2 # ps). dist p0 p < dist c0 c1"
    hence A: "p  set (p2 # ps). dist p0 p1  dist p0 p" "dist p0 p1 < dist c0 c1"
      using p1_def find_closest_dist "3.prems"(1) le_less_trans by blast+
    obtain C0' C1' where C'_def: "(C0', C1') = find_closest_pair (p0, p1) (p2 # ps)"
      using prod.collapse by blast
    hence B: "(C0', C1') = find_closest_pair (c0, c1) (p0 # p2 # ps)"
      using A(2) p1_def by simp
    have "sparse (dist C0' C1') (set (p2 # ps))"
      using "3.hyps"(2)[of p1 C0' C1'] p1_def C'_def "3.prems"(1) A(2) sorted_snd_def by fastforce
    moreover have "dist C0' C1'  dist p0 p1"
      using C'_def find_closest_pair_dist_mono by blast
    ultimately have "sparse (dist C0' C1') (set (p0 # p2 # ps))"
      using A sparse_identity order_trans by blast
    thus ?thesis
      using B by (metis "3.prems"(2) Pair_inject)
  next
    assume A: "¬ (p  set (p2 # ps). dist p0 p < dist c0 c1)"
    hence B: "dist c0 c1  dist p0 p1"
      using find_closest_set[of "p2 # ps" p0 "dist c0 c1"] p1_def by auto
    obtain C0' C1' where C'_def: "(C0', C1') = find_closest_pair (c0, c1) (p2 # ps)"
      using prod.collapse by blast
    hence C: "(C0', C1') = find_closest_pair (c0, c1) (p0 # p2 # ps)"
      using B p1_def by simp
    have "sparse (dist C0' C1') (set (p2 # ps))"
      using "3.hyps"(1)[of p1 C0' C1'] p1_def C'_def B "3.prems" sorted_snd_def by simp
    moreover have "dist C0' C1'  dist c0 c1"
      using C'_def find_closest_pair_dist_mono by blast
    ultimately have "sparse (dist C0' C1') (set (p0 # p2 # ps))"
      using A sparse_identity[of "dist C0' C1'" "p2 # ps" p0] order_trans by force
    thus ?thesis
      using C by (metis "3.prems"(2) Pair_inject)
  qed
qed (auto simp: sparse_def)

declare find_closest_pair.simps [simp del]

fun combine_tm :: "(point × point)  (point × point)  int  point list  (point × point) tm" where
  "combine_tm (p0L, p1L) (p0R, p1R) l ps =1 (
    let (c0, c1) = if dist p0L p1L < dist p0R p1R then (p0L, p1L) else (p0R, p1R) in
    do {
      ps' <- filter_tm (λp. dist p (l, snd p) < dist c0 c1) ps;
      find_closest_pair_tm (c0, c1) ps'
    }
  )"

fun combine :: "(point × point)  (point × point)  int  point list  (point × point)" where
  "combine (p0L, p1L) (p0R, p1R) l ps = (
    let (c0, c1) = if dist p0L p1L < dist p0R p1R then (p0L, p1L) else (p0R, p1R) in
    let ps' = filter (λp. dist p (l, snd p) < dist c0 c1) ps in
    find_closest_pair (c0, c1) ps'
  )"

lemma combine_eq_val_combine_tm:
  "val (combine_tm (p0L, p1L) (p0R, p1R) l ps) = combine (p0L, p1L) (p0R, p1R) l ps"
  by (auto simp: filter_eq_val_filter_tm find_closest_pair_eq_val_find_closest_pair_tm)

lemma combine_set:
  assumes "(c0, c1) = combine (p0L, p1L) (p0R, p1R) l ps"
  shows "(c0  set ps  c1  set ps)  (c0 = p0L  c1 = p1L)  (c0 = p0R  c1 = p1R)"
proof -
  obtain C0' C1' where C'_def: "(C0', C1') = (if dist p0L p1L < dist p0R p1R then (p0L, p1L) else (p0R, p1R))"
    by metis
  define ps' where ps'_def: "ps' = filter (λp. dist p (l, snd p) < dist C0' C1') ps"
  obtain C0 C1 where C_def: "(C0, C1) = find_closest_pair (C0', C1') ps'"
    using prod.collapse by blast
  note defs = C'_def ps'_def C_def
  have "(C0  set ps'  C1  set ps')  (C0 = C0'  C1 = C1')"
    using C_def find_closest_pair_set by blast+
  hence "(C0  set ps  C1  set ps) (C0 = C0'  C1 = C1')"
    using ps'_def by auto
  moreover have "C0 = c0" "C1 = c1"
    using assms defs apply (auto split: if_splits prod.splits) by (metis Pair_inject)+
  ultimately show ?thesis
    using C'_def by (auto split: if_splits)
qed

lemma combine_c0_ne_c1:
  assumes "p0L  p1L" "p0R  p1R" "distinct ps"
  assumes "(c0, c1) = combine (p0L, p1L) (p0R, p1R) l ps"
  shows "c0  c1"
proof -
  obtain C0' C1' where C'_def: "(C0', C1') = (if dist p0L p1L < dist p0R p1R then (p0L, p1L) else (p0R, p1R))"
    by metis
  define ps' where ps'_def: "ps' = filter (λp. dist p (l, snd p) < dist C0' C1') ps"
  obtain C0 C1 where C_def: "(C0, C1) = find_closest_pair (C0', C1') ps'"
    using prod.collapse by blast
  note defs = C'_def ps'_def C_def
  have "C0  C1"
    using defs find_closest_pair_c0_ne_c1[of C0' C1' ps'] assms by (auto split: if_splits)
  moreover have "C0 = c0" "C1 = c1"
    using assms(4) defs apply (auto split: if_splits prod.splits) by (metis Pair_inject)+
  ultimately show ?thesis
    by blast
qed

lemma combine_dist:
  assumes "sorted_snd ps" "set ps = psL  psR"
  assumes "p  psL. fst p  l" "p  psR. l  fst p"
  assumes "sparse (dist p0L p1L) psL" "sparse (dist p0R p1R) psR"
  assumes "(c0, c1) = combine (p0L, p1L) (p0R, p1R) l ps"
  shows "sparse (dist c0 c1) (set ps)"
proof -
  obtain C0' C1' where C'_def: "(C0', C1') = (if dist p0L p1L < dist p0R p1R then (p0L, p1L) else (p0R, p1R))"
    by metis
  define ps' where ps'_def: "ps' = filter (λp. dist p (l, snd p) < dist C0' C1') ps"
  obtain C0 C1 where C_def: "(C0, C1) = find_closest_pair (C0', C1') ps'"
    using prod.collapse by blast
  note defs = C'_def ps'_def C_def
  have EQ: "C0 = c0" "C1 = c1"
    using defs assms(7) apply (auto split: if_splits prod.splits) by (metis Pair_inject)+
  have ps': "ps' = filter (λp. l - dist C0' C1' < fst p  fst p < l + dist C0' C1') ps"
    using ps'_def dist_transform by simp
  have psL: "sparse (dist C0' C1') psL"
    using assms(3,5) C'_def sparse_def apply (auto split: if_splits) by force+
  have psR: "sparse (dist C0' C1') psR"
    using assms(4,6) C'_def sparse_def apply (auto split: if_splits) by force+
  have "sorted_snd ps'"
    using ps'_def assms(1) sorted_snd_def sorted_wrt_filter by blast
  hence *: "sparse (dist C0 C1) (set ps')"
    using find_closest_pair_dist C_def by simp
  have "p0  set ps. p1  set ps. p0  p1  dist p0 p1 < dist C0' C1'  p0  set ps'  p1  set ps'"
    using set_band_filter ps' psL psR assms(2,3,4) by blast
  moreover have "dist C0 C1  dist C0' C1'"
    using C_def find_closest_pair_dist_mono by blast
  ultimately have "p0  set ps. p1  set ps. p0  p1  dist p0 p1 < dist C0 C1  p0  set ps'  p1  set ps'"
    by simp
  hence "sparse (dist C0 C1) (set ps)"
    using sparse_def * by (meson not_less)
  thus ?thesis
    using EQ by blast
qed

declare combine.simps [simp del]
declare combine_tm.simps[simp del]

subsubsection "Divide and Conquer Algorithm"

declare split_at_take_drop_conv [simp add]

function closest_pair_rec_tm :: "point list  (point list × point × point) tm" where
  "closest_pair_rec_tm xs =1 (
    do {
      n <- length_tm xs;
      if n  3 then
        do {
          ys <- mergesort_tm snd xs;
          p <- closest_pair_bf_tm xs;
          return (ys, p)
        }
      else
        do {
          (xsL, xsR) <- split_at_tm (n div 2) xs;
          (ysL, p0L, p1L) <- closest_pair_rec_tm xsL;
          (ysR, p0R, p1R) <- closest_pair_rec_tm xsR;
          ys <- merge_tm snd ysL ysR;
          (p0, p1) <- combine_tm (p0L, p1L) (p0R, p1R) (fst (hd xsR)) ys;
          return (ys, p0, p1)
       }
    }
  )"
  by pat_completeness auto
termination closest_pair_rec_tm
  by (relation "Wellfounded.measure (λxs. length xs)")
     (auto simp add: length_eq_val_length_tm split_at_eq_val_split_at_tm)

function closest_pair_rec :: "point list  (point list * point * point)" where
  "closest_pair_rec xs = (
    let n = length xs in
    if n  3 then
      (mergesort snd xs, closest_pair_bf xs)
    else
      let (xsL, xsR) = split_at (n div 2) xs in
      let (ysL, p0L, p1L) = closest_pair_rec xsL in
      let (ysR, p0R, p1R) = closest_pair_rec xsR in
      let ys = merge snd ysL ysR in
      (ys, combine (p0L, p1L) (p0R, p1R) (fst (hd xsR)) ys)
  )"
  by pat_completeness auto
termination closest_pair_rec
  by (relation "Wellfounded.measure (λxs. length xs)")
     (auto simp: Let_def)

declare split_at_take_drop_conv [simp del]

lemma closest_pair_rec_simps:
  assumes "n = length xs" "¬ (n  3)"
  shows "closest_pair_rec xs = (
    let (xsL, xsR) = split_at (n div 2) xs in
    let (ysL, p0L, p1L) = closest_pair_rec xsL in
    let (ysR, p0R, p1R) = closest_pair_rec xsR in
    let ys = merge snd ysL ysR in
    (ys, combine (p0L, p1L) (p0R, p1R) (fst (hd xsR)) ys)
  )"
  using assms by (auto simp: Let_def)

declare closest_pair_rec.simps [simp del]

lemma closest_pair_rec_eq_val_closest_pair_rec_tm:
  "val (closest_pair_rec_tm xs) = closest_pair_rec xs"
proof (induction rule: length_induct)
  case (1 xs)
  define n where "n = length xs"
  obtain xsL xsR where xs_def: "(xsL, xsR) = split_at (n div 2) xs"
    by (metis surj_pair)
  note defs = n_def xs_def
  show ?case
  proof cases
    assume "n  3"
    then show ?thesis
      using defs
      by (auto simp: length_eq_val_length_tm mergesort_eq_val_mergesort_tm
                     closest_pair_bf_eq_val_closest_pair_bf_tm closest_pair_rec.simps)
  next
    assume asm: "¬ n  3"
    have "length xsL < length xs" "length xsR < length xs"
      using asm defs by (auto simp: split_at_take_drop_conv)
    hence "val (closest_pair_rec_tm xsL) = closest_pair_rec xsL"
          "val (closest_pair_rec_tm xsR) = closest_pair_rec xsR"
      using "1.IH" by blast+
    thus ?thesis
      using asm defs
      apply (subst closest_pair_rec.simps, subst closest_pair_rec_tm.simps)
      by (auto simp del: closest_pair_rec_tm.simps
               simp add: Let_def length_eq_val_length_tm merge_eq_val_merge_tm
                         split_at_eq_val_split_at_tm combine_eq_val_combine_tm
               split: prod.split)
  qed
qed

lemma closest_pair_rec_set_length_sorted_snd:
  assumes "(ys, p) = closest_pair_rec xs"
  shows "set ys = set xs  length ys = length xs  sorted_snd ys"
  using assms
proof (induction xs arbitrary: ys p rule: length_induct)
  case (1 xs)
  let ?n = "length xs"
  show ?case
  proof (cases "?n  3")
    case True
    thus ?thesis using "1.prems" sorted_snd_def
      by (auto simp: mergesort closest_pair_rec.simps)
  next
    case False

    obtain XSL XSR where XSLR_def: "(XSL, XSR) = split_at (?n div 2) xs"
      using prod.collapse by blast
    define L where "L = fst (hd XSR)"
    obtain YSL PL where YSPL_def: "(YSL, PL) = closest_pair_rec XSL"
      using prod.collapse by blast
    obtain YSR PR where YSPR_def: "(YSR, PR) = closest_pair_rec XSR"
      using prod.collapse by blast
    define YS where "YS = merge (λp. snd p) YSL YSR"
    define P where "P = combine PL PR L YS"
    note defs = XSLR_def L_def YSPL_def YSPR_def YS_def P_def

    have "length XSL < length xs" "length XSR < length xs"
      using False defs by (auto simp: split_at_take_drop_conv)
    hence IH: "set XSL = set YSL" "set XSR = set YSR"
              "length XSL = length YSL" "length XSR = length YSR"
              "sorted_snd YSL" "sorted_snd YSR"
      using "1.IH" defs by metis+

    have "set xs = set XSL  set XSR"
      using defs by (auto simp: set_take_drop split_at_take_drop_conv)
    hence SET: "set xs = set YS"
      using set_merge IH(1,2) defs by fast

    have "length xs = length XSL + length XSR"
      using defs by (auto simp: split_at_take_drop_conv)
    hence LENGTH: "length xs = length YS"
      using IH(3,4) length_merge defs by metis

    have SORTED: "sorted_snd YS"
      using IH(5,6) by (simp add: defs sorted_snd_def sorted_merge)

    have "(YS, P) = closest_pair_rec xs"
      using False closest_pair_rec_simps defs by (auto simp: Let_def split: prod.split)
    hence "(ys, p) = (YS, P)"
      using "1.prems" by argo
    thus ?thesis
      using SET LENGTH SORTED by simp
  qed
qed

lemma closest_pair_rec_distinct:
  assumes "distinct xs" "(ys, p) = closest_pair_rec xs"
  shows "distinct ys"
  using assms
proof (induction xs arbitrary: ys p rule: length_induct)
  case (1 xs)
  let ?n = "length xs"
  show ?case
  proof (cases "?n  3")
    case True
    thus ?thesis using "1.prems"
      by (auto simp: mergesort closest_pair_rec.simps)
  next
    case False

    obtain XSL XSR where XSLR_def: "(XSL, XSR) = split_at (?n div 2) xs"
      using prod.collapse by blast
    define L where "L = fst (hd XSR)"
    obtain YSL PL where YSPL_def: "(YSL, PL) = closest_pair_rec XSL"
      using prod.collapse by blast
    obtain YSR PR where YSPR_def: "(YSR, PR) = closest_pair_rec XSR"
      using prod.collapse by blast
    define YS where "YS = merge (λp. snd p) YSL YSR"
    define P where "P = combine PL PR L YS"
    note defs = XSLR_def L_def YSPL_def YSPR_def YS_def P_def

    have "length XSL < length xs" "length XSR < length xs"
      using False defs by (auto simp: split_at_take_drop_conv)
    moreover have "distinct XSL" "distinct XSR"
      using "1.prems"(1) defs by (auto simp: split_at_take_drop_conv)
    ultimately have IH: "distinct YSL" "distinct YSR"
      using "1.IH" defs by blast+

    have "set XSL  set XSR = {}"
      using "1.prems"(1) defs by (auto simp: split_at_take_drop_conv set_take_disj_set_drop_if_distinct)
    moreover have "set XSL = set YSL" "set XSR = set YSR"
      using closest_pair_rec_set_length_sorted_snd defs by blast+
    ultimately have "set YSL  set YSR = {}"
      by blast
    hence DISTINCT: "distinct YS"
      using distinct_merge IH defs by blast

    have "(YS, P) = closest_pair_rec xs"
      using False closest_pair_rec_simps defs by (auto simp: Let_def split: prod.split)
    hence "(ys, p) = (YS, P)"
      using "1.prems" by argo
    thus ?thesis
      using DISTINCT by blast
  qed
qed

lemma closest_pair_rec_c0_c1:
  assumes "1 < length xs" "distinct xs" "(ys, c0, c1) = closest_pair_rec xs"
  shows "c0  set xs  c1  set xs  c0  c1"
  using assms
proof (induction xs arbitrary: ys c0 c1 rule: length_induct)
  case (1 xs)
  let ?n = "length xs"
  show ?case
  proof (cases "?n  3")
    case True
    hence "(c0, c1) = closest_pair_bf xs"
      using "1.prems"(3) closest_pair_rec.simps by simp
    thus ?thesis
      using "1.prems"(1,2) closest_pair_bf_c0_c1 by simp
  next
    case False

    obtain XSL XSR where XSLR_def: "(XSL, XSR) = split_at (?n div 2) xs"
      using prod.collapse by blast
    define L where "L = fst (hd XSR)"

    obtain YSL C0L C1L where YSC01L_def: "(YSL, C0L, C1L) = closest_pair_rec XSL"
      using prod.collapse by metis
    obtain YSR C0R C1R where YSC01R_def: "(YSR, C0R, C1R) = closest_pair_rec XSR"
      using prod.collapse by metis

    define YS where "YS = merge (λp. snd p) YSL YSR"
    obtain C0 C1 where C01_def: "(C0, C1) = combine (C0L, C1L) (C0R, C1R) L YS"
      using prod.collapse by metis
    note defs = XSLR_def L_def YSC01L_def YSC01R_def YS_def C01_def

    have "1 < length XSL" "length XSL < length xs" "distinct XSL"
      using False "1.prems"(2) defs by (auto simp: split_at_take_drop_conv)
    hence "C0L  set XSL" "C1L  set XSL" and IHL1: "C0L  C1L"
      using "1.IH" defs by metis+
    hence IHL2: "C0L  set xs" "C1L  set xs"
      using split_at_take_drop_conv in_set_takeD fst_conv defs by metis+

    have "1 < length XSR" "length XSR < length xs" "distinct XSR"
      using False "1.prems"(2) defs by (auto simp: split_at_take_drop_conv)
    hence "C0R  set XSR" "C1R  set XSR" and IHR1: "C0R  C1R"
      using "1.IH" defs by metis+
    hence IHR2: "C0R  set xs" "C1R  set xs"
      using split_at_take_drop_conv in_set_dropD snd_conv defs by metis+

    have *: "(YS, C0, C1) = closest_pair_rec xs"
      using False closest_pair_rec_simps defs by (auto simp: Let_def split: prod.split)
    have YS: "set xs = set YS" "distinct YS"
      using "1.prems"(2) closest_pair_rec_set_length_sorted_snd closest_pair_rec_distinct * by blast+

    have "C0  set xs" "C1  set xs"
      using combine_set IHL2 IHR2 YS defs by blast+
    moreover have "C0  C1"
      using combine_c0_ne_c1 IHL1(1) IHR1(1) YS defs by blast
    ultimately show ?thesis
      using "1.prems"(3) * by (metis Pair_inject)
  qed
qed

lemma closest_pair_rec_dist:
  assumes "1 < length xs" "sorted_fst xs" "(ys, c0, c1) = closest_pair_rec xs"
  shows "sparse (dist c0 c1) (set xs)"
  using assms
proof (induction xs arbitrary: ys c0 c1 rule: length_induct)
  case (1 xs)
  let ?n = "length xs"
  show ?case
  proof (cases "?n  3")
    case True
    hence "(c0, c1) = closest_pair_bf xs"
      using "1.prems"(3) closest_pair_rec.simps by simp
    thus ?thesis
      using "1.prems"(1,3) closest_pair_bf_dist by metis
  next
    case False

    obtain XSL XSR where XSLR_def: "(XSL, XSR) = split_at (?n div 2) xs"
      using prod.collapse by blast
    define L where "L = fst (hd XSR)"

    obtain YSL C0L C1L where YSC01L_def: "(YSL, C0L, C1L) = closest_pair_rec XSL"
      using prod.collapse by metis
    obtain YSR C0R C1R where YSC01R_def: "(YSR, C0R, C1R) = closest_pair_rec XSR"
      using prod.collapse by metis

    define YS where "YS = merge (λp. snd p) YSL YSR"
    obtain C0 C1 where C01_def: "(C0, C1) = combine (C0L, C1L) (C0R, C1R) L YS"
      using prod.collapse by metis
    note defs = XSLR_def L_def YSC01L_def YSC01R_def YS_def C01_def

    have XSLR: "XSL = take (?n div 2) xs" "XSR = drop (?n div 2) xs"
      using defs by (auto simp: split_at_take_drop_conv)

    have "1 < length XSL" "length XSL < length xs"
      using False XSLR by simp_all
    moreover have "sorted_fst XSL"
      using "1.prems"(2) XSLR by (auto simp: sorted_fst_def sorted_wrt_take)
    ultimately have L: "sparse (dist C0L C1L) (set XSL)"
                       "set XSL = set YSL"
      using 1 closest_pair_rec_set_length_sorted_snd closest_pair_rec_c0_c1
              YSC01L_def by blast+
    hence IHL: "sparse (dist C0L C1L) (set YSL)"
      by argo

    have "1 < length XSR" "length XSR < length xs"
      using False XSLR by simp_all
    moreover have "sorted_fst XSR"
      using "1.prems"(2) XSLR by (auto simp: sorted_fst_def sorted_wrt_drop)
    ultimately have R: "sparse (dist C0R C1R) (set XSR)"
                       "set XSR = set YSR"
      using 1 closest_pair_rec_set_length_sorted_snd closest_pair_rec_c0_c1
              YSC01R_def by blast+
    hence IHR: "sparse (dist C0R C1R) (set YSR)"
      by argo

    have *: "(YS, C0, C1) = closest_pair_rec xs"
      using False closest_pair_rec_simps defs by (auto simp: Let_def split: prod.split)

    have "set xs = set YS" "sorted_snd YS"
      using "1.prems"(2) closest_pair_rec_set_length_sorted_snd closest_pair_rec_distinct * by blast+
    moreover have "p  set YSL. fst p  L"
      using False "1.prems"(2) XSLR L_def L(2) sorted_fst_take_less_hd_drop by simp
    moreover have "p  set YSR. L  fst p"
      using False "1.prems"(2) XSLR L_def R(2) sorted_fst_hd_drop_less_drop by simp
    moreover have "set YS = set YSL  set YSR"
      using set_merge defs by fast
    moreover have "(C0, C1) = combine (C0L, C1L) (C0R, C1R) L YS"
      by (auto simp add: defs)
    ultimately have "sparse (dist C0 C1) (set xs)"
      using combine_dist IHL IHR by auto
    moreover have "(YS, C0, C1) = (ys, c0, c1)"
      using "1.prems"(3) * by simp
    ultimately show ?thesis
      by blast
  qed
qed

fun closest_pair_tm :: "point list  (point * point) tm" where
  "closest_pair_tm [] =1 return undefined"
| "closest_pair_tm [_] =1 return undefined"
| "closest_pair_tm ps =1 (
    do {
      xs <- mergesort_tm fst ps;
      (_, p) <- closest_pair_rec_tm xs;
      return p
    }
  )"

fun closest_pair :: "point list  (point * point)" where
  "closest_pair [] = undefined"
| "closest_pair [_] = undefined"
| "closest_pair ps = (let (_, p) = closest_pair_rec (mergesort fst ps) in p)"

lemma closest_pair_eq_val_closest_pair_tm:
  "val (closest_pair_tm ps) = closest_pair ps"
  by (induction ps rule: induct_list012)
     (auto simp del: closest_pair_rec_tm.simps mergesort_tm.simps
           simp add: closest_pair_rec_eq_val_closest_pair_rec_tm mergesort_eq_val_mergesort_tm
           split: prod.split)

lemma closest_pair_simps:
  "1 < length ps  closest_pair ps = (let (_, p) = closest_pair_rec (mergesort fst ps) in p)"
  by (induction ps rule: induct_list012) auto

declare closest_pair.simps [simp del]

theorem closest_pair_c0_c1:
  assumes "1 < length ps" "distinct ps" "(c0, c1) = closest_pair ps"
  shows "c0  set ps" "c1  set ps" "c0  c1"
  using assms closest_pair_rec_c0_c1[of "mergesort fst ps"]
  by (auto simp: closest_pair_simps mergesort split: prod.splits)

theorem closest_pair_dist:
  assumes "1 < length ps" "(c0, c1) = closest_pair ps"
  shows "sparse (dist c0 c1) (set ps)"
  using assms sorted_fst_def closest_pair_rec_dist[of "mergesort fst ps"] closest_pair_rec_c0_c1[of "mergesort fst ps"]
  by (auto simp: closest_pair_simps mergesort split: prod.splits)


subsection "Time Complexity Proof"

subsubsection "Core Argument"

lemma core_argument:
  fixes δ :: real and p :: point and ps :: "point list"
  assumes "distinct (p # ps)" "sorted_snd (p # ps)" "0  δ" "set (p # ps) = psL  psR"
  assumes "q  set (p # ps). l - δ < fst q  fst q < l + δ"
  assumes "q  psL. fst q  l" "q  psR. l  fst q"
  assumes "sparse δ psL" "sparse δ psR"
  shows "length (filter (λq. snd q - snd p  δ) ps)  7"
proof -
  define PS where "PS = p # ps"
  define R where "R = cbox (l - δ, snd p) (l + δ, snd p + δ)"
  define RPS where "RPS = { p  set PS. p  R }"
  define LSQ where "LSQ = cbox (l - δ, snd p) (l, snd p + δ)"
  define LSQPS where "LSQPS = { p  psL. p  LSQ }"
  define RSQ where "RSQ = cbox (l, snd p) (l + δ, snd p + δ)"
  define RSQPS where "RSQPS = { p  psR. p  RSQ }"
  note defs = PS_def R_def RPS_def LSQ_def LSQPS_def RSQ_def RSQPS_def

  have "R = LSQ  RSQ"
    using defs cbox_right_un by auto
  moreover have "p  psL. p  RSQ  p  LSQ"
    using RSQ_def LSQ_def assms(6) by auto
  moreover have "p  psR. p  LSQ  p  RSQ"
    using RSQ_def LSQ_def assms(7) by auto
  ultimately have "RPS = LSQPS  RSQPS"
    using LSQPS_def RSQPS_def PS_def RPS_def assms(4) by blast

  have "sparse δ LSQPS"
    using assms(8) LSQPS_def sparse_def by simp
  hence CLSQPS: "card LSQPS  4"
    using max_points_square[of LSQPS "l - δ" "snd p" δ] assms(3) LSQ_def LSQPS_def by auto

  have "sparse δ RSQPS"
    using assms(9) RSQPS_def sparse_def by simp
  hence CRSQPS: "card RSQPS  4"
    using max_points_square[of RSQPS l "snd p" δ] assms(3) RSQ_def RSQPS_def by auto

  have CRPS: "card RPS  8"
    using CLSQPS CRSQPS card_Un_le[of LSQPS RSQPS] RPS = LSQPS  RSQPS by auto

  have "set (p # filter (λq. snd q - snd p  δ) ps)  RPS"
  proof standard
    fix q
    assume *: "q  set (p # filter (λq. snd q - snd p  δ) ps)"
    hence CPS: "q  set PS"
      using PS_def by auto
    hence "snd p  snd q" "snd q  snd p + δ"
      using assms(2,3) PS_def sorted_snd_def * by (auto split: if_splits)
    moreover have "l - δ < fst q" "fst q < l + δ"
      using CPS assms(5) PS_def by blast+
    ultimately have "q  R"
      using R_def mem_cbox_2D[of "l - δ" "fst q" "l + δ" "snd p" "snd q" "snd p + δ"]
      by (simp add: prod.case_eq_if)
    thus "q  RPS"
      using CPS RPS_def by simp
  qed
  moreover have "finite RPS"
    by (simp add: RPS_def)
  ultimately have "card (set (p # filter (λq. snd q - snd p  δ) ps))  8"
    using CRPS card_mono[of RPS "set (p # filter (λq. snd q - snd p  δ) ps)"] by simp
  moreover have "distinct (p # filter (λq. snd q - snd p  δ) ps)"
    using assms(1) by simp
  ultimately have "length (p # filter (λq. snd q - snd p  δ) ps)  8"
    using assms(1) PS_def distinct_card by metis
  thus ?thesis
    by simp
qed

subsubsection "Combine Step"

fun t_find_closest :: "point  real  point list  nat" where
  "t_find_closest _ _ [] = 1"
| "t_find_closest _ _ [_] = 1"
| "t_find_closest p δ (p0 # ps) = 1 + (
    if δ  snd p0 - snd p then 0
    else t_find_closest p (min δ (dist p p0)) ps
  )"

lemma t_find_closest_eq_time_find_closest_tm:
  "t_find_closest p δ ps = time (find_closest_tm p δ ps)"
  by (induction p δ ps rule: t_find_closest.induct)
     (auto simp: time_simps)

lemma t_find_closest_mono:
  "δ'  δ  t_find_closest p δ' ps  t_find_closest p δ ps"
  by (induction rule: t_find_closest.induct)
     (auto simp: Let_def min_def)

lemma t_find_closest_cnt:
  "t_find_closest p δ ps  1 + length (filter (λq. snd q - snd p  δ) ps)"
proof (induction p δ ps rule: t_find_closest.induct)
  case (3 p δ p0 p2 ps)
  show ?case
  proof (cases "δ  snd p0 - snd p")
    case True
    thus ?thesis
      by simp
  next
    case False
    hence *: "snd p0 - snd p  δ"
      by simp
    have "t_find_closest p δ (p0 # p2 # ps) = 1 + t_find_closest p (min δ (dist p p0)) (p2 # ps)"
      using False by simp
    also have "...  1 + 1 + length (filter (λq. snd q - snd p  min δ (dist p p0)) (p2 # ps))"
      using False 3 by simp
    also have "...  1 + 1 + length (filter (λq. snd q - snd p  δ) (p2 # ps))"
      using * by (meson add_le_cancel_left length_filter_P_impl_Q min.bounded_iff)
    also have "...  1 + length (filter (λq. snd q - snd p  δ) (p0 # p2 # ps))"
      using False by simp
    ultimately show ?thesis
      by simp
  qed
qed auto

corollary t_find_closest_bound:
  fixes δ :: real and p :: point and ps :: "point list" and l :: int
  assumes "distinct (p # ps)" "sorted_snd (p # ps)" "0  δ" "set (p # ps) = psL  psR"
  assumes "p'  set (p # ps). l - δ < fst p'  fst p' < l + δ"
  assumes "p  psL. fst p  l" "p  psR. l  fst p"
  assumes "sparse δ psL" "sparse δ psR"
  shows "t_find_closest p δ ps  8"
  using assms core_argument[of p ps δ psL psR l] t_find_closest_cnt[of p δ ps] by linarith

fun t_find_closest_pair :: "(point * point)  point list  nat" where
  "t_find_closest_pair _ [] = 1"
| "t_find_closest_pair _ [_] = 1"
| "t_find_closest_pair (c0, c1) (p0 # ps) = 1 + (
    let p1 = find_closest p0 (dist c0 c1) ps in
    t_find_closest p0 (dist c0 c1) ps + (
    if dist c0 c1  dist p0 p1 then
      t_find_closest_pair (c0, c1) ps
    else
      t_find_closest_pair (p0, p1) ps
  ))"

lemma t_find_closest_pair_eq_time_find_closest_pair_tm:
  "t_find_closest_pair (c0, c1) ps = time (find_closest_pair_tm (c0, c1) ps)"
  by (induction "(c0, c1)" ps arbitrary: c0 c1 rule: t_find_closest_pair.induct)
     (auto simp: time_simps find_closest_eq_val_find_closest_tm t_find_closest_eq_time_find_closest_tm)

lemma t_find_closest_pair_bound:
  assumes "distinct ps" "sorted_snd ps" "δ = dist c0 c1" "set ps = psL  psR"
  assumes "p  set ps. l - Δ < fst p  fst p < l + Δ"
  assumes "p  psL. fst p  l" "p  psR. l  fst p"
  assumes "sparse Δ psL" "sparse Δ psR" "δ  Δ"
  shows "t_find_closest_pair (c0, c1) ps  9 * length ps + 1"
  using assms
proof (induction "(c0, c1)" ps arbitrary: δ c0 c1 psL psR rule: t_find_closest_pair.induct)
  case (3 c0 c1 p0 p2 ps)
  let ?ps = "p2 # ps"
  define p1 where p1_def: "p1 = find_closest p0 (dist c0 c1) ?ps"
  define PSL where PSL_def: "PSL = psL - { p0 }"
  define PSR where PSR_def: "PSR = psR - { p0 }"
  note defs = p1_def PSL_def PSR_def
  have *: "0  Δ"
    using "3.prems"(3,10) zero_le_dist[of c0 c1] by argo
  hence "t_find_closest p0 Δ ?ps  8"
    using t_find_closest_bound[of p0 ?ps Δ psL psR] "3.prems" by blast
  hence A: "t_find_closest p0 (dist c0 c1) ?ps  8"
    by (metis "3.prems"(3,10) order_trans t_find_closest_mono)
  have B: "distinct ?ps" "sorted_snd ?ps"
    using "3.prems"(1,2) sorted_snd_def by simp_all
  have C: "set ?ps = PSL  PSR"
    using defs "3.prems"(1,4) by auto
  have D: "p  set ?ps. l - Δ < fst p  fst p < l + Δ"
    using "3.prems"(5) by simp
  have E: "p  PSL. fst p  l" "p  PSR. l  fst p"
    using defs "3.prems"(6,7) by simp_all
  have F: "sparse Δ PSL" "sparse Δ PSR"
    using defs "3.prems"(8,9) sparse_def by simp_all
  show ?case
  proof (cases "dist c0 c1  dist p0 p1")
    case True
    hence "t_find_closest_pair (c0, c1) ?ps  9 * length ?ps + 1"
      using "3.hyps"(1) "3.prems"(3,10) defs(1) B C D E F by blast
    moreover have "t_find_closest_pair (c0, c1) (p0 # ?ps) =
                   1 + t_find_closest p0 (dist c0 c1) ?ps + t_find_closest_pair (c0, c1) ?ps"
      using defs True by (auto split: prod.splits)
    ultimately show ?thesis
      using A by auto
  next
    case False
    moreover have "0  dist p0 p1"
      by auto
    ultimately have "t_find_closest_pair (p0, p1) ?ps  9 * length ?ps + 1"
      using "3.hyps"(2) "3.prems"(3,10) defs(1) B C D E F by auto
    moreover have "t_find_closest_pair (c0, c1) (p0 # ?ps) =
                   1 + t_find_closest p0 (dist c0 c1) ?ps + t_find_closest_pair (p0, p1) ?ps"
      using defs False by (auto split: prod.splits)
    ultimately show ?thesis
      using A by simp
  qed
qed auto

fun t_combine :: "(point * point)  (point * point)  int  point list  nat" where
  "t_combine (p0L, p1L) (p0R, p1R) l ps = 1 + (
    let (c0, c1) = if dist p0L p1L < dist p0R p1R then (p0L, p1L) else (p0R, p1R) in
    let ps' = filter (λp. dist p (l, snd p) < dist c0 c1) ps in
    time (filter_tm (λp. dist p (l, snd p) < dist c0 c1) ps) + t_find_closest_pair (c0, c1) ps'
  )"

lemma t_combine_eq_time_combine_tm:
  "t_combine (p0L, p1L) (p0R, p1R) l ps = time (combine_tm (p0L, p1L) (p0R, p1R) l ps)"
  by (auto simp: combine_tm.simps time_simps t_find_closest_pair_eq_time_find_closest_pair_tm filter_eq_val_filter_tm)

lemma t_combine_bound:
  fixes ps :: "point list"
  assumes "distinct ps" "sorted_snd ps" "set ps = psL  psR"
  assumes "p  psL. fst p  l" "p  psR. l  fst p"
  assumes "sparse (dist p0L p1L) psL" "sparse (dist p0R p1R) psR"
  shows "t_combine (p0L, p1L) (p0R, p1R) l ps  10 * length ps + 3"
proof -
  obtain c0 c1 where c_def:
    "(c0, c1) = (if dist p0L p1L < dist p0R p1R then (p0L, p1L) else (p0R, p1R))" by metis
  let ?P = "(λp. dist p (l, snd p) < dist c0 c1)"
  define ps' where ps'_def: "ps' = filter ?P ps"
  define psL' where psL'_def: "psL' = { p  psL. ?P p }"
  define psR' where psR'_def: "psR' = { p  psR. ?P p }"
  note defs = c_def ps'_def psL'_def psR'_def
  have "sparse (dist c0 c1) psL" "sparse (dist c0 c1) psR"
    using assms(6,7) sparse_mono c_def by (auto split: if_splits)
  hence "sparse (dist c0 c1) psL'" "sparse (dist c0 c1) psR'"
    using psL'_def psR'_def sparse_def by auto
  moreover have "distinct ps'"
    using ps'_def assms(1) by simp
  moreover have "sorted_snd ps'"
    using ps'_def assms(2) sorted_snd_def sorted_wrt_filter by blast
  moreover have "0  dist c0 c1"
    by simp
  moreover have "set ps' = psL'  psR'"
    using assms(3) defs(2,3,4) filter_Un by auto
  moreover have "p  set ps'. l - dist c0 c1 < fst p  fst p < l + dist c0 c1"
    using ps'_def dist_transform by force
  moreover have "p  psL'. fst p  l" "p  psR'. l  fst p"
    using assms(4,5) psL'_def psR'_def by blast+
  ultimately have "t_find_closest_pair (c0, c1) ps'  9 * length ps' + 1"
    using t_find_closest_pair_bound by blast
  moreover have "length ps'  length ps"
    using ps'_def by simp
  ultimately have *: "t_find_closest_pair (c0, c1) ps'  9 * length ps + 1"
    by simp
  have "t_combine (p0L, p1L) (p0R, p1R) l ps =
        1 + time (filter_tm ?P ps) + t_find_closest_pair (c0, c1) ps'"
    using defs by (auto split: prod.splits)
  also have "... = 2 + length ps + t_find_closest_pair (c0, c1) ps'"
    using time_filter_tm by auto
  finally show ?thesis
    using * by simp
qed

declare t_combine.simps [simp del]

subsubsection "Divide and Conquer Algorithm"

lemma time_closest_pair_rec_tm_simps_1:
  assumes "length xs  3"
  shows "time (closest_pair_rec_tm xs) = 1 + time (length_tm xs) + time (mergesort_tm snd xs) + time (closest_pair_bf_tm xs)"
  using assms by  (auto simp: time_simps length_eq_val_length_tm)

lemma time_closest_pair_rec_tm_simps_2:
  assumes "¬ (length xs  3)"
  shows "time (closest_pair_rec_tm xs) = 1 + (
    let (xsL, xsR) = val (split_at_tm (length xs div 2) xs) in
    let (ysL, pL) = val (closest_pair_rec_tm xsL) in
    let (ysR, pR) = val (closest_pair_rec_tm xsR) in
    let ys = val (merge_tm (λp. snd p) ysL ysR) in
    time (length_tm xs) + time (split_at_tm (length xs div 2) xs) + time (closest_pair_rec_tm xsL) +
    time (closest_pair_rec_tm xsR) + time (merge_tm (λp. snd p) ysL ysR) + t_combine pL pR (fst (hd xsR)) ys
  )"
  using assms
  apply (subst closest_pair_rec_tm.simps)
  by (auto simp del: closest_pair_rec_tm.simps
           simp add: time_simps length_eq_val_length_tm t_combine_eq_time_combine_tm
              split: prod.split)

function closest_pair_recurrence :: "nat  real" where
  "n  3  closest_pair_recurrence n = 3 + n + mergesort_recurrence n + n * n"
| "3 < n  closest_pair_recurrence n = 7 + 13 * n +
    closest_pair_recurrence (nat real n / 2) + closest_pair_recurrence (nat real n / 2)"
  by force simp_all
termination by akra_bazzi_termination simp_all

lemma closest_pair_recurrence_nonneg[simp]:
  "0  closest_pair_recurrence n"
  by (induction n rule: closest_pair_recurrence.induct) auto

lemma time_closest_pair_rec_conv_closest_pair_recurrence:
  assumes "distinct ps" "sorted_fst ps"
  shows "time (closest_pair_rec_tm ps)  closest_pair_recurrence (length ps)"
  using assms
proof (induction ps rule: length_induct)
  case (1 ps)
  let ?n = "length ps"
  show ?case
  proof (cases "?n  3")
    case True
    hence "time (closest_pair_rec_tm ps) = 1 + time (length_tm ps) + time (mergesort_tm snd ps) + time (closest_pair_bf_tm ps)"
      using time_closest_pair_rec_tm_simps_1 by simp
    moreover have "closest_pair_recurrence ?n = 3 + ?n + mergesort_recurrence ?n + ?n * ?n"
      using True by simp
    moreover have "time (length_tm ps)  1 + ?n" "time (mergesort_tm snd ps)  mergesort_recurrence ?n"
                  "time (closest_pair_bf_tm ps)  1 + ?n * ?n"
      using time_length_tm[of ps] time_mergesort_conv_mergesort_recurrence[of snd ps] time_closest_pair_bf_tm[of ps] by auto
    ultimately show ?thesis
      by linarith
  next
    case False

    obtain XSL XSR where XS_def: "(XSL, XSR) = val (split_at_tm (?n div 2) ps)"
      using prod.collapse by blast
    obtain YSL C0L C1L where CPL_def: "(YSL, C0L, C1L) = val (closest_pair_rec_tm XSL)"
      using prod.collapse by metis
    obtain YSR C0R C1R where CPR_def: "(YSR, C0R, C1R) = val (closest_pair_rec_tm XSR)"
      using prod.collapse by metis
    define YS where "YS = val (merge_tm (λp. snd p) YSL YSR)"
    obtain C0 C1 where C01_def: "(C0, C1) = val (combine_tm (C0L, C1L) (C0R, C1R) (fst (hd XSR)) YS)"
      using prod.collapse by metis
    note defs = XS_def CPL_def CPR_def YS_def C01_def

    have XSLR: "XSL = take (?n div 2) ps" "XSR = drop (?n div 2) ps"
      using defs by (auto simp: split_at_take_drop_conv split_at_eq_val_split_at_tm)
    hence "length XSL = ?n div 2" "length XSR = ?n - ?n div 2"
      by simp_all
    hence *: "(nat real ?n / 2) = length XSL" "(nat real ?n / 2) = length XSR"
      by linarith+
    have "length XSL = length YSL" "length XSR = length YSR"
      using defs closest_pair_rec_set_length_sorted_snd closest_pair_rec_eq_val_closest_pair_rec_tm by metis+
    hence L: "?n = length YSL + length YSR"
      using defs XSLR by fastforce

    have "1 < length XSL" "length XSL < length ps"
      using False XSLR by simp_all
    moreover have "distinct XSL" "sorted_fst XSL"
      using XSLR "1.prems"(1,2) sorted_fst_def sorted_wrt_take by simp_all
    ultimately have "time (closest_pair_rec_tm XSL)  closest_pair_recurrence (length XSL)"
      using "1.IH" by simp
    hence IHL: "time (closest_pair_rec_tm XSL)  closest_pair_recurrence (nat real ?n / 2)"
      using * by simp

    have "1 < length XSR" "length XSR < length ps"
      using False XSLR by simp_all
    moreover have "distinct XSR" "sorted_fst XSR"
      using XSLR "1.prems"(1,2) sorted_fst_def sorted_wrt_drop by simp_all
    ultimately have "time (closest_pair_rec_tm XSR)  closest_pair_recurrence (length XSR)"
      using "1.IH" by simp
    hence IHR: "time (closest_pair_rec_tm XSR)  closest_pair_recurrence (nat real ?n / 2)"
      using * by simp

    have "(YS, C0, C1) = val (closest_pair_rec_tm ps)"
      using False closest_pair_rec_simps defs by (auto simp: Let_def length_eq_val_length_tm split!: prod.split)
    hence "set ps = set YS" "length ps = length YS" "distinct YS" "sorted_snd YS"
      using "1.prems" closest_pair_rec_set_length_sorted_snd closest_pair_rec_distinct
            closest_pair_rec_eq_val_closest_pair_rec_tm by auto
    moreover have "p  set YSL. fst p  fst (hd XSR)"
      using False "1.prems"(2) XSLR length XSL < length ps length XSL = length ps div 2
            CPL_def sorted_fst_take_less_hd_drop closest_pair_rec_set_length_sorted_snd
            closest_pair_rec_eq_val_closest_pair_rec_tm by metis
    moreover have "p  set YSR. fst (hd XSR)  fst p"
      using False "1.prems"(2) XSLR CPR_def sorted_fst_hd_drop_less_drop
            closest_pair_rec_set_length_sorted_snd closest_pair_rec_eq_val_closest_pair_rec_tm by metis
    moreover have "set YS = set YSL  set YSR"
      using set_merge defs by (metis merge_eq_val_merge_tm)
    moreover have "sparse (dist C0L C1L) (set YSL)"
      using CPL_def 1 < length XSL distinct XSL sorted_fst XSL
            closest_pair_rec_dist closest_pair_rec_set_length_sorted_snd
            closest_pair_rec_eq_val_closest_pair_rec_tm by auto
    moreover have "sparse (dist C0R C1R) (set YSR)"
      using CPR_def 1 < length XSR distinct XSR sorted_fst XSR
            closest_pair_rec_dist closest_pair_rec_set_length_sorted_snd
            closest_pair_rec_eq_val_closest_pair_rec_tm by auto
    ultimately have combine_bound: "t_combine (C0L, C1L) (C0R, C1R) (fst (hd XSR)) YS  3 + 10 * ?n"
      using t_combine_bound[of YS "set YSL" "set YSR" "fst (hd XSR)"] by (simp add: add.commute)
    have "time (closest_pair_rec_tm ps) = 1 + time (length_tm ps) + time (split_at_tm (?n div 2) ps) +
              time (closest_pair_rec_tm XSL) + time (closest_pair_rec_tm XSR) + time (merge_tm (λp. snd p) YSL YSR) +
              t_combine (C0L, C1L) (C0R, C1R) (fst (hd XSR)) YS"
      using time_closest_pair_rec_tm_simps_2[OF False] defs
      by (auto simp del: closest_pair_rec_tm.simps simp add: Let_def split: prod.split)
    also have "...  7 + 13 * ?n + time (closest_pair_rec_tm XSL) + time (closest_pair_rec_tm XSR)"
      using time_merge_tm[of "(λp. snd p)" YSL YSR] L combine_bound by (simp add: time_length_tm time_split_at_tm)
    also have "...  7 + 13 * ?n + closest_pair_recurrence (nat real ?n / 2) +
              closest_pair_recurrence (nat real ?n / 2)"
      using IHL IHR by simp
    also have "... = closest_pair_recurrence (length ps)"
      using False by simp
    finally show ?thesis
      by simp
  qed
qed

theorem closest_pair_recurrence:
  "closest_pair_recurrence  Θ(λn. n * ln n)"
  by (master_theorem) auto

theorem time_closest_pair_rec_bigo:
  "(λxs. time (closest_pair_rec_tm xs))  O[length going_to at_top within { ps. distinct ps  sorted_fst ps }]((λn. n * ln n) o length)"
proof -
  have 0: "ps. ps  { ps. distinct ps  sorted_fst ps } 
           time (closest_pair_rec_tm ps)  (closest_pair_recurrence o length) ps"
    unfolding comp_def using time_closest_pair_rec_conv_closest_pair_recurrence by auto
  show ?thesis
    using bigo_measure_trans[OF 0] bigthetaD1[OF closest_pair_recurrence] of_nat_0_le_iff by blast
qed

definition closest_pair_time :: "nat  real" where
  "closest_pair_time n = 1 + mergesort_recurrence n + closest_pair_recurrence n"

lemma time_closest_pair_conv_closest_pair_recurrence:
  assumes "distinct ps"
  shows "time (closest_pair_tm ps)  closest_pair_time (length ps)"
  using assms
  unfolding closest_pair_time_def
proof (induction rule: induct_list012)
  case (3 x y zs)
  let ?ps = "x # y # zs"
  define xs where "xs = val (mergesort_tm fst ?ps)"
  have *: "distinct xs" "sorted_fst xs" "length xs = length ?ps"
    using xs_def mergesort(4)[OF "3.prems", of fst] mergesort(1)[of fst ?ps] mergesort(3)[of fst ?ps]
          sorted_fst_def mergesort_eq_val_mergesort_tm by metis+
  have "time (closest_pair_tm ?ps) = 1 + time (mergesort_tm fst ?ps) + time (closest_pair_rec_tm xs)"
    using xs_def by (auto simp del: mergesort_tm.simps closest_pair_rec_tm.simps simp add: time_simps split: prod.split)
  also have "...  1 + mergesort_recurrence (length ?ps) + time (closest_pair_rec_tm xs)"
    using time_mergesort_conv_mergesort_recurrence[of fst ?ps] by simp
  also have "...  1 + mergesort_recurrence (length ?ps) + closest_pair_recurrence (length ?ps)"
    using time_closest_pair_rec_conv_closest_pair_recurrence[of xs] * by auto
  finally show ?case
    by blast
qed (auto simp: time_simps)

corollary closest_pair_time:
  "closest_pair_time  O(λn. n * ln n)"
  unfolding closest_pair_time_def
  using mergesort_recurrence closest_pair_recurrence sum_in_bigo(1) const_1_bigo_n_ln_n by blast

corollary time_closest_pair_bigo:
  "(λps. time (closest_pair_tm ps))  O[length going_to at_top within { ps. distinct ps }]((λn. n * ln n) o length)"
proof -
  have 0: "ps. ps  { ps. distinct ps } 
           time (closest_pair_tm ps)  (closest_pair_time o length) ps"
    unfolding comp_def using time_closest_pair_conv_closest_pair_recurrence by auto
  show ?thesis
    using bigo_measure_trans[OF 0] closest_pair_time by fastforce
qed


subsection "Code Export"

subsubsection "Combine Step"

fun find_closest_code :: "point  int  point list  (int * point)" where
  "find_closest_code _ _ [] = undefined"
| "find_closest_code p _ [p0] = (dist_code p p0, p0)"
| "find_closest_code p δ (p0 # ps) = (
    let δ0 = dist_code p p0 in
    if δ  (snd p0 - snd p)2 then
      (δ0, p0)
    else
      let (δ1, p1) = find_closest_code p (min δ δ0) ps in
      if δ0  δ1 then
        (δ0, p0)
      else
        (δ1, p1)
  )"

lemma find_closest_code_dist_eq:
  "0 < length ps  (δc, c) = find_closest_code p δ ps  δc = dist_code p c"
proof (induction p δ ps arbitrary: δc c rule: find_closest_code.induct)
  case (3 p δ p0 p2 ps)
  show ?case
  proof cases
    assume "δ  (snd p0 - snd p)2"
    thus ?thesis
      using "3.prems"(2) by simp
  next
    assume A: "¬ δ  (snd p0 - snd p)2"
    define δ0 where δ0_def: "δ0 = dist_code p p0"
    obtain δ1 p1 where δ1_def: "(δ1, p1) = find_closest_code p (min δ δ0) (p2 # ps)"
      by (metis surj_pair)
    note defs = δ0_def δ1_def
    have "δ1 = dist_code p p1"
      using "3.IH"[of δ0 δ1 p1] A defs by simp
    thus ?thesis
      using defs "3.prems" by (auto simp: Let_def split: if_splits prod.splits)
  qed
qed simp_all

declare find_closest.simps [simp add]

lemma find_closest_code_eq:
  assumes "0 < length ps" "δ = dist c0 c1" "δ' = dist_code c0 c1" "sorted_snd (p # ps)"
  assumes "c = find_closest p δ ps" "(δc', c') = find_closest_code p δ' ps"
  shows "c = c'"
  using assms
proof (induction p δ ps arbitrary: δ' c0 c1 c δc' c' rule: find_closest.induct)
  case (3 p δ p0 p2 ps)
  define δ0 δ0' where δ0_def: "δ0 = dist p p0" "δ0' = dist_code p p0"
  obtain p1 δ1' p1' where δ1_def: "p1 = find_closest p (min δ δ0) (p2 # ps)"
    "(δ1', p1') = find_closest_code p (min δ' δ0') (p2 # ps)"
    by (metis surj_pair)
  note defs = δ0_def δ1_def
  show ?case
  proof cases
    assume *: "δ  snd p0 - snd p"
    hence "δ'  (snd p0 - snd p)2"
      using "3.prems"(2,3) dist_eq_dist_code_abs_le by fastforce
    thus ?thesis
      using * "3.prems"(5,6) by simp
  next
    assume *: "¬ δ  snd p0 - snd p"
    moreover have "0  snd p0 - snd p"
      using "3.prems"(4) sorted_snd_def by simp
    ultimately have A: "¬ δ'  (snd p0 - snd p)2"
      using "3.prems"(2,3) dist_eq_dist_code_abs_le[of c0 c1 "snd p0 - snd p"] by simp
    have "min δ δ0 = δ  min δ' δ0' = δ'" "min δ δ0 = δ0  min δ' δ0' = δ0'"
      by (metis "3.prems"(2,3) defs(1,2) dist_eq_dist_code_le min.commute min_def)+
    moreover have "sorted_snd (p # p2 # ps)"
      using "3.prems"(4) sorted_snd_def by simp
    ultimately have B: "p1 = p1'"
      using "3.IH"[of c0 c1 δ' p1 δ1' p1'] "3.IH"[of p p0 δ0' p1 δ1' p1'] "3.prems"(2,3) defs * by auto
    have "δ1' = dist_code p p1'"
      using find_closest_code_dist_eq defs by blast
    hence "δ0  dist p p1  δ0'  δ1'"
      using defs(1,2) dist_eq_dist_code_le by (simp add: B)
    thus ?thesis
      using "3.prems"(5,6) * A B defs by (auto simp: Let_def split: prod.splits)
  qed
qed auto

fun find_closest_pair_code :: "(int * point * point)  point list  (int * point * point)" where
  "find_closest_pair_code (δ, c0, c1) [] = (δ, c0, c1)"
| "find_closest_pair_code (δ, c0, c1) [p] = (δ, c0, c1)"
| "find_closest_pair_code (δ, c0, c1) (p0 # ps) = (
    let (δ', p1) = find_closest_code p0 δ ps in
    if δ  δ' then
      find_closest_pair_code (δ, c0, c1) ps
    else
      find_closest_pair_code (δ', p0, p1) ps
  )"

lemma find_closest_pair_code_dist_eq:
  assumes "δ = dist_code c0 c1" "(Δ, C0, C1) = find_closest_pair_code (δ, c0, c1) ps"
  shows "Δ = dist_code C0 C1"
  using assms
proof (induction "(δ, c0, c1)" ps arbitrary: δ c0 c1 Δ C0 C1 rule: find_closest_pair_code.induct)
  case (3 δ c0 c1 p0 p2 ps)
  obtain δ' p1 where δ'_def: "(δ', p1) = find_closest_code p0 δ (p2 # ps)"
    by (metis surj_pair)
  hence A: "δ' = dist_code p0 p1"
    using find_closest_code_dist_eq by blast
  show ?case
  proof (cases "δ  δ'")
    case True
    obtain Δ' C0' C1' where Δ'_def: "(Δ', C0', C1') = find_closest_pair_code (δ, c0, c1) (p2 # ps)"
      by (metis prod_cases4)
    note defs = δ'_def Δ'_def
    hence "Δ' = dist_code C0' C1'"
      using "3.hyps"(1)[of "(δ', p1)" δ' p1] "3.prems"(1) True δ'_def by blast
    moreover have "Δ = Δ'" "C0 = C0'" "C1 = C1'"
      using defs True "3.prems"(2) apply (auto split: prod.splits) by (metis Pair_inject)+
    ultimately show ?thesis
      by simp
  next
    case False
    obtain Δ' C0' C1' where Δ'_def: "(Δ', C0', C1') = find_closest_pair_code (δ', p0, p1) (p2 # ps)"
      by (metis prod_cases4)
    note defs = δ'_def Δ'_def
    hence "Δ' = dist_code C0' C1'"
      using "3.hyps"(2)[of "(δ', p1)" δ' p1] A False δ'_def by blast
    moreover have "Δ = Δ'" "C0 = C0'" "C1 = C1'"
      using defs False "3.prems"(2) apply (auto split: prod.splits) by (metis Pair_inject)+
    ultimately show ?thesis
      by simp
  qed
qed auto

declare find_closest_pair.simps [simp add]

lemma find_closest_pair_code_eq:
  assumes "δ = dist c0 c1" "δ' = dist_code c0 c1" "sorted_snd ps"
  assumes "(C0, C1) = find_closest_pair (c0, c1) ps"
  assumes "(Δ', C0', C1') = find_closest_pair_code (δ', c0, c1) ps"
  shows "C0 = C0'  C1 = C1'"
  using assms
proof (induction "(c0, c1)" ps arbitrary: δ δ' c0 c1 C0 C1 Δ' C0' C1' rule: find_closest_pair.induct)
  case (3 c0 c1 p0 p2 ps)
  obtain p1 δp' p1' where δp_def: "p1 = find_closest p0 δ (p2 # ps)"
    "(δp', p1') = find_closest_code p0 δ' (p2 # ps)"
    by (metis surj_pair)
  hence A: "δp' = dist_code p0 p1'"
    using find_closest_code_dist_eq by blast
  have B: "p1 = p1'"
    using "3.prems"(1,2,3) δp_def find_closest_code_eq by blast
  show ?case
  proof (cases "δ  dist p0 p1")
    case True
    hence C: "δ'  δp'"
      by (simp add: "3.prems"(1,2) A B dist_eq_dist_code_le)
    obtain C0i C1i Δi' C0i' C1i' where Δi_def:
      "(C0i, C1i) = find_closest_pair (c0, c1) (p2 # ps)"
      "(Δi', C0i', C1i') = find_closest_pair_code (δ', c0, c1) (p2 # ps)"
      by (metis prod_cases3)
    note defs = δp_def Δi_def
    have "sorted_snd (p2 # ps)"
      using "3.prems"(3) sorted_snd_def by simp
    hence "C0i = C0i'  C1i = C1i'"
      using "3.hyps"(1) "3.prems"(1,2) True defs by blast
    moreover have "C0 = C0i" "C1 = C1i"
      using defs(1,3) True "3.prems"(1,4) apply (auto split: prod.splits) by (metis Pair_inject)+
    moreover have "Δ' = Δi'" "C0' = C0i'" "C1' = C1i'"
      using defs(2,4) C "3.prems"(5) apply (auto split: prod.splits) by (metis Pair_inject)+
    ultimately show ?thesis
      by simp
  next
    case False
    hence C: "¬ δ'  δp'"
      by (simp add: "3.prems"(1,2) A B dist_eq_dist_code_le)
    obtain C0i C1i Δi' C0i' C1i' where Δi_def:
      "(C0i, C1i) = find_closest_pair (p0, p1) (p2 # ps)"
      "(Δi', C0i', C1i') = find_closest_pair_code (δp', p0, p1') (p2 # ps)"
      by (metis prod_cases3)
    note defs = δp_def Δi_def
    have "sorted_snd (p2 # ps)"
      using "3.prems"(3) sorted_snd_def by simp
    hence "C0i = C0i'  C1i = C1i'"
      using "3.prems"(1) "3.hyps"(2) A B False defs by blast
    moreover have "C0 = C0i" "C1 = C1i"
      using defs(1,3) False "3.prems"(1,4) apply (auto split: prod.splits) by (metis Pair_inject)+
    moreover have "Δ' = Δi'" "C0' = C0i'" "C1' = C1i'"
      using defs(2,4) C "3.prems"(5) apply (auto split: prod.splits) by (metis Pair_inject)+
    ultimately show ?thesis
      by simp
  qed
qed auto

fun combine_code :: "(int * point * point)  (int * point * point)  int  point list  (int * point * point)" where
  "combine_code (δL, p0L, p1L) (δR, p0R, p1R) l ps = (
    let (δ, c0, c1) = if δL < δR then (δL, p0L, p1L) else (δR, p0R, p1R) in
    let ps' = filter (λp. (fst p - l)2 < δ) ps in
    find_closest_pair_code (δ, c0, c1) ps'
  )"

lemma combine_code_dist_eq:
  assumes "δL = dist_code p0L p1L" "δR = dist_code p0R p1R"
  assumes "(δ, c0, c1) = combine_code (δL, p0L, p1L) (δR, p0R, p1R) l ps"
  shows "δ = dist_code c0 c1"
  using assms by (auto simp: find_closest_pair_code_dist_eq split: if_splits)

lemma combine_code_eq:
  assumes "δL' = dist_code p0L p1L" "δR' = dist_code p0R p1R" "sorted_snd ps"
  assumes "(c0, c1) = combine (p0L, p1L) (p0R, p1R) l ps"
  assumes "(δ', c0', c1') = combine_code (δL', p0L, p1L) (δR', p0R, p1R) l ps"
  shows "c0 = c0'  c1 = c1'"
proof -
  obtain C0i C1i Δi' C0i' C1i' where Δi_def:
    "(C0i, C1i) = (if dist p0L p1L < dist p0R p1R then (p0L, p1L) else (p0R, p1R))"
    "(Δi', C0i', C1i') = (if δL' < δR' then (δL', p0L, p1L) else (δR', p0R, p1R))"
    by metis
  define ps' ps'' where ps'_def:
    "ps' = filter (λp. dist p (l, snd p) < dist C0i C1i) ps"
    "ps'' = filter (λp. (fst p - l)2 < Δi') ps"
  obtain C0 C1 Δ' C0' C1' where Δ_def:
    "(C0, C1) = find_closest_pair (C0i, C1i) ps'"
    "(Δ', C0', C1') = find_closest_pair_code (Δi', C0i', C1i') ps''"
    by (metis prod_cases3)
  note defs = Δi_def ps'_def Δ_def
  have *: "C0i = C0i'" "C1i = C1i'" "Δi' = dist_code C0i' C1i'"
    using Δi_def assms(1,2,3,4) dist_eq_dist_code_lt by (auto split: if_splits)
  hence "p. ¦fst p - l¦ < dist C0i C1i  (fst p - l)2 < Δi'"
    using dist_eq_dist_code_abs_lt by (metis (mono_tags) of_int_abs)
  hence "ps' = ps''"
    using ps'_def dist_fst_abs by auto
  moreover have "sorted_snd ps'"
    using assms(3) ps'_def sorted_snd_def sorted_wrt_filter by blast
  ultimately have "C0 = C0'" "C1 = C1'"
    using * find_closest_pair_code_eq Δ_def by blast+
  moreover have "C0 = c0" "C1 = c1"
    using assms(4) defs(1,3,5) apply (auto simp: combine.simps split: prod.splits) by (metis Pair_inject)+
  moreover have "C0' = c0'" "C1' = c1'"
    using assms(5) defs(2,4,6) apply (auto split: prod.splits) by (metis prod.inject)+
  ultimately show ?thesis
    by blast
qed

subsubsection "Divide and Conquer Algorithm"

function closest_pair_rec_code :: "point list  (point list * int * point * point)" where
  "closest_pair_rec_code xs = (
    let n = length xs in
    if n  3 then
      (mergesort snd xs, closest_pair_bf_code xs)
    else
      let (xsL, xsR) = split_at (n div 2) xs in
      let l = fst (hd xsR) in

      let (ysL, pL) = closest_pair_rec_code xsL in
      let (ysR, pR) = closest_pair_rec_code xsR in

      let ys = merge snd ysL ysR in
      (ys, combine_code pL pR l ys)
  )"
  by pat_completeness auto
termination closest_pair_rec_code
  by (relation "Wellfounded.measure (λxs. length xs)")
     (auto simp: split_at_take_drop_conv Let_def)

lemma closest_pair_rec_code_simps:
  assumes "n = length xs" "¬ (n  3)"
  shows "closest_pair_rec_code xs = (
    let (xsL, xsR) = split_at (n div 2) xs in
    let l = fst (hd xsR) in
    let (ysL, pL) = closest_pair_rec_code xsL in
    let (ysR, pR) = closest_pair_rec_code xsR in
    let ys = merge snd ysL ysR in
    (ys, combine_code pL pR l ys)
  )"
  using assms by (auto simp: Let_def)

declare combine.simps combine_code.simps closest_pair_rec_code.simps [simp del]

lemma closest_pair_rec_code_dist_eq:
  assumes "1 < length xs" "(ys, δ, c0, c1) = closest_pair_rec_code xs"
  shows "δ = dist_code c0 c1"
  using assms
proof (induction xs arbitrary: ys δ c0 c1 rule: length_induct)
  case (1 xs)
  let ?n = "length xs"
  show ?case
  proof (cases "?n  3")
    case True
    hence "(δ, c0, c1) = closest_pair_bf_code xs"
      using "1.prems"(2) closest_pair_rec_code.simps by simp
    thus ?thesis
      using "1.prems"(1) closest_pair_bf_code_dist_eq by simp
  next
    case False

    obtain XSL XSR where XSLR_def: "(XSL, XSR) = split_at (?n div 2) xs"
      using prod.collapse by blast
    define L where "L = fst (hd XSR)"

    obtain YSL ΔL C0L C1L where YSC01L_def: "(YSL, ΔL, C0L, C1L) = closest_pair_rec_code XSL"
      using prod.collapse by metis
    obtain YSR ΔR C0R C1R where YSC01R_def: "(YSR, ΔR, C0R, C1R) = closest_pair_rec_code XSR"
      using prod.collapse by metis

    define YS where "YS = merge (λp. snd p) YSL YSR"
    obtain Δ C0 C1 where C01_def: "(Δ, C0, C1) = combine_code (ΔL, C0L, C1L) (ΔR, C0R, C1R) L YS"
      using prod.collapse by metis
    note defs = XSLR_def L_def YSC01L_def YSC01R_def YS_def C01_def

    have "1 < length XSL" "length XSL < length xs"
      using False "1.prems"(1) defs by (auto simp: split_at_take_drop_conv)
    hence IHL: "ΔL = dist_code C0L C1L"
      using "1.IH" defs by metis+

    have "1 < length XSR" "length XSR < length xs"
      using False "1.prems"(1) defs by (auto simp: split_at_take_drop_conv)
    hence IHR: "ΔR = dist_code C0R C1R"
      using "1.IH" defs by metis+

    have *: "(YS, Δ, C0, C1) = closest_pair_rec_code xs"
      using False closest_pair_rec_code_simps defs by (auto simp: Let_def split: prod.split)
    moreover have "Δ = dist_code C0 C1"
      using combine_code_dist_eq IHL IHR C01_def by blast
    ultimately show ?thesis
      using "1.prems"(2) * by (metis Pair_inject)
  qed
qed

lemma closest_pair_rec_ys_eq:
  assumes "1 < length xs"
  assumes "(ys, c0, c1) = closest_pair_rec xs"
  assumes "(ys', δ', c0', c1') = closest_pair_rec_code xs"
  shows "ys = ys'"
  using assms
proof (induction xs arbitrary: ys c0 c1 ys' δ' c0' c1' rule: length_induct)
  case (1 xs)
  let ?n = "length xs"
  show ?case
  proof (cases "?n  3")
    case True
    hence "ys = mergesort snd xs"
      using "1.prems"(2) closest_pair_rec.simps by simp
    moreover have "ys' = mergesort snd xs"
      using "1.prems"(3) closest_pair_rec_code.simps by (simp add: True)
    ultimately show ?thesis
      using "1.prems"(1) by simp
  next
    case False

    obtain XSL XSR where XSLR_def: "(XSL, XSR) = split_at (?n div 2) xs"
      using prod.collapse by blast
    define L where "L = fst (hd XSR)"

    obtain YSL C0L C1L YSL' ΔL' C0L' C1L' where YSC01L_def:
      "(YSL, C0L, C1L) = closest_pair_rec XSL"
      "(YSL', ΔL', C0L', C1L') = closest_pair_rec_code XSL"
      using prod.collapse by metis
    obtain YSR C0R C1R YSR' ΔR' C0R' C1R' where YSC01R_def:
      "(YSR, C0R, C1R) = closest_pair_rec XSR"
      "(YSR', ΔR', C0R', C1R') = closest_pair_rec_code XSR"
      using prod.collapse by metis

    define YS YS' where YS_def:
      "YS = merge (λp. snd p) YSL YSR"
      "YS' = merge (λp. snd p) YSL' YSR'"
    obtain C0 C1 Δ' C0' C1' where C01_def:
      "(C0, C1) = combine (C0L, C1L) (C0R, C1R) L YS"
      "(Δ', C0', C1') = combine_code (ΔL', C0L', C1L') (ΔR', C0R', C1R') L YS'"
      using prod.collapse by metis
    note defs = XSLR_def L_def YSC01L_def YSC01R_def YS_def C01_def

    have "1 < length XSL" "length XSL < length xs"
      using False "1.prems"(1) defs by (auto simp: split_at_take_drop_conv)
    hence IHL: "YSL = YSL'"
      using "1.IH" defs by metis

    have "1 < length XSR" "length XSR < length xs"
      using False "1.prems"(1) defs by (auto simp: split_at_take_drop_conv)
    hence IHR: "YSR = YSR'"
      using "1.IH" defs by metis

    have "(YS, C0, C1) = closest_pair_rec xs"
      using False closest_pair_rec_simps defs(1,2,3,5,7,9)
      by (auto simp: Let_def split: prod.split)
    moreover have "(YS', Δ', C0', C1') = closest_pair_rec_code xs"
      using False closest_pair_rec_code_simps defs(1,2,4,6,8,10)
      by (auto simp: Let_def split: prod.split)
    moreover have "YS = YS'"
      using IHL IHR YS_def by simp
    ultimately show ?thesis
      by (metis "1.prems"(2,3) Pair_inject)
  qed
qed

lemma closest_pair_rec_code_eq:
  assumes "1 < length xs"
  assumes "(ys, c0, c1) = closest_pair_rec xs"
  assumes "(ys', δ', c0', c1') = closest_pair_rec_code xs"
  shows "c0 = c0'  c1 = c1'"
  using assms
proof (induction xs arbitrary: ys c0 c1 ys' δ' c0' c1' rule: length_induct)
  case (1 xs)
  let ?n = "length xs"
  show ?case
  proof (cases "?n  3")
    case True
    hence "(c0, c1) = closest_pair_bf xs"
      using "1.prems"(2) closest_pair_rec.simps by simp
    moreover have "(δ', c0', c1') = closest_pair_bf_code xs"
      using "1.prems"(3) closest_pair_rec_code.simps by (simp add: True)
    ultimately show ?thesis
      using "1.prems"(1) closest_pair_bf_code_eq by simp
  next
    case False

    obtain XSL XSR where XSLR_def: "(XSL, XSR) = split_at (?n div 2) xs"
      using prod.collapse by blast
    define L where "L = fst (hd XSR)"

    obtain YSL C0L C1L YSL' ΔL' C0L' C1L' where YSC01L_def:
      "(YSL, C0L, C1L) = closest_pair_rec XSL"
      "(YSL', ΔL', C0L', C1L') = closest_pair_rec_code XSL"
      using prod.collapse by metis
    obtain YSR C0R C1R YSR' ΔR' C0R' C1R' where YSC01R_def:
      "(YSR, C0R, C1R) = closest_pair_rec XSR"
      "(YSR', ΔR', C0R', C1R') = closest_pair_rec_code XSR"
      using prod.collapse by metis

    define YS YS' where YS_def:
      "YS = merge (λp. snd p) YSL YSR"
      "YS' = merge (λp. snd p) YSL' YSR'"
    obtain C0 C1 Δ' C0' C1' where C01_def:
      "(C0, C1) = combine (C0L, C1L) (C0R, C1R) L YS"
      "(Δ', C0', C1') = combine_code (ΔL', C0L', C1L') (ΔR', C0R', C1R') L YS'"
      using prod.collapse by metis
    note defs = XSLR_def L_def YSC01L_def YSC01R_def YS_def C01_def

    have "1 < length XSL" "length XSL < length xs"
      using False "1.prems"(1) defs by (auto simp: split_at_take_drop_conv)
    hence IHL: "C0L = C0L'" "C1L = C1L'"
      using "1.IH" defs by metis+

    have "1 < length XSR" "length XSR < length xs"
      using False "1.prems"(1) defs by (auto simp: split_at_take_drop_conv)
    hence IHR: "C0R = C0R'" "C1R = C1R'"
      using "1.IH" defs by metis+

    have "sorted_snd YSL" "sorted_snd YSR"
      using closest_pair_rec_set_length_sorted_snd YSC01L_def(1) YSC01R_def(1) by blast+
    hence "sorted_snd YS"
      using sorted_merge sorted_snd_def YS_def by blast
    moreover have "YS = YS'"
      using defs 1 < length XSL 1 < length XSR closest_pair_rec_ys_eq by blast
    moreover have "ΔL' = dist_code C0L' C1L'" "ΔR' = dist_code C0R' C1R'"
      using defs 1 < length XSL 1 < length XSR closest_pair_rec_code_dist_eq by blast+
    ultimately have "C0 = C0'" "C1 = C1'"
      using combine_code_eq IHL IHR C01_def by blast+
    moreover have "(YS, C0, C1) = closest_pair_rec xs"
      using False closest_pair_rec_simps defs(1,2,3,5,7,9)
      by (auto simp: Let_def split: prod.split)
    moreover have "(YS', Δ', C0', C1') = closest_pair_rec_code xs"
      using False closest_pair_rec_code_simps defs(1,2,4,6,8,10)
      by (auto simp: Let_def split: prod.split)
    ultimately show ?thesis
      using "1.prems"(2,3) by (metis Pair_inject)
  qed
qed

declare closest_pair.simps [simp add]

fun closest_pair_code :: "point list  (point * point)" where
  "closest_pair_code [] = undefined"
| "closest_pair_code [_] = undefined"
| "closest_pair_code ps = (let (_, _, c0, c1) = closest_pair_rec_code (mergesort fst ps) in (c0, c1))"

lemma closest_pair_code_eq:
  "closest_pair ps = closest_pair_code ps"
proof (induction ps rule: induct_list012)
  case (3 x y zs)
  obtain ys c0 c1 ys' δ' c0' c1' where *:
    "(ys, c0, c1) = closest_pair_rec (mergesort fst (x # y # zs))"
    "(ys', δ', c0', c1') = closest_pair_rec_code (mergesort fst (x # y # zs))"
    by (metis prod_cases3)
  moreover have "1 < length (mergesort fst (x # y # zs))"
    using length_mergesort[of fst "x # y # zs"] by simp
  ultimately have "c0 = c0'" "c1 = c1'"
    using closest_pair_rec_code_eq by blast+
  thus ?case
    using * by (auto split: prod.splits)
qed auto

export_code closest_pair_code in OCaml
  module_name Verified

end