# Theory Optimal_BST

theory Optimal_BST
imports Weighted_Path_Length
(* Author: Tobias Nipkow, based on work by Daniel Somogyi *)

section ‹Optimal BSTs: The Cubic' Algorithm\label{sec:cubic}›

theory Optimal_BST
imports Weighted_Path_Length
begin

subsection ‹Function ‹argmin››

text ‹Function ‹argmin› iterates over a list and returns the rightmost element
that minimizes a given function:›

fun argmin :: "('a ⇒ ('b::linorder)) ⇒ 'a list ⇒ 'a" where
"argmin f (x#xs) =
(if xs = [] then x else
let m = argmin f xs in if f x < f m then x else m)"

text ‹An optimized version that avoids repeated computation of ‹f x›:›

fun argmin2 :: "('a ⇒ ('b::linorder)) ⇒ 'a list ⇒ 'a * 'b" where
"argmin2 f (x#xs) =
(let fx = f x
in if xs = [] then (x, fx)
else let mfm = argmin2 f xs
in if fx < snd mfm then (x,fx) else mfm)"

lemma argmin2_argmin: "xs ≠ [] ⟹ argmin2 f xs = (argmin f xs, f(argmin f xs))"
by (induction xs) (auto simp: Let_def)

lemma argmin_argmin2[code]: "argmin f xs = (if xs = [] then undefined else fst(argmin2 f xs))"
apply(auto simp: argmin2_argmin)
apply (meson argmin.elims list.distinct(1))
done

lemma argmin_forall: "xs ≠ [] ⟹ (⋀x. x∈set xs ⟹ P x) ⟹ P (argmin f xs)"
by(induction xs) (auto simp: Let_def)

lemma argmin_in: "xs ≠ [] ⟹ argmin f xs ∈ set xs"
using argmin_forall[of xs "λx. x∈set xs"] by blast

lemma argmin_Min: "xs ≠ [] ⟹ f (argmin f xs) = Min (f  set xs)"
by(induction xs) (auto simp: min_def intro!: antisym)

lemma argmin_pairs: "xs ≠ [] ⟹
(argmin f xs,f (argmin f xs)) = argmin snd (map (λx. (x,f x)) xs)"
by (induction f xs rule:argmin.induct) (auto, smt snd_conv)

lemma argmin_map: "xs ≠ [] ⟹ argmin c (map f xs) = f(argmin (c o f) xs)"

subsection ‹The Cubic' Algorithm›

text ‹We hide the details of the access frequencies ‹a› and ‹b› by working with an abstract
version of function ‹w› definied above (summing ‹a› and ‹b›). Later we interpret ‹w› accordingly.›

locale Optimal_BST =
fixes w :: "int ⇒ int ⇒ nat"
begin

subsubsection ‹Functions ‹wpl› and ‹min_wpl››

sublocale wpl where w = w .

text ‹Function ‹min_wpl i j› computes the minimal weighted path length of any tree ‹t›
where @{prop"inorder t = [i..j]"}. It simply tries all possible indices between ‹i› and ‹j›
as the root. Thus it implicitly constructs all possible trees.›

declare conj_cong [fundef_cong]
function min_wpl :: "int ⇒ int ⇒ nat" where
"min_wpl i j =
(if i > j then 0
else Min ((λk. min_wpl i (k-1) + min_wpl (k+1) j)  {i..j}) + w i j)"
by auto
termination by (relation "measure (λ(i,j). nat(j-i+1))") auto
declare min_wpl.simps[simp del]

text ‹Note that for efficiency reasons we have pulled ‹+ w i j› out of ‹Min›.
In the lemma below this is reversed because it simplifies the proofs.
Similar optimizations are possible in other functions below.›

lemma min_wpl_simps[simp]:
"i > j ⟹ min_wpl i j = 0"
"i ≤ j ⟹ min_wpl i j =
Min ((λk. min_wpl i (k-1) + min_wpl (k+1) j + w i j)  {i..j})"

lemma upto_split1:
"⟦ i ≤ j;  j ≤ k ⟧ ⟹ [i..k] = [i..j-1] @ [j..k]"
proof (induction j rule: int_ge_induct)
case base thus ?case by (simp add: upto_rec1)
next
case step thus ?case using upto_rec1 upto_rec2 by simp
qed

text‹Function @{const min_wpl} returns a lower bound for all possible BSTs:›

theorem min_wpl_is_optimal:
"inorder t = [i..j] ⟹ min_wpl i j ≤ wpl i j t"
proof(induction i j t rule: wpl.induct)
case 1
thus ?case by(simp add: upto.simps split: if_splits)
next
case (2 i j l k r)
then show ?case
proof cases
assume "i > j" thus ?thesis by(simp)
next
assume [arith]: "¬ i > j"

note inorder = inorder_upto_split[OF "2.prems"]

let ?M = "(λk. min_wpl i (k-1) + min_wpl (k+1) j + w i j)  {i..j}"
let ?w = "min_wpl i (k-1) + min_wpl (k+1) j + w i j"

have aux_min:"Min ?M ≤ ?w"
proof (rule Min_le)
show "finite ?M" by simp
show "?w ∈ ?M" using inorder(3,4) by simp
qed

have "min_wpl i j = Min ?M" by(simp)
also have "... ≤ ?w" by (rule aux_min)
also have "... ≤ wpl i (k-1) l + wpl (k+1) j r + w i j"
using inorder(1,2) "2.IH" by simp
also have "... = wpl i j ⟨l,k,r⟩" by simp
finally show ?thesis .
qed
qed

text ‹Now we show that the lower bound computed by @{const min_wpl}
is the wpl of an optimal tree that can be computed in the same manner.›

subsubsection ‹Function ‹opt_bst››

text‹This is the functional equivalent of the standard cubic imperative algorithm.
Unless it is memoized, the complexity is again exponential.
The pattern of recursion is the same as for @{const min_wpl} but instead of the minimal weight
it computes a tree with the minimal weight:›

function opt_bst :: "int ⇒ int ⇒ int tree" where
"opt_bst i j =
(if i > j then Leaf
else argmin (wpl i j) [⟨opt_bst i (k-1), k, opt_bst (k+1) j⟩. k ← [i..j]])"
by auto
termination by (relation "measure (λ(i,j) . nat(j-i+1))") auto
declare opt_bst.simps[simp del]

corollary opt_bst_simps[simp]:
"i > j ⟹ opt_bst i j = Leaf"
"i ≤ j ⟹ opt_bst i j =
(argmin (wpl i j) [⟨opt_bst i (k-1), k, opt_bst (k+1) j⟩. k ← [i..j]])"
by(auto simp add: opt_bst.simps[of i j])

text ‹As promised, @{const opt_bst} computes a tree with the minimal wpl:›

theorem wpl_opt_bst: "wpl i j (opt_bst i j) = min_wpl i j"
proof(induction i j rule: min_wpl.induct)
case (1 i j)
show ?case
proof cases
assume "i > j"
thus ?thesis by(simp)
next
assume [arith]: "¬ i > j"
let ?ts = "[⟨opt_bst i (k-1), k, opt_bst (k+1) j⟩. k ← [i..j]]"
let ?M = "((λk. min_wpl i (k-1) + min_wpl (k+1) j + w i j)  {i..j})"
have 1: "?ts ≠ []" by (auto simp add: upto.simps)
have "wpl i j (opt_bst i j) = wpl i j (argmin (wpl i j) ?ts)" by simp
also have "… = Min (wpl i j  (set ?ts))"
by(rule argmin_Min[OF 1])
also have "… = Min ?M"
proof (rule arg_cong[where f=Min])
show "wpl i j  (set ?ts) = ?M" using "1.IH"
by (force simp: Bex_def image_iff "1.IH")
qed
also have "… = min_wpl i j" by simp
finally show ?thesis .
qed
qed

corollary opt_bst_is_optimal:
"inorder t = [i..j] ⟹ wpl i j (opt_bst i j) ≤ wpl i j t"

subsubsection ‹Function ‹opt_bst_wpl››

text ‹Function @{const opt_bst} is simplistic because it computes the wpl
of each tree anew rather than returning it with the tree. That is what ‹opt_bst_wpl› does:›

function opt_bst_wpl :: "int ⇒ int ⇒ int tree × nat" where
"opt_bst_wpl i j =
(if i > j then (Leaf, 0)
else argmin snd [let (t1,c1) = opt_bst_wpl i (k-1);
(t2,c2) = opt_bst_wpl (k+1) j
in (⟨t1,k,t2⟩, c1 + c2 + w i j). k ← [i..j]])"
by auto
termination
by (relation "measure (λ(i,j). nat(j-i+1))")(auto)
declare opt_bst_wpl.simps[simp del]

text‹Function @{const opt_bst_wpl} returns an optimal tree and its wpl:›

lemma opt_bst_wpl_eq_pair:
"opt_bst_wpl i j = (opt_bst i j, wpl i j (opt_bst i j))"
proof(induction i j rule: opt_bst_wpl.induct)
case (1 i j)
note [simp] = opt_bst_wpl.simps[of i j]
show ?case
proof cases
assume "i > j" thus ?thesis using "1.prems" by auto
next
assume "¬ i > j"
thus ?thesis by (simp add: argmin_pairs comp_def "1.IH" cong: list.map_cong_simp)
qed
qed

corollary opt_bst_wpl_eq_pair': "opt_bst_wpl i j = (opt_bst i j, min_wpl i j)"
`