Theory OptBST
subsection ‹Optimal Binary Search Trees›
text ‹
The material presented in this section just contains a simple and non-optimal version
(cubic instead of quadratic in the number of keys).
It can now be viewed to be superseded by the AFP entry ‹Optimal_BST›.
It is kept here as a more easily understandable example and for archival purposes.
›
theory OptBST
imports
"HOL-Library.Tree"
"HOL-Library.Code_Target_Numeral"
"../state_monad/State_Main"
"../heap_monad/Heap_Default"
Example_Misc
"HOL-Library.Product_Lexorder"
"HOL-Library.RBT_Mapping"
begin
subsubsection ‹Function ‹argmin››
text ‹Function ‹argmin› iterates over a list and returns the rightmost element
that minimizes a given function:›
fun argmin :: "('a ⇒ ('b::linorder)) ⇒ 'a list ⇒ 'a" where
"argmin f (x#xs) =
(if xs = [] then x else
let m = argmin f xs in if f x < f m then x else m)"
text ‹Note that @{term arg_min_list} is similar but returns the leftmost element.›
lemma argmin_forall: "xs ≠ [] ⟹ (⋀x. x∈set xs ⟹ P x) ⟹ P (argmin f xs)"
by(induction xs) (auto simp: Let_def)
lemma argmin_Min: "xs ≠ [] ⟹ f (argmin f xs) = Min (f ` set xs)"
by(induction xs) (auto simp: min_def intro!: antisym)
subsubsection ‹Misc›
lemma upto_join: "⟦ i ≤ j; j ≤ k ⟧ ⟹ [i..j-1] @ j # [j+1..k] = [i..k]"
using upto_rec1 upto_split1 by auto
lemma atLeastAtMost_split:
"{i..j} = {i..k} ∪ {k+1..j}" if "i ≤ k" "k ≤ j" for i j k :: int
using that by auto
lemma atLeastAtMost_split_insert:
"{i..k} = insert k {i..k-1}" if "k ≥ i" for i :: int
using that by auto
subsubsection ‹Definitions›
context
fixes W :: "int ⇒ int ⇒ nat"
begin
fun wpl :: "int ⇒ int ⇒ int tree ⇒ nat" where
"wpl i j Leaf = 0"
| "wpl i j (Node l k r) = wpl i (k-1) l + wpl (k+1) j r + W i j"
function min_wpl :: "int ⇒ int ⇒ nat" where
"min_wpl i j =
(if i > j then 0
else min_list (map (λk. min_wpl i (k-1) + min_wpl (k+1) j + W i j) [i..j]))"
by auto
termination by (relation "measure (λ(i,j) . nat(j-i+1))") auto
declare min_wpl.simps[simp del]
function opt_bst :: "int ⇒ int ⇒ int tree" where
"opt_bst i j =
(if i > j then Leaf else argmin (wpl i j) [⟨opt_bst i (k-1), k, opt_bst (k+1) j⟩. k ← [i..j]])"
by auto
termination by (relation "measure (λ(i,j) . nat(j-i+1))") auto
declare opt_bst.simps[simp del]
subsubsection ‹Functional Memoization›
context
fixes n :: nat
begin
context fixes
mem :: "nat option array"
begin
memoize_fun min_wpl⇩T: min_wpl
with_memory dp_consistency_heap_default where bound = "Bound (0, 0) (int n, int n)" and mem="mem"
monadifies (heap) min_wpl.simps
context includes heap_monad_syntax begin
thm min_wpl⇩T'.simps min_wpl⇩T_def
end
memoize_correct
by memoize_prover
lemmas memoized_empty = min_wpl⇩T.memoized_empty
end
context
includes heap_monad_syntax
notes [simp del] = min_wpl⇩T'.simps
begin
definition "min_wpl⇩h ≡ λ i j. Heap_Monad.bind (mem_empty (n * n)) (λ mem. min_wpl⇩T' mem i j)"
lemma min_wpl_heap:
"min_wpl i j = result_of (min_wpl⇩h i j) Heap.empty"
unfolding min_wpl⇩h_def
using memoized_empty[of _ "λ m. λ (a, b). min_wpl⇩T' m a b" "(i, j)", OF min_wpl⇩T.crel]
by (simp add: index_size_defs)
end
end
context includes state_monad_syntax begin
memoize_fun min_wpl⇩m: min_wpl with_memory dp_consistency_mapping monadifies (state) min_wpl.simps
thm min_wpl⇩m'.simps
memoize_correct
by memoize_prover
print_theorems
lemmas [code] = min_wpl⇩m.memoized_correct
memoize_fun opt_bst⇩m: opt_bst with_memory dp_consistency_mapping monadifies (state) opt_bst.simps
thm opt_bst⇩m'.simps
memoize_correct
by memoize_prover
print_theorems
lemmas [code] = opt_bst⇩m.memoized_correct
end
subsubsection ‹Correctness Proof›
lemma min_wpl_minimal:
"inorder t = [i..j] ⟹ min_wpl i j ≤ wpl i j t"
proof(induction i j t rule: wpl.induct)
case (1 i j)
then show ?case by (simp add: min_wpl.simps)
next
case (2 i j l k r)
then show ?case
proof cases
assume "i > j" thus ?thesis by(simp add: min_wpl.simps)
next
assume [arith]: "¬ i > j"
have kk_ij: "k∈set[i..j]" using 2
by (metis set_inorder tree.set_intros(2))
let ?M = "((λk. min_wpl i (k-1) + min_wpl (k+1) j + W i j) ` {i..j})"
let ?w = "min_wpl i (k-1) + min_wpl (k+1) j + W i j"
have aux_min:"Min ?M ≤ ?w"
proof (rule Min_le)
show "finite ?M" by simp
show "?w ∈ ?M" using kk_ij by auto
qed
have"inorder ⟨l,k,r⟩ = inorder l @k#inorder r" by auto
from this have C:"[i..j] = inorder l @ k#inorder r" using 2 by auto
have D: "[i..j] = [i..k-1]@k#[k+1..j]" using kk_ij upto_rec1 upto_split1
by (metis atLeastAtMost_iff set_upto)
have l_inorder: "inorder l = [i..k-1]"
by (smt C D append_Cons_eq_iff atLeastAtMost_iff set_upto)
have r_inorder: "inorder r = [k+1..j]"
by (smt C D append_Cons_eq_iff atLeastAtMost_iff set_upto)
have "min_wpl i j = Min ?M" by (simp add: min_wpl.simps min_list_Min)
also have "... ≤ ?w" by (rule aux_min)
also have "... ≤ wpl i (k-1) l + wpl (k+1) j r + W i j" using l_inorder r_inorder "2.IH" by simp
also have "... = wpl i j ⟨l,k,r⟩" by simp
finally show ?thesis .
qed
qed
lemma opt_bst_correct: "inorder (opt_bst i j) = [i..j]"
by (induction i j rule: opt_bst.induct)
(clarsimp simp: opt_bst.simps upto_join | rule argmin_forall)+
lemma wpl_opt_bst: "wpl i j (opt_bst i j) = min_wpl i j"
proof(induction i j rule: min_wpl.induct)
case (1 i j)
show ?case
proof cases
assume "i > j" thus ?thesis by(simp add: min_wpl.simps opt_bst.simps)
next
assume *[arith]: "¬ i > j"
let ?ts = "[⟨opt_bst i (k-1), k, opt_bst (k+1) j⟩. k <- [i..j]]"
let ?M = "((λk. min_wpl i (k-1) + min_wpl (k+1) j + W i j) ` {i..j})"
have "?ts ≠ []" by (auto simp add: upto.simps)
have "wpl i j (opt_bst i j) = wpl i j (argmin (wpl i j) ?ts)" by (simp add: opt_bst.simps)
also have "… = Min (wpl i j ` (set ?ts))" by (rule argmin_Min[OF ‹?ts ≠ []›])
also have "… = Min ?M"
proof (rule arg_cong[where f=Min])
show "wpl i j ` (set ?ts) = ?M"
by (fastforce simp: Bex_def image_iff 1[OF *])
qed
also have "… = min_wpl i j" by (simp add: min_wpl.simps min_list_Min)
finally show ?thesis .
qed
qed
lemma opt_bst_is_optimal:
"inorder t = [i..j] ⟹ wpl i j (opt_bst i j) ≤ wpl i j t"
by (simp add: min_wpl_minimal wpl_opt_bst)
end
subsubsection ‹Access Frequencies›
text ‹Usually, the problem is phrased in terms of access frequencies.
We now give an interpretation of @{term wpl} in this view and show that we have actually computed
the right thing.›
context
fixes p :: "int ⇒ nat"
begin
fun cost :: "int tree ⇒ nat" where
"cost Leaf = 0"
| "cost (Node l k r) = sum p (set_tree l) + cost l + p k + cost r + sum p (set_tree r)"
qualified definition W where
"W i j = sum p {i..j}"
lemma W_rec:
"W i j = (if j ≥ i then W i (j - 1) + p j else 0)"
unfolding W_def by (simp add: atLeastAtMost_split_insert)
lemma inorder_wpl_correct:
"inorder t = [i..j] ⟹ wpl W i j t = cost t"
proof (induction t arbitrary: i j)
case Leaf
show ?case
by simp
next
case (Node l k r)
from ‹inorder ⟨l, k, r⟩ = [i..j]› have *: "i ≤ k" "k ≤ j"
by - (simp, metis atLeastAtMost_iff in_set_conv_decomp set_upto)+
moreover from ‹i ≤ k› ‹k ≤ j› have "inorder l = [i..k-1]" "inorder r = [k+1..j]"
using ‹inorder ⟨l, k, r⟩ = [i..j]›[symmetric] by (simp add: upto_split3 append_Cons_eq_iff)+
ultimately show ?case
by (simp add: Node.IH, subst W_def, subst atLeastAtMost_split)
(simp add: sum.union_disjoint atLeastAtMost_split_insert flip: set_inorder)+
qed
text ‹The optimal binary search tree has minimal cost among all binary search trees.›
lemma opt_bst_has_optimal_cost:
"inorder t = [i..j] ⟹ cost (opt_bst W i j) ≤ cost t"
using inorder_wpl_correct opt_bst_is_optimal opt_bst_correct by metis
text ‹
The function @{term min_wpl} correctly computes the minimal cost among all binary search trees:
▪ Its cost is a lower bound for the cost of all binary search trees
▪ Its cost actually corresponds to an optimal binary search tree
›
lemma min_wpl_minimal_cost:
"inorder t = [i..j] ⟹ min_wpl W i j ≤ cost t"
using inorder_wpl_correct min_wpl_minimal by metis
lemma min_wpl_tree:
"cost (opt_bst W i j) = min_wpl W i j"
using wpl_opt_bst opt_bst_correct inorder_wpl_correct by metis
paragraph ‹An alternative view of costs.›
fun depth :: "'a ⇒ 'a tree ⇒ nat extended" where
"depth x Leaf = ∞"
| "depth x (Node l k r) = (if x = k then 1 else min (depth x l) (depth x r) + 1)"
fun the_fin where
"the_fin (Fin x) = x" | "the_fin _ = undefined"
definition cost' :: "int tree ⇒ nat" where
"cost' t = sum (λx. the_fin (depth x t) * p x) (set_tree t)"
lemma [simp]:
"the_fin 1 = 1"
by (simp add: one_extended_def)
lemma set_tree_depth:
assumes "x ∉ set_tree t"
shows "depth x t = ∞"
using assms by (induction t) auto
lemma depth_inf_iff:
"depth x t = ∞ ⟷ x ∉ set_tree t"
apply (induction t)
apply (auto simp: one_extended_def)
subgoal for t1 k t2
by (cases "depth x t1"; cases "depth x t2") auto
subgoal for t1 k t2
by (cases "depth x t1"; cases "depth x t2") auto
subgoal for t1 k t2
by (cases "depth x t1"; cases "depth x t2") auto
subgoal for t1 k t2
by (cases "depth x t1"; cases "depth x t2") auto
done
lemma depth_not_neg_inf[simp]:
"depth x t = -∞ ⟷ False"
apply (induction t)
apply (auto simp: one_extended_def)
subgoal for t1 k t2
by (cases "depth x t1"; cases "depth x t2") auto
done
lemma depth_FinD:
assumes "x ∈ set_tree t"
obtains d where "depth x t = Fin d"
using assms by (cases "depth x t") (auto simp: depth_inf_iff)
lemma cost'_Leaf[simp]:
"cost' Leaf = 0"
unfolding cost'_def by simp
lemma cost'_Node:
"distinct (inorder ⟨l, x, r⟩) ⟹
cost' ⟨l, x, r⟩ = sum p (set_tree l) + cost' l + p x + cost' r + sum p (set_tree r)"
unfolding cost'_def
apply simp
apply (subst sum.union_disjoint)
apply (simp; fail)+
apply (subst sum.cong[OF HOL.refl, where h = "λx. (the_fin (depth x l) + 1) * p x"])
subgoal for k
using set_tree_depth by (force simp: one_extended_def elim: depth_FinD)
apply (subst (2) sum.cong[OF HOL.refl, where h = "λx. (the_fin (depth x r) + 1) * p x"])
subgoal
using set_tree_depth by (force simp: one_extended_def elim: depth_FinD)
apply (simp add: sum.distrib)
done
lemma weight_correct:
"distinct (inorder t) ⟹ cost' t = cost t"
by (induction t; simp add: cost'_Node)
subsubsection ‹Memoizing Weights›
function W_fun where
"W_fun i j = (if i > j then 0 else W_fun i (j - 1) + p j)"
by auto
termination
by (relation "measure (λ(i::int, j::int). nat (j - i + 1))") auto
lemma W_fun_correct:
"W_fun i j = W i j"
by (induction rule: W_fun.induct) (simp add: W_def atLeastAtMost_split_insert)
memoize_fun W⇩m: W_fun
with_memory dp_consistency_mapping
monadifies (state) W_fun.simps
memoize_correct
by memoize_prover
definition
"compute_W n = snd (run_state (State_Main.map⇩T' (λi. W⇩m' i n) [0..n]) Mapping.empty)"
notation W⇩m.crel_vs (‹crel›)
lemmas W⇩m_crel = W⇩m.crel[unfolded W⇩m.consistentDP_def, THEN rel_funD,
of "(m, x)" "(m, y)" for m x y, unfolded prod.case]
lemma compute_W_correct:
assumes "Mapping.lookup (compute_W n) (i, j) = Some x"
shows "W i j = x"
proof -
include state_monad_syntax and app_syntax and lifting_syntax
let ?p = "State_Main.map⇩T' (λi. W⇩m' i n) [0..n]"
let ?q = "map (λi. W i n) [0..n]"
have "?q = map $ ⟪(λi. W_fun i n)⟫ $ ⟪[0..n]⟫"
unfolding Wrap_def App_def W_fun_correct ..
have "?p = State_Main.map⇩T . ⟨λi. W⇩m' i n⟩ . ⟨[0..n]⟩"
unfolding State_Monad_Ext.fun_app_lifted_def State_Main.map⇩T_def bind_left_identity ..
have "W⇩m.crel_vs (list_all2 (=)) ?q ?p"
unfolding ‹?p = _› ‹?q = _›
apply (subst Transfer.Rel_def[symmetric])
apply memoize_prover_match_step+
apply (subst Rel_def, rule W⇩m_crel, rule HOL.refl)
done
then have "W⇩m.cmem (compute_W n)"
unfolding compute_W_def by (elim W⇩m.crel_vs_elim[OF _ W⇩m.cmem_empty]; simp del: W⇩m'.simps)
with assms show ?thesis
unfolding W_fun_correct[symmetric] by (elim W⇩m.cmem_elim) (simp)+
qed
definition
"min_wpl' i j ≡
let
M = compute_W j;
W = (λi j. case Mapping.lookup M (i, j) of None ⇒ W i j | Some x ⇒ x)
in min_wpl W i j"
lemma W_compute: "W i j = (case Mapping.lookup (compute_W n) (i, j) of None ⇒ W i j | Some x ⇒ x)"
by (auto dest: compute_W_correct split: option.split)
lemma min_wpl'_correct:
"min_wpl' i j = min_wpl W i j"
using W_compute unfolding min_wpl'_def by simp
definition
"opt_bst' i j ≡
let
M = compute_W j;
W = (λi j. case Mapping.lookup M (i, j) of None ⇒ W i j | Some x ⇒ x)
in opt_bst W i j"
lemma opt_bst'_correct:
"opt_bst' i j = opt_bst W i j"
using W_compute unfolding opt_bst'_def by simp
end
subsubsection ‹Test Case›
text ‹Functional Implementations›
lemma "min_wpl (λi j. nat(i+j)) 0 4 = 10"
by eval
lemma "opt_bst (λi j. nat(i+j)) 0 4 = ⟨⟨⟨⟨⟨⟨⟩, 0, ⟨⟩⟩, 1, ⟨⟩⟩, 2, ⟨⟩⟩, 3, ⟨⟩⟩, 4, ⟨⟩⟩"
by eval
text ‹Using Frequencies›
definition
"list_to_p xs (i::int) = (if i - 1 ≥ 0 ∧ nat (i - 1) < length xs then xs ! nat (i - 1) else 0)"
definition
"ex_p_1 = [10, 30, 15, 25, 20]"
definition
"opt_tree_1 =
⟨
⟨
⟨⟨⟩, 1::int, ⟨⟩⟩,
2,
⟨⟨⟩, 3, ⟨⟩⟩
⟩,
4,
⟨⟨⟩, 5, ⟨⟩⟩
⟩"
lemma "opt_bst' (list_to_p ex_p_1) 1 5 = opt_tree_1"
by eval
text ‹Imperative Implementation›
code_thms min_wpl
definition "min_wpl_test = min_wpl⇩h (λi j. nat(i+j)) 4 0 4"
code_reflect Test functions min_wpl_test
ML ‹Test.min_wpl_test ()›
end