Theory OptBST

theory OptBST
imports Tree Code_Target_Numeral State_Main Heap_Default Example_Misc
subsection ‹Optimal Binary Search Trees›

theory OptBST
imports
  "HOL-Library.Tree"
  "HOL-Library.Code_Target_Numeral"
  "../state_monad/State_Main" 
  "../heap_monad/Heap_Default"
  Example_Misc
begin

subsubsection ‹Misc›

(* FIXME mv List *)
lemma induct_list012:
  "⟦P []; ⋀x. P [x]; ⋀x y zs. P (y # zs) ⟹ P (x # y # zs)⟧ ⟹ P xs"
by induction_schema (pat_completeness, lexicographic_order)

lemma upto_split1: 
  "i ≤ j ⟹ j ≤ k ⟹ [i..k] = [i..j-1] @ [j..k]"
proof (induction j rule: int_ge_induct)
  case base thus ?case by (simp add: upto_rec1)
next
  case step thus ?case using upto_rec1 upto_rec2 by simp
qed

lemma upto_split2: 
  "i ≤ j ⟹ j ≤ k ⟹ [i..k] = [i..j] @ [j+1..k]"
using upto_rec1 upto_rec2 upto_split1 by auto

lemma upto_Nil: "[i..j] = [] ⟷ j < i"
by (simp add: upto.simps)

lemma upto_Nil2: "[] = [i..j] ⟷ j < i"
by (simp add: upto.simps)

lemma upto_join3: "⟦ i ≤ j; j ≤ k ⟧ ⟹ [i..j-1] @ j # [j+1..k] = [i..k]"
using upto_rec1 upto_split1 by auto

fun arg_min_list :: "('a ⇒ ('b::linorder)) ⇒ 'a list ⇒ 'a" where
"arg_min_list f [x] = x" |
"arg_min_list f (x#y#zs) = (let m = arg_min_list f (y#zs) in if f x ≤ f m then x else m)"

lemma f_arg_min_list_f: "xs ≠ [] ⟹ f (arg_min_list f xs) = Min (f ` (set xs))"
by(induction f xs rule: arg_min_list.induct) (auto simp: min_def intro!: antisym)
(* end mv *)


subsubsection ‹Definitions›

context
fixes W :: "int ⇒ int ⇒ nat"
begin

fun wpl :: "int ⇒ int ⇒ int tree ⇒ nat" where
   "wpl i j Leaf = 0"
 | "wpl i j (Node l k r) = wpl i (k-1) l + wpl (k+1) j r + W i j"

function min_wpl :: "int ⇒ int ⇒ nat" where
"min_wpl i j =
  (if i > j then 0
   else min_list (map (λk. min_wpl i (k-1) + min_wpl (k+1) j + W i j) [i..j]))"
by auto
termination by (relation "measure (λ(i,j) . nat(j-i+1))") auto
declare min_wpl.simps[simp del]

function opt_bst :: "int * int ⇒ int tree" where
"opt_bst (i,j) =
  (if i > j then Leaf else arg_min_list (wpl i j) [⟨opt_bst (i,k-1), k, opt_bst (k+1,j)⟩. k ← [i..j]])"
by auto
termination by (relation "measure (λ(i,j) . nat(j-i+1))") auto
declare opt_bst.simps[simp del]


subsubsection ‹Functional Memoization›

context
  fixes n :: nat
begin

context fixes
  mem :: "nat option array"
begin

memoize_fun min_wplT: min_wpl
  with_memory dp_consistency_heap_default where bound = "Bound (0, 0) (int n, int n)" and mem="mem"
  monadifies (heap) min_wpl.simps

context includes heap_monad_syntax begin
thm min_wplT'.simps min_wplT_def
end

memoize_correct
  by memoize_prover

lemmas memoized_empty = min_wplT.memoized_empty

end (* Fixed array *)

context
  includes heap_monad_syntax
  notes [simp del] = min_wplT'.simps
begin

definition "min_wplh ≡ λ i j. Heap_Monad.bind (mem_empty (n * n)) (λ mem. min_wplT' mem i j)"

lemma min_wpl_heap:
  "min_wpl i j = result_of (min_wplh i j) Heap.empty"
  unfolding min_wplh_def
  using memoized_empty[of _ "λ m. λ (a, b). min_wplT' m a b" "(i, j)", OF min_wplT.crel]
  by (simp add: index_size_defs)

end

end (* Bound *)

context begin
memoize_fun min_wplm: min_wpl with_memory dp_consistency_mapping monadifies (state) min_wpl.simps
thm min_wplm'.simps

memoize_correct
  by memoize_prover
print_theorems
lemmas [code] = min_wplm.memoized_correct

end


subsubsection ‹Correctness Proof›

lemma min_wpl_minimal:
  "inorder t = [i..j] ⟹ min_wpl i j ≤ wpl i j t"
proof(induction i j t rule: wpl.induct)
  case (1 i j)
  then show ?case by (simp add: min_wpl.simps upto_Nil2)
next
  case (2 i j l k r)
  then show ?case 
  proof cases
    assume "i > j" thus ?thesis by(simp add: min_wpl.simps)
  next
    assume [arith]: "¬ i > j"
    have kk_ij: "k∈set[i..j]" using 2 
        by (metis set_inorder tree.set_intros(2))
        
    let ?M = "((λk. min_wpl i (k-1) + min_wpl (k+1) j + W i j) ` {i..j})"
    let ?w = "min_wpl i (k-1) + min_wpl (k+1) j + W i j"
 
    have aux_min:"Min ?M ≤ ?w"
    proof (rule Min_le)
      show "finite ?M" by simp
      show "?w ∈ ?M" using kk_ij by auto
    qed

    have"inorder ⟨l,k,r⟩ = inorder l @k#inorder r" by auto
    from this have C:"[i..j] = inorder l @ k#inorder r" using 2 by auto
    have D: "[i..j] = [i..k-1]@k#[k+1..j]" using kk_ij upto_rec1 upto_split1
      by (metis atLeastAtMost_iff set_upto) 

    have l_inorder: "inorder l = [i..k-1]"
      by (smt C D append_Cons_eq_iff atLeastAtMost_iff set_upto)
    have r_inorder: "inorder r = [k+1..j]" 
      by (smt C D append_Cons_eq_iff atLeastAtMost_iff set_upto)

    have "min_wpl i j = Min ?M" by (simp add: min_wpl.simps min_list_Min upto_Nil)
    also have "... ≤ ?w" by (rule aux_min)    
    also have "... ≤ wpl i (k-1) l + wpl (k+1) j r + W i j" using l_inorder r_inorder "2.IH" by simp
    also have "... = wpl i j ⟨l,k,r⟩" by simp
    finally show ?thesis .
  qed
qed

lemma P_arg_min_list: "(⋀x. x ∈ set xs ⟹ P x) ⟹ xs ≠ [] ⟹ P(arg_min_list f xs)"
by(induction f xs rule: arg_min_list.induct) (auto simp: Let_def)

lemma opt_bst_correct: "inorder (opt_bst (i,j)) = [i..j]"
apply(induction "(i,j)" arbitrary: i j rule: opt_bst.induct)
by (force simp: opt_bst.simps upto_Nil upto_join3 intro: P_arg_min_list)

lemma wpl_opt_bst: "wpl i j (opt_bst (i,j)) = min_wpl i j"
proof(induction i j rule: min_wpl.induct)
  case (1 i j)
  show ?case
  proof cases
    assume "i > j" thus ?thesis by(simp add: min_wpl.simps opt_bst.simps)
  next
    assume *[arith]: "¬ i > j"
    let ?ts = "[⟨opt_bst (i,k-1), k, opt_bst (k+1,j)⟩. k <- [i..j]]"
    let ?M = "((λk. min_wpl i (k-1) + min_wpl (k+1) j + W i j) ` {i..j})"
    have "?ts ≠ []" by (auto simp add: upto.simps)
    have "wpl i j (opt_bst (i,j)) = wpl i j (arg_min_list (wpl i j) ?ts)" by (simp add: opt_bst.simps)
    also have "… = Min (wpl i j ` (set ?ts))" by(rule f_arg_min_list_f[OF ‹?ts ≠ []›])
    also have "… = Min ?M"
    proof (rule arg_cong[where f=Min])
      show "wpl i j ` (set ?ts) = ?M"
        by (fastforce simp: Bex_def image_iff 1[OF *])
    qed
    also have "… = min_wpl i j" by (simp add: min_wpl.simps min_list_Min upto_Nil)
    finally show ?thesis .
  qed
qed

lemma opt_bst_is_optimal:
  "inorder t = [i..j] ⟹ wpl i j (opt_bst (i,j)) ≤ wpl i j t"
by (simp add: min_wpl_minimal wpl_opt_bst)

end

subsubsection ‹Test Case›

code_thms min_wpl

definition "min_wpl_test = min_wplh (λi j. nat(i+j)) 2 3 3"

code_reflect Test functions min_wpl_test

ML ‹Test.min_wpl_test ()›

end