Theory Randomised_BSTs
section ‹Randomised Binary Search Trees›
theory Randomised_BSTs
imports "Random_BSTs.Random_BSTs" "Monad_Normalisation.Monad_Normalisation"
begin
subsection ‹Auxiliary facts›
text ‹
First of all, we need some fairly simple auxiliary lemmas.
›
lemma return_pmf_if: "return_pmf (if P then a else b) = (if P then return_pmf a else return_pmf b)"
by simp
context
begin
interpretation pmf_as_function .
lemma True_in_set_bernoulli_pmf_iff [simp]:
"True ∈ set_pmf (bernoulli_pmf p) ⟷ p > 0"
by transfer auto
lemma False_in_set_bernoulli_pmf_iff [simp]:
"False ∈ set_pmf (bernoulli_pmf p) ⟷ p < 1"
by transfer auto
end
lemma in_set_pmf_of_setD: "x ∈ set_pmf (pmf_of_set A) ⟹ finite A ⟹ A ≠ {} ⟹ x ∈ A"
by (subst (asm) set_pmf_of_set) auto
lemma random_bst_reduce:
"finite A ⟹ A ≠ {} ⟹
random_bst A = do {x ← pmf_of_set A; l ← random_bst {y∈A. y < x};
r ← random_bst {y∈A. y > x}; return_pmf ⟨l, x, r⟩}"
by (subst random_bst.simps) auto
lemma pmf_bind_bernoulli:
assumes "x ∈ {0..1}"
shows "pmf (bernoulli_pmf x ⤜ f) y = x * pmf (f True) y + (1 - x) * pmf (f False) y"
using assms by (simp add: pmf_bind)
lemma vimage_bool_pair:
"f -` A = (⋃x∈{True, False}. ⋃y∈{True, False}. if f (x, y) ∈ A then {(x, y)} else {})"
(is "?lhs = ?rhs") unfolding set_eq_iff
proof
fix x :: "bool × bool"
obtain a b where [simp]: "x = (a, b)" by (cases x)
show "x ∈ ?lhs ⟷ x ∈ ?rhs"
by (cases a; cases b) auto
qed
lemma Leaf_in_set_random_bst_iff [simp]:
"Leaf ∈ set_pmf (random_bst A) ⟷ A = {} ∨ ¬finite A"
by (subst random_bst.simps) auto
lemma bst_insert [intro]: "bst t ⟹ bst (Tree_Set.insert x t)"
by (simp add: bst_iff_sorted_wrt_less inorder_insert sorted_ins_list)
lemma bst_bst_of_list [intro]: "bst (bst_of_list xs)"
proof -
have "bst (fold Tree_Set.insert xs t)" if "bst t" for t
using that
proof (induction xs arbitrary: t)
case (Cons y xs)
show ?case by (auto intro!: Cons bst_insert)
qed auto
thus ?thesis by (simp add: bst_of_list_altdef)
qed
lemma bst_random_bst:
assumes "t ∈ set_pmf (random_bst A)"
shows "bst t"
proof (cases "finite A")
case True
have "random_bst A = map_pmf bst_of_list (pmf_of_set (permutations_of_set A))"
by (rule random_bst_altdef) fact+
also have "set_pmf … = bst_of_list ` permutations_of_set A"
using True by auto
finally show ?thesis using assms by auto
next
case False
hence "random_bst A = return_pmf ⟨⟩"
by (simp add: random_bst.simps)
with assms show ?thesis by simp
qed
lemma set_random_bst:
assumes "t ∈ set_pmf (random_bst A)" "finite A"
shows "set_tree t = A"
proof -
have "random_bst A = map_pmf bst_of_list (pmf_of_set (permutations_of_set A))"
by (rule random_bst_altdef) fact+
also have "set_pmf … = bst_of_list ` permutations_of_set A"
using assms by auto
finally show ?thesis using assms
by (auto simp: permutations_of_setD)
qed
lemma isin_bst:
assumes "bst t"
shows "isin t x ⟷ x ∈ set_tree t"
using assms
by (subst isin_set) (auto simp: bst_iff_sorted_wrt_less)
lemma isin_random_bst:
assumes "finite A" "t ∈ set_pmf (random_bst A)"
shows "isin t x ⟷ x ∈ A"
proof -
from assms have "bst t" by (auto dest: bst_random_bst)
with assms show ?thesis by (simp add: isin_bst set_random_bst)
qed
lemma card_3way_split:
assumes "x ∈ (A :: 'a :: linorder set)" "finite A"
shows "card A = card {y∈A. y < x} + card {y∈A. y > x} + 1"
proof -
from assms have "A = insert x ({y∈A. y < x} ∪ {y∈A. y > x})"
by auto
also have "card … = card {y∈A. y < x} + card {y∈A. y > x} + 1"
using assms by (subst card_insert_disjoint) (auto intro: card_Un_disjoint)
finally show ?thesis .
qed
text ‹
The following theorem allows splitting a uniformly random choice from a union of two disjoint
sets to first tossing a coin to decide on one of the constituent sets and then chooing an
element from it uniformly at random.
›
lemma pmf_of_set_union_split:
assumes "finite A" "finite B" "A ∩ B = {}" "A ∪ B ≠ {}"
assumes "p = card A / (card A + card B)"
shows "do {b ← bernoulli_pmf p; if b then pmf_of_set A else pmf_of_set B} = pmf_of_set (A ∪ B)"
(is "?lhs = ?rhs")
proof (rule pmf_eqI)
fix x :: 'a
from assms have p: "p ∈ {0..1}"
by (auto simp: divide_simps assms(5) split: if_splits)
have "pmf ?lhs x = pmf (pmf_of_set A) x * p + pmf (pmf_of_set B) x * (1 - p)"
unfolding pmf_bind using p by (subst integral_bernoulli_pmf) auto
also consider "x ∈ A" "B ≠ {}" | "x ∈ B" "A ≠ {}" | "x ∈ A" "B = {}" | "x ∈ B" "A = {}" |
"x ∉ A" "x ∉ B"
using assms by auto
hence "pmf (pmf_of_set A) x * p + pmf (pmf_of_set B) x * (1 - p) = pmf ?rhs x"
proof cases
assume "x ∉ A" "x ∉ B"
thus ?thesis using assms by (cases "A = {}"; cases "B = {}") auto
next
assume "x ∈ A" and [simp]: "B ≠ {}"
have "pmf (pmf_of_set A) x * p + pmf (pmf_of_set B) x * (1 - p) = p / real (card A)"
using ‹x ∈ A› assms(1-4) by (subst (1 2) pmf_of_set) (auto simp: indicator_def)
also have "… = pmf ?rhs x"
using assms ‹x ∈ A› by (subst pmf_of_set) (auto simp: card_Un_disjoint)
finally show ?thesis .
next
assume "x ∈ B" and [simp]: "A ≠ {}"
from assms have *: "card (A ∪ B) > 0" by (subst card_gt_0_iff) auto
have "pmf (pmf_of_set A) x * p + pmf (pmf_of_set B) x * (1 - p) = (1 - p) / real (card B)"
using ‹x ∈ B› assms(1-4) by (subst (1 2) pmf_of_set) (auto simp: indicator_def)
also have "… = pmf ?rhs x"
using assms ‹x ∈ B› *
by (subst pmf_of_set) (auto simp: card_Un_disjoint assms(5) divide_simps)
finally show ?thesis .
qed (insert assms(1-4), auto simp: assms(5))
finally show "pmf ?lhs x = pmf ?rhs x" .
qed
lemma pmf_of_set_split_inter_diff:
assumes "finite A" "finite B" "A ≠ {}" "B ≠ {}"
assumes "p = card (A ∩ B) / card B"
shows "do {b ← bernoulli_pmf p; if b then pmf_of_set (A ∩ B) else pmf_of_set (B - A)} =
pmf_of_set B" (is "?lhs = ?rhs")
proof -
have eq: "B = (A ∩ B) ∪ (B - A)" by auto
have card_eq: "card B = card (A ∩ B) + card (B - A)"
using assms by (subst eq, subst card_Un_disjoint) auto
have "?lhs = pmf_of_set ((A ∩ B) ∪ (B - A))"
using assms by (intro pmf_of_set_union_split) (auto simp: card_eq)
with eq show ?thesis by simp
qed
text ‹
Similarly to the above rule, we can split up a uniformly random choice from the disjoint
union of three sets. This could be done with two coin flips, but it is more convenient to
choose a natural number uniformly at random instead and then do a case distinction on it.
›
lemma pmf_of_set_3way_split:
fixes f g h :: "'a ⇒ 'b pmf"
assumes "finite A" "A ≠ {}" "A1 ∩ A2 = {}" "A1 ∩ A3 = {}" "A2 ∩ A3 = {}" "A1 ∪ A2 ∪ A3 = A"
shows "do {x ← pmf_of_set A; if x ∈ A1 then f x else if x ∈ A2 then g x else h x} =
do {i ← pmf_of_set {..<card A};
if i < card A1 then pmf_of_set A1 ⤜ f
else if i < card A1 + card A2 then pmf_of_set A2 ⤜ g
else pmf_of_set A3 ⤜ h}" (is "?lhs = ?rhs")
proof (intro pmf_eqI)
fix x :: 'b
define m n l where "m = card A1" and "n = card A2" and "l = card A3"
have [simp]: "finite A1" "finite A2" "finite A3"
by (rule finite_subset[of _ A]; use assms in force)+
from assms have card_pos: "card A > 0" by auto
have A_eq: "A = A1 ∪ A2 ∪ A3" using assms by simp
have card_A_eq: "card A = card A1 + card A2 + card A3"
using assms unfolding A_eq by (subst card_Un_disjoint, simp, simp, force)+ auto
have card_A_eq': "{..<card A} = {..<m} ∪ {m..<m + n} ∪ {m + n..<card A}"
by (auto simp: m_def n_def card_A_eq)
let ?M = "λi. if i < m then pmf_of_set A1 ⤜ f else if i < m + n then
pmf_of_set A2 ⤜ g else pmf_of_set A3 ⤜ h"
have card_times_pmf_of_set_bind:
"card X * pmf (pmf_of_set X ⤜ f) x = (∑y∈X. pmf (f y) x)"
if "finite X" for X :: "'a set" and f :: "'a ⇒ 'b pmf"
using that by (cases "X = {}") (auto simp: pmf_bind_pmf_of_set)
have "pmf ?rhs x = (∑i<card A. pmf (?M i) x) / card A"
(is "_ = ?S / _") using assms card_pos unfolding m_def n_def
by (subst pmf_bind_pmf_of_set) auto
also have "?S = (real m * pmf (pmf_of_set A1 ⤜ f) x +
real n * pmf (pmf_of_set A2 ⤜ g) x +
real l * pmf (pmf_of_set A3 ⤜ h) x)" unfolding card_A_eq'
by (subst sum.union_disjoint, simp, simp, force)+ (auto simp: card_A_eq m_def n_def l_def)
also have "… = (∑y∈A1. pmf (f y) x) + (∑y∈A2. pmf (g y) x) + (∑y∈A3. pmf (h y) x)"
unfolding m_def n_def l_def by (subst (1 2 3) card_times_pmf_of_set_bind) auto
also have "… = (∑y∈A1 ∪ A2 ∪ A3.
pmf (if y ∈ A1 then f y else if y ∈ A2 then g y else h y) x)"
using assms(1-5)
by (subst sum.union_disjoint, simp, simp, force)+
(intro arg_cong2[of _ _ _ _ "(+)"] sum.cong, auto)
also have "… / card A = pmf ?lhs x"
using assms by (simp add: pmf_bind_pmf_of_set)
finally show "pmf ?lhs x = pmf ?rhs x"
unfolding m_def n_def l_def card_A_eq ..
qed
subsection ‹Partitioning a BST›
text ‹
The split operation takes a search parameter ‹x› and partitions a BST into two BSTs
containing all the values that are smaller than ‹x› and those that are greater than ‹x›,
respectively. Note that ‹x› need not be an element of the tree.
›
fun split_bst :: "'a :: linorder ⇒ 'a tree ⇒ 'a tree × 'a tree" where
"split_bst _ ⟨⟩ = (⟨⟩, ⟨⟩)"
| "split_bst x ⟨l, y, r⟩ =
(if y < x then
case split_bst x r of (t1, t2) ⇒ (⟨l, y, t1⟩, t2)
else if y > x then
case split_bst x l of (t1, t2) ⇒ (t1, ⟨t2, y, r⟩)
else
(l, r))"
fun split_bst' :: "'a :: linorder ⇒ 'a tree ⇒ bool × 'a tree × 'a tree" where
"split_bst' _ ⟨⟩ = (False, ⟨⟩, ⟨⟩)"
| "split_bst' x ⟨l, y, r⟩ =
(if y < x then
case split_bst' x r of (b, t1, t2) ⇒ (b, ⟨l, y, t1⟩, t2)
else if y > x then
case split_bst' x l of (b, t1, t2) ⇒ (b, t1, ⟨t2, y, r⟩)
else
(True, l, r))"
lemma split_bst'_altdef: "split_bst' x t = (isin t x, split_bst x t)"
by (induction x t rule: split_bst.induct) (auto simp: case_prod_unfold)
lemma fst_split_bst' [simp]: "fst (split_bst' x t) = isin t x"
and snd_split_bst' [simp]: "snd (split_bst' x t) = split_bst x t"
by (simp_all add: split_bst'_altdef)
lemma size_fst_split_bst [termination_simp]: "size (fst (split_bst x t)) ≤ size t"
by (induction t) (auto simp: case_prod_unfold)
lemma size_snd_split_bst [termination_simp]: "size (snd (split_bst x t)) ≤ size t"
by (induction t) (auto simp: case_prod_unfold)
lemmas size_split_bst = size_fst_split_bst size_snd_split_bst
lemma set_split_bst1: "bst t ⟹ set_tree (fst (split_bst x t)) = {y ∈ set_tree t. y < x}"
by (induction t) (auto split: prod.splits)
lemma set_split_bst2: "bst t ⟹ set_tree (snd (split_bst x t)) = {y ∈ set_tree t. y > x}"
by (induction t) (auto split: prod.splits)
lemma bst_split_bst1 [intro]: "bst t ⟹ bst (fst (split_bst x t))"
by (induction t) (auto simp: case_prod_unfold set_split_bst1)
lemma bst_split_bst2 [intro]: "bst t ⟹ bst (snd (split_bst x t))"
by (induction t) (auto simp: case_prod_unfold set_split_bst2)
text ‹
Splitting a random BST produces two random BSTs:
›
theorem split_random_bst:
assumes "finite A"
shows "map_pmf (split_bst x) (random_bst A) =
pair_pmf (random_bst {y∈A. y < x}) (random_bst {y∈A. y > x})"
using assms
proof (induction A rule: random_bst.induct)
case (1 A)
define A⇩1 A⇩2 where "A⇩1 = {y∈A. y < x}" and "A⇩2 = {y∈A. y > x}"
have [simp]: "¬x ∈ A⇩2" if "x ∈ A⇩1" for x using that by (auto simp: A⇩1_def A⇩2_def)
from ‹finite A› have [simp]: "finite A⇩1" "finite A⇩2" by (auto simp: A⇩1_def A⇩2_def)
include monad_normalisation
show ?case
proof (cases "A = {}")
case True
thus ?thesis by (auto simp: pair_return_pmf1)
next
case False
have "map_pmf (split_bst x) (random_bst A) =
do {y ← pmf_of_set A;
if y < x then
do {
l ← random_bst {z∈A. z < y};
(t1, t2) ← map_pmf (split_bst x) (random_bst {z∈A. z > y});
return_pmf (⟨l, y, t1⟩, t2)
}
else if y > x then
do {
(t1, t2) ← map_pmf (split_bst x) (random_bst {z∈A. z < y});
r ← random_bst {z∈A. z > y};
return_pmf (t1, (⟨t2, y, r⟩))
}
else
do {
l ← random_bst {z∈A. z < y};
r ← random_bst {z∈A. z > y};
return_pmf (l, r)
}
}"
using "1.prems" False
by (subst random_bst.simps)
(simp add: map_bind_pmf bind_map_pmf return_pmf_if case_prod_unfold cong: if_cong)
also have "… = do {y ← pmf_of_set A;
if y < x then
do {
l ← random_bst {z∈A. z < y};
(t1, t2) ← pair_pmf (random_bst {z∈{z∈A. z > y}. z < x})
(random_bst {z∈{z∈A. z > y}. z > x});
return_pmf (⟨l, y, t1⟩, t2)
}
else if y > x then
do {
(t1, t2) ← pair_pmf (random_bst {z∈{z∈A. z < y}. z < x})
(random_bst {z∈{z∈A. z < y}. z > x});
r ← random_bst {z∈A. z > y};
return_pmf (t1, (⟨t2, y, r⟩))
}
else
do {
l ← random_bst {z∈A. z < y};
r ← random_bst {z∈A. z > y};
return_pmf (l, r)
}
}"
using ‹finite A› and ‹A ≠ {}› thm "1.IH"
by (intro bind_pmf_cong if_cong refl "1.IH") auto
also have "… = do {y ← pmf_of_set A;
if y < x then
do {
l ← random_bst {z∈A. z < y};
t1 ← random_bst {z∈{z∈A. z > y}. z < x};
t2 ← random_bst {z∈{z∈A. z > y}. z > x};
return_pmf (⟨l, y, t1⟩, t2)
}
else if y > x then
do {
t1 ← random_bst {z∈{z∈A. z < y}. z < x};
t2 ← random_bst {z∈{z∈A. z < y}. z > x};
r ← random_bst {z∈A. z > y};
return_pmf (t1, (⟨t2, y, r⟩))
}
else
do {
l ← random_bst {z∈A. z < y};
r ← random_bst {z∈A. z > y};
return_pmf (l, r)
}
}"
by (simp add: pair_pmf_def cong: if_cong)
also have "… = do {y ← pmf_of_set A;
if y ∈ A⇩1 then
do {
l ← random_bst {z∈A⇩1. z < y};
t1 ← random_bst {z∈A⇩1. z > y};
t2 ← random_bst A⇩2;
return_pmf (⟨l, y, t1⟩, t2)
}
else if y ∈ A⇩2 then
do {
t1 ← random_bst A⇩1;
t2 ← random_bst {z∈A⇩2. z < y};
r ← random_bst {z∈A⇩2. z > y};
return_pmf (t1, (⟨t2, y, r⟩))
}
else
pair_pmf (random_bst A⇩1) (random_bst A⇩2)
}"
using ‹finite A› ‹A ≠ {}›
by (intro bind_pmf_cong refl if_cong arg_cong[of _ _ random_bst])
(auto simp: A⇩1_def A⇩2_def pair_pmf_def)
also have "… = do {i ← pmf_of_set {..<card A};
if i < card A⇩1 then
do {
y ← pmf_of_set A⇩1;
l ← random_bst {z∈A⇩1. z < y};
t1 ← random_bst {z∈A⇩1. z > y};
t2 ← random_bst A⇩2;
return_pmf (⟨l, y, t1⟩, t2)
}
else if i < card A⇩1 + card A⇩2 then
do {
y ← pmf_of_set A⇩2;
t1 ← random_bst A⇩1;
t2 ← random_bst {z∈A⇩2. z < y};
r ← random_bst {z∈A⇩2. z > y};
return_pmf (t1, (⟨t2, y, r⟩))
}
else do {
y ← pmf_of_set (if x ∈ A then {x} else {});
pair_pmf (random_bst A⇩1) (random_bst A⇩2)
}
}" using ‹finite A› ‹A ≠ {}›
by (intro pmf_of_set_3way_split) (auto simp: A⇩1_def A⇩2_def not_less_iff_gr_or_eq)
also have "… = do {i ← pmf_of_set {..<card A};
if i < card A⇩1 then
pair_pmf (random_bst A⇩1) (random_bst A⇩2)
else if i < card A⇩1 + card A⇩2 then
pair_pmf (random_bst A⇩1) (random_bst A⇩2)
else
pair_pmf (random_bst A⇩1) (random_bst A⇩2)
}"
using ‹finite A› ‹A ≠ {}›
proof (intro bind_pmf_cong refl if_cong, goal_cases)
case (1 i)
hence "A⇩1 ≠ {}" by auto
thus ?case using ‹finite A› by (simp add: pair_pmf_def random_bst_reduce)
next
case (2 i)
hence "A⇩2 ≠ {}" by auto
thus ?case using ‹finite A› by (simp add: pair_pmf_def random_bst_reduce)
qed auto
also have "… = pair_pmf (random_bst A⇩1) (random_bst A⇩2)"
by (simp cong: if_cong)
finally show ?thesis by (simp add: A⇩1_def A⇩2_def)
qed
qed
subsection ‹Joining›
text ‹
The ``join'' operation computes the union of two BSTs ‹l› and ‹r› where all the values in
‹l› are stricly smaller than those in ‹r›.
›
fun mrbst_join :: "'a tree ⇒ 'a tree ⇒ 'a tree pmf" where
"mrbst_join t1 t2 =
(if t1 = ⟨⟩ then return_pmf t2
else if t2 = ⟨⟩ then return_pmf t1
else do {
b ← bernoulli_pmf (size t1 / (size t1 + size t2));
if b then
(case t1 of ⟨l, x, r⟩ ⇒ map_pmf (λr'. ⟨l, x, r'⟩) (mrbst_join r t2))
else
(case t2 of ⟨l, x, r⟩ ⇒ map_pmf (λl'. ⟨l', x, r⟩) (mrbst_join t1 l))
})"
lemma mrbst_join_Leaf_left [simp]: "mrbst_join ⟨⟩ = return_pmf"
by (simp add: fun_eq_iff)
lemma mrbst_join_Leaf_right [simp]: "mrbst_join t ⟨⟩ = return_pmf t"
by (simp add: fun_eq_iff)
lemma mrbst_join_reduce:
"t1 ≠ ⟨⟩ ⟹ t2 ≠ ⟨⟩ ⟹ mrbst_join t1 t2 =
do {
b ← bernoulli_pmf (size t1 / (size t1 + size t2));
if b then
(case t1 of ⟨l, x, r⟩ ⇒ map_pmf (λr'. ⟨l, x, r'⟩) (mrbst_join r t2))
else
(case t2 of ⟨l, x, r⟩ ⇒ map_pmf (λl'. ⟨l', x, r⟩) (mrbst_join t1 l))
}"
by (subst mrbst_join.simps) auto
lemmas [simp del] = mrbst_join.simps
lemma
assumes "t' ∈ set_pmf (mrbst_join t1 t2)" "bst t1" "bst t2"
assumes "⋀x y. x ∈ set_tree t1 ⟹ y ∈ set_tree t2 ⟹ x < y"
shows bst_mrbst_join: "bst t'"
and set_mrbst_join: "set_tree t' = set_tree t1 ∪ set_tree t2"
proof -
have "bst t' ∧ set_tree t' = set_tree t1 ∪ set_tree t2"
using assms
proof (induction "size t1 + size t2" arbitrary: t1 t2 t' rule: less_induct)
case (less t1 t2 t')
show ?case
proof (cases "t1 = ⟨⟩ ∨ t2 = ⟨⟩")
case False
hence "t' ∈ set_pmf (case t1 of ⟨l, x, r⟩ ⇒ map_pmf (Node l x) (mrbst_join r t2)) ∨
t' ∈ set_pmf (case t2 of ⟨l, x, r⟩ ⇒ map_pmf (λl'. ⟨l', x, r⟩) (mrbst_join t1 l))"
using less.prems by (subst (asm) mrbst_join_reduce) (auto split: if_splits)
thus ?thesis
proof
assume "t' ∈ set_pmf (case t1 of ⟨l, x, r⟩ ⇒ map_pmf (Node l x) (mrbst_join r t2))"
then obtain l x r r'
where *: "t1 = ⟨l, x, r⟩" "r' ∈ set_pmf (mrbst_join r t2)" "t' = ⟨l, x, r'⟩"
using False by (auto split: tree.splits)
from * and less.prems have "bst r' ∧ set_tree r' = set_tree r ∪ set_tree t2"
by (intro less) auto
with * and less.prems show ?thesis by auto
next
assume "t' ∈ set_pmf (case t2 of ⟨l, x, r⟩ ⇒ map_pmf (λl'. ⟨l', x, r⟩) (mrbst_join t1 l))"
then obtain l x r l'
where *: "t2 = ⟨l, x, r⟩" "l' ∈ set_pmf (mrbst_join t1 l)" "t' = ⟨l', x, r⟩"
using False by (auto split: tree.splits)
from * and less.prems have "bst l' ∧ set_tree l' = set_tree t1 ∪ set_tree l"
by (intro less) auto
with * and less.prems show ?thesis by auto
qed
qed (insert less.prems, auto)
qed
thus "bst t'" "set_tree t' = set_tree t1 ∪ set_tree t2" by auto
qed
text ‹
Joining two random BSTs that satisfy the necessary preconditions again yields a random BST.
›
theorem mrbst_join_correct:
fixes A B :: "'a :: linorder set"
assumes "finite A" "finite B" "⋀x y. x ∈ A ⟹ y ∈ B ⟹ x < y"
shows "do {t1 ← random_bst A; t2 ← random_bst B; mrbst_join t1 t2} = random_bst (A ∪ B)"
proof -
from assms have "finite (A ∪ B)" by simp
from this and assms show ?thesis
proof (induction "A ∪ B" arbitrary: A B rule: finite_psubset_induct)
case (psubset A B)
define m n where "m = card A" and "n = card B"
define p where "p = m / (m + n)"
include monad_normalisation
show ?case
proof (cases "A = {} ∨ B = {}")
case True
thus ?thesis by auto
next
case False
have AB: "A ≠ {}" "B ≠ {}" "finite A" "finite B"
using False psubset.prems by auto
have p_pos: "A ≠ {}" if "p > 0" using ‹finite A› that
using AB by (auto simp: p_def m_def n_def)
have p_lt1: "B ≠ {}" if "p < 1"
using AB by (auto simp: p_def m_def n_def)
have "do {t1 ← random_bst A; t2 ← random_bst B; mrbst_join t1 t2} =
do {t1 ← random_bst A;
t2 ← random_bst B;
b ← bernoulli_pmf (size t1 / (size t1 + size t2));
if b then
case t1 of ⟨l, x, r⟩ ⇒ map_pmf (λr'. ⟨l, x, r'⟩) (mrbst_join r t2)
else
case t2 of ⟨l, x, r⟩ ⇒ map_pmf (λl'. ⟨l', x, r⟩) (mrbst_join t1 l)
}"
using AB
by (intro bind_pmf_cong refl, subst mrbst_join_reduce) auto
also have "… = do {t1 ← random_bst A;
t2 ← random_bst B;
b ← bernoulli_pmf p;
if b then
case t1 of ⟨l, x, r⟩ ⇒ map_pmf (λr'. ⟨l, x, r'⟩) (mrbst_join r t2)
else
case t2 of ⟨l, x, r⟩ ⇒ map_pmf (λl'. ⟨l', x, r⟩) (mrbst_join t1 l)
}"
using AB by (intro bind_pmf_cong refl arg_cong[of _ _ bernoulli_pmf])
(auto simp: p_def m_def n_def size_random_bst)
also have "… = do {
b ← bernoulli_pmf p;
if b then do {
t1 ← random_bst A;
t2 ← random_bst B;
case t1 of ⟨l, x, r⟩ ⇒ map_pmf (λr'. ⟨l, x, r'⟩) (mrbst_join r t2)
} else do {
t1 ← random_bst A;
t2 ← random_bst B;
case t2 of ⟨l, x, r⟩ ⇒ map_pmf (λl'. ⟨l', x, r⟩) (mrbst_join t1 l)
}
}"
by simp
also have "… = do {
b ← bernoulli_pmf p;
if b then do {
x ← pmf_of_set A;
l ← random_bst {y∈A ∪ B. y < x};
r ← random_bst {y∈A ∪ B. y > x};
return_pmf ⟨l, x, r⟩
} else do {
x ← pmf_of_set B;
l ← random_bst {y∈A ∪ B. y < x};
r ← random_bst {y∈A ∪ B. y > x};
return_pmf ⟨l, x, r⟩
}
}"
proof (intro bind_pmf_cong refl if_cong, goal_cases)
case (1 b)
hence [simp]: "A ≠ {}" using p_pos by auto
have "do {t1 ← random_bst A; t2 ← random_bst B;
case t1 of ⟨l, x, r⟩ ⇒ map_pmf (λr'. ⟨l, x, r'⟩) (mrbst_join r t2)} =
do {
x ← pmf_of_set A;
l ← random_bst {y∈A. y < x};
r ← do {r ← random_bst {y∈A. y > x}; t2 ← random_bst B; mrbst_join r t2};
return_pmf ⟨l, x, r⟩
}"
using AB by (subst random_bst_reduce) (auto simp: map_pmf_def)
also have "… = do {
x ← pmf_of_set A;
l ← random_bst {y∈A. y < x};
r ← random_bst ({y∈A. y > x} ∪ B);
return_pmf ⟨l, x, r⟩
}"
using AB psubset.prems
by (intro bind_pmf_cong refl psubset arg_cong[of _ _ random_bst]) auto
also have "… = do {
x ← pmf_of_set A;
l ← random_bst {y∈A ∪ B. y < x};
r ← random_bst {y∈A ∪ B. y > x};
return_pmf ⟨l, x, r⟩
}"
using AB psubset.prems
by (intro bind_pmf_cong refl arg_cong[of _ _ random_bst]; force)
finally show ?case .
next
case (2 b)
hence [simp]: "B ≠ {}" using p_lt1 by auto
have "do {t1 ← random_bst A; t2 ← random_bst B;
case t2 of ⟨l, x, r⟩ ⇒ map_pmf (λl'. ⟨l', x, r⟩) (mrbst_join t1 l)} =
do {
x ← pmf_of_set B;
l ← do {t1 ← random_bst A; l ← random_bst {y∈B. y < x}; mrbst_join t1 l};
r ← random_bst {y∈B. y > x};
return_pmf ⟨l, x, r⟩
}"
using AB by (subst random_bst_reduce) (auto simp: map_pmf_def)
also have "… = do {
x ← pmf_of_set B;
l ← random_bst (A ∪ {y∈B. y < x});
r ← random_bst {y∈B. y > x};
return_pmf ⟨l, x, r⟩
}"
using AB psubset.prems
by (intro bind_pmf_cong refl psubset arg_cong[of _ _ random_bst]) auto
also have "… = do {
x ← pmf_of_set B;
l ← random_bst {y∈A ∪ B. y < x};
r ← random_bst {y∈A ∪ B. y > x};
return_pmf ⟨l, x, r⟩
}"
using AB psubset.prems
by (intro bind_pmf_cong refl arg_cong[of _ _ random_bst]; force)
finally show ?case .
qed
also have "… = do {
b ← bernoulli_pmf p;
x ← (if b then pmf_of_set A else pmf_of_set B);
l ← random_bst {y∈A ∪ B. y < x};
r ← random_bst {y∈A ∪ B. y > x};
return_pmf ⟨l, x, r⟩
}"
by (intro bind_pmf_cong) simp_all
also have "… = do {
x ← do {b ← bernoulli_pmf p; if b then pmf_of_set A else pmf_of_set B};
l ← random_bst {y∈A ∪ B. y < x};
r ← random_bst {y∈A ∪ B. y > x};
return_pmf ⟨l, x, r⟩
}"
by simp
also have "do {b ← bernoulli_pmf p; if b then pmf_of_set A else pmf_of_set B} =
pmf_of_set (A ∪ B)"
using AB psubset.prems by (intro pmf_of_set_union_split) (auto simp: p_def m_def n_def)
also have "do {
x ← pmf_of_set (A ∪ B);
l ← random_bst {y∈A ∪ B. y < x};
r ← random_bst {y∈A ∪ B. y > x};
return_pmf ⟨l, x, r⟩
} = random_bst (A ∪ B)"
using AB by (intro random_bst_reduce [symmetric]) auto
finally show ?thesis .
qed
qed
qed
subsection ‹Pushdown›
text ‹
The ``push down'' operation ``forgets'' information about the root of a tree in the following
sense: It takes a non-empty tree whose root is some known fixed value and whose children are
random BSTs and shuffles the root in such a way that the resulting tree is a random BST.
›
fun mrbst_push_down :: "'a tree ⇒ 'a ⇒ 'a tree ⇒ 'a tree pmf" where
"mrbst_push_down l x r =
do {
k ← pmf_of_set {0..size l + size r};
if k < size l then
case l of
⟨ll, y, lr⟩ ⇒ map_pmf (λr'. ⟨ll, y, r'⟩) (mrbst_push_down lr x r)
else if k < size l + size r then
case r of
⟨rl, y, rr⟩ ⇒ map_pmf (λl'. ⟨l', y, rr⟩) (mrbst_push_down l x rl)
else
return_pmf ⟨l, x, r⟩
}"
lemmas [simp del] = mrbst_push_down.simps
lemma
assumes "t' ∈ set_pmf (mrbst_push_down t1 x t2)" "bst t1" "bst t2"
assumes "⋀y. y ∈ set_tree t1 ⟹ y < x" "⋀y. y ∈ set_tree t2 ⟹ y > x"
shows bst_mrbst_push_down: "bst t'"
and set_mrbst_push_down: "set_tree t' = {x} ∪ set_tree t1 ∪ set_tree t2"
proof -
have "bst t' ∧ set_tree t' = {x} ∪ set_tree t1 ∪ set_tree t2"
using assms
proof (induction "size t1 + size t2" arbitrary: t1 t2 t' rule: less_induct)
case (less t1 t2 t')
have "t1 ≠ ⟨⟩ ∧ t' ∈ set_pmf (case t1 of ⟨l, y, r⟩ ⇒
map_pmf (Node l y) (mrbst_push_down r x t2)) ∨
t2 ≠ ⟨⟩ ∧ t' ∈ set_pmf (case t2 of ⟨l, y, r⟩ ⇒
map_pmf (λl'. ⟨l', y, r⟩) (mrbst_push_down t1 x l)) ∨
t' = ⟨t1, x, t2⟩"
using less.prems by (subst (asm) mrbst_push_down.simps) (auto split: if_splits)
thus ?case
proof (elim disjE, goal_cases)
case 1
then obtain l y r r'
where *: "t1 = ⟨l, y, r⟩" "r' ∈ set_pmf (mrbst_push_down r x t2)" "t' = ⟨l, y, r'⟩"
by (auto split: tree.splits)
from * and less.prems have "bst r' ∧ set_tree r' = {x} ∪ set_tree r ∪ set_tree t2"
by (intro less) auto
with * and less.prems show ?case by force
next
case 2
then obtain l y r l'
where *: "t2 = ⟨l, y, r⟩" "l' ∈ set_pmf (mrbst_push_down t1 x l)" "t' = ⟨l', y, r⟩"
by (auto split: tree.splits)
from * and less.prems have "bst l' ∧ set_tree l' = {x} ∪ set_tree t1 ∪ set_tree l"
by (intro less) auto
with * and less.prems show ?case by force
qed (insert less.prems, auto)
qed
thus "bst t'" "set_tree t' = {x} ∪ set_tree t1 ∪ set_tree t2" by auto
qed
theorem mrbst_push_down_correct:
fixes A B :: "'a :: linorder set"
assumes "finite A" "finite B" "⋀y. y ∈ A ⟹ y < x" "⋀y. y ∈ B ⟹ x < y"
shows "do {l ← random_bst A; r ← random_bst B; mrbst_push_down l x r} =
random_bst ({x} ∪ A ∪ B)"
proof -
from assms have "finite (A ∪ B)" by simp
from this and assms show ?thesis
proof (induction "A ∪ B" arbitrary: A B rule: finite_psubset_induct)
case (psubset A B)
define m n where "m = card A" and "n = card B"
have A_ne: "A ≠ {}" if "m > 0"
using that by (auto simp: m_def)
have B_ne: "B ≠ {}" if "n > 0"
using that by (auto simp: n_def)
include monad_normalisation
have "do {l ← random_bst A; r ← random_bst B; mrbst_push_down l x r} =
do {l ← random_bst A;
r ← random_bst B;
k ← pmf_of_set {0..m + n};
if k < m then
case l of ⟨ll, y, lr⟩ ⇒ map_pmf (λr'. ⟨ll, y, r'⟩) (mrbst_push_down lr x r)
else if k < m + n then
case r of ⟨rl, y, rr⟩ ⇒ map_pmf (λl'. ⟨l', y, rr⟩) (mrbst_push_down l x rl)
else
return_pmf ⟨l, x, r⟩
}"
using psubset.prems
by (subst mrbst_push_down.simps, intro bind_pmf_cong refl)
(auto simp: size_random_bst m_def n_def)
also have "… = do {k ← pmf_of_set {0..m + n};
if k < m then do {
l ← random_bst A;
r ← random_bst B;
case l of ⟨ll, y, lr⟩ ⇒ map_pmf (λr'. ⟨ll, y, r'⟩) (mrbst_push_down lr x r)
} else if k < m + n then do {
l ← random_bst A;
r ← random_bst B;
case r of ⟨rl, y, rr⟩ ⇒ map_pmf (λl'. ⟨l', y, rr⟩) (mrbst_push_down l x rl)
} else do {
l ← random_bst A;
r ← random_bst B;
return_pmf ⟨l, x, r⟩
}
}"
by (simp cong: if_cong)
also have "… = do {k ← pmf_of_set {0..m + n};
if k < m then do {
y ← pmf_of_set A;
ll ← random_bst {z∈A. z < y};
r' ← do {lr ← random_bst {z∈A. z > y};
r ← random_bst B;
mrbst_push_down lr x r};
return_pmf ⟨ll, y, r'⟩
} else if k < m + n then do {
y ← pmf_of_set B;
l' ← do {l ← random_bst A;
rl ← random_bst {z∈B. z < y};
mrbst_push_down l x rl};
rr ← random_bst {z∈B. z > y};
return_pmf ⟨l', y, rr⟩
} else do {
l ← random_bst A;
r ← random_bst B;
return_pmf ⟨l, x, r⟩
}
}"
proof (intro bind_pmf_cong refl if_cong, goal_cases)
case (1 k)
hence "A ≠ {}" by (auto simp: m_def)
with ‹finite A› show ?case by (simp add: random_bst_reduce map_pmf_def)
next
case (2 k)
hence "B ≠ {}" by (auto simp: m_def n_def)
with ‹finite B› show ?case by (simp add: random_bst_reduce map_pmf_def)
qed
also have "… = do {k ← pmf_of_set {0..m + n};
if k < m then do {
y ← pmf_of_set A;
ll ← random_bst {z∈A. z < y};
r' ← random_bst ({x} ∪ {z∈A. z > y} ∪ B);
return_pmf ⟨ll, y, r'⟩
} else if k < m + n then do {
y ← pmf_of_set B;
l' ← random_bst ({x} ∪ A ∪ {z∈B. z < y});
rr ← random_bst {z∈B. z > y};
return_pmf ⟨l', y, rr⟩
} else do {
l ← random_bst A;
r ← random_bst B;
return_pmf ⟨l, x, r⟩
}
}"
using psubset.prems A_ne B_ne
proof (intro bind_pmf_cong refl if_cong psubset)
fix k y assume "k < m" "y ∈ set_pmf (pmf_of_set A)"
thus "{z∈A. z > y} ∪ B ⊂ A ∪ B"
using psubset.prems A_ne by (fastforce dest!: in_set_pmf_of_setD)
next
fix k y assume "¬k < m" "k < m + n" "y ∈ set_pmf (pmf_of_set B)"
thus "A ∪ {z∈B. z < y} ⊂ A ∪ B"
using psubset.prems B_ne by (fastforce dest!: in_set_pmf_of_setD)
qed auto
also have "… = do {k ← pmf_of_set {0..m + n};
if k < m then do {
y ← pmf_of_set A;
ll ← random_bst {z∈{x} ∪ A ∪ B. z < y};
r' ← random_bst {z∈{x} ∪ A ∪ B. z > y};
return_pmf ⟨ll, y, r'⟩
} else if k < m + n then do {
y ← pmf_of_set B;
l' ← random_bst {z∈{x} ∪ A ∪ B. z < y};
rr ← random_bst {z∈{x} ∪ A ∪ B. z > y};
return_pmf ⟨l', y, rr⟩
} else do {
l ← random_bst {z∈{x} ∪ A ∪ B. z < x};
r ← random_bst {z∈{x} ∪ A ∪ B. z > x};
return_pmf ⟨l, x, r⟩
}
}"
using psubset.prems A_ne B_ne
by (intro bind_pmf_cong if_cong refl arg_cong[of _ _ random_bst];
force dest: psubset.prems(3,4))
also have "… = do {k ← pmf_of_set {0..m + n};
if k < m then do {
y ← pmf_of_set A;
ll ← random_bst {z∈{x} ∪ A ∪ B. z < y};
r' ← random_bst {z∈{x} ∪ A ∪ B. z > y};
return_pmf ⟨ll, y, r'⟩
} else if k < m + n then do {
y ← pmf_of_set B;
l' ← random_bst {z∈{x} ∪ A ∪ B. z < y};
rr ← random_bst {z∈{x} ∪ A ∪ B. z > y};
return_pmf ⟨l', y, rr⟩
} else do {
y ← pmf_of_set {x};
l ← random_bst {z∈{x} ∪ A ∪ B. z < y};
r ← random_bst {z∈{x} ∪ A ∪ B. z > y};
return_pmf ⟨l, x, r⟩
}
}" (is "_ = ?X {0..m+n}")
by (simp add: pmf_of_set_singleton cong: if_cong)
also have "{0..m + n} = {..<card (A ∪ B ∪ {x})}" using psubset.prems
by (subst card_Un_disjoint, simp, simp, force)+
(auto simp: m_def n_def)
also have "?X … = do {y ← pmf_of_set ({x} ∪ A ∪ B);
l ← random_bst {z∈{x} ∪ A ∪ B. z < y};
r ← random_bst {z∈{x} ∪ A ∪ B. z > y};
return_pmf ⟨l, y, r⟩}"
unfolding m_def n_def using psubset.prems
by (subst pmf_of_set_3way_split [symmetric])
(auto dest!: psubset.prems(3,4) cong: if_cong intro: bind_pmf_cong)
also have "… = random_bst ({x} ∪ A ∪ B)"
using psubset.prems by (simp add: random_bst_reduce)
finally show ?case .
qed
qed
lemma mrbst_push_down_correct':
assumes "finite (A :: 'a :: linorder set)" "x ∈ A"
shows "do {l ← random_bst {y∈A. y < x}; r ← random_bst {y∈A. y > x}; mrbst_push_down l x r} =
random_bst A" (is "?lhs = ?rhs")
proof -
have "?lhs = random_bst ({x} ∪ {y∈A. y < x} ∪ {y∈A. y > x})"
using assms by (intro mrbst_push_down_correct) auto
also have "{x} ∪ {y∈A. y < x} ∪ {y∈A. y > x} = A"
using assms by auto
finally show ?thesis .
qed
subsection ‹Intersection and Difference›
text ‹
The algorithms for intersection and difference of two trees are almost identical; the only
difference is that the ``if'' statement at the end of the recursive case is flipped. We
therefore introduce a generic intersection/difference operation first and prove its correctness
to avoid duplication.
›
fun mrbst_inter_diff where
"mrbst_inter_diff _ ⟨⟩ _ = return_pmf ⟨⟩"
| "mrbst_inter_diff b ⟨l1, x, r1⟩ t2 =
(case split_bst' x t2 of (sep, l2, r2) ⇒
do {
l ← mrbst_inter_diff b l1 l2;
r ← mrbst_inter_diff b r1 r2;
if sep = b then return_pmf ⟨l, x, r⟩ else mrbst_join l r
})"
lemma mrbst_inter_diff_reduce:
"mrbst_inter_diff b ⟨l1, x, r1⟩ =
(λt2. case split_bst' x t2 of (sep, l2, r2) ⇒
do {
l ← mrbst_inter_diff b l1 l2;
r ← mrbst_inter_diff b r1 r2;
if sep = b then return_pmf ⟨l, x, r⟩ else mrbst_join l r
})"
by (rule ext) simp
lemma mrbst_inter_diff_Leaf_left [simp]:
"mrbst_inter_diff b ⟨⟩ = (λ_. return_pmf ⟨⟩)"
by (simp add: fun_eq_iff)
lemma mrbst_inter_diff_Leaf_right [simp]:
"mrbst_inter_diff b (t1 :: 'a :: linorder tree) ⟨⟩ = return_pmf (if b then ⟨⟩ else t1)"
by (induction t1) (auto simp: bind_return_pmf)
lemma
fixes t1 t2 :: "'a :: linorder tree" and b :: bool
defines "setop ≡ (if b then (∩) else (-) :: 'a set ⇒ _)"
assumes "t' ∈ set_pmf (mrbst_inter_diff b t1 t2)" "bst t1" "bst t2"
shows bst_mrbst_inter_diff: "bst t'"
and set_mrbst_inter_diff: "set_tree t' = setop (set_tree t1) (set_tree t2)"
proof -
write setop (infixl ‹⋄› 80)
have "bst t' ∧ set_tree t' = set_tree t1 ⋄ set_tree t2"
using assms(2-)
proof (induction t1 arbitrary: t2 t')
case (Node l1 x r1 t2)
note bst = ‹bst ⟨l1, x, r1⟩› ‹bst t2›
define l2 r2 where "l2 = fst (split_bst x t2)" and "r2 = snd (split_bst x t2)"
obtain l r
where lr: "l ∈ set_pmf (mrbst_inter_diff b l1 l2)" "r ∈ set_pmf (mrbst_inter_diff b r1 r2)"
and t': "t' ∈ (if x ∈ set_tree t2 ⟷ b then {⟨l, x, r⟩} else set_pmf (mrbst_join l r))"
using Node.prems by (force simp: case_prod_unfold l2_def r2_def isin_bst split: if_splits)
from lr have lr': "bst l ∧ set_tree l = set_tree l1 ⋄ set_tree l2"
"bst r ∧ set_tree r = set_tree r1 ⋄ set_tree r2"
using Node.prems by (intro Node.IH; force simp: l2_def r2_def)+
have "set_tree t' = set_tree l ∪ set_tree r ∪ (if x ∈ set_tree t2 ⟷ b then {x} else {})"
proof (cases "x ∈ set_tree t2 ⟷ b")
case False
have "x < y" if "x ∈ set_tree l" "y ∈ set_tree r" for x y
using that lr' bst by (force simp: setop_def split: if_splits)
hence set_t': "set_tree t' = set_tree l ∪ set_tree r"
using t' set_mrbst_join[of t' l r] False lr' by auto
with False show ?thesis by simp
qed (use t' in auto)
also have "… = set_tree ⟨l1, x, r1⟩ ⋄ set_tree t2"
using lr' bst by (auto simp: setop_def l2_def r2_def set_split_bst1 set_split_bst2)
finally have "set_tree t' = set_tree ⟨l1, x, r1⟩ ⋄ set_tree t2" .
moreover from lr' t' bst have "bst t'"
by (force split: if_splits simp: setop_def intro!: bst_mrbst_join[of t' l r])
ultimately show ?case by auto
qed (auto simp: setop_def)
thus "bst t'" and "set_tree t' = set_tree t1 ⋄ set_tree t2" by auto
qed
theorem mrbst_inter_diff_correct:
fixes A B :: "'a :: linorder set" and b :: bool
defines "setop ≡ (if b then (∩) else (-) :: 'a set ⇒ _)"
assumes "finite A" "finite B"
shows "do {t1 ← random_bst A; t2 ← random_bst B; mrbst_inter_diff b t1 t2} =
random_bst (setop A B)"
using assms(2-)
proof (induction A arbitrary: B rule: finite_psubset_induct)
case (psubset A B)
write setop (infixl ‹⋄› 80)
include monad_normalisation
show ?case
proof (cases "A = {}")
case True
thus ?thesis by (auto simp: setop_def)
next
case False
define R1 R2 where "R1 = (λx. random_bst {y∈A. y < x})" "R2 = (λx. random_bst {y∈A. y > x})"
have A_eq: "A = (A ∩ B) ∪ (A - B)" by auto
have card_A_eq: "card A = card (A ∩ B) + card (A - B)"
using ‹finite A› ‹finite B› by (subst A_eq, subst card_Un_disjoint) auto
have eq: "pmf_of_set A =
do {b ← bernoulli_pmf (card (A ∩ B) / card A);
if b then pmf_of_set (A ∩ B) else pmf_of_set (A - B)}"
using psubset.prems False ‹finite A› A_eq card_A_eq
by (subst A_eq, intro pmf_of_set_union_split [symmetric]) auto
have "card A > 0"
using ‹finite A› ‹A ≠ {}› by (subst card_gt_0_iff) auto
have not_subset: "¬A ⊆ B" if "card (A ∩ B) < card A"
proof
assume "A ⊆ B"
hence "A ∩ B = A" by auto
with that show False by simp
qed
have "do {t1 ← random_bst A; t2 ← random_bst B; mrbst_inter_diff b t1 t2} =
do {
x ← pmf_of_set A;
l1 ← random_bst {y∈A. y < x};
r1 ← random_bst {y∈A. y > x};
t2 ← random_bst B;
let (l2, r2) = split_bst x t2;
l ← mrbst_inter_diff b l1 l2;
r ← mrbst_inter_diff b r1 r2;
if isin t2 x = b then return_pmf ⟨l, x, r⟩ else mrbst_join l r
}"
using ‹finite A› ‹A ≠ {}›
by (subst random_bst_reduce)
(auto simp: mrbst_inter_diff_reduce map_pmf_def split_bst'_altdef)
also have "… = do {
x ← pmf_of_set A;
l1 ← random_bst {y∈A. y < x};
r1 ← random_bst {y∈A. y > x};
t2 ← random_bst B;
let (l2, r2) = split_bst x t2;
l ← mrbst_inter_diff b l1 l2;
r ← mrbst_inter_diff b r1 r2;
if x ∈ B = b then return_pmf ⟨l, x, r⟩ else mrbst_join l r
}"
unfolding Let_def case_prod_unfold using ‹finite B›
by (intro bind_pmf_cong refl) (auto simp: isin_random_bst)
also have "… = do {
x ← pmf_of_set A;
l1 ← random_bst {y∈A. y < x};
r1 ← random_bst {y∈A. y > x};
(l2, r2) ← map_pmf (split_bst x) (random_bst B);
l ← mrbst_inter_diff b l1 l2;
r ← mrbst_inter_diff b r1 r2;
if x ∈ B = b then return_pmf ⟨l, x, r⟩ else mrbst_join l r
}"
by (simp add: Let_def map_pmf_def)
also have "… = do {
x ← pmf_of_set A;
l1 ← random_bst {y∈A. y < x};
r1 ← random_bst {y∈A. y > x};
(l2, r2) ← pair_pmf (random_bst {y∈B. y < x}) (random_bst {y∈B. y > x});
l ← mrbst_inter_diff b l1 l2;
r ← mrbst_inter_diff b r1 r2;
if x ∈ B = b then return_pmf ⟨l, x, r⟩ else mrbst_join l r
}"
by (intro bind_pmf_cong refl split_random_bst ‹finite B›)
also have "… = do {
x ← pmf_of_set A;
l1 ← R1 x;
r1 ← R2 x;
l2 ← random_bst {y∈B. y < x};
r2 ← random_bst {y∈B. y > x};
l ← mrbst_inter_diff b l1 l2;
r ← mrbst_inter_diff b r1 r2;
if x ∈ B = b then return_pmf ⟨l, x, r⟩ else mrbst_join l r
}"
unfolding pair_pmf_def bind_assoc_pmf R1_R2_def by simp
also have "… = do {
x ← pmf_of_set A;
l ← do {l1 ← R1 x; l2 ← random_bst {y∈B. y < x}; mrbst_inter_diff b l1 l2};
r ← do {r1 ← R2 x; r2 ← random_bst {y∈B. y > x}; mrbst_inter_diff b r1 r2};
if x ∈ B = b then return_pmf ⟨l, x, r⟩ else mrbst_join l r
}"
unfolding bind_assoc_pmf by (intro bind_pmf_cong[OF refl]) simp
also have "… = do {
x ← pmf_of_set A;
l ← random_bst ({y∈A. y < x} ⋄ {y∈B. y < x});
r ← random_bst ({y∈A. y > x} ⋄ {y∈B. y > x});
if x ∈ B = b then return_pmf ⟨l, x, r⟩ else mrbst_join l r
}"
using ‹finite A› ‹finite B› ‹A ≠ {}› unfolding R1_R2_def
by (intro bind_pmf_cong refl psubset.IH) auto
also have "… = do {
x ← pmf_of_set A;
if x ∈ B = b then do {
l ← random_bst ({y∈A. y < x} ⋄ {y∈B. y < x});
r ← random_bst ({y∈A. y > x} ⋄ {y∈B. y > x});
return_pmf ⟨l, x, r⟩
} else do {
l ← random_bst ({y∈A. y < x} ⋄ {y∈B. y < x});
r ← random_bst ({y∈A. y > x} ⋄ {y∈B. y > x});
mrbst_join l r
}
}"
by simp
also have "… = do {
x ← pmf_of_set A;
if x ∈ B = b then do {
l ← random_bst ({y∈A. y < x} ⋄ {y∈B. y < x});
r ← random_bst ({y∈A. y > x} ⋄ {y∈B. y > x});
return_pmf ⟨l, x, r⟩
} else do {
random_bst ({y∈A. y < x} ⋄ {y∈B. y < x} ∪ {y∈A. y > x} ⋄ {y∈B. y > x})
}
}"
using ‹finite A› ‹finite B›
by (intro bind_pmf_cong refl mrbst_join_correct if_cong) (auto simp: setop_def)
also have "… = do {
x ← pmf_of_set A;
if x ∈ B = b then do {
l ← random_bst ({y∈A ⋄ B. y < x});
r ← random_bst ({y∈A ⋄ B. y > x});
return_pmf ⟨l, x, r⟩
} else do {
random_bst (A ⋄ B)
}
}" (is "_ = pmf_of_set A ⤜ ?f")
using ‹finite A› ‹A ≠ {}›
by (intro bind_pmf_cong refl if_cong arg_cong[of _ _ random_bst])
(auto simp: order.strict_iff_order setop_def)
also have "… = do {
b' ← bernoulli_pmf (card (A ∩ B) / card A);
x ← (if b' then pmf_of_set (A ∩ B) else pmf_of_set (A - B));
if b' = b then do {
l ← random_bst ({y∈A ⋄ B. y < x});
r ← random_bst ({y∈A ⋄ B. y > x});
return_pmf ⟨l, x, r⟩
} else do {
random_bst (A ⋄ B)
}
}"
unfolding bind_assoc_pmf eq using ‹card A > 0› ‹finite A› ‹finite B› not_subset
by (intro bind_pmf_cong refl if_cong)
(auto intro: bind_pmf_cong split: if_splits simp: divide_simps card_gt_0_iff
dest!: in_set_pmf_of_setD)
also have "… = do {
b' ← bernoulli_pmf (card (A ∩ B) / card A);
if b' = b then do {
x ← pmf_of_set (A ⋄ B);
l ← random_bst ({y∈A ⋄ B. y < x});
r ← random_bst ({y∈A ⋄ B. y > x});
return_pmf ⟨l, x, r⟩
} else do {
random_bst (A ⋄ B)
}
}"
by (intro bind_pmf_cong) (auto simp: setop_def)
also have "… = do {
b' ← bernoulli_pmf (card (A ∩ B) / card A);
if b' = b then do {
random_bst (A ⋄ B)
} else do {
random_bst (A ⋄ B)
}
}"
using ‹finite A› ‹finite B› ‹A ≠ {}› not_subset ‹card A > 0›
by (intro bind_pmf_cong refl if_cong random_bst_reduce [symmetric])
(auto simp: setop_def field_simps)
also have "… = random_bst (A ⋄ B)" by simp
finally show ?thesis .
qed
qed
text ‹
We now derive the intersection and difference from the generic operation:
›
fun mrbst_inter where
"mrbst_inter ⟨⟩ _ = return_pmf ⟨⟩"
| "mrbst_inter ⟨l1, x, r1⟩ t2 =
(case split_bst' x t2 of (sep, l2, r2) ⇒
do {
l ← mrbst_inter l1 l2;
r ← mrbst_inter r1 r2;
if sep then return_pmf ⟨l, x, r⟩ else mrbst_join l r
})"
lemma mrbst_inter_Leaf_left [simp]:
"mrbst_inter ⟨⟩ = (λ_. return_pmf ⟨⟩)"
by (simp add: fun_eq_iff)
lemma mrbst_inter_Leaf_right [simp]:
"mrbst_inter (t1 :: 'a :: linorder tree) ⟨⟩ = return_pmf ⟨⟩"
by (induction t1) (auto simp: bind_return_pmf)
lemma mrbst_inter_reduce:
"mrbst_inter ⟨l1, x, r1⟩ =
(λt2. case split_bst' x t2 of (sep, l2, r2) ⇒
do {
l ← mrbst_inter l1 l2;
r ← mrbst_inter r1 r2;
if sep then return_pmf ⟨l, x, r⟩ else mrbst_join l r
})"
by (rule ext) simp
lemma mrbst_inter_altdef: "mrbst_inter = mrbst_inter_diff True"
proof (intro ext)
fix t1 t2 :: "'a tree"
show "mrbst_inter t1 t2 = mrbst_inter_diff True t1 t2"
by (induction t1 arbitrary: t2) auto
qed
corollary
fixes t1 t2 :: "'a :: linorder tree"
assumes "t' ∈ set_pmf (mrbst_inter t1 t2)" "bst t1" "bst t2"
shows bst_mrbst_inter: "bst t'"
and set_mrbst_inter: "set_tree t' = set_tree t1 ∩ set_tree t2"
using bst_mrbst_inter_diff[of t' True t1 t2] set_mrbst_inter_diff[of t' True t1 t2] assms
by (simp_all add: mrbst_inter_altdef)
corollary mrbst_inter_correct:
fixes A B :: "'a :: linorder set"
assumes "finite A" "finite B"
shows "do {t1 ← random_bst A; t2 ← random_bst B; mrbst_inter t1 t2} = random_bst (A ∩ B)"
using assms unfolding mrbst_inter_altdef by (subst mrbst_inter_diff_correct) simp_all
fun mrbst_diff where
"mrbst_diff ⟨⟩ _ = return_pmf ⟨⟩"
| "mrbst_diff ⟨l1, x, r1⟩ t2 =
(case split_bst' x t2 of (sep, l2, r2) ⇒
do {
l ← mrbst_diff l1 l2;
r ← mrbst_diff r1 r2;
if sep then mrbst_join l r else return_pmf ⟨l, x, r⟩
})"
lemma mrbst_diff_Leaf_left [simp]:
"mrbst_diff ⟨⟩ = (λ_. return_pmf ⟨⟩)"
by (simp add: fun_eq_iff)
lemma mrbst_diff_Leaf_right [simp]:
"mrbst_diff (t1 :: 'a :: linorder tree) ⟨⟩ = return_pmf t1"
by (induction t1) (auto simp: bind_return_pmf)
lemma mrbst_diff_reduce:
"mrbst_diff ⟨l1, x, r1⟩ =
(λt2. case split_bst' x t2 of (sep, l2, r2) ⇒
do {
l ← mrbst_diff l1 l2;
r ← mrbst_diff r1 r2;
if sep then mrbst_join l r else return_pmf ⟨l, x, r⟩
})"
by (rule ext) simp
lemma If_not: "(if ¬b then x else y) = (if b then y else x)"
by auto
lemma mrbst_diff_altdef: "mrbst_diff = mrbst_inter_diff False"
proof (intro ext)
fix t1 t2 :: "'a tree"
show "mrbst_diff t1 t2 = mrbst_inter_diff False t1 t2"
by (induction t1 arbitrary: t2) (auto simp: If_not)
qed
corollary
fixes t1 t2 :: "'a :: linorder tree"
assumes "t' ∈ set_pmf (mrbst_diff t1 t2)" "bst t1" "bst t2"
shows bst_mrbst_diff: "bst t'"
and set_mrbst_diff: "set_tree t' = set_tree t1 - set_tree t2"
using bst_mrbst_inter_diff[of t' False t1 t2] set_mrbst_inter_diff[of t' False t1 t2] assms
by (simp_all add: mrbst_diff_altdef)
corollary mrbst_diff_correct:
fixes A B :: "'a :: linorder set"
assumes "finite A" "finite B"
shows "do {t1 ← random_bst A; t2 ← random_bst B; mrbst_diff t1 t2} = random_bst (A - B)"
using assms unfolding mrbst_diff_altdef by (subst mrbst_inter_diff_correct) simp_all
subsection ‹Union›
text ‹
The algorithm for the union of two trees is by far the most complicated one. It involves a
›
context
notes
case_prod_unfold [termination_simp]
if_splits [split]
begin
fun mrbst_union where
"mrbst_union ⟨⟩ t2 = return_pmf t2"
| "mrbst_union t1 ⟨⟩ = return_pmf t1"
| "mrbst_union ⟨l1, x, r1⟩ ⟨l2, y, r2⟩ =
do {
let m = size ⟨l1, x, r1⟩; let n = size ⟨l2, y, r2⟩;
b ← bernoulli_pmf (m / (m + n));
if b then do {
let (l2', r2') = split_bst x ⟨l2, y, r2⟩;
l ← mrbst_union l1 l2';
r ← mrbst_union r1 r2';
return_pmf ⟨l, x, r⟩
} else do {
let (sep, l1', r1') = split_bst' y ⟨l1, x, r1⟩;
l ← mrbst_union l1' l2;
r ← mrbst_union r1' r2;
if sep then
mrbst_push_down l y r
else
return_pmf ⟨l, y, r⟩
}
}"
end
lemma mrbst_union_Leaf_left [simp]: "mrbst_union ⟨⟩ = return_pmf"
by (rule ext) simp
lemma mrbst_union_Leaf_right [simp]: "mrbst_union t1 ⟨⟩ = return_pmf t1"
by (cases t1) simp_all
lemma
fixes t1 t2 :: "'a :: linorder tree" and b :: bool
assumes "t' ∈ set_pmf (mrbst_union t1 t2)" "bst t1" "bst t2"
shows bst_mrbst_union: "bst t'"
and set_mrbst_union: "set_tree t' = set_tree t1 ∪ set_tree t2"
proof -
have "bst t' ∧ set_tree t' = set_tree t1 ∪ set_tree t2"
using assms
proof (induction "size t1 + size t2" arbitrary: t1 t2 t' rule: less_induct)
case (less t1 t2 t')
show ?case
proof (cases "t1 = ⟨⟩ ∨ t2 = ⟨⟩")
case False
then obtain l1 x r1 l2 y r2 where t1: "t1 = ⟨l1, x, r1⟩" and t2: "t2 = ⟨l2, y, r2⟩"
by (cases t1; cases t2) auto
from less.prems consider l r where
"l ∈ set_pmf (mrbst_union l1 (fst (split_bst x t2)))"
"r ∈ set_pmf (mrbst_union r1 (snd (split_bst x t2)))"
"t' = ⟨l, x, r⟩"
| l r where
"l ∈ set_pmf (mrbst_union (fst (split_bst y t1)) l2)"
"r ∈ set_pmf (mrbst_union (snd (split_bst y t1)) r2)"
"t' ∈ (if isin ⟨l1, x, r1⟩ y then set_pmf (mrbst_push_down l y r) else {⟨l, y, r⟩})"
by (auto simp: case_prod_unfold t1 t2 Let_def
simp del: split_bst.simps split_bst'.simps isin.simps split: if_splits)
thus ?thesis
proof cases
case 1
hence lr: "bst l ∧ set_tree l = set_tree l1 ∪ set_tree (fst (split_bst x t2))"
"bst r ∧ set_tree r = set_tree r1 ∪ set_tree (snd (split_bst x t2))"
using less.prems size_split_bst[of x t2]
by (intro less; force simp: t1)+
thus ?thesis
using 1 less.prems by (auto simp: t1 set_split_bst1 set_split_bst2)
next
case 2
hence lr: "bst l ∧ set_tree l = set_tree (fst (split_bst y t1)) ∪ set_tree l2"
"bst r ∧ set_tree r = set_tree (snd (split_bst y t1)) ∪ set_tree r2"
using less.prems size_split_bst[of y t1]
by (intro less; force simp: t2)+
show ?thesis
proof (cases "isin ⟨l1, x, r1⟩ y")
case False
thus ?thesis using 2 less.prems lr
by (auto simp del: isin.simps simp: t2 set_split_bst1 set_split_bst2)
next
case True
have bst': "∀z∈set_tree l. z < y" "∀z∈set_tree r. z > y"
using lr less.prems by (auto simp: set_split_bst1 set_split_bst2 t2)
from True and 2 have t': "t' ∈ set_pmf (mrbst_push_down l y r)"
by (auto simp del: isin.simps)
from t' have "bst t'"
by (rule bst_mrbst_push_down) (use lr bst' in auto)
moreover from t' have "set_tree t' = {y} ∪ set_tree l ∪ set_tree r"
by (rule set_mrbst_push_down) (use lr bst' in auto)
ultimately show ?thesis using less.prems lr
by (auto simp del: isin.simps simp: t2 set_split_bst1 set_split_bst2)
qed
qed
qed (use less.prems in auto)
qed
thus "bst t'" and "set_tree t' = set_tree t1 ∪ set_tree t2" by auto
qed
theorem mrbst_union_correct:
assumes "finite A" "finite B"
shows "do {t1 ← random_bst A; t2 ← random_bst B; mrbst_union t1 t2} =
random_bst (A ∪ B)"
proof -
from assms have "finite (A ∪ B)" by simp
thus ?thesis
proof (induction "A ∪ B" arbitrary: A B rule: finite_psubset_induct)
case (psubset A B)
show ?case
proof (cases "A = {} ∨ B = {}")
case True
thus ?thesis including monad_normalisation by auto
next
case False
with psubset.hyps have AB: "finite A" "finite B" "A ≠ {}" "B ≠ {}" by auto
define m n l where "m = card A" and "n = card B" and "l = card (A ∩ B)"
define p q where "p = m / (m + n)" and "q = l / n"
define r where "r = p / (1 - (1 - p) * q)"
from AB have mn: "m > 0" "n > 0" by (auto simp: m_def n_def)
have pq: "p ∈ {0..1}" "q ∈ {0..1}"
using AB by (auto simp: p_def q_def m_def n_def l_def divide_simps intro: card_mono)
moreover have "p ≠ 0"
using AB by (auto simp: p_def m_def n_def divide_simps add_nonneg_eq_0_iff)
ultimately have "p > 0" by auto
have "B - A = B - (A ∩ B)" by auto
also have "card … = n - l"
using AB unfolding n_def l_def by (intro card_Diff_subset) auto
finally have [simp]: "card (B - A) = n - l" .
from AB have "l ≤ n" unfolding l_def n_def by (intro card_mono) auto
have "p ≤ 1 - (1 - p) * q"
using mn ‹l ≤ n› by (auto simp: p_def q_def divide_simps)
hence r_aux: "(1 - p) * q ∈ {0..1 - p}"
using pq by auto
include monad_normalisation
define RA1 RA2 RB1 RB2
where "RA1 = (λx. random_bst {z∈A. z < x})" and "RA2 = (λx. random_bst {z∈A. z > x})"
and "RB1 = (λx. random_bst {z∈B. z < x})" and "RB2 = (λx. random_bst {z∈B. z > x})"
have "do {t1 ← random_bst A; t2 ← random_bst B; mrbst_union t1 t2} =
do {
x ← pmf_of_set A;
l1 ← random_bst {z∈A. z < x};
r1 ← random_bst {z∈A. z > x};
y ← pmf_of_set B;
l2 ← random_bst {z∈B. z < y};
r2 ← random_bst {z∈B. z > y};
let m = size ⟨l1, x, r1⟩;
let n = size ⟨l2, y, r2⟩;
b ← bernoulli_pmf (m / (m + n));
if b then do {
l ← mrbst_union l1 (fst (split_bst x ⟨l2, y, r2⟩));
r ← mrbst_union r1 (snd (split_bst x ⟨l2, y, r2⟩));
return_pmf ⟨l, x, r⟩
} else do {
l ← mrbst_union (fst (split_bst y ⟨l1, x, r1⟩)) l2;
r ← mrbst_union (snd (split_bst y ⟨l1, x, r1⟩)) r2;
if isin ⟨l1, x, r1⟩ y then
mrbst_push_down l y r
else
return_pmf ⟨l, y, r⟩
}
}" using AB
by (simp add: random_bst_reduce split_bst'_altdef Let_def case_prod_unfold cong: if_cong)
also have "… = do {
x ← pmf_of_set A;
l1 ← random_bst {z∈A. z < x};
r1 ← random_bst {z∈A. z > x};
y ← pmf_of_set B;
l2 ← random_bst {z∈B. z < y};
r2 ← random_bst {z∈B. z > y};
b ← bernoulli_pmf p;
if b then do {
l ← mrbst_union l1 (fst (split_bst x ⟨l2, y, r2⟩));
r ← mrbst_union r1 (snd (split_bst x ⟨l2, y, r2⟩));
return_pmf ⟨l, x, r⟩
} else do {
l ← mrbst_union (fst (split_bst y ⟨l1, x, r1⟩)) l2;
r ← mrbst_union (snd (split_bst y ⟨l1, x, r1⟩)) r2;
if y ∈ A then
mrbst_push_down l y r
else
return_pmf ⟨l, y, r⟩
}
}"
unfolding Let_def
proof (intro bind_pmf_cong refl if_cong)
fix l1 x r1 y
assume "l1 ∈ set_pmf (random_bst {z∈A. z < x})" "r1 ∈ set_pmf (random_bst {z∈A. z > x})"
"x ∈ set_pmf (pmf_of_set A)"
thus "isin ⟨l1, x, r1⟩ y ⟷ (y ∈ A)"
using AB by (subst isin_bst) (auto simp: bst_random_bst set_random_bst)
qed (insert AB,
auto simp: size_random_bst m_def n_def p_def isin_random_bst dest!: card_3way_split)
also have "… = do {
b ← bernoulli_pmf p;
if b then do {
x ← pmf_of_set A;
(l1, r1) ← pair_pmf (random_bst {z∈A. z < x}) (random_bst {z∈A. z > x});
(l2, r2) ← map_pmf (split_bst x) (random_bst B);
l ← mrbst_union l1 l2;
r ← mrbst_union r1 r2;
return_pmf ⟨l, x, r⟩
} else do {
y ← pmf_of_set B;
(l1, r1) ← map_pmf (split_bst y) (random_bst A);
(l2, r2) ← pair_pmf (random_bst {z∈B. z < y}) (random_bst {z∈B. z > y});
l ← mrbst_union l1 l2;
r ← mrbst_union r1 r2;
if y ∈ A then
mrbst_push_down l y r
else
return_pmf ⟨l, y, r⟩
}
}" using AB
by (simp add: random_bst_reduce map_pmf_def case_prod_unfold pair_pmf_def cong: if_cong)
also have "… = do {
b ← bernoulli_pmf p;
if b then do {
x ← pmf_of_set A;
(l1, r1) ← pair_pmf (RA1 x) (RA2 x);
(l2, r2) ← pair_pmf (RB1 x) (RB2 x);
l ← mrbst_union l1 l2;
r ← mrbst_union r1 r2;
return_pmf ⟨l, x, r⟩
} else do {
y ← pmf_of_set B;
(l1, r1) ← pair_pmf (RA1 y) (RA2 y);
(l2, r2) ← pair_pmf (RB1 y) (RB2 y);
l ← mrbst_union l1 l2;
r ← mrbst_union r1 r2;
if y ∈ A then
mrbst_push_down l y r
else
return_pmf ⟨l, y, r⟩
}
}"
unfolding case_prod_unfold RA1_def RA2_def RB1_def RB2_def
by (intro bind_pmf_cong refl if_cong split_random_bst AB)
also have "… = do {
b ← bernoulli_pmf p;
if b then do {
x ← pmf_of_set A;
l ← do {l1 ← RA1 x; l2 ← RB1 x; mrbst_union l1 l2};
r ← do {r1 ← RA2 x; r2 ← RB2 x; mrbst_union r1 r2};
return_pmf ⟨l, x, r⟩
} else do {
y ← pmf_of_set B;
l ← do {l1 ← RA1 y; l2 ← RB1 y; mrbst_union l1 l2};
r ← do {r1 ← RA2 y; r2 ← RB2 y; mrbst_union r1 r2};
if y ∈ A then
mrbst_push_down l y r
else
return_pmf ⟨l, y, r⟩
}
}"
by (simp add: pair_pmf_def cong: if_cong)
also have "… = do {
b ← bernoulli_pmf p;
if b then do {
x ← pmf_of_set A;
l ← random_bst ({z∈A. z < x} ∪ {z∈B. z < x});
r ← random_bst ({z∈A. z > x} ∪ {z∈B. z > x});
return_pmf ⟨l, x, r⟩
} else do {
y ← pmf_of_set B;
l ← random_bst ({z∈A. z < y} ∪ {z∈B. z < y});
r ← random_bst ({z∈A. z > y} ∪ {z∈B. z > y});
if y ∈ A then
mrbst_push_down l y r
else
return_pmf ⟨l, y, r⟩
}
}"
unfolding RA1_def RA2_def RB1_def RB2_def using AB
by (intro bind_pmf_cong if_cong refl psubset) auto
also have "… = do {
b ← bernoulli_pmf p;
if b then do {
x ← pmf_of_set A;
l ← random_bst {z∈A ∪ B. z < x};
r ← random_bst {z∈A ∪ B. z > x};
return_pmf ⟨l, x, r⟩
} else do {
y ← pmf_of_set B;
l ← random_bst {z∈A ∪ B. z < y};
r ← random_bst {z∈A ∪ B. z > y};
if y ∈ A then
mrbst_push_down l y r
else
return_pmf ⟨l, y, r⟩
}
}"
by (intro bind_pmf_cong if_cong refl arg_cong[of _ _ random_bst]) auto
also have "… = do {
b ← bernoulli_pmf p;
if b then do {
x ← pmf_of_set A;
l ← random_bst {z∈A ∪ B. z < x};
r ← random_bst {z∈A ∪ B. z > x};
return_pmf ⟨l, x, r⟩
} else do {
b' ← bernoulli_pmf q;
if b' then do {
y ← pmf_of_set (A ∩ B);
random_bst (A ∪ B)
} else do {
y ← pmf_of_set (B - A);
l ← random_bst {z∈A ∪ B. z < y};
r ← random_bst {z∈A ∪ B. z > y};
return_pmf ⟨l, y, r⟩
}
}
}"
proof (intro bind_pmf_cong refl if_cong, goal_cases)
case (1 b)
have q_pos: "A ∩ B ≠ {}" if "q > 0" using that by (auto simp: q_def l_def)
have q_lt1: "B - A ≠ {}" if "q < 1"
proof
assume "B - A = {}"
hence "A ∩ B = B" by auto
thus False using that AB by (auto simp: q_def l_def n_def)
qed
have eq: "pmf_of_set B = do {b' ← bernoulli_pmf q;
if b' then pmf_of_set (A ∩ B) else pmf_of_set (B - A)}"
using AB by (intro pmf_of_set_split_inter_diff [symmetric])
(auto simp: q_def l_def n_def)
have "do {y ← pmf_of_set B;
l ← random_bst {z∈A ∪ B. z < y};
r ← random_bst {z∈A ∪ B. z > y};
if y ∈ A then
mrbst_push_down l y r
else
return_pmf ⟨l, y, r⟩
} =
do {
b' ← bernoulli_pmf q;
y ← (if b' then pmf_of_set (A ∩ B) else pmf_of_set (B - A));
l ← random_bst {z∈A ∪ B. z < y};
r ← random_bst {z∈A ∪ B. z > y};
if b' then
mrbst_push_down l y r
else
return_pmf ⟨l, y, r⟩
}" unfolding eq bind_assoc_pmf using AB q_pos q_lt1
by (intro bind_pmf_cong refl if_cong) (auto split: if_splits)
also have "… = do {
b' ← bernoulli_pmf q;
if b' then do {
y ← pmf_of_set (A ∩ B);
do {l ← random_bst {z∈A ∪ B. z < y};
r ← random_bst {z∈A ∪ B. z > y};
mrbst_push_down l y r}
} else do {
y ← pmf_of_set (B - A);
l ← random_bst {z∈A ∪ B. z < y};
r ← random_bst {z∈A ∪ B. z > y};
return_pmf ⟨l, y, r⟩
}
}" by (simp cong: if_cong)
also have "… = do {
b' ← bernoulli_pmf q;
if b' then do {
y ← pmf_of_set (A ∩ B);
random_bst (A ∪ B)
} else do {
y ← pmf_of_set (B - A);
l ← random_bst {z∈A ∪ B. z < y};
r ← random_bst {z∈A ∪ B. z > y};
return_pmf ⟨l, y, r⟩
}
}"
using AB q_pos by (intro bind_pmf_cong if_cong refl mrbst_push_down_correct') auto
finally show ?case .
qed
also have "… = do {
b ← bernoulli_pmf p;
b' ← bernoulli_pmf q;
if b then do {
x ← pmf_of_set A;
l ← random_bst {z∈A ∪ B. z < x};
r ← random_bst {z∈A ∪ B. z > x};
return_pmf ⟨l, x, r⟩
} else if b' then do {
random_bst (A ∪ B)
} else do {
y ← pmf_of_set (B - A);
l ← random_bst {z∈A ∪ B. z < y};
r ← random_bst {z∈A ∪ B. z > y};
return_pmf ⟨l, y, r⟩
}
}"
by (simp cong: if_cong)
also have "… = do {
(b, b') ← pair_pmf (bernoulli_pmf p) (bernoulli_pmf q);
if b ∨ ¬b' then do {
x ← (if b then pmf_of_set A else pmf_of_set (B - A));
l ← random_bst {z∈A ∪ B. z < x};
r ← random_bst {z∈A ∪ B. z > x};
return_pmf ⟨l, x, r⟩
} else do {
random_bst (A ∪ B)
}
}" unfolding pair_pmf_def bind_assoc_pmf
by (intro bind_pmf_cong) auto
also have "… = do {
(b, b') ← map_pmf (λ(b, b'). (b ∨ ¬b', b))
(pair_pmf (bernoulli_pmf p) (bernoulli_pmf q));
if b then do {
x ← (if b' then pmf_of_set A else pmf_of_set (B - A));
l ← random_bst {z∈A ∪ B. z < x};
r ← random_bst {z∈A ∪ B. z > x};
return_pmf ⟨l, x, r⟩
} else do {
random_bst (A ∪ B)
}
}" (is "_ = bind_pmf _ ?f")
by (simp add: bind_map_pmf case_prod_unfold cong: if_cong)
also have "map_pmf (λ(b, b'). (b ∨ ¬b', b))
(pair_pmf (bernoulli_pmf p) (bernoulli_pmf q)) =
do {
b ← bernoulli_pmf (1 - (1 - p) * q);
b' ← (if b then bernoulli_pmf r else return_pmf False);
return_pmf (b, b')
}" (is "?lhs = ?rhs")
proof (intro pmf_eqI)
fix bb' :: "bool × bool"
obtain b b' where [simp]: "bb' = (b, b')" by (cases bb')
thus "pmf ?lhs bb' = pmf ?rhs bb'"
using pq r_aux ‹p > 0›
by (cases b; cases b')
(auto simp: pmf_map pmf_bind_bernoulli measure_measure_pmf_finite
vimage_bool_pair pmf_pair r_def field_simps)
qed
also have "… ⤜ ?f = do {
b ← bernoulli_pmf (1 - (1 - p) * q);
if b then do {
x ← do {b' ← bernoulli_pmf r;
if b' then pmf_of_set A else pmf_of_set (B - A)};
l ← random_bst {z∈A ∪ B. z < x};
r ← random_bst {z∈A ∪ B. z > x};
return_pmf ⟨l, x, r⟩
} else do {
random_bst (A ∪ B)
}
}"
by (simp cong: if_cong)
also have "… = do {
b ← bernoulli_pmf (1 - (1 - p) * q);
if b then do {
x ← pmf_of_set (A ∪ (B - A));
l ← random_bst {z∈A ∪ B. z < x};
r ← random_bst {z∈A ∪ B. z > x};
return_pmf ⟨l, x, r⟩
} else do {
random_bst (A ∪ B)
}
}" (is "_ = ?f (A ∪ (B - A))")
using AB pq ‹l ≤ n› mn
by (intro bind_pmf_cong if_cong refl pmf_of_set_union_split)
(auto simp: m_def [symmetric] n_def [symmetric] r_def p_def q_def divide_simps)
also have "A ∪ (B - A) = A ∪ B" by auto
also have "?f … = random_bst (A ∪ B)"
using AB by (simp add: random_bst_reduce cong: if_cong)
finally show ?thesis .
qed
qed
qed
subsection ‹Insertion and Deletion›
text ‹
The insertion and deletion operations are simple special cases of the union
and difference operations where one of the trees is a singleton tree.
›
fun mrbst_insert where
"mrbst_insert x ⟨⟩ = return_pmf ⟨⟨⟩, x, ⟨⟩⟩"
| "mrbst_insert x ⟨l, y, r⟩ =
do {
b ← bernoulli_pmf (1 / real (size l + size r + 2));
if b then do {
let (l', r') = split_bst x ⟨l, y, r⟩;
return_pmf ⟨l', x, r'⟩
} else if x < y then do {
map_pmf (λl'. ⟨l', y, r⟩) (mrbst_insert x l)
} else if x > y then do {
map_pmf (λr'. ⟨l, y, r'⟩) (mrbst_insert x r)
} else do {
mrbst_push_down l y r
}
}"
lemma mrbst_insert_altdef: "mrbst_insert x t = mrbst_union ⟨⟨⟩, x, ⟨⟩⟩ t"
by (induction x t rule: mrbst_insert.induct)
(simp_all add: Let_def map_pmf_def bind_return_pmf case_prod_unfold cong: if_cong)
corollary
fixes t :: "'a :: linorder tree"
assumes "t' ∈ set_pmf (mrbst_insert x t)" "bst t"
shows bst_mrbst_insert: "bst t'"
and set_mrbst_insert: "set_tree t' = insert x (set_tree t)"
using bst_mrbst_union[of t' "⟨⟨⟩, x, ⟨⟩⟩" t] set_mrbst_union[of t' "⟨⟨⟩, x, ⟨⟩⟩" t] assms
by (simp_all add: mrbst_insert_altdef)
corollary mrbst_insert_correct:
assumes "finite A"
shows "random_bst A ⤜ mrbst_insert x = random_bst (insert x A)"
using mrbst_union_correct[of "{x}" A] assms
by (simp add: mrbst_insert_altdef[abs_def] bind_return_pmf)
fun mrbst_delete :: "'a :: ord ⇒ 'a tree ⇒ 'a tree pmf" where
"mrbst_delete x ⟨⟩ = return_pmf ⟨⟩"
| "mrbst_delete x ⟨l, y, r⟩ = (
if x < y then
map_pmf (λl'. ⟨l', y, r⟩) (mrbst_delete x l)
else if x > y then
map_pmf (λr'. ⟨l, y, r'⟩) (mrbst_delete x r)
else
mrbst_join l r)"
lemma mrbst_delete_altdef: "mrbst_delete x t = mrbst_diff t ⟨⟨⟩, x, ⟨⟩⟩"
by (induction t) (auto simp: bind_return_pmf map_pmf_def)
corollary
fixes t :: "'a :: linorder tree"
assumes "t' ∈ set_pmf (mrbst_delete x t)" "bst t"
shows bst_mrbst_delete: "bst t'"
and set_mrbst_delete: "set_tree t' = set_tree t - {x}"
using bst_mrbst_diff[of t' t "⟨⟨⟩, x, ⟨⟩⟩"] set_mrbst_diff[of t' t "⟨⟨⟩, x, ⟨⟩⟩"] assms
by (simp_all add: mrbst_delete_altdef)
corollary mrbst_delete_correct:
"finite A ⟹ do {t ← random_bst A; mrbst_delete x t} = random_bst (A - {x})"
using mrbst_diff_correct[of A "{x}"] by (simp add: mrbst_delete_altdef bind_return_pmf)
end