# Theory Optimal_BST2

(* Author: Tobias Nipkow, based on work by Daniel Somogyi *)

theory Optimal_BST2
imports
Optimal_BST
begin

text ‹Knuth presented an optimization of the previously known cubic dynamic programming algorithm
to a quadratic one. A simplified proof of this optimization was found by Yao~\cite{Yao80}.
Mehlhorn follows Yao closely. The core of the optimization argument is given abstractly
in theory @{theory Optimal_BST.Quadrilateral_Inequality}. In addition we first need to establish some more
properties of @{const argmin}.›

text‹An index-based specification of @{const argmin}
expressing that the last minimal list-element is picked:›

lemma argmin_takes_last: "xs ≠ [] ⟹
argmin f xs = xs ! Max {i. i < length xs ∧ (∀x ∈ set xs. f(xs!i) ≤ f x)}"
(is "_ ⟹ _ = _ ! Max (?M xs)")
proof(induction xs)
case (Cons x xs)
show ?case
proof cases
assume "xs = []" thus ?thesis by(simp cong: conj_cong)
next
assume 0: "xs ≠ []"
show ?thesis
proof cases
assume 1: "∀u ∈ set xs. f x < f u"
hence 2: "?M (x#xs) = {0}"
by (fastforce simp: not_less[symmetric] less_Suc_eq_0_disj)
have "f x < f (argmin f xs)" using 0 1 argmin_Min[of xs f] by auto
with 1 Cons.prems show ?case by(subst 2) (auto simp: Let_def)
next
assume 1: "¬(∀u ∈ set xs. f x < f u)"
have 2: "¬ f x < f (argmin f xs)" using 1 argmin_Min[of xs f] 0 by auto
have "argmin f xs : {u ∈ set xs. ∀x∈set xs. f u ≤ f x}"
using 0 argmin_Min[of xs f] by (simp add: argmin_in)
hence "{u ∈ set xs. ∀x∈set xs. f u ≤ f x} ≠ {}" by blast
hence ne: "?M xs ≠ {}" by(auto simp: in_set_conv_nth)
have "Max (?M (x#xs)) = Max (?M xs) + 1"
proof (cases "∃u ∈ set xs. f u < f x")
case True
hence "?M (x#xs) = (+) 1  ?M xs"
by (auto simp: nth_Cons' image_def less_Suc_eq_0_disj)
thus ?thesis
using mono_Max_commute[of "(+) 1" "?M xs"] ne by (auto simp: mono_def)
next
case False
hence *: "?M (x#xs) = insert 0 ((+) 1  ?M xs)"
using 1 by (auto simp: nth_Cons' image_def less_Suc_eq_0_disj)
hence "Max (?M (x#xs)) = Max ((+) 1  ?M xs)" using Max_insert ne by simp
thus ?thesis using mono_Max_commute[of "(+) 1" "?M xs"] ne by (auto simp: mono_def)
qed
with Cons 2 0 show ?case by auto
qed
qed
qed simp

lemma Min_ex: "⟦ finite F; F ≠ {} ⟧ ⟹ ∃m ∈ F. ∀n ∈ F. m ≤ (n::_::linorder)"
using eq_Min_iff[of F "Min F"] by (fastforce)

text ‹A consequence of @{thm [source] argmin_takes_last}:›

lemma argmin_Max_Args_min_on: assumes [arith]: "i ≤ j"
shows "argmin f [i..j] = Max (Args_min_on f {i..j})"
proof -
let ?min = "λk. ∀n ∈ {i..j}. f([i..j]!k) ≤ f n"
let ?M = "{k. k < nat(j-i+1) ∧ ?min k}"
let ?Max = "Max ?M"
have "?M ≠ {}" using Min_ex[of "f  {i..j}"]
apply(rule_tac x="nat (m-i)" in exI)
by simp
hence "?Max < nat(j-i+1)" by(simp add: nth_upto)
hence 1: "i + int ?Max ≤ j" by linarith
have "argmin f [i..j] = [i..j] ! ?Max"
using argmin_takes_last[of "[i..j]" f] by simp
also have "… = i + int ?Max" using 1 by(simp add: nth_upto)
also have "… = i + Max(int  {k. k < nat(j-i+1) ∧ ?min k})"
using ‹?M ≠ {}› by (simp add: monoI mono_Max_commute)
also have "… = Max ((λx. i + x)  (int  {k. k < nat(j-i+1) ∧ ?min k}))"
using ‹?M ≠ {}› by (simp add: monoI mono_Max_commute)
also have "(λx. i + x)  (int  {k. k < nat(j-i+1) ∧ ?min k}) =
{k. is_arg_min_on f {i..j} k}"
apply(auto simp: is_arg_min_on_def Ball_def nth_upto image_def cong: conj_cong)
apply(rule_tac x = "x-i" in exI)
apply auto
apply(rule_tac x = "nat(x-i)" in exI)
by auto
finally show ?thesis by(simp add: Args_min_simps)
qed

text ‹As a consequence of @{thm [source] argmin_Max_Args_min_on} the following lemma
allows us to justify the restriction of the index range of @{const argmin}
used below in the optimized (quadratic) algorithm.›

lemma argmin_red_ivl:
assumes "i ≤ i'" "argmin f [i..j] ∈ {i'..j'}" "j' ≤ j"
shows "argmin f [i'..j'] = argmin f [i..j]"
proof -
have ij[arith]: "i ≤ j" using assms by simp
have ij'[arith]: "i' ≤ j'" using assms by simp
from Min_ex[of "f  {i..j}"] have m: "∃m ∈ {i..j}. ∀n ∈ {i..j}. f m ≤ f n" by auto
note * = argmin_Max_Args_min_on[OF ij, of f]
note ** = argmin_Max_Args_min_on[OF ij', of f]
let ?M = "Args_min_on f {i..j}"
let ?M' = "Args_min_on f {i'..j'}"
have M: "finite ?M" "?M ≠ {}"
using m by (fastforce simp: Args_min_simps simp del: atLeastAtMost_iff)+
have "Max ?M ∈ ?M" by (simp add: M)
have "Max ?M ∈ ?M'" using Max_in[OF M] assms * by(auto simp: Args_min_simps)
have "?M' ⊆ ?M" using ‹Max ?M ∈ ?M› ‹Max ?M ∈ ?M'› assms(1,3)
have "finite ?M'" using M(1) ‹?M' ⊆ ?M› infinite_super by blast
hence "Max ?M ≤ Max ?M'" by (simp add: ‹Max ?M ∈ ?M'›)
have "Max ?M' ≤ Max ?M" using Max.subset_imp[OF ‹?M' ⊆ ?M› _ M(1)] ‹Max ?M ∈ ?M'› by auto
thus ?thesis using * ** ‹Max ?M ≤ Max ?M'› by force
qed

fun root:: "'a tree ⇒ 'a" where
"root ⟨_,r,_⟩ = r"

text ‹Now we can formulate and verify the improved algorithm. This requires two
assumptions on the weight function ‹w›.›

locale Optimal_BST2 = Optimal_BST +
assumes monotone_w: "⟦i ≤ i'; i' ≤ j; j ≤ j'⟧ ⟹ w i' j ≤ w i j'"
assumes QI_w: "⟦i ≤ i'; i' ≤ j; j ≤ j'⟧⟹ w i j + w i' j'≤ w i' j + w i j'"
begin

text‹When finding an optimal tree for @{term "[i..j]"} the optimization consists in reducing
the search for the root from @{term "[i..j]"} to
@{term "[root (opt_bst2 i (j-1)) .. root (opt_bst2 (i+1) j)]"}:›

function opt_bst2 :: "int ⇒ int ⇒ int tree" where
"opt_bst2 i j =
(if i > j then Leaf else
if i = j then Node Leaf i Leaf else
let left = root (opt_bst2 i (j-1)) in
let right= root (opt_bst2 (i+1) j) in
argmin (wpl i j) [⟨opt_bst2 i (k-1), k, opt_bst2 (k+1) j⟩. k ← [left..right]])"
by auto

text ‹The termination of @{const opt_bst2} is not completely obvious.
We first need to establish some functional properties of the terminating computations.
We start by showing that the root of the returned tree is always between ‹left› and ‹right›.
This is essentially equivalent to proving that ‹left ≤ right›
because otherwise @{const argmin} is applied to ‹[]›, which is undefined.›

lemma left_le_right:
"opt_bst2_dom(i,j) ⟹
(i=j ⟶ root(opt_bst2 i j) = i) ∧
(i<j ⟶ root(opt_bst2 i j) ∈ {root(opt_bst2 i (j-1)) .. root(opt_bst2 (i+1) j)})"
proof (induction rule: opt_bst2.pinduct)
case (1 i j)
let ?left = "root (opt_bst2 i (j-1))"
let ?right = "root (opt_bst2 (i+1) j)"
let ?f ="(λk. ⟨opt_bst2 i (k - 1), k, opt_bst2 (k + 1) j⟩)"
show ?case
proof cases
assume "i > j" thus ?thesis by auto
next
assume [arith]: "¬ i > j"
show ?thesis
proof cases
assume "i = j" thus ?thesis using opt_bst2.psimps[OF "1.hyps"] by simp
next
assume [arith]: "i ≠ j"
have left_le_right: "?left ≤ ?right"
proof cases
assume [arith]: "i = j-1"
have l: "root (opt_bst2 i (j - 1)) = i" using "1.IH"(1) by auto
have r: "root (opt_bst2 (i+1) j) = j" using "1.IH"(2) by auto
show ?thesis using l r by auto
next
assume "¬ i = j-1"
hence[arith]: "i < j-1" by arith
have "?left ≤ root (opt_bst2 (i + 1) (j - 1))" using "1.IH"(1) by auto
also have "... ≤ root (opt_bst2 (i+1) j)" using "1.IH"(2) by auto
finally have "?left ≤ ?right" .
thus ?thesis by auto
qed

let ?lambda = "λt. root t ∈ {?left .. ?right}"
show ?thesis
using argmin_forall[of ‹map ?f [?left..?right]› ‹?lambda› ‹wpl i j›] left_le_right
by (fastforce simp add: opt_bst2.psimps[OF "1.hyps"])
qed
qed
qed

text ‹Now we can bound the result of @{const opt_bst2} easily:›

lemma root_opt_bst2_bound:
"opt_bst2_dom (i,j) ⟹ i ≤ j ⟹ root (opt_bst2 i j) ∈ {i..j}"
proof(induction i j rule:opt_bst2.pinduct)
case (1 i j)
show ?case using  "1.prems" "1.IH"(1,2) left_le_right[OF "1.hyps"] by force
qed

text ‹Now termination follows easily:›

lemma opt_bst2_dom: "∀args. opt_bst2_dom args"
by (relation "measure (λ(i,j). nat (j-i+1))") (auto dest: root_opt_bst2_bound)

termination by(rule opt_bst2_dom)

declare opt_bst2.simps[simp del]

abbreviation "min_wpl3 i j k ≡ min_wpl i (k-1) + min_wpl (k+1) j + w i j"

text‹The correctness proof \cite{Yao} is based on a general theory of quatrilateral inequalities'
developed in locale QI that we now instantiate:›

interpretation QI
where
c = "λi j. min_wpl (i+1) j"
and c_k = "λi j. min_wpl3 (i+1) j"
and w = "λi j. w (i+1) j"
proof (standard, goal_cases)
case (1 i i' j j')
thus ?case using QI_w by simp
next
case (2 i i' j j')
thus ?case using monotone_w by simp
next
case (3 i j)
thus ?case by simp
next
case (4 i j k)
show ?case by simp
qed

lemma K_argmin: "i < j ⟹ K i j = argmin (min_wpl3 (i+1) j) [i+1..j]"

theorem opt_bst2_opt_bst: "opt_bst2 i j = opt_bst i j"
proof (induction i j rule: opt_bst2.induct)
case (1 i j)
show ?case
proof cases
assume "i ≥ j" thus ?thesis by(cases "i=j") (auto simp: opt_bst2.simps)
next
assume [arith]: "¬ i ≥ j"
let ?c = "λk. min_wpl i (k-1) + min_wpl (k+1) j + w i j"
let ?opt = "λk. ⟨opt_bst i (k-1), k, opt_bst (k+1) j⟩"
have 1: "i ≤ K (i-1) (j-1)"
using argmin_in[of "[i..j-1]"] by(auto simp add: K_argmin)
have 2: "argmin ?c [i..j] ∈ {K (i-1) (j-1)..K i j}"
using lemma_3[of "i-1" "j-1"] by(simp add: K_argmin)
have 3: "K i j ≤ j" using argmin_in[of "[i+1..j]"] by(auto simp: K_argmin)
have *: "argmin ?c [K (i-1) (j-1)..K i j] = argmin ?c [i..j]"
by(rule argmin_red_ivl[OF 1 2 3])
have "opt_bst2 i j =
argmin (wpl i j) (map ?opt [root(opt_bst2 i (j-1))..root(opt_bst2 (i+1) j)])"
using [[simp_depth_limit=3]]
by(simp add: "1.IH"(3,4)[OF _ _ refl refl] opt_bst2.simps[of i j] cong: list.map_cong_simp)
also have "… = argmin (wpl i j) (map ?opt [root(opt_bst i (j-1))..root(opt_bst (i+1) j)])"
also have "root(opt_bst i (j-1)) = K (i-1) (j-1)"
by(simp add: argmin_map wpl_opt_bst comp_def K_argmin)
also have "root(opt_bst (i+1) j) = K i j"
by(simp add: argmin_map wpl_opt_bst comp_def K_argmin)
also have "argmin (wpl i j) (map ?opt [K (i - 1) (j - 1)..K i j])
= ?opt (argmin (wpl i j o ?opt) [K (i-1) (j-1)..K i j])"
using lemma_3[of "i-1" "j-1"] by(simp add: argmin_map)
also have "… = ?opt (argmin ?c [K (i-1) (j-1)..K i j])"
also have "… = ?opt(argmin ?c [i..j])"
also have "… = ?opt(argmin (wpl i j o ?opt) [i..j])"
also have "… = argmin (wpl i j) (map ?opt [i..j])"
also have "… = opt_bst i j"
by simp
finally show ?thesis .
qed
qed

corollary opt_bst2_is_optimal: "wpl i j (opt_bst2 i j) = min_wpl i j"

function opt_bst_wpl2 :: "int ⇒ int ⇒ int tree × nat" where
"opt_bst_wpl2 i j =
(if i > j then (Leaf,0) else
if i = j then (Node Leaf i Leaf, w i i) else
let l = root(fst(opt_bst_wpl2 i (j-1)));
r = root(fst(opt_bst_wpl2 (i+1) j)) in
argmin snd
[let (tl,wl) = opt_bst_wpl2 i (k-1); (tr,wr) = opt_bst_wpl2 (k+1) j
in (⟨tl, k, tr⟩, wl + wr + w i j) . k ← [l..r]])"
by auto

lemma left_le_right2:
"opt_bst_wpl2_dom(i,j) ⟹
(i=j ⟶ root(fst(opt_bst_wpl2 i j)) = i) ∧
(i<j ⟶ root(fst(opt_bst_wpl2 i j)) ∈
{root(fst(opt_bst_wpl2 i (j-1))) .. root(fst(opt_bst_wpl2 (i+1) j))})"
proof (induction rule: opt_bst_wpl2.pinduct)
case (1 i j)
let ?l = "root (fst(opt_bst_wpl2 i (j-1)))"
let ?r = "root (fst(opt_bst_wpl2 (i+1) j))"
let ?f ="λk. let (tl,wl) = opt_bst_wpl2 i (k-1); (tr,wr) = opt_bst_wpl2 (k+1) j
in (⟨tl, k, tr⟩, wl + wr + w i j)"
show ?case
proof cases
assume "i > j" thus ?thesis by auto
next
assume [arith]: "¬ i > j"
show ?thesis
proof cases
assume "i = j" thus ?thesis using opt_bst_wpl2.psimps[OF "1.hyps"] by simp
next
assume [arith]: "i ≠ j"
have left_le_right: "?l ≤ ?r"
proof cases
assume [arith]: "i = j-1"
have l: "root (fst(opt_bst_wpl2 i (j-1))) = i" using "1.IH"(1) by auto
have r: "root (fst(opt_bst_wpl2 (i+1) j)) = j" using ‹i = j-1› "1.IH"(2)
by auto
show ?thesis using l r by auto
next
assume "¬ i = j-1"
hence[arith]: "i < j-1" by arith
have "?l ≤ root (fst(opt_bst_wpl2 (i+1) (j-1)))" using "1.IH"(1) by auto
also have "... ≤ root (fst(opt_bst_wpl2 (i+1) j))" using "1.IH"(2) by auto
finally have "?l ≤ ?r" .
thus ?thesis by auto
qed

let ?P = "λt. root (fst t) ∈ {?l .. ?r}"
show ?thesis
using argmin_forall[of ‹map ?f [?l..?r]› ?P snd] left_le_right
by (fastforce simp add: opt_bst_wpl2.psimps[OF "1.hyps"] split: prod.splits)
qed
qed
qed

text ‹Now we can bound the result of @{const opt_bst_wpl2}:›

lemma root_opt_bst_wpl2_bound:
"opt_bst_wpl2_dom (i,j) ⟹ i ≤ j ⟹ root (fst(opt_bst_wpl2 i j)) ∈ {i..j}"
proof(induction i j rule:opt_bst_wpl2.pinduct)
case (1 i j)
show ?case using  "1.prems" "1.IH"(1)  "1.IH"(2)[OF _ _ refl] left_le_right2[OF "1.hyps"]
by fastforce
qed

text ‹Now termination follows easily:›

lemma opt_bst_wpl2_dom: "∀args. opt_bst_wpl2_dom args"
by (relation "measure (λ(i,j). nat (j-i+1))") (auto dest: root_opt_bst_wpl2_bound)

termination by(rule opt_bst_wpl2_dom)

declare opt_bst_wpl2.simps[simp del]

lemma opt_bst_wpl2_eq_pair:
"opt_bst_wpl2 i j = (opt_bst2 i j, wpl i j (opt_bst2 i j))"
proof(induction i j rule: opt_bst_wpl2.induct)
case (1 i j)
note [simp] = opt_bst2.simps[of i j] opt_bst_wpl2.simps[of i j] (*opt_bst2_opt_bst*)
show ?case
proof cases
assume "i > j" thus ?thesis using "1.prems" by (simp)
next
assume [arith]: "¬ i > j"
show ?case
proof cases
assume [arith]: "i = j"
show ?thesis by(simp) (simp add: ‹i = j›)
next
assume [arith]: "i ≠ j"
let ?l = "root (opt_bst2 i (j-1))" let ?r = "root (opt_bst2 (i+1) j)"
have *: "?l ≤ ?r"
using left_le_right[of i j] by (fastforce simp: opt_bst2_opt_bst opt_bst2_dom)
let ?f = "λk. case opt_bst_wpl2 i (k-1) of
(l,wl) ⇒ case opt_bst_wpl2 (k+1) j of
(r,wr) ⇒ (⟨l,k,r⟩, wl + wr + w i j)"
let ?g = "λk. (⟨opt_bst2 i (k-1), k, opt_bst2 (k+1) j⟩,
wpl i (k-1) (opt_bst2 i (k-1)) + wpl (k+1) j (opt_bst2 (k+1) j) + w i j)"
have fg: "?f k = ?g k" if k: "k ∈ {?l..?r}" for k
proof -
have 1: "opt_bst_wpl2 i (k-1) = (opt_bst2 i (k-1), wpl i (k-1) (opt_bst2 i (k-1)))"
using k "1.IH"(3) by(simp add: "1.IH"(1,2))
have 2: "opt_bst_wpl2 (k+1) j = (opt_bst2 (k+1) j, wpl (k+1) j (opt_bst2 (k+1) j))"
using 1 k "1.IH"(4) by(simp add: "1.IH"(1,2))
show ?thesis using 1 2 by(simp)
qed
have "opt_bst_wpl2 i j =
argmin snd (map ?f [root(fst(opt_bst_wpl2 i (j-1)))..root(fst(opt_bst_wpl2 (i+1) j))])"
by(simp)
also have "… = argmin snd (map ?f [?l..?r])"
also have "… = argmin snd (map ?g [?l..?r])"
using fg by (simp cong: list.map_cong_simp)
also have "… = (opt_bst2 i j, wpl i j (opt_bst2 i j))" using *