Theory OptBST

subsection ‹Optimal Binary Search Trees›

text ‹
The material presented in this section just contains a simple and non-optimal version
(cubic instead of quadratic in the number of keys).
It can now be viewed to be superseded by the AFP entry Optimal_BST›.
It is kept here as a more easily understandable example and for archival purposes.
›

theory OptBST
imports
  "HOL-Library.Tree"
  "HOL-Library.Code_Target_Numeral"
  "../state_monad/State_Main" 
  "../heap_monad/Heap_Default"
  Example_Misc
  "HOL-Library.Product_Lexorder"
  "HOL-Library.RBT_Mapping"
begin

subsubsection ‹Function argmin›

text ‹Function argmin› iterates over a list and returns the rightmost element
that minimizes a given function:›

fun argmin :: "('a  ('b::linorder))  'a list  'a" where
"argmin f (x#xs) =
  (if xs = [] then x else
   let m = argmin f xs in if f x < f m then x else m)"

text ‹Note that @{term arg_min_list} is similar but returns the leftmost element.›

lemma argmin_forall: "xs  []  (x. xset xs  P x)  P (argmin f xs)"
by(induction xs) (auto simp: Let_def)

lemma argmin_Min: "xs  []  f (argmin f xs) = Min (f ` set xs)"
by(induction xs) (auto simp: min_def intro!: antisym)


subsubsection ‹Misc›

lemma upto_join: " i  j; j  k   [i..j-1] @ j # [j+1..k] = [i..k]"
  using upto_rec1 upto_split1 by auto

lemma atLeastAtMost_split:
  "{i..j} = {i..k}  {k+1..j}" if "i  k" "k  j" for i j k :: int
  using that by auto

lemma atLeastAtMost_split_insert:
  "{i..k} = insert k {i..k-1}" if "k  i" for i :: int
  using that by auto

subsubsection ‹Definitions›

context
fixes W :: "int  int  nat"
begin

fun wpl :: "int  int  int tree  nat" where
   "wpl i j Leaf = 0"
 | "wpl i j (Node l k r) = wpl i (k-1) l + wpl (k+1) j r + W i j"

function min_wpl :: "int  int  nat" where
"min_wpl i j =
  (if i > j then 0
   else min_list (map (λk. min_wpl i (k-1) + min_wpl (k+1) j + W i j) [i..j]))"
  by auto
termination by (relation "measure (λ(i,j) . nat(j-i+1))") auto
declare min_wpl.simps[simp del]

function opt_bst :: "int  int  int tree" where
"opt_bst i j =
  (if i > j then Leaf else argmin (wpl i j) [opt_bst i (k-1), k, opt_bst (k+1) j. k  [i..j]])"
  by auto
termination by (relation "measure (λ(i,j) . nat(j-i+1))") auto
declare opt_bst.simps[simp del]


subsubsection ‹Functional Memoization›

context
  fixes n :: nat
begin

context fixes
  mem :: "nat option array"
begin

memoize_fun min_wplT: min_wpl
  with_memory dp_consistency_heap_default where bound = "Bound (0, 0) (int n, int n)" and mem="mem"
  monadifies (heap) min_wpl.simps

context includes heap_monad_syntax begin
thm min_wplT'.simps min_wplT_def
end

memoize_correct
  by memoize_prover

lemmas memoized_empty = min_wplT.memoized_empty

end (* Fixed array *)

context
  includes heap_monad_syntax
  notes [simp del] = min_wplT'.simps
begin

definition "min_wplh  λ i j. Heap_Monad.bind (mem_empty (n * n)) (λ mem. min_wplT' mem i j)"

lemma min_wpl_heap:
  "min_wpl i j = result_of (min_wplh i j) Heap.empty"
  unfolding min_wplh_def
  using memoized_empty[of _ "λ m. λ (a, b). min_wplT' m a b" "(i, j)", OF min_wplT.crel]
  by (simp add: index_size_defs)

end

end (* Bound *)

context includes state_monad_syntax begin

memoize_fun min_wplm: min_wpl with_memory dp_consistency_mapping monadifies (state) min_wpl.simps
thm min_wplm'.simps

memoize_correct
  by memoize_prover
print_theorems
lemmas [code] = min_wplm.memoized_correct

memoize_fun opt_bstm: opt_bst with_memory dp_consistency_mapping monadifies (state) opt_bst.simps
thm opt_bstm'.simps

memoize_correct
  by memoize_prover
print_theorems
lemmas [code] = opt_bstm.memoized_correct

end


subsubsection ‹Correctness Proof›

lemma min_wpl_minimal:
  "inorder t = [i..j]  min_wpl i j  wpl i j t"
proof(induction i j t rule: wpl.induct)
  case (1 i j)
  then show ?case by (simp add: min_wpl.simps)
next
  case (2 i j l k r)
  then show ?case 
  proof cases
    assume "i > j" thus ?thesis by(simp add: min_wpl.simps)
  next
    assume [arith]: "¬ i > j"
    have kk_ij: "kset[i..j]" using 2 
        by (metis set_inorder tree.set_intros(2))
        
    let ?M = "((λk. min_wpl i (k-1) + min_wpl (k+1) j + W i j) ` {i..j})"
    let ?w = "min_wpl i (k-1) + min_wpl (k+1) j + W i j"
 
    have aux_min:"Min ?M  ?w"
    proof (rule Min_le)
      show "finite ?M" by simp
      show "?w  ?M" using kk_ij by auto
    qed

    have"inorder l,k,r = inorder l @k#inorder r" by auto
    from this have C:"[i..j] = inorder l @ k#inorder r" using 2 by auto
    have D: "[i..j] = [i..k-1]@k#[k+1..j]" using kk_ij upto_rec1 upto_split1
      by (metis atLeastAtMost_iff set_upto) 

    have l_inorder: "inorder l = [i..k-1]"
      by (smt C D append_Cons_eq_iff atLeastAtMost_iff set_upto)
    have r_inorder: "inorder r = [k+1..j]" 
      by (smt C D append_Cons_eq_iff atLeastAtMost_iff set_upto)

    have "min_wpl i j = Min ?M" by (simp add: min_wpl.simps min_list_Min)
    also have "...  ?w" by (rule aux_min)    
    also have "...  wpl i (k-1) l + wpl (k+1) j r + W i j" using l_inorder r_inorder "2.IH" by simp
    also have "... = wpl i j l,k,r" by simp
    finally show ?thesis .
  qed
qed

lemma opt_bst_correct: "inorder (opt_bst i j) = [i..j]"
  by (induction i j rule: opt_bst.induct)
     (clarsimp simp: opt_bst.simps upto_join | rule argmin_forall)+

lemma wpl_opt_bst: "wpl i j (opt_bst i j) = min_wpl i j"
proof(induction i j rule: min_wpl.induct)
  case (1 i j)
  show ?case
  proof cases
    assume "i > j" thus ?thesis by(simp add: min_wpl.simps opt_bst.simps)
  next
    assume *[arith]: "¬ i > j"
    let ?ts = "[opt_bst i (k-1), k, opt_bst (k+1) j. k <- [i..j]]"
    let ?M = "((λk. min_wpl i (k-1) + min_wpl (k+1) j + W i j) ` {i..j})"
    have "?ts  []" by (auto simp add: upto.simps)
    have "wpl i j (opt_bst i j) = wpl i j (argmin (wpl i j) ?ts)" by (simp add: opt_bst.simps)
    also have " = Min (wpl i j ` (set ?ts))" by (rule argmin_Min[OF ?ts  []])
    also have " = Min ?M"
    proof (rule arg_cong[where f=Min])
      show "wpl i j ` (set ?ts) = ?M"
        by (fastforce simp: Bex_def image_iff 1[OF *])
    qed
    also have " = min_wpl i j" by (simp add: min_wpl.simps min_list_Min)
    finally show ?thesis .
  qed
qed

lemma opt_bst_is_optimal:
  "inorder t = [i..j]  wpl i j (opt_bst i j)  wpl i j t"
  by (simp add: min_wpl_minimal wpl_opt_bst)

end (* Weight function *)

subsubsection ‹Access Frequencies›
text ‹Usually, the problem is phrased in terms of access frequencies.
We now give an interpretation of @{term wpl} in this view and show that we have actually computed
the right thing.›

context
  ― ‹We are given a range [i..j]› of integer keys with access frequencies p›.
  These can be thought of as a probability distribution but are not required to represent one.
  This model assumes that the tree will contain all keys in the range [i..j]›.
  See Optimal_BST› for a model with missing keys.
  ›
  fixes p :: "int  nat"
begin

― ‹The ‹weighted path path length› (or cost›) of a tree.›
fun cost :: "int tree  nat" where
  "cost Leaf = 0"
| "cost (Node l k r) = sum p (set_tree l) + cost l + p k + cost r + sum p (set_tree r)"

― ‹Deriving a weight function from p›.›
qualified definition W where
  "W i j = sum p {i..j}"

― ‹We will use this later for computing W› efficiently.›
lemma W_rec:
  "W i j = (if j  i then W i (j - 1) + p j else 0)"
  unfolding W_def by (simp add: atLeastAtMost_split_insert)

― ‹The weight function correctly implements costs.›
lemma inorder_wpl_correct:
  "inorder t = [i..j]  wpl W i j t = cost t"
proof (induction t arbitrary: i j)
case Leaf
  show ?case
    by simp
next
  case (Node l k r)
  from inorder l, k, r = [i..j] have *: "i  k" "k  j"
     by - (simp, metis atLeastAtMost_iff in_set_conv_decomp set_upto)+
   moreover from i  k k  j have "inorder l = [i..k-1]" "inorder r = [k+1..j]"
    using inorder l, k, r = [i..j][symmetric] by (simp add: upto_split3 append_Cons_eq_iff)+
  ultimately show ?case
    by (simp add: Node.IH, subst W_def, subst atLeastAtMost_split)
       (simp add: sum.union_disjoint atLeastAtMost_split_insert flip: set_inorder)+
qed

text ‹The optimal binary search tree has minimal cost among all binary search trees.›
lemma opt_bst_has_optimal_cost:
  "inorder t = [i..j]  cost (opt_bst W i j)  cost t"
  using inorder_wpl_correct opt_bst_is_optimal opt_bst_correct by metis

text ‹
  The function @{term min_wpl} correctly computes the minimal cost among all binary search trees:
   Its cost is a lower bound for the cost of all binary search trees
   Its cost actually corresponds to an optimal binary search tree
›
lemma min_wpl_minimal_cost:
  "inorder t = [i..j]  min_wpl W i j  cost t"
  using inorder_wpl_correct min_wpl_minimal by metis

lemma min_wpl_tree:
  "cost (opt_bst W i j) = min_wpl W i j"
  using wpl_opt_bst opt_bst_correct inorder_wpl_correct by metis


paragraph ‹An alternative view of costs.›

fun depth :: "'a  'a tree  nat extended" where
  "depth x Leaf = "
| "depth x (Node l k r) = (if x = k then 1 else min (depth x l) (depth x r) + 1)"

fun the_fin where
  "the_fin (Fin x) = x" | "the_fin _ = undefined"

definition cost' :: "int tree  nat" where
  "cost' t = sum (λx. the_fin (depth x t) * p x) (set_tree t)"

lemma [simp]:
  "the_fin 1 = 1"
  by (simp add: one_extended_def)

lemma set_tree_depth:
  assumes "x  set_tree t"
  shows "depth x t = "
  using assms by (induction t) auto

lemma depth_inf_iff:
  "depth x t =   x  set_tree t"
  apply (induction t)
   apply (auto simp: one_extended_def)
  subgoal for t1 k t2
    by (cases "depth x t1"; cases "depth x t2") auto
  subgoal for t1 k t2
    by (cases "depth x t1"; cases "depth x t2") auto
  subgoal for t1 k t2
    by (cases "depth x t1"; cases "depth x t2") auto
  subgoal for t1 k t2
    by (cases "depth x t1"; cases "depth x t2") auto
  done

lemma depth_not_neg_inf[simp]:
  "depth x t = -∞  False"
  apply (induction t)
   apply (auto simp: one_extended_def)
  subgoal for t1 k t2
    by (cases "depth x t1"; cases "depth x t2") auto
  done

lemma depth_FinD:
  assumes "x  set_tree t"
  obtains d where "depth x t = Fin d"
  using assms by (cases "depth x t") (auto simp: depth_inf_iff)

lemma cost'_Leaf[simp]:
  "cost' Leaf = 0"
  unfolding cost'_def by simp

lemma cost'_Node:
  "distinct (inorder l, x, r) 
  cost' l, x, r = sum p (set_tree l) + cost' l + p x + cost' r + sum p (set_tree r)"
  unfolding cost'_def
  apply simp
  apply (subst sum.union_disjoint)
     apply (simp; fail)+
  apply (subst sum.cong[OF HOL.refl, where h = "λx. (the_fin (depth x l) + 1) * p x"])
  subgoal for k
    using set_tree_depth by (force simp: one_extended_def elim: depth_FinD)
  apply (subst (2) sum.cong[OF HOL.refl, where h = "λx. (the_fin (depth x r) + 1) * p x"])
  subgoal
    using set_tree_depth by (force simp: one_extended_def elim: depth_FinD)
  apply (simp add: sum.distrib)
  done

― ‹The two variants coincide›
lemma weight_correct:
  "distinct (inorder t)  cost' t = cost t"
  by (induction t; simp add: cost'_Node)


subsubsection ‹Memoizing Weights›

function W_fun where
  "W_fun i j = (if i > j then 0 else W_fun i (j - 1) + p j)"
  by auto

termination
  by (relation "measure (λ(i::int, j::int). nat (j - i + 1))") auto

lemma W_fun_correct:
  "W_fun i j = W i j"
  by (induction rule: W_fun.induct) (simp add: W_def atLeastAtMost_split_insert)

memoize_fun Wm: W_fun
  with_memory  dp_consistency_mapping
  monadifies (state) W_fun.simps

memoize_correct
  by memoize_prover

definition
  "compute_W n = snd (run_state (State_Main.mapT' (λi. Wm' i n) [0..n]) Mapping.empty)"

notation Wm.crel_vs ("crel")

lemmas Wm_crel = Wm.crel[unfolded Wm.consistentDP_def, THEN rel_funD,
      of "(m, x)" "(m, y)" for m x y, unfolded prod.case]

lemma compute_W_correct:
  assumes "Mapping.lookup (compute_W n) (i, j) = Some x"
  shows "W i j = x"
proof -
  include state_monad_syntax app_syntax lifting_syntax
  let ?p = "State_Main.mapT' (λi. Wm' i n) [0..n]"
  let ?q = "map (λi. W i n) [0..n]"
  have "?q = map $ (λi. W_fun i n) $ [0..n]"
    unfolding Wrap_def App_def W_fun_correct ..
  have "?p = State_Main.mapT . λi. Wm' i n . [0..n]"
    unfolding State_Monad_Ext.fun_app_lifted_def State_Main.mapT_def bind_left_identity ..
  ― ‹Not forgetting to write @{term  "list_all2 (=)"} instead of @{term "(=)"} was the tricky part.›
  have "Wm.crel_vs (list_all2 (=)) ?q ?p"
    unfolding ?p = _ ?q = _
    apply (subst Transfer.Rel_def[symmetric])
    apply memoize_prover_match_step+
    apply (subst Rel_def, rule Wm_crel, rule HOL.refl)
    done
  then have "Wm.cmem (compute_W n)"
    unfolding compute_W_def by (elim Wm.crel_vs_elim[OF _ Wm.cmem_empty]; simp del: Wm'.simps)
  with assms show ?thesis
    unfolding W_fun_correct[symmetric] by (elim Wm.cmem_elim) (simp)+
qed

definition
  "min_wpl' i j 
  let
    M = compute_W j;
    W = (λi j. case Mapping.lookup M (i, j) of None  W i j | Some x  x)
  in min_wpl W i j"

lemma W_compute: "W i j = (case Mapping.lookup (compute_W n) (i, j) of None  W i j | Some x  x)"
  by (auto dest: compute_W_correct split: option.split)

lemma min_wpl'_correct:
  "min_wpl' i j = min_wpl W i j"
  using W_compute unfolding min_wpl'_def by simp

definition
  "opt_bst' i j 
  let
    M = compute_W j;
    W = (λi j. case Mapping.lookup M (i, j) of None  W i j | Some x  x)
  in opt_bst W i j"

lemma opt_bst'_correct:
  "opt_bst' i j = opt_bst W i j"
  using W_compute unfolding opt_bst'_def by simp

end (* fixed p *)


subsubsection ‹Test Case›

text ‹Functional Implementations›

lemma "min_wpl (λi j. nat(i+j)) 0 4 = 10"
  by eval

lemma "opt_bst (λi j. nat(i+j)) 0 4 = ⟨⟩, 0, ⟨⟩, 1, ⟨⟩, 2, ⟨⟩, 3, ⟨⟩, 4, ⟨⟩"
  by eval

text ‹Using Frequencies›
definition
  "list_to_p xs (i::int) = (if i - 1  0  nat (i - 1) < length xs then xs ! nat (i - 1) else 0)"

definition
  "ex_p_1 = [10, 30, 15, 25, 20]"

definition
  "opt_tree_1 =
  
    
      ⟨⟩, 1::int, ⟨⟩,
      2,
      ⟨⟩, 3, ⟨⟩
    ,
    4,
    ⟨⟩, 5, ⟨⟩
  "

lemma "opt_bst' (list_to_p ex_p_1) 1 5 = opt_tree_1"
  by eval


text ‹Imperative Implementation›

code_thms min_wpl

definition "min_wpl_test = min_wplh (λi j. nat(i+j)) 4 0 4"

code_reflect Test functions min_wpl_test

ML Test.min_wpl_test ()

end