Theory Randomised_BSTs

(*
  File:      Randomised_BSTs.thy
  Author:    Manuel Eberl (TU München)

  A formalisation of the randomised binary search trees described by Martínez & Roura.
*)
section ‹Randomised Binary Search Trees›
theory Randomised_BSTs
  imports "Random_BSTs.Random_BSTs" "Monad_Normalisation.Monad_Normalisation"
begin

subsection ‹Auxiliary facts›

text ‹
  First of all, we need some fairly simple auxiliary lemmas.
›

lemma return_pmf_if: "return_pmf (if P then a else b) = (if P then return_pmf a else return_pmf b)"
  by simp

context
begin

interpretation pmf_as_function .

lemma True_in_set_bernoulli_pmf_iff [simp]:
  "True  set_pmf (bernoulli_pmf p)  p > 0"
  by transfer auto

lemma False_in_set_bernoulli_pmf_iff [simp]:
  "False  set_pmf (bernoulli_pmf p)  p < 1"
  by transfer auto

end

lemma in_set_pmf_of_setD: "x  set_pmf (pmf_of_set A)  finite A  A  {}  x  A"
  by (subst (asm) set_pmf_of_set) auto

lemma random_bst_reduce:
  "finite A  A  {} 
     random_bst A = do {x  pmf_of_set A; l  random_bst {yA. y < x};
                        r  random_bst {yA. y > x}; return_pmf l, x, r}"
  by (subst random_bst.simps) auto

lemma pmf_bind_bernoulli:
  assumes "x  {0..1}"
  shows   "pmf (bernoulli_pmf x  f) y = x * pmf (f True) y + (1 - x) * pmf (f False) y"
  using assms by (simp add: pmf_bind)

lemma vimage_bool_pair:
  "f -` A = (x{True, False}. y{True, False}. if f (x, y)  A then {(x, y)} else {})"
  (is "?lhs = ?rhs") unfolding set_eq_iff
proof
  fix x :: "bool × bool"
  obtain a b where [simp]: "x = (a, b)" by (cases x)
  show "x  ?lhs  x  ?rhs"
    by (cases a; cases b) auto
qed

lemma Leaf_in_set_random_bst_iff [simp]:
  "Leaf  set_pmf (random_bst A)  A = {}  ¬finite A"
  by (subst random_bst.simps) auto

lemma bst_insert [intro]: "bst t  bst (Tree_Set.insert x t)"
  by (simp add: bst_iff_sorted_wrt_less inorder_insert sorted_ins_list)

lemma bst_bst_of_list [intro]: "bst (bst_of_list xs)"
proof -
  have "bst (fold Tree_Set.insert xs t)" if "bst t" for t
    using that
  proof (induction xs arbitrary: t)
    case (Cons y xs)
    show ?case by (auto intro!: Cons bst_insert)
  qed auto
  thus ?thesis by (simp add: bst_of_list_altdef)
qed

lemma bst_random_bst:
  assumes "t  set_pmf (random_bst A)"
  shows   "bst t"
proof (cases "finite A")
  case True
  have "random_bst A = map_pmf bst_of_list (pmf_of_set (permutations_of_set A))"
    by (rule random_bst_altdef) fact+
  also have "set_pmf  = bst_of_list ` permutations_of_set A"
    using True by auto
  finally show ?thesis using assms by auto
next
  case False
  hence "random_bst A = return_pmf ⟨⟩"
    by (simp add: random_bst.simps)
  with assms show ?thesis by simp
qed

lemma set_random_bst:
  assumes "t  set_pmf (random_bst A)" "finite A"
  shows   "set_tree t = A"
proof -
  have "random_bst A = map_pmf bst_of_list (pmf_of_set (permutations_of_set A))"
    by (rule random_bst_altdef) fact+
  also have "set_pmf  = bst_of_list ` permutations_of_set A"
    using assms by auto
  finally show ?thesis using assms
    by (auto simp: permutations_of_setD)
qed

lemma isin_bst:
  assumes "bst t"
  shows   "isin t x  x  set_tree t"
  using assms
  by (subst isin_set) (auto simp: bst_iff_sorted_wrt_less)

lemma isin_random_bst:
  assumes "finite A" "t  set_pmf (random_bst A)"
  shows   "isin t x  x  A"
proof -
  from assms have "bst t" by (auto dest: bst_random_bst)
  with assms show ?thesis by (simp add: isin_bst set_random_bst)
qed

lemma card_3way_split:
  assumes "x  (A :: 'a :: linorder set)" "finite A"
  shows   "card A = card {yA. y < x} + card {yA. y > x} + 1"
proof -
  from assms have "A = insert x ({yA. y < x}  {yA. y > x})"
    by auto
  also have "card  = card {yA. y < x} + card {yA. y > x} + 1"
    using assms by (subst card_insert_disjoint) (auto intro: card_Un_disjoint)
  finally show ?thesis .
qed


text ‹
  The following theorem allows splitting a uniformly random choice from a union of two disjoint
  sets to first tossing a coin to decide on one of the constituent sets and then chooing an 
  element from it uniformly at random.
›
lemma pmf_of_set_union_split:
  assumes "finite A" "finite B" "A  B = {}" "A  B  {}"
  assumes "p = card A / (card A + card B)"
  shows   "do {b  bernoulli_pmf p; if b then pmf_of_set A else pmf_of_set B} = pmf_of_set (A  B)"
            (is "?lhs = ?rhs")
proof (rule pmf_eqI)
  fix x :: 'a
  from assms have p: "p  {0..1}"
    by (auto simp: divide_simps assms(5) split: if_splits)

  have "pmf ?lhs x = pmf (pmf_of_set A) x * p + pmf (pmf_of_set B) x * (1 - p)"
    unfolding pmf_bind using p by (subst integral_bernoulli_pmf) auto
  also consider "x  A" "B  {}" | "x  B" "A  {}" | "x  A" "B = {}" | "x  B" "A = {}" |
                "x  A" "x  B"
    using assms by auto
  hence "pmf (pmf_of_set A) x * p + pmf (pmf_of_set B) x * (1 - p) = pmf ?rhs x"
  proof cases
    assume "x  A" "x  B"
    thus ?thesis using assms by (cases "A = {}"; cases "B = {}") auto
  next
    assume "x  A" and [simp]: "B  {}"
    have "pmf (pmf_of_set A) x * p + pmf (pmf_of_set B) x * (1 - p) = p / real (card A)"
      using x  A assms(1-4) by (subst (1 2) pmf_of_set) (auto simp: indicator_def)
    also have " = pmf ?rhs x"
      using assms x  A by (subst pmf_of_set) (auto simp: card_Un_disjoint)
    finally show ?thesis .
  next
    assume "x  B" and [simp]: "A  {}"
    from assms have *: "card (A  B) > 0" by (subst card_gt_0_iff) auto
    have "pmf (pmf_of_set A) x * p + pmf (pmf_of_set B) x * (1 - p) = (1 - p) / real (card B)"
      using x  B assms(1-4) by (subst (1 2) pmf_of_set) (auto simp: indicator_def)
    also have " = pmf ?rhs x"
      using assms x  B *
      by (subst pmf_of_set) (auto simp: card_Un_disjoint assms(5) divide_simps)
    finally show ?thesis .
  qed (insert assms(1-4), auto simp: assms(5))
  finally show "pmf ?lhs x = pmf ?rhs x" .
qed

lemma pmf_of_set_split_inter_diff:
  assumes "finite A" "finite B" "A  {}" "B  {}"
  assumes "p = card (A  B) / card B"
  shows   "do {b  bernoulli_pmf p; if b then pmf_of_set (A  B) else pmf_of_set (B - A)} =
             pmf_of_set B" (is "?lhs = ?rhs")
proof -
  have eq: "B = (A  B)  (B - A)" by auto
  have card_eq: "card B = card (A  B) + card (B - A)"
    using assms by (subst eq, subst card_Un_disjoint) auto
  have "?lhs = pmf_of_set ((A  B)  (B - A))"
    using assms by (intro pmf_of_set_union_split) (auto simp: card_eq)
  with eq show ?thesis by simp
qed

text ‹
  Similarly to the above rule, we can split up a uniformly random choice from the disjoint
  union of three sets. This could be done with two coin flips, but it is more convenient to
  choose a natural number uniformly at random instead and then do a case distinction on it.
›
lemma pmf_of_set_3way_split:
  fixes f g h :: "'a  'b pmf"
  assumes "finite A" "A  {}" "A1  A2 = {}" "A1  A3 = {}" "A2  A3 = {}" "A1  A2  A3 = A"
  shows   "do {x  pmf_of_set A; if x  A1 then f x else if x  A2 then g x else h x} =
           do {i  pmf_of_set {..<card A};
               if i < card A1 then pmf_of_set A1  f
               else if i < card A1 + card A2 then pmf_of_set A2  g
               else pmf_of_set A3  h}" (is "?lhs = ?rhs")
proof (intro pmf_eqI)
  fix x :: 'b
  define m n l where "m = card A1" and "n = card A2" and "l = card A3"
  have [simp]: "finite A1" "finite A2" "finite A3"
    by (rule finite_subset[of _ A]; use assms in force)+
  from assms have card_pos: "card A > 0" by auto
  have A_eq: "A = A1  A2  A3" using assms by simp
  have card_A_eq: "card A = card A1 + card A2 + card A3"
    using assms unfolding A_eq by (subst card_Un_disjoint, simp, simp, force)+ auto
  have card_A_eq': "{..<card A} = {..<m}  {m..<m + n}  {m + n..<card A}"
    by (auto simp: m_def n_def card_A_eq)
  let ?M = "λi. if i < m then pmf_of_set A1  f else if i < m + n then
                  pmf_of_set A2  g else pmf_of_set A3  h"

  have card_times_pmf_of_set_bind:
      "card X * pmf (pmf_of_set X  f) x = (yX. pmf (f y) x)"
      if "finite X" for X :: "'a set" and f :: "'a  'b pmf"
    using that by (cases "X = {}") (auto simp: pmf_bind_pmf_of_set)  

  have "pmf ?rhs x = (i<card A. pmf (?M i) x) / card A"
    (is "_ = ?S / _") using assms card_pos unfolding m_def n_def
    by (subst pmf_bind_pmf_of_set) auto
  also have "?S = (real m * pmf (pmf_of_set A1  f) x +
                   real n * pmf (pmf_of_set A2  g) x +
                   real l * pmf (pmf_of_set A3  h) x)" unfolding card_A_eq'
    by (subst sum.union_disjoint, simp, simp, force)+ (auto simp: card_A_eq m_def n_def l_def)
  also have " = (yA1. pmf (f y) x) + (yA2. pmf (g y) x) + (yA3. pmf (h y) x)"
    unfolding m_def n_def l_def by (subst (1 2 3) card_times_pmf_of_set_bind) auto
  also have " = (yA1  A2  A3.
                       pmf (if y  A1 then f y else if y  A2 then g y else h y) x)"
    using assms(1-5)
    by (subst sum.union_disjoint, simp, simp, force)+
       (intro arg_cong2[of _ _ _ _ "(+)"] sum.cong, auto)
  also have " / card A = pmf ?lhs x"
    using assms by (simp add: pmf_bind_pmf_of_set)
  finally show "pmf ?lhs x = pmf ?rhs x"
    unfolding m_def n_def l_def card_A_eq ..
qed


subsection ‹Partitioning a BST›

text ‹
  The split operation takes a search parameter x› and partitions a BST into two BSTs
  containing all the values that are smaller than x› and those that are greater than x›,
  respectively. Note that x› need not be an element of the tree.
›

fun split_bst :: "'a :: linorder  'a tree  'a tree × 'a tree" where
  "split_bst _ ⟨⟩ = (⟨⟩, ⟨⟩)"
| "split_bst x l, y, r =
     (if y < x then
        case split_bst x r of (t1, t2)  (l, y, t1, t2)
      else if y > x then
        case split_bst x l of (t1, t2)  (t1, t2, y, r)
      else
        (l, r))"

fun split_bst' :: "'a :: linorder  'a tree  bool × 'a tree × 'a tree" where
  "split_bst' _ ⟨⟩ = (False, ⟨⟩, ⟨⟩)"
| "split_bst' x l, y, r =
     (if y < x then
        case split_bst' x r of (b, t1, t2)  (b, l, y, t1, t2)
      else if y > x then
        case split_bst' x l of (b, t1, t2)  (b, t1, t2, y, r)
      else
        (True, l, r))"

lemma split_bst'_altdef: "split_bst' x t = (isin t x, split_bst x t)"
  by (induction x t rule: split_bst.induct) (auto simp: case_prod_unfold)

lemma fst_split_bst' [simp]: "fst (split_bst' x t) = isin t x"
  and snd_split_bst' [simp]: "snd (split_bst' x t) = split_bst x t"
  by (simp_all add: split_bst'_altdef)


lemma size_fst_split_bst [termination_simp]: "size (fst (split_bst x t))  size t"
  by (induction t) (auto simp: case_prod_unfold)

lemma size_snd_split_bst [termination_simp]: "size (snd (split_bst x t))  size t"
  by (induction t) (auto simp: case_prod_unfold)

lemmas size_split_bst = size_fst_split_bst size_snd_split_bst

lemma set_split_bst1: "bst t  set_tree (fst (split_bst x t)) = {y  set_tree t. y < x}"
  by (induction t) (auto split: prod.splits)

lemma set_split_bst2: "bst t  set_tree (snd (split_bst x t)) = {y  set_tree t. y > x}"
  by (induction t) (auto split: prod.splits)

lemma bst_split_bst1 [intro]: "bst t  bst (fst (split_bst x t))"
  by (induction t) (auto simp: case_prod_unfold set_split_bst1)

lemma bst_split_bst2 [intro]: "bst t  bst (snd (split_bst x t))"
  by (induction t) (auto simp: case_prod_unfold set_split_bst2)

text ‹
  Splitting a random BST produces two random BSTs:
›
theorem split_random_bst:
  assumes "finite A"
  shows   "map_pmf (split_bst x) (random_bst A) =
             pair_pmf (random_bst {yA. y < x}) (random_bst {yA. y > x})"
  using assms
proof (induction A rule: random_bst.induct)
  case (1 A)
  define A1 A2 where "A1 = {yA. y < x}" and "A2 = {yA. y > x}"
  have [simp]: "¬x  A2" if "x  A1" for x using that by (auto simp: A1_def A2_def)
  from finite A have [simp]: "finite A1" "finite A2" by (auto simp: A1_def A2_def)
  include monad_normalisation

  show ?case
  proof (cases "A = {}")
    case True
    thus ?thesis by (auto simp: pair_return_pmf1)
  next
    case False

    have "map_pmf (split_bst x) (random_bst A) =
            do {y  pmf_of_set A;
                if y < x then
                  do {
                    l  random_bst {zA. z < y};
                    (t1, t2)  map_pmf (split_bst x) (random_bst {zA. z > y});
                    return_pmf (l, y, t1, t2)
                  }
                else if y > x then
                  do {
                    (t1, t2)  map_pmf (split_bst x) (random_bst {zA. z < y});
                    r  random_bst {zA. z > y};
                    return_pmf (t1, (t2, y, r))
                  }
                else
                  do {
                    l  random_bst {zA. z < y};
                    r  random_bst {zA. z > y};
                    return_pmf (l, r)
                  }
               }"
      using "1.prems" False
      by (subst random_bst.simps)
         (simp add: map_bind_pmf bind_map_pmf return_pmf_if case_prod_unfold cong: if_cong)
    also have " = do {y  pmf_of_set A;
                        if y < x then
                          do {
                            l  random_bst {zA. z < y};
                            (t1, t2)  pair_pmf (random_bst {z{zA. z > y}. z < x})
                                                 (random_bst {z{zA. z > y}. z > x});
                            return_pmf (l, y, t1, t2)
                          }
                        else if y > x then
                          do {
                            (t1, t2)  pair_pmf (random_bst {z{zA. z < y}. z < x})
                                                 (random_bst {z{zA. z < y}. z > x});
                            r  random_bst {zA. z > y};
                            return_pmf (t1, (t2, y, r))
                          }
                         else 
                           do {
                             l  random_bst {zA. z < y};
                             r  random_bst {zA. z > y};
                             return_pmf (l, r)
                           }
                       }"
      using finite A and A  {} thm "1.IH"
      by (intro bind_pmf_cong if_cong refl "1.IH") auto
    also have " = do {y  pmf_of_set A;
                        if y < x then
                          do {
                            l  random_bst {zA. z < y};
                            t1  random_bst {z{zA. z > y}. z < x};
                            t2  random_bst {z{zA. z > y}. z > x};
                            return_pmf (l, y, t1, t2)
                          }
                        else if y > x then
                          do {
                            t1  random_bst {z{zA. z < y}. z < x};
                            t2  random_bst {z{zA. z < y}. z > x};
                            r  random_bst {zA. z > y};
                            return_pmf (t1, (t2, y, r))
                          }
                         else 
                           do {
                             l  random_bst {zA. z < y};
                             r  random_bst {zA. z > y};
                             return_pmf (l, r)
                           }
                       }"
      by (simp add: pair_pmf_def cong: if_cong)
    also have " = do {y  pmf_of_set A;
                        if y  A1 then
                          do {
                            l  random_bst {zA1. z < y};
                            t1  random_bst {zA1. z > y};
                            t2  random_bst A2;
                            return_pmf (l, y, t1, t2)
                          }
                        else if y  A2 then
                          do {
                            t1  random_bst A1;
                            t2  random_bst {zA2. z < y};
                            r  random_bst {zA2. z > y};
                            return_pmf (t1, (t2, y, r))
                          }
                         else
                           pair_pmf (random_bst A1) (random_bst A2)
                       }"
      using finite A A  {}
      by (intro bind_pmf_cong refl if_cong arg_cong[of _ _ random_bst])
         (auto simp: A1_def A2_def pair_pmf_def)
    also have " = do {i  pmf_of_set {..<card A};
                        if i < card A1 then
                          do {
                            y  pmf_of_set A1;
                            l  random_bst {zA1. z < y};
                            t1  random_bst {zA1. z > y};
                            t2  random_bst A2;
                            return_pmf (l, y, t1, t2)
                          }
                        else if i < card A1 + card A2 then
                          do {
                            y  pmf_of_set A2;
                            t1  random_bst A1;
                            t2  random_bst {zA2. z < y};
                            r  random_bst {zA2. z > y};
                            return_pmf (t1, (t2, y, r))
                          }
                         else do {
                           y  pmf_of_set (if x  A then {x} else {});
                           pair_pmf (random_bst A1) (random_bst A2)
                         }
                       }" using finite A A  {}
      by (intro pmf_of_set_3way_split) (auto simp: A1_def A2_def not_less_iff_gr_or_eq)
    also have " = do {i  pmf_of_set {..<card A};
                        if i < card A1 then
                          pair_pmf (random_bst A1) (random_bst A2)
                        else if i < card A1 + card A2 then
                          pair_pmf (random_bst A1) (random_bst A2)
                         else 
                          pair_pmf (random_bst A1) (random_bst A2)
                       }"
      using finite A A  {}
    proof (intro bind_pmf_cong refl if_cong, goal_cases)
      case (1 i)
      hence "A1  {}" by auto
      thus ?case using finite A by (simp add: pair_pmf_def random_bst_reduce)
    next
      case (2 i)
      hence "A2  {}" by auto
      thus ?case using finite A by (simp add: pair_pmf_def random_bst_reduce)
    qed auto
    also have " = pair_pmf (random_bst A1) (random_bst A2)"
      by (simp cong: if_cong)
    finally show ?thesis by (simp add: A1_def A2_def)
  qed
qed


subsection ‹Joining›

text ‹
  The ``join'' operation computes the union of two BSTs l› and r› where all the values in
  l› are stricly smaller than those in r›.
›
fun mrbst_join :: "'a tree  'a tree  'a tree pmf" where
  "mrbst_join t1 t2 =
     (if t1 = ⟨⟩ then return_pmf t2
      else if t2 = ⟨⟩ then return_pmf t1
      else do {
        b  bernoulli_pmf (size t1 / (size t1 + size t2));
        if b then
          (case t1 of l, x, r  map_pmf (λr'. l, x, r') (mrbst_join r t2))
        else
          (case t2 of l, x, r  map_pmf (λl'. l', x, r) (mrbst_join t1 l))
      })"

lemma mrbst_join_Leaf_left [simp]: "mrbst_join ⟨⟩ = return_pmf"
  by (simp add: fun_eq_iff)

lemma mrbst_join_Leaf_right [simp]: "mrbst_join t ⟨⟩ = return_pmf t"
  by (simp add: fun_eq_iff)

lemma mrbst_join_reduce:
  "t1  ⟨⟩  t2  ⟨⟩  mrbst_join t1 t2 =
     do {
        b  bernoulli_pmf (size t1 / (size t1 + size t2));
        if b then
          (case t1 of l, x, r  map_pmf (λr'. l, x, r') (mrbst_join r t2))
        else
          (case t2 of l, x, r  map_pmf (λl'. l', x, r) (mrbst_join t1 l))
      }"
  by (subst mrbst_join.simps) auto

lemmas [simp del] = mrbst_join.simps

lemma
  assumes "t'  set_pmf (mrbst_join t1 t2)" "bst t1" "bst t2"
  assumes "x y. x  set_tree t1  y  set_tree t2  x < y"
  shows   bst_mrbst_join: "bst t'"
    and   set_mrbst_join: "set_tree t' = set_tree t1  set_tree t2"
proof -
  have "bst t'  set_tree t' = set_tree t1  set_tree t2"
  using assms
  proof (induction "size t1 + size t2" arbitrary: t1 t2 t' rule: less_induct)
    case (less t1 t2 t')
    show ?case
    proof (cases "t1 = ⟨⟩  t2 = ⟨⟩")
      case False
      hence "t'  set_pmf (case t1 of l, x, r  map_pmf (Node l x) (mrbst_join r t2)) 
             t'  set_pmf (case t2 of l, x, r  map_pmf (λl'. l', x, r) (mrbst_join t1 l))"
        using less.prems by (subst (asm) mrbst_join_reduce) (auto split: if_splits)
      thus ?thesis
      proof
        assume "t'  set_pmf (case t1 of l, x, r  map_pmf (Node l x) (mrbst_join r t2))"
        then obtain l x r r'
          where *: "t1 = l, x, r" "r'  set_pmf (mrbst_join r t2)" "t' = l, x, r'"
          using False by (auto split: tree.splits)
        from * and less.prems have "bst r'  set_tree r' = set_tree r  set_tree t2"
          by (intro less) auto
        with * and less.prems show ?thesis by auto
      next
        assume "t'  set_pmf (case t2 of l, x, r  map_pmf (λl'. l', x, r) (mrbst_join t1 l))"
        then obtain l x r l'
          where *: "t2 = l, x, r" "l'  set_pmf (mrbst_join t1 l)" "t' = l', x, r"
          using False by (auto split: tree.splits)
        from * and less.prems have "bst l'  set_tree l' = set_tree t1  set_tree l"
          by (intro less) auto
        with * and less.prems show ?thesis by auto
      qed
    qed (insert less.prems, auto)
  qed
  thus "bst t'" "set_tree t' = set_tree t1  set_tree t2" by auto
qed

text ‹
  Joining two random BSTs that satisfy the necessary preconditions again yields a random BST.
›
theorem mrbst_join_correct:
  fixes A B :: "'a :: linorder set"
  assumes "finite A" "finite B" "x y. x  A  y  B  x < y"
  shows   "do {t1  random_bst A; t2  random_bst B; mrbst_join t1 t2} = random_bst (A  B)"
proof -
  from assms have "finite (A  B)" by simp
  from this and assms show ?thesis
  proof (induction "A  B" arbitrary: A B rule: finite_psubset_induct)
    case (psubset A B)
    define m n where "m = card A" and "n = card B"
    define p where "p = m / (m + n)"

    include monad_normalisation
    show ?case
    proof (cases "A = {}  B = {}")
      case True
      thus ?thesis by auto
    next
      case False
      have AB: "A  {}" "B  {}" "finite A" "finite B"
        using False psubset.prems by auto
      have p_pos: "A  {}" if "p > 0" using finite A that
        using AB by (auto simp: p_def m_def n_def)
      have p_lt1: "B  {}" if "p < 1"
        using AB by (auto simp: p_def m_def n_def)

      have "do {t1  random_bst A; t2  random_bst B; mrbst_join t1 t2} =
            do {t1  random_bst A;
                t2  random_bst B;
                b  bernoulli_pmf (size t1 / (size t1 + size t2));
                if b then
                  case t1 of l, x, r  map_pmf (λr'. l, x, r') (mrbst_join r t2)
                else
                  case t2 of l, x, r  map_pmf (λl'. l', x, r) (mrbst_join t1 l)
               }"
        using AB
        by (intro bind_pmf_cong refl, subst mrbst_join_reduce) auto
      also have " = do {t1  random_bst A;
                          t2  random_bst B;
                          b  bernoulli_pmf p;
                          if b then
                            case t1 of l, x, r  map_pmf (λr'. l, x, r') (mrbst_join r t2)
                          else
                            case t2 of l, x, r  map_pmf (λl'. l', x, r) (mrbst_join t1 l)
                         }"
        using AB by (intro bind_pmf_cong refl arg_cong[of _ _ bernoulli_pmf])
                    (auto simp: p_def m_def n_def size_random_bst)
      also have " = do {
                        b  bernoulli_pmf p;
                        if b then do {
                          t1  random_bst A;
                          t2  random_bst B;
                          case t1 of l, x, r  map_pmf (λr'. l, x, r') (mrbst_join r t2)
                        } else do {
                          t1  random_bst A;
                          t2  random_bst B;
                          case t2 of l, x, r  map_pmf (λl'. l', x, r) (mrbst_join t1 l)
                        }
                      }"
        by simp
      also have " = do {
                        b  bernoulli_pmf p;
                        if b then do {
                          x  pmf_of_set A;
                          l  random_bst {yA  B. y < x};
                          r  random_bst {yA  B. y > x};
                          return_pmf l, x, r
                        } else do {
                          x  pmf_of_set B;
                          l  random_bst {yA  B. y < x};
                          r  random_bst {yA  B. y > x};
                          return_pmf l, x, r
                        }
                      }"
      proof (intro bind_pmf_cong refl if_cong, goal_cases)
        case (1 b)
        hence [simp]: "A  {}" using p_pos by auto
        have "do {t1  random_bst A; t2  random_bst B;
                  case t1 of l, x, r  map_pmf (λr'. l, x, r') (mrbst_join r t2)} =
              do {
                x  pmf_of_set A;
                l  random_bst {yA. y < x};
                r  do {r  random_bst {yA. y > x}; t2  random_bst B; mrbst_join r t2};
                return_pmf l, x, r
              }"
          using AB by (subst random_bst_reduce) (auto simp: map_pmf_def)
        also have " = do {
                          x  pmf_of_set A;
                          l  random_bst {yA. y < x};
                          r  random_bst ({yA. y > x}  B);
                          return_pmf l, x, r
                        }"
          using AB psubset.prems 
          by (intro bind_pmf_cong refl psubset arg_cong[of _ _ random_bst]) auto
        also have " = do {
                          x  pmf_of_set A;
                          l  random_bst {yA  B. y < x};
                          r  random_bst {yA  B. y > x};
                          return_pmf l, x, r
                        }"
          using AB psubset.prems
          by (intro bind_pmf_cong refl arg_cong[of _ _ random_bst]; force)
        finally show ?case .
      next
        case (2 b)
        hence [simp]: "B  {}" using p_lt1 by auto
        have "do {t1  random_bst A; t2  random_bst B;
                  case t2 of l, x, r  map_pmf (λl'. l', x, r) (mrbst_join t1 l)} =
              do {
                x  pmf_of_set B;
                l  do {t1  random_bst A; l  random_bst {yB. y < x}; mrbst_join t1 l};
                r  random_bst {yB. y > x};
                return_pmf l, x, r
              }"
          using AB by (subst random_bst_reduce) (auto simp: map_pmf_def)
        also have " = do {
                          x  pmf_of_set B;
                          l  random_bst (A  {yB. y < x});
                          r  random_bst {yB. y > x};
                          return_pmf l, x, r
                        }"
          using AB psubset.prems 
          by (intro bind_pmf_cong refl psubset arg_cong[of _ _ random_bst]) auto
        also have " = do {
                          x  pmf_of_set B;
                          l  random_bst {yA  B. y < x};
                          r  random_bst {yA  B. y > x};
                          return_pmf l, x, r
                        }"
          using AB psubset.prems
          by (intro bind_pmf_cong refl arg_cong[of _ _ random_bst]; force)
        finally show ?case .
      qed
      also have " = do {
                        b  bernoulli_pmf p;
                        x  (if b then pmf_of_set A else pmf_of_set B);
                        l  random_bst {yA  B. y < x};
                        r  random_bst {yA  B. y > x};
                        return_pmf l, x, r
                      }"
        by (intro bind_pmf_cong) simp_all
      also have " = do {
                        x  do {b  bernoulli_pmf p; if b then pmf_of_set A else pmf_of_set B};
                        l  random_bst {yA  B. y < x};
                        r  random_bst {yA  B. y > x};
                        return_pmf l, x, r
                      }"
        by simp
      also have "do {b  bernoulli_pmf p; if b then pmf_of_set A else pmf_of_set B} =
                   pmf_of_set (A  B)"
        using AB psubset.prems by (intro pmf_of_set_union_split) (auto simp: p_def m_def n_def) 
      also have "do {
                   x  pmf_of_set (A  B);
                   l  random_bst {yA  B. y < x};
                   r  random_bst {yA  B. y > x};
                   return_pmf l, x, r
                 } = random_bst (A  B)"
        using AB by (intro random_bst_reduce [symmetric]) auto
      finally show ?thesis .
    qed
  qed
qed


subsection ‹Pushdown›

text ‹
  The ``push down'' operation ``forgets'' information about the root of a tree in the following
  sense: It takes a non-empty tree whose root is some known fixed value and whose children are
  random BSTs and shuffles the root in such a way that the resulting tree is a random BST.
›
fun mrbst_push_down :: "'a tree  'a  'a tree  'a tree pmf" where
  "mrbst_push_down l x r =
     do {
       k  pmf_of_set {0..size l + size r};
       if k < size l then
         case l of
           ll, y, lr  map_pmf (λr'. ll, y, r') (mrbst_push_down lr x r)
       else if k < size l + size r then
         case r of
           rl, y, rr  map_pmf (λl'. l', y, rr) (mrbst_push_down l x rl)
       else
         return_pmf l, x, r
     }"

lemmas [simp del] = mrbst_push_down.simps

lemma
  assumes "t'  set_pmf (mrbst_push_down t1 x t2)" "bst t1" "bst t2"
  assumes "y. y  set_tree t1  y < x" "y. y  set_tree t2  y > x"
  shows   bst_mrbst_push_down: "bst t'"
    and   set_mrbst_push_down: "set_tree t' = {x}  set_tree t1  set_tree t2"
proof -
  have "bst t'  set_tree t' = {x}  set_tree t1  set_tree t2"
  using assms
  proof (induction "size t1 + size t2" arbitrary: t1 t2 t' rule: less_induct)
    case (less t1 t2 t')
    have "t1  ⟨⟩  t'  set_pmf (case t1 of l, y, r 
                            map_pmf (Node l y) (mrbst_push_down r x t2)) 
          t2  ⟨⟩  t'  set_pmf (case t2 of l, y, r 
                            map_pmf (λl'. l', y, r) (mrbst_push_down t1 x l)) 
          t' = t1, x, t2"
      using less.prems by (subst (asm) mrbst_push_down.simps) (auto split: if_splits)
    thus ?case
    proof (elim disjE, goal_cases)
      case 1
      then obtain l y r r'
        where *: "t1 = l, y, r" "r'  set_pmf (mrbst_push_down r x t2)" "t' = l, y, r'"
        by (auto split: tree.splits)
      from * and less.prems have "bst r'  set_tree r' = {x}  set_tree r  set_tree t2"
        by (intro less) auto
      with * and less.prems show ?case by force
    next
      case 2
      then obtain l y r l'
        where *: "t2 = l, y, r" "l'  set_pmf (mrbst_push_down t1 x l)" "t' = l', y, r"
        by (auto split: tree.splits)
      from * and less.prems have "bst l'  set_tree l' = {x}  set_tree t1  set_tree l"
        by (intro less) auto
      with * and less.prems show ?case by force
    qed (insert less.prems, auto)
  qed
  thus "bst t'" "set_tree t' = {x}  set_tree t1  set_tree t2" by auto
qed

theorem mrbst_push_down_correct:
  fixes A B :: "'a :: linorder set"
  assumes "finite A" "finite B" "y. y  A  y < x" "y. y  B  x < y"
  shows   "do {l  random_bst A; r  random_bst B; mrbst_push_down l x r} =
             random_bst ({x}  A  B)"
proof -
  from assms have "finite (A  B)" by simp
  from this and assms show ?thesis
  proof (induction "A  B" arbitrary: A B rule: finite_psubset_induct)
    case (psubset A B)
    define m n where "m = card A" and "n = card B"
    have A_ne: "A  {}" if "m > 0"
      using that by (auto simp: m_def)
    have B_ne: "B  {}" if "n > 0"
      using that by (auto simp: n_def)

    include monad_normalisation
    have "do {l  random_bst A; r  random_bst B; mrbst_push_down l x r} =
          do {l  random_bst A;
              r  random_bst B;
              k  pmf_of_set {0..m + n};
              if k < m then
                case l of ll, y, lr  map_pmf (λr'. ll, y, r') (mrbst_push_down lr x r)
              else if k < m + n then
                case r of rl, y, rr  map_pmf (λl'. l', y, rr) (mrbst_push_down l x rl)
              else
                return_pmf l, x, r
             }"
      using psubset.prems
      by (subst mrbst_push_down.simps, intro bind_pmf_cong refl)
         (auto simp: size_random_bst m_def n_def)
    also have " = do {k  pmf_of_set {0..m + n};
                        if k < m then do {
                          l  random_bst A;
                          r  random_bst B;
                          case l of ll, y, lr  map_pmf (λr'. ll, y, r') (mrbst_push_down lr x r)
                        } else if k < m + n then do {
                          l  random_bst A;
                          r  random_bst B;
                          case r of rl, y, rr  map_pmf (λl'. l', y, rr) (mrbst_push_down l x rl)
                        } else do {
                          l  random_bst A;
                          r  random_bst B;
                          return_pmf l, x, r
                        }
                       }"
      by (simp cong: if_cong)
    also have " = do {k  pmf_of_set {0..m + n};
                        if k < m then do {
                          y  pmf_of_set A;
                          ll  random_bst {zA. z < y};
                          r'  do {lr  random_bst {zA. z > y};
                                    r  random_bst B;
                                    mrbst_push_down lr x r};
                          return_pmf ll, y, r'
                        } else if k < m + n then do {
                          y  pmf_of_set B;
                          l'  do {l  random_bst A;
                                    rl  random_bst {zB. z < y};
                                    mrbst_push_down l x rl};
                          rr  random_bst {zB. z > y};
                          return_pmf l', y, rr
                        } else do {
                          l  random_bst A;
                          r  random_bst B;
                          return_pmf l, x, r
                        }
                       }"
    proof (intro bind_pmf_cong refl if_cong, goal_cases)
      case (1 k)
      hence "A  {}" by (auto simp: m_def)
      with finite A show ?case by (simp add: random_bst_reduce map_pmf_def)
    next
      case (2 k)
      hence "B  {}" by (auto simp: m_def n_def)
      with finite B show ?case by (simp add: random_bst_reduce map_pmf_def)
    qed
    also have " = do {k  pmf_of_set {0..m + n};
                        if k < m then do {
                          y  pmf_of_set A;
                          ll  random_bst {zA. z < y};
                          r'  random_bst ({x}  {zA. z > y}  B);
                          return_pmf ll, y, r'
                        } else if k < m + n then do {
                          y  pmf_of_set B;
                          l'  random_bst ({x}  A  {zB. z < y});
                          rr  random_bst {zB. z > y};
                          return_pmf l', y, rr
                        } else do {
                          l  random_bst A;
                          r  random_bst B;
                          return_pmf l, x, r
                        }
                       }"
      using psubset.prems A_ne B_ne
    proof (intro bind_pmf_cong refl if_cong psubset)
      fix k y assume "k < m" "y  set_pmf (pmf_of_set A)"
      thus "{zA. z > y}  B  A  B"
        using psubset.prems A_ne by (fastforce dest!: in_set_pmf_of_setD)
    next
      fix k y assume "¬k < m" "k < m + n" "y  set_pmf (pmf_of_set B)"
      thus "A  {zB. z < y}  A  B"
        using psubset.prems B_ne by (fastforce dest!: in_set_pmf_of_setD)
    qed auto
    also have " = do {k  pmf_of_set {0..m + n};
                        if k < m then do {
                          y  pmf_of_set A;
                          ll  random_bst {z{x}  A  B. z < y};
                          r'  random_bst {z{x}  A  B. z > y};
                          return_pmf ll, y, r'
                        } else if k < m + n then do {
                          y  pmf_of_set B;
                          l'  random_bst {z{x}  A  B. z < y};
                          rr  random_bst {z{x}  A  B. z > y};
                          return_pmf l', y, rr
                        } else do {
                          l  random_bst {z{x}  A  B. z < x};
                          r  random_bst {z{x}  A  B. z > x};
                          return_pmf l, x, r
                        }
                       }"
      using psubset.prems A_ne B_ne
      by (intro bind_pmf_cong if_cong refl arg_cong[of _ _ random_bst];
          force dest: psubset.prems(3,4))
    also have " = do {k  pmf_of_set {0..m + n};
                        if k < m then do {
                          y  pmf_of_set A;
                          ll  random_bst {z{x}  A  B. z < y};
                          r'  random_bst {z{x}  A  B. z > y};
                          return_pmf ll, y, r'
                        } else if k < m + n then do {
                          y  pmf_of_set B;
                          l'  random_bst {z{x}  A  B. z < y};
                          rr  random_bst {z{x}  A  B. z > y};
                          return_pmf l', y, rr
                        } else do {
                          y  pmf_of_set {x};
                          l  random_bst {z{x}  A  B. z < y};
                          r  random_bst {z{x}  A  B. z > y};
                          return_pmf l, x, r
                        }
                       }" (is "_ = ?X {0..m+n}")
      by (simp add: pmf_of_set_singleton cong: if_cong)
    also have "{0..m + n} = {..<card (A  B  {x})}" using psubset.prems
      by (subst card_Un_disjoint, simp, simp, force)+
         (auto simp: m_def n_def)
    also have "?X  = do {y  pmf_of_set ({x}  A  B);
                           l  random_bst {z{x}  A  B. z < y};
                           r  random_bst {z{x}  A  B. z > y};
                           return_pmf l, y, r}"
      unfolding m_def n_def using psubset.prems
      by (subst pmf_of_set_3way_split [symmetric])
         (auto dest!: psubset.prems(3,4) cong: if_cong intro: bind_pmf_cong)
    also have " = random_bst ({x}  A  B)"
      using psubset.prems by (simp add: random_bst_reduce)
    finally show ?case .
  qed
qed

lemma mrbst_push_down_correct':
  assumes "finite (A :: 'a :: linorder set)" "x  A"
  shows   "do {l  random_bst {yA. y < x}; r  random_bst {yA. y > x}; mrbst_push_down l x r} =
             random_bst A" (is "?lhs = ?rhs")
proof -
  have "?lhs = random_bst ({x}  {yA. y < x}  {yA. y > x})"
    using assms by (intro mrbst_push_down_correct) auto
  also have "{x}  {yA. y < x}  {yA. y > x} = A"
    using assms by auto
  finally show ?thesis .
qed


subsection ‹Intersection and Difference›

text ‹
  The algorithms for intersection and difference of two trees are almost identical; the only
  difference is that the ``if'' statement at the end of the recursive case is flipped. We
  therefore introduce a generic intersection/difference operation first and prove its correctness
  to avoid duplication.
›
fun mrbst_inter_diff where
  "mrbst_inter_diff _ ⟨⟩ _ = return_pmf ⟨⟩"
| "mrbst_inter_diff b l1, x, r1 t2 =
     (case split_bst' x t2 of (sep, l2, r2) 
        do {
          l  mrbst_inter_diff b l1 l2;
          r  mrbst_inter_diff b r1 r2;
          if sep = b then return_pmf l, x, r else mrbst_join l r
        })"

lemma mrbst_inter_diff_reduce:
  "mrbst_inter_diff b l1, x, r1 =
     (λt2. case split_bst' x t2 of (sep, l2, r2) 
        do {
           l  mrbst_inter_diff b l1 l2;
           r  mrbst_inter_diff b r1 r2;
           if sep = b then return_pmf l, x, r else mrbst_join l r
         })"
  by (rule ext) simp

lemma mrbst_inter_diff_Leaf_left [simp]:
  "mrbst_inter_diff b ⟨⟩ = (λ_. return_pmf ⟨⟩)"
  by (simp add: fun_eq_iff)

lemma mrbst_inter_diff_Leaf_right [simp]:
  "mrbst_inter_diff b (t1 :: 'a :: linorder tree) ⟨⟩ = return_pmf (if b then ⟨⟩ else t1)"
  by (induction t1) (auto simp: bind_return_pmf)

lemma
  fixes t1 t2 :: "'a :: linorder tree" and b :: bool
  defines "setop  (if b then (∩) else (-) :: 'a set  _)"
  assumes "t'  set_pmf (mrbst_inter_diff b t1 t2)" "bst t1" "bst t2"
  shows   bst_mrbst_inter_diff: "bst t'"
    and   set_mrbst_inter_diff: "set_tree t' = setop (set_tree t1) (set_tree t2)"
proof -
  write setop (infixl  80)
  have "bst t'  set_tree t' = set_tree t1  set_tree t2"
  using assms(2-)
  proof (induction t1 arbitrary: t2 t')
    case (Node l1 x r1 t2)
    note bst = bst l1, x, r1 bst t2
    define l2 r2 where "l2 = fst (split_bst x t2)" and "r2 = snd (split_bst x t2)"
    obtain l r
      where lr: "l  set_pmf (mrbst_inter_diff b l1 l2)" "r  set_pmf (mrbst_inter_diff b r1 r2)"
        and t': "t'  (if x  set_tree t2  b then {l, x, r} else set_pmf (mrbst_join l r))"
      using Node.prems by (force simp: case_prod_unfold l2_def r2_def isin_bst split: if_splits)
    from lr have lr': "bst l  set_tree l = set_tree l1  set_tree l2"
                      "bst r  set_tree r = set_tree r1  set_tree r2"
      using Node.prems by (intro Node.IH; force simp: l2_def r2_def)+

    have "set_tree t' = set_tree l  set_tree r  (if x  set_tree t2  b then {x} else {})"
    proof (cases "x  set_tree t2  b")
      case False
      have "x < y" if "x  set_tree l" "y  set_tree r" for x y
        using that lr' bst by (force simp: setop_def split: if_splits)
      hence set_t': "set_tree t' = set_tree l  set_tree r"
        using t' set_mrbst_join[of t' l r] False lr' by auto
      with False show ?thesis by simp
    qed (use t' in auto)
    also have " = set_tree l1, x, r1  set_tree t2"
      using lr' bst by (auto simp: setop_def l2_def r2_def set_split_bst1 set_split_bst2)
    finally have "set_tree t' = set_tree l1, x, r1  set_tree t2" .
    moreover from lr' t' bst have "bst t'"
      by (force split: if_splits simp: setop_def intro!: bst_mrbst_join[of t' l r])
    ultimately show ?case by auto
  qed (auto simp: setop_def)
  thus "bst t'" and "set_tree t' = set_tree t1  set_tree t2" by auto
qed

theorem mrbst_inter_diff_correct:
  fixes A B :: "'a :: linorder set" and b :: bool
  defines "setop  (if b then (∩) else (-) :: 'a set  _)"
  assumes "finite A" "finite B"
  shows   "do {t1  random_bst A; t2  random_bst B; mrbst_inter_diff b t1 t2} =
             random_bst (setop A B)"
  using assms(2-)
proof (induction A arbitrary: B rule: finite_psubset_induct)
  case (psubset A B)
  write setop (infixl  80)
  include monad_normalisation
  show ?case
  proof (cases "A = {}")
    case True
    thus ?thesis by (auto simp: setop_def)
  next
    case False
    define R1 R2 where "R1 = (λx. random_bst {yA. y < x})" "R2 = (λx. random_bst {yA. y > x})"

    have A_eq: "A = (A  B)  (A - B)" by auto
    have card_A_eq: "card A = card (A  B) + card (A - B)"
      using finite A finite B by (subst A_eq, subst card_Un_disjoint) auto
    have eq: "pmf_of_set A =
                do {b  bernoulli_pmf (card (A  B) / card A);
                    if b then pmf_of_set (A  B) else pmf_of_set (A - B)}"
      using psubset.prems False finite A A_eq card_A_eq
      by (subst A_eq, intro pmf_of_set_union_split [symmetric]) auto
    have "card A > 0"
      using finite A A  {} by (subst card_gt_0_iff) auto
    have not_subset: "¬A  B" if "card (A  B) < card A"
    proof
      assume "A  B"
      hence "A  B = A" by auto
      with that show False by simp
    qed

    have "do {t1  random_bst A; t2  random_bst B; mrbst_inter_diff b t1 t2} =
          do {
            x  pmf_of_set A;
            l1  random_bst {yA. y < x};
            r1  random_bst {yA. y > x};
            t2  random_bst B;
            let (l2, r2) = split_bst x t2;
            l  mrbst_inter_diff b l1 l2;
            r  mrbst_inter_diff b r1 r2;
            if isin t2 x = b then return_pmf l, x, r else mrbst_join l r
          }"
      using finite A A  {}
      by (subst random_bst_reduce)
         (auto simp: mrbst_inter_diff_reduce map_pmf_def split_bst'_altdef)
    also have " = do {
                      x  pmf_of_set A;
                      l1  random_bst {yA. y < x};
                      r1  random_bst {yA. y > x};
                      t2  random_bst B;
                      let (l2, r2) = split_bst x t2;
                      l  mrbst_inter_diff b l1 l2;
                      r  mrbst_inter_diff b r1 r2;
                      if x  B = b then return_pmf l, x, r else mrbst_join l r
                    }"
      unfolding Let_def case_prod_unfold using finite B
      by (intro bind_pmf_cong refl) (auto simp: isin_random_bst)
    also have " = do {
                      x  pmf_of_set A;
                      l1  random_bst {yA. y < x};
                      r1  random_bst {yA. y > x};
                      (l2, r2)  map_pmf (split_bst x) (random_bst B);
                      l  mrbst_inter_diff b l1 l2;
                      r  mrbst_inter_diff b r1 r2;
                      if x  B = b then return_pmf l, x, r else mrbst_join l r
                    }"
      by (simp add: Let_def map_pmf_def)
    also have " = do {
                      x  pmf_of_set A;
                      l1  random_bst {yA. y < x};
                      r1  random_bst {yA. y > x};
                      (l2, r2)  pair_pmf (random_bst {yB. y < x}) (random_bst {yB. y > x});
                      l  mrbst_inter_diff b l1 l2;
                      r  mrbst_inter_diff b r1 r2;
                      if x  B = b then return_pmf l, x, r else mrbst_join l r
                    }"
      by (intro bind_pmf_cong refl split_random_bst finite B)
    also have " = do {
                      x  pmf_of_set A;
                      l1  R1 x;
                      r1  R2 x;
                      l2  random_bst {yB. y < x};
                      r2  random_bst {yB. y > x};
                      l  mrbst_inter_diff b l1 l2;
                      r  mrbst_inter_diff b r1 r2;
                      if x  B = b then return_pmf l, x, r else mrbst_join l r
                    }"
      unfolding pair_pmf_def bind_assoc_pmf R1_R2_def by simp
    also have " = do {
                      x  pmf_of_set A;
                      l  do {l1  R1 x; l2  random_bst {yB. y < x}; mrbst_inter_diff b l1 l2};
                      r  do {r1  R2 x; r2  random_bst {yB. y > x}; mrbst_inter_diff b r1 r2};
                      if x  B = b then return_pmf l, x, r else mrbst_join l r
                    }"
      unfolding bind_assoc_pmf by (intro bind_pmf_cong[OF refl]) simp
    also have " = do {
                      x  pmf_of_set A;
                      l  random_bst ({yA. y < x}  {yB. y < x});
                      r  random_bst ({yA. y > x}  {yB. y > x});
                      if x  B = b then return_pmf l, x, r else mrbst_join l r
                    }"
      using finite A finite B A  {} unfolding R1_R2_def
      by (intro bind_pmf_cong refl psubset.IH) auto
    also have " = do {
                      x  pmf_of_set A;
                      if x  B = b then do {
                        l  random_bst ({yA. y < x}  {yB. y < x});
                        r  random_bst ({yA. y > x}  {yB. y > x});
                        return_pmf l, x, r
                      } else do {
                        l  random_bst ({yA. y < x}  {yB. y < x});
                        r  random_bst ({yA. y > x}  {yB. y > x});
                        mrbst_join l r
                      }
                    }"
      by simp
    also have " = do {
                      x  pmf_of_set A;
                      if x  B = b then do {
                        l  random_bst ({yA. y < x}  {yB. y < x});
                        r  random_bst ({yA. y > x}  {yB. y > x});
                        return_pmf l, x, r
                      } else do {
                        random_bst ({yA. y < x}  {yB. y < x}  {yA. y > x}  {yB. y > x})
                      }
                    }"
      using finite A finite B
      by (intro bind_pmf_cong refl mrbst_join_correct if_cong) (auto simp: setop_def)
    also have " = do {
                      x  pmf_of_set A;
                      if x  B = b then do {
                        l  random_bst ({yA  B. y < x});
                        r  random_bst ({yA  B. y > x});
                        return_pmf l, x, r
                      } else do {
                        random_bst (A  B)
                      }
                    }" (is "_ = pmf_of_set A  ?f")
      using finite A A  {}
      by (intro bind_pmf_cong refl if_cong arg_cong[of _ _ random_bst])
         (auto simp: order.strict_iff_order setop_def)
    also have " = do {
                      b'  bernoulli_pmf (card (A  B) / card A);
                      x  (if b' then pmf_of_set (A  B) else pmf_of_set (A - B));
                      if b' = b then do {
                        l  random_bst ({yA  B. y < x});
                        r  random_bst ({yA  B. y > x});
                        return_pmf l, x, r
                      } else do {
                        random_bst (A  B)
                      }
                    }"
      unfolding bind_assoc_pmf eq using card A > 0 finite A finite B not_subset
      by (intro bind_pmf_cong refl if_cong)
         (auto intro: bind_pmf_cong split: if_splits simp: divide_simps card_gt_0_iff
               dest!: in_set_pmf_of_setD)
    also have " = do {
                      b'  bernoulli_pmf (card (A  B) / card A);
                      if b' = b then do {
                        x  pmf_of_set (A  B);
                        l  random_bst ({yA  B. y < x});
                        r  random_bst ({yA  B. y > x});
                        return_pmf l, x, r
                      } else do {
                        random_bst (A  B)
                      }
                    }"
      by (intro bind_pmf_cong) (auto simp: setop_def)
    also have " = do {
                      b'  bernoulli_pmf (card (A  B) / card A);
                      if b' = b then do {
                        random_bst (A  B)
                      } else do {
                        random_bst (A  B)
                      }
                    }"
      using finite A finite B A  {} not_subset card A > 0
      by (intro bind_pmf_cong refl if_cong random_bst_reduce [symmetric])
         (auto simp: setop_def field_simps)
    also have " = random_bst (A  B)" by simp
    finally show ?thesis .
  qed
qed


text ‹
  We now derive the intersection and difference from the generic operation:
›

fun mrbst_inter where
  "mrbst_inter ⟨⟩ _ = return_pmf ⟨⟩"
| "mrbst_inter l1, x, r1 t2 =
     (case split_bst' x t2 of (sep, l2, r2) 
        do {
          l  mrbst_inter l1 l2;
          r  mrbst_inter r1 r2;
          if sep then return_pmf l, x, r else mrbst_join l r
        })"

lemma mrbst_inter_Leaf_left [simp]:
  "mrbst_inter ⟨⟩ = (λ_. return_pmf ⟨⟩)"
  by (simp add: fun_eq_iff)

lemma mrbst_inter_Leaf_right [simp]:
  "mrbst_inter (t1 :: 'a :: linorder tree) ⟨⟩ = return_pmf ⟨⟩"
  by (induction t1) (auto simp: bind_return_pmf)

lemma mrbst_inter_reduce:
  "mrbst_inter l1, x, r1 =
     (λt2. case split_bst' x t2 of (sep, l2, r2) 
        do {
           l  mrbst_inter l1 l2;
           r  mrbst_inter r1 r2;
           if sep then return_pmf l, x, r else mrbst_join l r
         })"
  by (rule ext) simp

lemma mrbst_inter_altdef: "mrbst_inter = mrbst_inter_diff True"
proof (intro ext)
  fix t1 t2 :: "'a tree"
  show "mrbst_inter t1 t2 = mrbst_inter_diff True t1 t2"
    by (induction t1 arbitrary: t2) auto
qed

corollary
  fixes t1 t2 :: "'a :: linorder tree"
  assumes "t'  set_pmf (mrbst_inter t1 t2)" "bst t1" "bst t2"
  shows   bst_mrbst_inter: "bst t'"
    and   set_mrbst_inter: "set_tree t' = set_tree t1  set_tree t2"
  using bst_mrbst_inter_diff[of t' True t1 t2] set_mrbst_inter_diff[of t' True t1 t2] assms
  by (simp_all add: mrbst_inter_altdef)

corollary mrbst_inter_correct:
  fixes A B :: "'a :: linorder set"
  assumes "finite A" "finite B"
  shows   "do {t1  random_bst A; t2  random_bst B; mrbst_inter t1 t2} = random_bst (A  B)"
  using assms unfolding mrbst_inter_altdef by (subst mrbst_inter_diff_correct) simp_all


fun mrbst_diff where
  "mrbst_diff ⟨⟩ _ = return_pmf ⟨⟩"
| "mrbst_diff l1, x, r1 t2 =
     (case split_bst' x t2 of (sep, l2, r2) 
        do {
          l  mrbst_diff l1 l2;
          r  mrbst_diff r1 r2;
          if sep then mrbst_join l r else return_pmf l, x, r
        })"

lemma mrbst_diff_Leaf_left [simp]:
  "mrbst_diff ⟨⟩ = (λ_. return_pmf ⟨⟩)"
  by (simp add: fun_eq_iff)

lemma mrbst_diff_Leaf_right [simp]:
  "mrbst_diff (t1 :: 'a :: linorder tree) ⟨⟩ = return_pmf t1"
  by (induction t1) (auto simp: bind_return_pmf)

lemma mrbst_diff_reduce:
  "mrbst_diff l1, x, r1 =
     (λt2. case split_bst' x t2 of (sep, l2, r2) 
        do {
           l  mrbst_diff l1 l2;
           r  mrbst_diff r1 r2;
           if sep then mrbst_join l r else return_pmf l, x, r
         })"
  by (rule ext) simp

lemma If_not: "(if ¬b then x else y) = (if b then y else x)"
  by auto

lemma mrbst_diff_altdef: "mrbst_diff = mrbst_inter_diff False"
proof (intro ext)
  fix t1 t2 :: "'a tree"
  show "mrbst_diff t1 t2 = mrbst_inter_diff False t1 t2"
    by (induction t1 arbitrary: t2) (auto simp: If_not)
qed

corollary
  fixes t1 t2 :: "'a :: linorder tree"
  assumes "t'  set_pmf (mrbst_diff t1 t2)" "bst t1" "bst t2"
  shows   bst_mrbst_diff: "bst t'"
    and   set_mrbst_diff: "set_tree t' = set_tree t1 - set_tree t2"
  using bst_mrbst_inter_diff[of t' False t1 t2] set_mrbst_inter_diff[of t' False t1 t2] assms
  by (simp_all add: mrbst_diff_altdef)

corollary mrbst_diff_correct:
  fixes A B :: "'a :: linorder set"
  assumes "finite A" "finite B"
  shows   "do {t1  random_bst A; t2  random_bst B; mrbst_diff t1 t2} = random_bst (A - B)"
  using assms unfolding mrbst_diff_altdef by (subst mrbst_inter_diff_correct) simp_all


subsection ‹Union›

text ‹
  The algorithm for the union of two trees is by far the most complicated one. It involves a 
›

(*<*)
context
  notes
    case_prod_unfold [termination_simp]
    if_splits [split]
begin
(*>*)

fun mrbst_union where
  "mrbst_union ⟨⟩ t2 = return_pmf t2"
| "mrbst_union t1 ⟨⟩ = return_pmf t1"
| "mrbst_union l1, x, r1 l2, y, r2 =
     do {
       let m = size l1, x, r1; let n = size l2, y, r2;
       b  bernoulli_pmf (m / (m + n));
       if b then do {
         let (l2', r2') = split_bst x l2, y, r2;
         l  mrbst_union l1 l2';
         r  mrbst_union r1 r2';
         return_pmf l, x, r
       } else do {
         let (sep, l1', r1') = split_bst' y l1, x, r1;
         l  mrbst_union l1' l2;
         r  mrbst_union r1' r2;
         if sep then
           mrbst_push_down l y r
         else
           return_pmf l, y, r
       }
     }"

(*<*)
end
(*>*)

lemma mrbst_union_Leaf_left [simp]: "mrbst_union ⟨⟩ = return_pmf"
  by (rule ext) simp

lemma mrbst_union_Leaf_right [simp]: "mrbst_union t1 ⟨⟩ = return_pmf t1"
  by (cases t1) simp_all

lemma
  fixes t1 t2 :: "'a :: linorder tree" and b :: bool
  assumes "t'  set_pmf (mrbst_union t1 t2)" "bst t1" "bst t2"
  shows   bst_mrbst_union: "bst t'"
    and   set_mrbst_union: "set_tree t' = set_tree t1  set_tree t2"
proof -
  have "bst t'  set_tree t' = set_tree t1  set_tree t2"
  using assms
  proof (induction "size t1 + size t2" arbitrary: t1 t2 t' rule: less_induct)
    case (less t1 t2 t')
    show ?case
    proof (cases "t1 = ⟨⟩  t2 = ⟨⟩")
      case False
      then obtain l1 x r1 l2 y r2 where t1: "t1 = l1, x, r1" and t2: "t2 = l2, y, r2"
        by (cases t1; cases t2) auto
      from less.prems consider l r where
        "l  set_pmf (mrbst_union l1 (fst (split_bst x t2)))"
        "r  set_pmf (mrbst_union r1 (snd (split_bst x t2)))"
        "t' = l, x, r"
      | l r where
        "l  set_pmf (mrbst_union (fst (split_bst y t1)) l2)"
        "r  set_pmf (mrbst_union (snd (split_bst y t1)) r2)"
        "t'  (if isin l1, x, r1 y then set_pmf (mrbst_push_down l y r) else {l, y, r})"
        by (auto simp: case_prod_unfold t1 t2 Let_def
                 simp del: split_bst.simps split_bst'.simps isin.simps split: if_splits)
      thus ?thesis
      proof cases
        case 1
        hence lr: "bst l  set_tree l = set_tree l1  set_tree (fst (split_bst x t2))"
                  "bst r  set_tree r = set_tree r1  set_tree (snd (split_bst x t2))"
          using less.prems size_split_bst[of x t2]
          by (intro less; force simp: t1)+
        thus ?thesis
          using 1 less.prems by (auto simp: t1 set_split_bst1 set_split_bst2)
      next
        case 2
        hence lr: "bst l  set_tree l = set_tree (fst (split_bst y t1))  set_tree l2"
                  "bst r  set_tree r = set_tree (snd (split_bst y t1))  set_tree r2"
          using less.prems size_split_bst[of y t1]
          by (intro less; force simp: t2)+
        show ?thesis
        proof (cases "isin l1, x, r1 y")
          case False
          thus ?thesis using 2 less.prems lr
            by (auto simp del: isin.simps simp: t2 set_split_bst1 set_split_bst2)
        next
          case True
          have bst': "zset_tree l. z < y" "zset_tree r. z > y"
            using lr less.prems by (auto simp: set_split_bst1 set_split_bst2 t2)
          from True and 2 have t': "t'  set_pmf (mrbst_push_down l y r)"
            by (auto simp del: isin.simps)
          from t' have "bst t'"
            by (rule bst_mrbst_push_down) (use lr bst' in auto)
          moreover from t' have "set_tree t' = {y}  set_tree l  set_tree r"
            by (rule set_mrbst_push_down) (use lr bst' in auto)
          ultimately show ?thesis using less.prems lr
            by (auto simp del: isin.simps simp: t2 set_split_bst1 set_split_bst2)
        qed
      qed
    qed (use less.prems in auto)
  qed
  thus "bst t'" and "set_tree t' = set_tree t1  set_tree t2" by auto
qed

theorem mrbst_union_correct:
  assumes "finite A" "finite B"
  shows   "do {t1  random_bst A; t2  random_bst B; mrbst_union t1 t2} =
             random_bst (A  B)"
proof -
  from assms have "finite (A  B)" by simp
  thus ?thesis
  proof (induction "A  B" arbitrary: A B rule: finite_psubset_induct)
    case (psubset A B)
    show ?case
    proof (cases "A = {}  B = {}")
      case True
      thus ?thesis including monad_normalisation by auto
    next
      case False
      with psubset.hyps have AB: "finite A" "finite B" "A  {}" "B  {}" by auto
      define m n l where "m = card A" and "n = card B" and "l = card (A  B)"
      define p q where "p = m / (m + n)" and "q = l / n"
      define r where "r = p / (1 - (1 - p) * q)"
      from AB have mn: "m > 0" "n > 0" by (auto simp: m_def n_def)
      have pq: "p  {0..1}" "q  {0..1}"
        using AB by (auto simp: p_def q_def m_def n_def l_def divide_simps intro: card_mono)
      moreover have "p  0"
        using AB by (auto simp: p_def m_def n_def divide_simps add_nonneg_eq_0_iff)
      ultimately have "p > 0" by auto

      have "B - A = B - (A  B)" by auto
      also have "card  = n - l"
        using AB unfolding n_def l_def by (intro card_Diff_subset) auto
      finally have [simp]: "card (B - A) = n - l" .
      from AB have "l  n" unfolding l_def n_def by (intro card_mono) auto

      have "p  1 - (1 - p) * q"
        using mn l  n by (auto simp: p_def q_def divide_simps)
      hence r_aux: "(1 - p) * q  {0..1 - p}"
        using pq by auto

      include monad_normalisation
      define RA1 RA2 RB1 RB2
        where "RA1 = (λx. random_bst {zA. z < x})" and "RA2 = (λx. random_bst {zA. z > x})"
          and "RB1 = (λx. random_bst {zB. z < x})" and "RB2 = (λx. random_bst {zB. z > x})"

      have "do {t1  random_bst A; t2  random_bst B; mrbst_union t1 t2} =
              do {
                x  pmf_of_set A;
                l1  random_bst {zA. z < x};
                r1  random_bst {zA. z > x};
                y  pmf_of_set B;
                l2  random_bst {zB. z < y};
                r2  random_bst {zB. z > y};
                let m = size l1, x, r1;
                let n = size l2, y, r2;
                b  bernoulli_pmf (m / (m + n));
                if b then do {
                  l  mrbst_union l1 (fst (split_bst x l2, y, r2));
                  r  mrbst_union r1 (snd (split_bst x l2, y, r2));
                  return_pmf l, x, r
                } else do {
                  l  mrbst_union (fst (split_bst y l1, x, r1)) l2;
                  r  mrbst_union (snd (split_bst y l1, x, r1)) r2;
                  if isin l1, x, r1 y then
                    mrbst_push_down l y r
                  else
                    return_pmf l, y, r
                }
              }" using AB
        by (simp add: random_bst_reduce split_bst'_altdef Let_def case_prod_unfold cong: if_cong)
      also have " = do {
                        x  pmf_of_set A;
                        l1  random_bst {zA. z < x};
                        r1  random_bst {zA. z > x};
                        y  pmf_of_set B;
                        l2  random_bst {zB. z < y};
                        r2  random_bst {zB. z > y};
                        b  bernoulli_pmf p;
                        if b then do {
                          l  mrbst_union l1 (fst (split_bst x l2, y, r2));
                          r  mrbst_union r1 (snd (split_bst x l2, y, r2));
                          return_pmf l, x, r
                        } else do {
                          l  mrbst_union (fst (split_bst y l1, x, r1)) l2;
                          r  mrbst_union (snd (split_bst y l1, x, r1)) r2;
                          if y  A then
                            mrbst_push_down l y r
                          else
                            return_pmf l, y, r
                        }
                      }"
        unfolding Let_def
      proof (intro bind_pmf_cong refl if_cong)
        fix l1 x r1 y
        assume "l1  set_pmf (random_bst {zA. z < x})" "r1  set_pmf (random_bst {zA. z > x})"
               "x  set_pmf (pmf_of_set A)"
        thus "isin l1, x, r1 y  (y  A)"
          using AB by (subst isin_bst) (auto simp: bst_random_bst set_random_bst)
      qed (insert AB,
           auto simp: size_random_bst m_def n_def p_def isin_random_bst dest!: card_3way_split)
      also have " = do {
                        b  bernoulli_pmf p;
                        if b then do {
                          x  pmf_of_set A;
                          (l1, r1)  pair_pmf (random_bst {zA. z < x}) (random_bst {zA. z > x});
                          (l2, r2)  map_pmf (split_bst x) (random_bst B);
                          l  mrbst_union l1 l2;
                          r  mrbst_union r1 r2;
                          return_pmf l, x, r
                        } else do {
                          y  pmf_of_set B;
                          (l1, r1)  map_pmf (split_bst y) (random_bst A);
                          (l2, r2)  pair_pmf (random_bst {zB. z < y}) (random_bst {zB. z > y});
                          l  mrbst_union l1 l2;
                          r  mrbst_union r1 r2;
                          if y  A then
                            mrbst_push_down l y r
                          else
                            return_pmf l, y, r
                        }
                      }" using AB
        by (simp add: random_bst_reduce map_pmf_def case_prod_unfold pair_pmf_def cong: if_cong)
      also have " = do {
                        b  bernoulli_pmf p;
                        if b then do {
                          x  pmf_of_set A;
                          (l1, r1)  pair_pmf (RA1 x) (RA2 x);
                          (l2, r2)  pair_pmf (RB1 x) (RB2 x);
                          l  mrbst_union l1 l2;
                          r  mrbst_union r1 r2;
                          return_pmf l, x, r
                        } else do {
                          y  pmf_of_set B;
                          (l1, r1)  pair_pmf (RA1 y) (RA2 y);
                          (l2, r2)  pair_pmf (RB1 y) (RB2 y);
                          l  mrbst_union l1 l2;
                          r  mrbst_union r1 r2;
                          if y  A then
                            mrbst_push_down l y r
                          else
                            return_pmf l, y, r
                        }
                      }"
        unfolding case_prod_unfold RA1_def RA2_def RB1_def RB2_def
        by (intro bind_pmf_cong refl if_cong split_random_bst AB)
      also have " = do {
                        b  bernoulli_pmf p;
                        if b then do {
                          x  pmf_of_set A;
                          l  do {l1  RA1 x; l2  RB1 x; mrbst_union l1 l2};
                          r  do {r1  RA2 x; r2  RB2 x; mrbst_union r1 r2};
                          return_pmf l, x, r
                        } else do {
                          y  pmf_of_set B;
                          l  do {l1  RA1 y; l2  RB1 y; mrbst_union l1 l2};
                          r  do {r1  RA2 y; r2  RB2 y; mrbst_union r1 r2};
                          if y  A then
                            mrbst_push_down l y r
                          else
                            return_pmf l, y, r
                        }
                      }"
        by (simp add: pair_pmf_def cong: if_cong)
      also have " = do {
                        b  bernoulli_pmf p;
                        if b then do {
                          x  pmf_of_set A;
                          l  random_bst ({zA. z < x}  {zB. z < x});
                          r  random_bst ({zA. z > x}  {zB. z > x});
                          return_pmf l, x, r
                        } else do {
                          y  pmf_of_set B;
                          l  random_bst ({zA. z < y}  {zB. z < y});
                          r  random_bst ({zA. z > y}  {zB. z > y});
                          if y  A then
                            mrbst_push_down l y r
                          else
                            return_pmf l, y, r
                        }
                      }"
        unfolding RA1_def RA2_def RB1_def RB2_def using AB
        by (intro bind_pmf_cong if_cong refl psubset) auto
      also have " = do {
                        b  bernoulli_pmf p;
                        if b then do {
                          x  pmf_of_set A;
                          l  random_bst {zA  B. z < x};
                          r  random_bst {zA  B. z > x};
                          return_pmf l, x, r
                        } else do {
                          y  pmf_of_set B;
                          l  random_bst {zA  B. z < y};
                          r  random_bst {zA  B. z > y};
                          if y  A then
                            mrbst_push_down l y r
                          else
                            return_pmf l, y, r
                        }
                      }"
        by (intro bind_pmf_cong if_cong refl arg_cong[of _ _ random_bst]) auto
      also have " = do {
                        b  bernoulli_pmf p;
                        if b then do {
                          x  pmf_of_set A;
                          l  random_bst {zA  B. z < x};
                          r  random_bst {zA  B. z > x};
                          return_pmf l, x, r
                        } else do {
                          b'  bernoulli_pmf q;
                          if b' then do {
                            y  pmf_of_set (A  B);
                            random_bst (A  B)
                          } else do {
                            y  pmf_of_set (B - A);
                            l  random_bst {zA  B. z < y};
                            r  random_bst {zA  B. z > y};
                            return_pmf l, y, r
                          }
                        }
                      }"
      proof (intro bind_pmf_cong refl if_cong, goal_cases)
        case (1 b)
        have q_pos: "A  B  {}" if "q > 0" using that by (auto simp: q_def l_def)
        have q_lt1: "B - A  {}" if "q < 1"
        proof
          assume "B - A = {}"
          hence "A  B = B" by auto
          thus False using that AB by (auto simp: q_def l_def n_def)
        qed

        have eq: "pmf_of_set B = do {b'  bernoulli_pmf q;
                                     if b' then pmf_of_set (A  B) else pmf_of_set (B - A)}"
          using AB by (intro pmf_of_set_split_inter_diff [symmetric])
                      (auto simp: q_def l_def n_def)
        have "do {y  pmf_of_set B;
                  l  random_bst {zA  B. z < y};
                  r  random_bst {zA  B. z > y};
                  if y  A then
                    mrbst_push_down l y r
                  else
                    return_pmf l, y, r
                 } =
              do {
                b'  bernoulli_pmf q;
                y  (if b' then pmf_of_set (A  B) else pmf_of_set (B - A));
                l  random_bst {zA  B. z < y};
                r  random_bst {zA  B. z > y};
                if b' then
                  mrbst_push_down l y r
                else
                  return_pmf l, y, r
              }" unfolding eq bind_assoc_pmf using AB q_pos q_lt1
          by (intro bind_pmf_cong refl if_cong) (auto split: if_splits)
        also have " = do {
                          b'  bernoulli_pmf q;
                          if b' then do {
                            y  pmf_of_set (A  B);
                            do {l  random_bst {zA  B. z < y};
                                r  random_bst {zA  B. z > y};
                                mrbst_push_down l y r}
                          } else do {
                            y  pmf_of_set (B - A);
                            l  random_bst {zA  B. z < y};
                            r  random_bst {zA  B. z > y};
                            return_pmf l, y, r
                          }
                        }" by (simp cong: if_cong)
        also have " = do {
                          b'  bernoulli_pmf q;
                          if b' then do {
                            y  pmf_of_set (A  B);
                            random_bst (A  B)
                          } else do {
                            y  pmf_of_set (B - A);
                            l  random_bst {zA  B. z < y};
                            r  random_bst {zA  B. z > y};
                            return_pmf l, y, r
                          }
                        }"
          using AB q_pos by (intro bind_pmf_cong if_cong refl mrbst_push_down_correct') auto
        finally show ?case .
      qed
      also have " = do {
                        b  bernoulli_pmf p;
                        b'  bernoulli_pmf q;
                        if b then do {
                          x  pmf_of_set A;
                          l  random_bst {zA  B. z < x};
                          r  random_bst {zA  B. z > x};
                          return_pmf l, x, r
                        } else if b' then do {
                            random_bst (A  B)
                        } else do {
                          y  pmf_of_set (B - A);
                          l  random_bst {zA  B. z < y};
                          r  random_bst {zA  B. z > y};
                          return_pmf l, y, r
                        }
                      }"
        by (simp cong: if_cong)
      also have " = do {
                        (b, b')  pair_pmf (bernoulli_pmf p) (bernoulli_pmf q);
                        if b  ¬b' then do {
                          x  (if b then pmf_of_set A else pmf_of_set (B - A));
                          l  random_bst {zA  B. z < x};
                          r  random_bst {zA  B. z > x};
                          return_pmf l, x, r
                        } else do {
                            random_bst (A  B)
                        }
                      }" unfolding pair_pmf_def bind_assoc_pmf
        by (intro bind_pmf_cong) auto
      also have " = do {
                        (b, b')  map_pmf (λ(b, b'). (b  ¬b', b))
                                    (pair_pmf (bernoulli_pmf p) (bernoulli_pmf q));
                        if b then do {
                          x  (if b' then pmf_of_set A else pmf_of_set (B - A));
                          l  random_bst {zA  B. z < x};
                          r  random_bst {zA  B. z > x};
                          return_pmf l, x, r
                        } else do {
                            random_bst (A  B)
                        }
                      }" (is "_ = bind_pmf _ ?f")
        by (simp add: bind_map_pmf case_prod_unfold cong: if_cong)
      also have "map_pmf (λ(b, b'). (b  ¬b', b))
                   (pair_pmf (bernoulli_pmf p) (bernoulli_pmf q)) =
                 do {
                   b  bernoulli_pmf (1 - (1 - p) * q);
                   b'  (if b then bernoulli_pmf r else return_pmf False);
                   return_pmf (b, b')
                 }" (is "?lhs = ?rhs")
      proof (intro pmf_eqI)
        fix bb' :: "bool × bool" 
        obtain b b' where [simp]: "bb' = (b, b')" by (cases bb')
        thus "pmf ?lhs bb' = pmf ?rhs bb'"
          using pq r_aux p > 0
          by (cases b; cases b')
             (auto simp: pmf_map pmf_bind_bernoulli measure_measure_pmf_finite 
                         vimage_bool_pair pmf_pair r_def field_simps)
      qed
      also have "  ?f = do {
                              b  bernoulli_pmf (1 - (1 - p) * q);
                              if b then do {
                                x  do {b'  bernoulli_pmf r;
                                         if b' then pmf_of_set A else pmf_of_set (B - A)};
                                l  random_bst {zA  B. z < x};
                                r  random_bst {zA  B. z > x};
                                return_pmf l, x, r
                              } else do {
                                random_bst (A  B)
                              }
                            }"
        by (simp cong: if_cong)
      also have " = do {
                        b  bernoulli_pmf (1 - (1 - p) * q);
                        if b then do {
                          x  pmf_of_set (A  (B - A));
                          l  random_bst {zA  B. z < x};
                          r  random_bst {zA  B. z > x};
                          return_pmf l, x, r
                        } else do {
                          random_bst (A  B)
                        }
                      }" (is "_ = ?f (A  (B - A))")
        using AB pq l  n mn
        by (intro bind_pmf_cong if_cong refl pmf_of_set_union_split)
           (auto simp: m_def [symmetric] n_def [symmetric] r_def p_def q_def divide_simps)
      also have "A  (B - A) = A  B" by auto
      also have "?f  = random_bst (A  B)"
        using AB by (simp add: random_bst_reduce cong: if_cong)
      finally show ?thesis .
    qed
  qed
qed


subsection ‹Insertion and Deletion›

text ‹
  The insertion and deletion operations are simple special cases of the union
  and difference operations where one of the trees is a singleton tree.
›
fun mrbst_insert where
  "mrbst_insert x ⟨⟩ = return_pmf ⟨⟩, x, ⟨⟩"
| "mrbst_insert x l, y, r =
     do {
       b  bernoulli_pmf (1 / real (size l + size r + 2));
       if b then do {
         let (l', r') = split_bst x l, y, r;
         return_pmf l', x, r'
       } else if x < y then do {
         map_pmf (λl'. l', y, r) (mrbst_insert x l)
       } else if x > y then do {
         map_pmf (λr'. l, y, r') (mrbst_insert x r)
       } else do {
         mrbst_push_down l y r
       }
     }"

lemma mrbst_insert_altdef: "mrbst_insert x t = mrbst_union ⟨⟩, x, ⟨⟩ t"
  by (induction x t rule: mrbst_insert.induct)
     (simp_all add: Let_def map_pmf_def bind_return_pmf case_prod_unfold cong: if_cong)

corollary
  fixes t :: "'a :: linorder tree"
  assumes "t'  set_pmf (mrbst_insert x t)" "bst t"
  shows   bst_mrbst_insert: "bst t'"
    and   set_mrbst_insert: "set_tree t' = insert x (set_tree t)"
  using bst_mrbst_union[of t' "⟨⟩, x, ⟨⟩" t] set_mrbst_union[of t' "⟨⟩, x, ⟨⟩" t] assms
  by (simp_all add: mrbst_insert_altdef)

corollary mrbst_insert_correct:
  assumes "finite A"
  shows   "random_bst A  mrbst_insert x = random_bst (insert x A)"
  using mrbst_union_correct[of "{x}" A] assms
  by (simp add: mrbst_insert_altdef[abs_def] bind_return_pmf)


fun mrbst_delete :: "'a :: ord  'a tree  'a tree pmf" where
  "mrbst_delete x ⟨⟩ = return_pmf ⟨⟩"
| "mrbst_delete x l, y, r = (
     if x < y then
       map_pmf (λl'. l', y, r) (mrbst_delete x l)
     else if x > y then
       map_pmf (λr'. l, y, r') (mrbst_delete x r)
     else 
       mrbst_join l r)"

lemma mrbst_delete_altdef: "mrbst_delete x t = mrbst_diff t ⟨⟩, x, ⟨⟩"
  by (induction t) (auto simp: bind_return_pmf map_pmf_def)

corollary
  fixes t :: "'a :: linorder tree"
  assumes "t'  set_pmf (mrbst_delete x t)" "bst t"
  shows   bst_mrbst_delete: "bst t'"
    and   set_mrbst_delete: "set_tree t' = set_tree t - {x}"
  using bst_mrbst_diff[of t' t "⟨⟩, x, ⟨⟩"] set_mrbst_diff[of t' t "⟨⟩, x, ⟨⟩"] assms
  by (simp_all add: mrbst_delete_altdef)

corollary mrbst_delete_correct:
  "finite A  do {t  random_bst A; mrbst_delete x t} = random_bst (A - {x})"
  using mrbst_diff_correct[of A "{x}"] by (simp add: mrbst_delete_altdef bind_return_pmf)

end