Theory Optimal_BST

theory Optimal_BST
imports Weighted_Path_Length
(* Author: Tobias Nipkow, based on work by Daniel Somogyi *)

section ‹Optimal BSTs: The `Cubic' Algorithm\label{sec:cubic}›

theory Optimal_BST
imports Weighted_Path_Length

subsection ‹Function ‹argmin››

text ‹Function ‹argmin› iterates over a list and returns the rightmost element
that minimizes a given function:›

fun argmin :: "('a ⇒ ('b::linorder)) ⇒ 'a list ⇒ 'a" where
"argmin f (x#xs) =
  (if xs = [] then x else
   let m = argmin f xs in if f x < f m then x else m)"

text ‹An optimized version that avoids repeated computation of ‹f x›:›

fun argmin2 :: "('a ⇒ ('b::linorder)) ⇒ 'a list ⇒ 'a * 'b" where
"argmin2 f (x#xs) =
  (let fx = f x
   in if xs = [] then (x, fx)
      else let mfm = argmin2 f xs
           in if fx < snd mfm then (x,fx) else mfm)"

lemma argmin2_argmin: "xs ≠ [] ⟹ argmin2 f xs = (argmin f xs, f(argmin f xs))"
by (induction xs) (auto simp: Let_def)

lemma argmin_argmin2[code]: "argmin f xs = (if xs = [] then undefined else fst(argmin2 f xs))"
apply(auto simp: argmin2_argmin)
apply (meson argmin.elims list.distinct(1))

lemma argmin_forall: "xs ≠ [] ⟹ (⋀x. x∈set xs ⟹ P x) ⟹ P (argmin f xs)"
by(induction xs) (auto simp: Let_def)

lemma argmin_in: "xs ≠ [] ⟹ argmin f xs ∈ set xs"
using argmin_forall[of xs "λx. x∈set xs"] by blast

lemma argmin_Min: "xs ≠ [] ⟹ f (argmin f xs) = Min (f ` set xs)"
by(induction xs) (auto simp: min_def intro!: antisym)

lemma argmin_pairs: "xs ≠ [] ⟹
  (argmin f xs,f (argmin f xs)) = argmin snd (map (λx. (x,f x)) xs)"
by (induction f xs rule:argmin.induct) (auto, smt snd_conv)

lemma argmin_map: "xs ≠ [] ⟹ argmin c (map f xs) = f(argmin (c o f) xs)"
by(induction xs) (simp_all add: Let_def)

subsection ‹The `Cubic' Algorithm›

text ‹We hide the details of the access frequencies ‹a› and ‹b› by working with an abstract
version of function ‹w› definied above (summing ‹a› and ‹b›). Later we interpret ‹w› accordingly.›

locale Optimal_BST =
fixes w :: "int ⇒ int ⇒ nat"

subsubsection ‹Functions ‹wpl› and ‹min_wpl››

sublocale wpl where w = w .

text ‹Function ‹min_wpl i j› computes the minimal weighted path length of any tree ‹t›
where @{prop"inorder t = [i..j]"}. It simply tries all possible indices between ‹i› and ‹j›
as the root. Thus it implicitly constructs all possible trees.›

declare conj_cong [fundef_cong]
function min_wpl :: "int ⇒ int ⇒ nat" where
"min_wpl i j =
  (if i > j then 0
   else Min ((λk. min_wpl i (k-1) + min_wpl (k+1) j) ` {i..j}) + w i j)"
by auto
termination by (relation "measure (λ(i,j). nat(j-i+1))") auto
declare min_wpl.simps[simp del]

text ‹Note that for efficiency reasons we have pulled ‹+ w i j› out of ‹Min›.
In the lemma below this is reversed because it simplifies the proofs.
Similar optimizations are possible in other functions below.›

lemma min_wpl_simps[simp]:
  "i > j ⟹ min_wpl i j = 0"
  "i ≤ j ⟹ min_wpl i j =
     Min ((λk. min_wpl i (k-1) + min_wpl (k+1) j + w i j) ` {i..j})"
by(auto simp add: min_wpl.simps[of i j] Min_add_commute)

lemma upto_split1: 
  "⟦ i ≤ j;  j ≤ k ⟧ ⟹ [i..k] = [i..j-1] @ [j..k]"
proof (induction j rule: int_ge_induct)
  case base thus ?case by (simp add: upto_rec1)
  case step thus ?case using upto_rec1 upto_rec2 by simp

text‹Function @{const min_wpl} returns a lower bound for all possible BSTs:›

theorem min_wpl_is_optimal:
  "inorder t = [i..j] ⟹ min_wpl i j ≤ wpl i j t"
proof(induction i j t rule: wpl.induct)
  case 1
  thus ?case by(simp add: upto.simps split: if_splits)
  case (2 i j l k r)
  then show ?case 
  proof cases
    assume "i > j" thus ?thesis by(simp)
    assume [arith]: "¬ i > j"

    note inorder = inorder_upto_split[OF "2.prems"]
    let ?M = "(λk. min_wpl i (k-1) + min_wpl (k+1) j + w i j) ` {i..j}"
    let ?w = "min_wpl i (k-1) + min_wpl (k+1) j + w i j"
    have aux_min:"Min ?M ≤ ?w"
    proof (rule Min_le)
      show "finite ?M" by simp
      show "?w ∈ ?M" using inorder(3,4) by simp

    have "min_wpl i j = Min ?M" by(simp)
    also have "... ≤ ?w" by (rule aux_min)    
    also have "... ≤ wpl i (k-1) l + wpl (k+1) j r + w i j"
      using inorder(1,2) "2.IH" by simp
    also have "... = wpl i j ⟨l,k,r⟩" by simp
    finally show ?thesis .

text ‹Now we show that the lower bound computed by @{const min_wpl}
is the wpl of an optimal tree that can be computed in the same manner.›

subsubsection ‹Function ‹opt_bst››

text‹This is the functional equivalent of the standard cubic imperative algorithm.
Unless it is memoized, the complexity is again exponential.
The pattern of recursion is the same as for @{const min_wpl} but instead of the minimal weight
it computes a tree with the minimal weight:›

function opt_bst :: "int ⇒ int ⇒ int tree" where
"opt_bst i j =
  (if i > j then Leaf
   else argmin (wpl i j) [⟨opt_bst i (k-1), k, opt_bst (k+1) j⟩. k ← [i..j]])"
by auto
termination by (relation "measure (λ(i,j) . nat(j-i+1))") auto
declare opt_bst.simps[simp del]

corollary opt_bst_simps[simp]:
  "i > j ⟹ opt_bst i j = Leaf"
  "i ≤ j ⟹ opt_bst i j =
     (argmin (wpl i j) [⟨opt_bst i (k-1), k, opt_bst (k+1) j⟩. k ← [i..j]])"
by(auto simp add: opt_bst.simps[of i j])

text ‹As promised, @{const opt_bst} computes a tree with the minimal wpl:›

theorem wpl_opt_bst: "wpl i j (opt_bst i j) = min_wpl i j"
proof(induction i j rule: min_wpl.induct)
  case (1 i j)
  show ?case
  proof cases
    assume "i > j" 
    thus ?thesis by(simp)
    assume [arith]: "¬ i > j"
    let ?ts = "[⟨opt_bst i (k-1), k, opt_bst (k+1) j⟩. k ← [i..j]]"
    let ?M = "((λk. min_wpl i (k-1) + min_wpl (k+1) j + w i j) ` {i..j})"
    have 1: "?ts ≠ []" by (auto simp add: upto.simps)
    have "wpl i j (opt_bst i j) = wpl i j (argmin (wpl i j) ?ts)" by simp
    also have "… = Min (wpl i j ` (set ?ts))"
      by(rule argmin_Min[OF 1])
    also have "… = Min ?M"
    proof (rule arg_cong[where f=Min])
      show "wpl i j ` (set ?ts) = ?M" using "1.IH"
        by (force simp: Bex_def image_iff "1.IH")
    also have "… = min_wpl i j" by simp
    finally show ?thesis .

corollary opt_bst_is_optimal:
  "inorder t = [i..j] ⟹ wpl i j (opt_bst i j) ≤ wpl i j t"
by (simp add: min_wpl_is_optimal wpl_opt_bst)

subsubsection ‹Function ‹opt_bst_wpl››

text ‹Function @{const opt_bst} is simplistic because it computes the wpl
of each tree anew rather than returning it with the tree. That is what ‹opt_bst_wpl› does:›

function opt_bst_wpl :: "int ⇒ int ⇒ int tree × nat" where
"opt_bst_wpl i j = 
  (if i > j then (Leaf, 0)
   else argmin snd [let (t1,c1) = opt_bst_wpl i (k-1);
                        (t2,c2) = opt_bst_wpl (k+1) j
                     in (⟨t1,k,t2⟩, c1 + c2 + w i j). k ← [i..j]])"
by auto
  by (relation "measure (λ(i,j). nat(j-i+1))")(auto)
declare opt_bst_wpl.simps[simp del]

text‹Function @{const opt_bst_wpl} returns an optimal tree and its wpl:›

lemma opt_bst_wpl_eq_pair:
  "opt_bst_wpl i j = (opt_bst i j, wpl i j (opt_bst i j))"
proof(induction i j rule: opt_bst_wpl.induct)
  case (1 i j)
  note [simp] = opt_bst_wpl.simps[of i j]
  show ?case 
  proof cases
    assume "i > j" thus ?thesis using "1.prems" by auto
    assume "¬ i > j"
    thus ?thesis by (simp add: argmin_pairs comp_def "1.IH" cong: list.map_cong_simp)

corollary opt_bst_wpl_eq_pair': "opt_bst_wpl i j = (opt_bst i j, min_wpl i j)"
by (simp add: opt_bst_wpl_eq_pair wpl_opt_bst)

end (* locale Optimal_BST *)