Theory Random_Treap

(*
  File:      Random_Treap.thy
  Authors:   Max Haslbeck
*)
section ‹Random treaps›
theory Random_Treap
imports
  Probability_Misc
  Treap_Sort_and_BSTs
begin

subsection ‹Measurability›

text ‹
  The following lemmas are only relevant for measurability.
›

(* TODO Move *)
lemma tree_sigma_cong:
  assumes "sets M = sets M'"
  shows   "tree_sigma M = tree_sigma M'"
  using sets_eq_imp_space_eq[OF assms] using assms by (simp add: tree_sigma_def)

lemma distr_restrict:
  assumes "sets N = sets L" "sets K  sets M"
          "X. X  sets K  emeasure M X = emeasure K X"
          "X. X  sets M  X  space M - space K  emeasure M X = 0"
          "f  M M N" "f  K M L"
  shows   "distr M N f = distr K L f"
proof (rule measure_eqI)
  fix X assume "X  sets (distr M N f)"
  thus "emeasure (distr M N f) X = emeasure (distr K L f) X"
    using assms(1) by (intro emeasure_distr_restrict assms) simp_all
qed (use assms in auto)
(* END TODO *)

lemma sets_tree_sigma_count_space:
  assumes "countable B"
  shows   "sets (tree_sigma (count_space B)) = Pow (trees B)"
proof (intro equalityI subsetI)
  fix X assume X: "X  Pow (trees B)"
  have "{t}  sets (tree_sigma (count_space B))" if "t  trees B" for t
    using that
  proof (induction t)
    case (2 l r x)
    hence "{la, v, ra |la v ra. (v, la, ra)  {x} × {l} × {r}}
                sets (tree_sigma (count_space B))"
      by (intro Node_in_tree_sigma pair_measureI) auto
    thus ?case by simp
  qed simp_all
  with X have "(tX. {t})  sets (tree_sigma (count_space B))"
    by (intro sets.countable_UN' countable_subset[OF _ countable_trees[OF assms]]) auto
  also have "(tX. {t}) = X" by blast
  finally show "X  sets (tree_sigma (count_space B))" .
next
  fix X assume "X  sets (tree_sigma (count_space B))"
  from sets.sets_into_space[OF this] show "X  Pow (trees B)"
    by (simp add: space_tree_sigma)
qed

lemma height_primrec: "height = rec_tree 0 (λ_ _ _ a b. Suc (max a b))"
proof
  fix t :: "'a tree"
  show "height t = rec_tree 0 (λ_ _ _ a b. Suc (max a b)) t"
    by (induction t) simp_all
qed

lemma ipl_primrec: "ipl = rec_tree 0 (λl _ r a b. size l + size r + a + b)"
proof 
  fix t :: "'a tree"
  show "ipl t = rec_tree 0 (λl _ r a b. size l + size r + a + b) t"
    by (induction t) auto
qed

lemma size_primrec: "size = rec_tree 0 (λ_ _ _ a b. 1 + a + b)"
proof 
  fix t :: "'a tree"
  show "size t = rec_tree 0 (λ_ _ _ a b. 1 + a + b) t"
    by (induction t) auto
qed

lemma ipl_map_tree[simp]: "ipl (map_tree f t) = ipl t"
by (induction t) auto

lemma set_pmf_random_bst: "finite A  set_pmf (random_bst A)  trees A"
  by (subst random_bst_altdef) 
     (auto intro!: bst_of_list_trees simp add: bst_of_list_trees permutations_of_setD)

lemma trees_mono: "A  B  trees A  trees B"
proof
  fix t
  assume "A  B" "t  trees A"
  then show "t  trees B"
    by (induction t) auto
qed

lemma ins_primrec:
  "ins k (p::real) t = rec_tree 
    (Node Leaf (k,p) Leaf)
    (λl z r l' r'. case z of (k1, p1) 
      if k < k1 then
        (case l' of
          Leaf  Leaf
        | Node l2 (k2,p2) r2  
            if 0  p2 - p1 then Node (Node l2 (k2,p2) r2) (k1,p1) r
            else Node l2 (k2,p2) (Node r2 (k1,p1) r))
      else if k > k1 then
        (case r' of
          Leaf  Leaf
        | Node l2 (k2,p2) r2 
            if 0  p2 - p1 then Node l (k1,p1) (Node l2 (k2,p2) r2)
            else Node (Node l (k1,p1) l2) (k2,p2) r2)
      else Node l (k1,p1) r
      ) t"
proof (induction k p t rule: ins.induct)
  case (2 k p l k1 p1 r)
  thus ?case 
    by (cases "k < k1") (auto simp add: case_prod_beta ins_neq_Leaf split: tree.splits if_splits)
qed auto

lemma measurable_less_count_space [measurable (raw)]:
  assumes "countable A"
  assumes [measurable]: "a  B M count_space A"
  assumes [measurable]: "b  B M count_space A"
  shows   "Measurable.pred B (λx. a x < b x)"
proof -
  have "Measurable.pred (count_space (A × A)) (λx. fst x < snd x)" by simp
  also have "count_space (A × A) = count_space A M count_space A"
    using assms(1) by (simp add: pair_measure_countable)
  finally have "Measurable.pred B ((λx. fst x < snd x)  (λx. (a x, b x)))"
    by measurable
  thus ?thesis by (simp add: o_def)
qed

lemma measurable_ins [measurable (raw)]:
  assumes [measurable]: "countable A"
  assumes [measurable]: "k  B M count_space A"
  assumes [measurable]: "x  B M (lborel :: real measure)"
  assumes [measurable]: "t  B M tree_sigma (count_space A M lborel)"
  shows   "(λy. ins (k y) (x y) (t y))  B M tree_sigma (count_space A M lborel)"
  unfolding ins_primrec by measurable

lemma map_tree_primrec: "map_tree f t = rec_tree ⟨⟩ (λl a r l' r'.  l', f a, r') t"
  by (induction t) auto

definition 𝒰 where "𝒰 = (λa b::real. uniform_measure lborel {a..b})"

declare 𝒰_def[simp]

fun insR:: "'a::linorder  ('a × real) tree  'a set  ('a × real) tree measure" where
  "insR x t A  = distr (𝒰 0 1) (tree_sigma (count_space A M lborel)) (λp. ins x p t)"

fun rinss :: "'a::linorder list  ('a × real) tree  'a set  ('a × real) tree measure" where
  "rinss [] t A =  return (tree_sigma (count_space A M lborel)) t" |
  "rinss (x#xs) t A = insR x t A  (λt. rinss xs t A)"

lemma sets_rinss':
  assumes "countable B" "set ys  B"
  shows "t  trees (B × UNIV)  sets (rinss ys t B) = sets (tree_sigma (count_space B M lborel))"
  using assms proof(induction ys arbitrary: t)
  case (Cons y ys)
  then show ?case
    by (subst rinss.simps, subst sets_bind) (auto simp add: space_tree_sigma space_pair_measure)
qed auto

lemma measurable_foldl [measurable]:
  assumes "f  A M B" "set xs  space C"
  assumes "c. c  set xs  (λ(a,b). g a b c)  (A M B) M B"
  shows   "(λx. foldl (g x) (f x) xs)  A M B"
  using assms
proof (induction xs arbitrary: f)
  case Nil
  thus ?case by simp
next
  case (Cons x xs)
  note [measurable] = Cons.prems(1)
  from Cons.prems have [measurable]: "x  space C" by simp
  have "(λa. (a, f a))  A M A M B"
    by measurable
  hence "(λ(a,b). g a b x)  (λa. (a, f a))  A M B"
    by (rule measurable_comp) (rule Cons.prems, auto)
  hence "(λa. g a (f a) x)  A M B" by (simp add: o_def)
  hence "(λxa. foldl (g xa) (g xa (f xa) x) xs)  A M B"
    by (rule Cons.IH) (use Cons.prems in auto)
  thus ?case by simp
qed

lemma ins_trees: "t  trees A  (x,y)  A  ins x y t  trees A"
  by (induction x y t rule: ins.induct)
     (auto split: tree.splits simp: ins_neq_Leaf)


subsection ‹Main result›

text ‹
  In our setting, we have some countable set of values that may appear in the input and
  a concrete list consisting only of those elements with no repeated elements.

  We further define an abbreviation for the uniform distribution of permutations of that lists.
›

context
  fixes xs::"'a::linorder list" and A::"'a set" and random_perm :: "'a list  'a list measure"
  assumes con_assms: "countable A" "set xs  A" "distinct xs"
  defines "random_perm  (λxs. uniform_measure (count_space (permutations_of_set (set xs)))
                                 (permutations_of_set (set xs)))"
begin

text ‹
  Again, we first need some facts about measurability.
›
lemma sets_rinss [simp]: 
  assumes "t  trees (A × UNIV)" 
  shows "sets (rinss xs t A) = tree_sigma (count_space A M borel)"
proof -
  have "tree_sigma (count_space A M (lborel::real measure)) = tree_sigma (count_space A M borel)"
    by (intro tree_sigma_cong sets_pair_measure_cong) auto
  then show ?thesis
    using assms con_assms by (subst sets_rinss') auto
qed

lemma bst_of_list_measurable [measurable]:
  "bst_of_list  measurable (count_space (lists A)) (tree_sigma (count_space A))"
  by (subst measurable_count_space_eq1)
    (auto simp: space_tree_sigma intro!: bst_of_list_trees)

lemma insort_wrt_measurable [measurable]:
  "(λx. insort_wrt x xs)  count_space (Pow (A × A)) M count_space (lists A)"
  using con_assms by auto

lemma bst_of_list_sort_meaurable [measurable]:
  "(λx. bst_of_list (sort_key x xs))  
     PiM (set xs) (λi. borel::real measure) M tree_sigma (count_space A)"
proof -
  note measurable_linorder_from_keys_restrict'[measurable]
  have "(0::real) < 1"
    by auto
  then have [measurable]: "(λx. bst_of_list (insort_wrt (linorder_from_keys (set xs) x) xs))
                       PiM (set xs) (λi. borel :: real measure) M tree_sigma (count_space A)"
    using con_assms by measurable
  show ?thesis
    by (subst insort_wrt_sort_key[symmetric]) (measurable, auto)
qed


text ‹
  In a first step, we convert the bulk insertion operation to first choosing the
  priorities i.\,i.\,d.\ ahead of time and then inserting all the elements deterministically
  with their associated priority.
›
lemma random_treap_fold:
  assumes "t  space (tree_sigma (count_space A M lborel))"
  shows "rinss xs t A = distr (ΠM xset xs. 𝒰 0 1)
                              (tree_sigma (count_space A M lborel))
                              (λp. foldl (λt x. ins x (p x) t) t xs)"
proof -
  let ?U = "uniform_measure lborel {0::real..1}"
  have "set xs  space (count_space A)" "c  set xs  c  space (count_space A)" for c
    using con_assms by auto
  then have *[intro]: "(λp. foldl (λt x. ins x (p x) t) t xs) 
    PiM (set xs) (λx. ?U) M tree_sigma (count_space A M lborel)"
    if "t  space (tree_sigma (count_space A M lborel))" for t
   using that con_assms by measurable
  have insR': 
    "insR x t A = ?U  (λu. return (tree_sigma (count_space A M lborel)) (ins x u t))"
    if "x  A" "t  space (tree_sigma (count_space A M lborel))" for t x
    using con_assms assms that by (auto simp add: bind_return_distr' 𝒰_def)
  have "rinss xs t A = (ΠM xset xs. ?U) 
     (λp. return (tree_sigma (count_space A M lborel)) (foldl (λt x. ins x (p x) t) t xs))"
    using con_assms(2,3) assms proof (induction xs arbitrary: t)
  case Nil
  then show ?case
    by (intro measure_eqI) (auto simp add: space_PiM_empty emeasure_distr bind_return_distr')
next
  case (Cons x xs)
  note insR.simps[simp del]
  let ?treap_sigma  = "tree_sigma (count_space A M lborel)"
  have [measurable]: "set xs  space (count_space A)" "x  A"
                     "c  A  c  space (count_space A)" for c
    using Cons by auto
  have [intro!]: "ins k p t  space ?treap_sigma" if "t  space ?treap_sigma" "k  A"
    for k t and p::real
    using that
    by (auto intro!: ins_trees simp add: space_tree_sigma space_pair_measure)
  have [measurable]: "PiM (set xs) (λx. ?U)  space (prob_algebra (PiM (set xs) (λi. ?U)))"
    unfolding space_prob_algebra by (auto intro!: prob_space_uniform_measure prob_space_PiM)
  have [measurable]: "PiM (set xs) (λx. ?U)  space (subprob_algebra (PiM (set xs) (λi. ?U)))"
    unfolding space_subprob_algebra
    by (auto intro!: prob_space_imp_subprob_space prob_space_uniform_measure prob_space_PiM)
  have [measurable]: "(λx. x)  (?treap_sigma M PiM (set xs) (λi. ?U)) M ?treap_sigma M
            (?treap_sigma M PiM (set xs) (λi. borel)) M ?treap_sigma"
    by (auto intro!: measurable_ident_sets sets_pair_measure_cong sets_PiM_cong simp add: 𝒰_def)
  have [simp]: "(λw. PiM (set xs) (λx. ?U) 
        (λp. return ?treap_sigma (foldl (λt x. ins x (p x) t) w xs)))
         ?treap_sigma M subprob_algebra ?treap_sigma"
  proof -
    have [measurable]: "c  set xs  c  A" for c
      using Cons by auto
    show ?thesis
      using con_assms by measurable
  qed
  have [measurable]: "?U  space (prob_algebra (?U))"
    by (simp add: prob_space_uniform_measure space_prob_algebra)
  have [measurable, intro]: "(λt. rinss xs t A)  ?treap_sigma M subprob_algebra ?treap_sigma"
    if "set xs  A" for xs
    using that proof (induction xs)
    case (Cons x xs)
    then have [measurable]: "x  A" "set xs  A"
      by auto
    have [measurable]: "(λy. x)  tree_sigma (count_space A M lborel) M ?U M count_space A"
      using Cons by measurable
    have [measurable]: "(λx. x)  ?treap_sigma M ?U M ?treap_sigma M borel"
      unfolding 𝒰_def by auto
    have [measurable]: "(λt. distr (?U) (tree_sigma (count_space A M lborel)) (λp. ins x p t))
     ?treap_sigma M subprob_algebra ?treap_sigma"
      using con_assms by (intro measurable_prob_algebraD) measurable
    from Cons show ?case
      unfolding rinss.simps insR.simps 𝒰_def by measurable
  qed auto
  have [intro]: "(λu. return ?treap_sigma (ins x u t))  ?U M subprob_algebra ?treap_sigma"
    using con_assms Cons by measurable
  have [simp]: "space (?U M PiM (set xs) (λx. ?U))  {}"
    by (simp add: prob_space.not_empty prob_space_PiM prob_space_pair prob_space_uniform_measure)
  from Cons have "rinss (x # xs) t A = (?U 
                                       (λu. return ?treap_sigma (ins x u t))) 
                                       (λt. rinss xs t A)"
    by (simp add: insR')
  also have " = ?U  (λu. return ?treap_sigma (ins x u t)  (λt. rinss xs t A))"
    using con_assms Cons by (subst bind_assoc) auto
  also have " = ?U  (λu. rinss xs (ins x u t) A)"
    using con_assms Cons by (subst bind_return) auto
  also have " = ?U 
                 (λu. PiM (set xs) (λx. ?U) 
                 (λp. return ?treap_sigma (foldl (λt x. ins x (p x) t) (ins x u t) xs)))"
    using Cons by (subst Cons) (auto simp add: treap_ins keys_ins)
  also have " = ?U M PiM (set xs) (λx. ?U) 
                  (λ(u,p). return ?treap_sigma (foldl (λt x. ins x (p x) t) (ins x u t) xs))"
  proof -
    have [measurable]: "pair_prob_space (?U) (PiM (set xs) (λx. ?U))"
      by (simp add: 𝒰_def pair_prob_space_def pair_sigma_finite.intro prob_space_PiM 
          prob_space_imp_sigma_finite prob_space_uniform_measure)
    note this[unfolded 𝒰_def, measurable]
    have [measurable]: "c  set xs  c  A" for c
      using Cons by auto
    show ?thesis
      using con_assms Cons by (subst pair_prob_space.pair_measure_bind) measurable
  qed
  also have " = distr (?U M PiM (set xs) (λx. ?U)) (tree_sigma (count_space A M lborel))
                  (λ(u, f). foldl (λt x. ins x (f x) t) (ins x u t) xs)"
  proof -
    have [simp]: "c  set xs  c  A" for c
      using Cons by auto
    have "(λxa. foldl (λt x. ins x (snd xa x) t) (ins x (fst xa) t) xs) = 
          (λ(u, f). foldl (λt x. ins x (f x) t) (ins x u t) xs)"
      by (auto simp add: case_prod_beta')
    then show ?thesis
      using con_assms Cons by (subst case_prod_beta', subst bind_return_distr') measurable
  qed
  also have
    " = distr (?U M PiM (set xs) (λi. ?U)) ?treap_sigma
          (λf. foldl (λt y. ins y (if y = x then fst f else snd f y) t) (ins x (fst f) t) xs)"
  proof -
    have "foldl (λt y. ins y (snd f y) t) (ins x (fst f) t) xs =
          foldl (λt y. ins y (if y = x then fst f else snd f y) t) (ins x (fst f) t) xs" for f
      using Cons by (intro foldl_cong) auto
    then show ?thesis
      by (auto simp add: case_prod_beta')
  qed
  also have " = distr (?U M PiM (set xs) (λi. ?U)) (PiM (insert x (set xs)) (λi. ?U)) 
                          (λ(r, f). f(x := r)) 
                          (λp. return ?treap_sigma (foldl (λt x. ins x (p x) t) (ins x (p x) t) xs))"
    using con_assms  Cons 
    by (subst bind_distr_return) (measurable, auto simp add: case_prod_beta')
  also have " = PiM (insert x (set xs)) (λx. ?U) 
                  (λp. return ?treap_sigma (foldl (λt x. ins x (p x) t) (ins x (p x) t) xs))"
    by (subst distr_pair_PiM_eq_PiM) (auto simp add: prob_space_uniform_measure)
  finally show ?case
    by (simp)
qed
  then show ?thesis
    using assms by (subst bind_return_distr'[symmetric]) (auto simp add: bind_return_distr')
qed

corollary random_treap_fold_Leaf:
  shows "rinss xs Leaf A =
         distr (ΠM xset xs. 𝒰 0 1)
               (tree_sigma (count_space A M lborel))
               (λp. foldl (λt x. ins x (p x) t) Leaf xs)"
  by (auto simp add: random_treap_fold)

text ‹
  Next, we show that additionally forgetting the priorities in the end will yield
  the same distribution as inserting the elements into a BST by ascending priority.
›
lemma rinss_bst_of_list:
      "distr (rinss xs Leaf A) (tree_sigma (count_space A)) (map_tree fst) =
       distr (PiM (set xs) (λx. 𝒰 0 1)) (tree_sigma (count_space A))
             (λp. bst_of_list (sort_key p xs))" (is "?lhs = ?rhs")
proof -
  have [measurable]: "set xs  space (count_space A)"
    "c  set xs  c  space (count_space A)" for c
    using con_assms by auto
  have [simp]: "map_tree fst  (λp. foldl (λt x. ins x (p x) t) ⟨⟩ xs)
                 PiM (set xs) (λx. uniform_measure lborel {0::real..1}) M
                  tree_sigma (count_space A)"
    unfolding 𝒰_def map_tree_primrec using con_assms by measurable
  have "AE f in PiM (set xs) (λi. 𝒰 0 1). inj_on f (set xs)"
    unfolding 𝒰_def by (rule almost_everywhere_avoid_finite) auto
  then have "AE f in PiM (set xs) (λx. 𝒰 0 1).
             map_tree fst (foldl (λt (k,p). ins k p t) ⟨⟩ (map (λx. (x, f x)) xs)) =
             bst_of_list (sort_key f xs)"
    by (eventually_elim) (use con_assms in auto simp add: fold_ins_bst_of_list)
  then have [simp]: "AE f in PiM (set xs) (λx. 𝒰 0 1).
             map_tree fst (foldl (λt k. ins k (f k) t) ⟨⟩ xs) = bst_of_list (sort_key f xs)"
    by (simp add: foldl_map)
  have "?lhs = distr (PiM (set xs) (λx. 𝒰 0 1)) (tree_sigma (count_space A))
                     (map_tree fst  (λp. foldl (λt x. ins x (p x) t) ⟨⟩ xs))"
    unfolding random_treap_fold_Leaf 𝒰_def map_tree_primrec using con_assms
    by (subst distr_distr) measurable
  also have " = ?rhs"
    by (intro distr_cong_AE) (auto simp add: 𝒰_def)
  finally show ?thesis .
qed

text ‹
  This in turn is the same as choosing a random permutation of the input list and
  inserting the elements into a BST in that order.
›
lemma lborel_permutations_of_set_bst_of_list:
  shows "distr (PiM (set xs) (λx. 𝒰 0 1)) (tree_sigma (count_space A))
               (λp. bst_of_list (sort_key p xs)) =
         distr (random_perm xs) (tree_sigma (count_space A)) bst_of_list" (is "?lhs = ?rhs")
proof -
  have [measurable]: "(0::real) < 1"
    by auto
  have "insort_wrt R xs = insort_wrt R (remdups xs)" for R
    using con_assms distinct_remdups_id by metis
  then have *: "insort_wrt R xs = sorted_wrt_list_of_set R (set xs)"
    if "linorder_on (set xs) R" for R
    using that by (subst sorted_wrt_list_set) auto
  have [measurable]: "(λx. x)  count_space (permutations_of_set (set xs)) M count_space (lists A)"
    using con_assms permutations_of_setD by fastforce
  have [measurable]: "(λR. insort_wrt R xs) 
                      count_space (Pow (A × A)) M count_space (permutations_of_set (set xs))"
    using con_assms by (simp add: permutations_of_setI)
  have "?lhs 
   = distr (PiM (set xs) (λx. 𝒰 0 1)) (tree_sigma (count_space A))
           (λp. bst_of_list (insort_wrt (linorder_from_keys (set xs) p) xs))"
    unfolding Let_def by (simp add: insort_wrt_sort_key)
 also have " = 
  distr (distr (PiM (set xs) (λx. uniform_measure lborel {0::real..1}))
    (count_space (Pow (A × A))) (linorder_from_keys (set xs)))
  (tree_sigma (count_space A)) (λR. bst_of_list (insort_wrt R xs))"
   unfolding 𝒰_def using con_assms by (subst distr_distr) (measurable, metis comp_apply)
  also have " = 
  distr (uniform_measure (count_space (Pow (A × A))) (linorders_on (set xs)))
        (tree_sigma (count_space A)) (λR. bst_of_list (insort_wrt R xs))"
    using con_assms by (subst random_linorder_by_prios) auto
  also have " = distr (distr (uniform_measure (count_space (Pow (A × A))) (linorders_on (set xs)))
                               (count_space (permutations_of_set (set xs))) (λR. insort_wrt R xs))
                        (tree_sigma (count_space A)) bst_of_list"
    by (subst distr_distr) (measurable, metis comp_apply)
  also have " = distr (uniform_measure (count_space (permutations_of_set (set xs)))
                          ((λR. insort_wrt R xs) ` linorders_on (set xs)))
                    (tree_sigma (count_space A)) bst_of_list"
  proof -
    have "bij_betw (λR. insort_wrt R xs) (linorders_on (set xs)) (permutations_of_set (set xs))"
      by (subst bij_betw_cong, fastforce simp add: * linorders_on_def bij_betw_cong)
         (use bij_betw_linorders_on' in blast)
    then have "inj_on (λR. insort_wrt R xs) (linorders_on (set xs))"
      by (rule bij_betw_imp_inj_on)
    then have "distr (uniform_measure (count_space (Pow (A × A))) (linorders_on (set xs)))
                     (count_space (permutations_of_set (set xs))) (λR. insort_wrt R xs)
               = uniform_measure (count_space (permutations_of_set (set xs)))
                                 ((λR. insort_wrt R xs) ` linorders_on (set xs))"
      using con_assms by (intro distr_uniform_measure_count_space_inj)
        (auto simp add: linorders_on_def linorder_on_def refl_on_def)
    then show ?thesis by auto
  qed
  also have " = distr (random_perm xs) (tree_sigma (count_space A)) bst_of_list"
  proof -
    have "((λR. insort_wrt R xs) ` linorders_on (set xs)) = permutations_of_set (set xs)"
      by (intro bij_betw_imp_surj_on, subst bij_betw_cong, rule *)
         (fastforce simp add: linorders_on_def,  use bij_betw_linorders_on' in blast)
   then show ?thesis by (simp add: random_perm_def)
 qed
  finally show ?thesis .
qed

lemma distr_bst_of_list_tree_sigma_count_space: "
   distr (random_perm xs) (tree_sigma (count_space A)) bst_of_list =
     distr (random_perm xs) (count_space (trees A)) bst_of_list"
  using con_assms by (intro distr_cong)  (auto intro!: sets_tree_sigma_count_space)

text ‹
  This is the same as a \emph{random BST}.
›
lemma distr_bst_of_list_random_bst: "
  distr (random_perm xs) (count_space (trees A)) bst_of_list =
    restrict_space (random_bst (set xs)) (trees A)" (is "?lhs = ?rhs")
proof -
  have "?rhs = restrict_space (distr (uniform_measure (count_space UNIV)
                 (permutations_of_set (set xs))) (count_space UNIV) bst_of_list) (trees A)"
    by (auto simp: random_bst_altdef measure_pmf_of_set map_pmf_rep_eq)
  also have "distr (uniform_measure (count_space UNIV) (permutations_of_set (set xs))) 
                   (count_space UNIV) bst_of_list = 
               distr (random_perm xs) (count_space UNIV) bst_of_list"
    by (intro distr_restrict) (auto simp: random_perm_def)
  also have "restrict_space  (trees A) =
               distr (random_perm xs) (count_space (trees A)) bst_of_list"
    using con_assms
    by (subst restrict_distr)
       (auto simp: random_perm_def bst_of_list_trees restrict_count_space permutations_of_setD)
  finally show ?thesis ..
qed

text ‹
  We put everything together and obtain our main result:
›
theorem rinss_random_bst:
  "distr (rinss xs ⟨⟩ A) (tree_sigma (count_space A)) (map_tree fst) =
     restrict_space (measure_pmf (random_bst (set xs))) (trees A)"
  by (simp only: rinss_bst_of_list lborel_permutations_of_set_bst_of_list
                 distr_bst_of_list_tree_sigma_count_space distr_bst_of_list_random_bst)

end
end