Theory Interval_Tree

(*
  File: Interval_Tree.thy
  Author: Bohua Zhan
*)

section ‹Interval tree›

theory Interval_Tree
  imports Lists_Ex Interval
begin

text ‹
  Functional version of interval tree. This is an augmented data
  structure on top of regular binary search trees (see BST.thy).
  See cite‹Section 14.3› in "cormen2009introduction" for a reference.
›

subsection ‹Definition of an interval tree›

datatype interval_tree =
   Tip
 | Node (lsub: interval_tree) (val: "nat idx_interval") (tmax: nat) (rsub: interval_tree)
where
  "tmax Tip = 0"

setup add_resolve_prfstep @{thm interval_tree.distinct(1)}
setup fold add_rewrite_rule @{thms interval_tree.sel}
setup add_forward_prfstep @{thm interval_tree.collapse}
setup add_var_induct_rule @{thm interval_tree.induct}

subsection ‹Inorder traversal, and set of elements of a tree›

fun in_traverse :: "interval_tree  nat idx_interval list" where
  "in_traverse Tip = []"
| "in_traverse (Node l it m r) = in_traverse l @ it # in_traverse r"
setup fold add_rewrite_rule @{thms in_traverse.simps}

fun tree_set :: "interval_tree  nat idx_interval set" where
  "tree_set Tip = {}"
| "tree_set (Node l it m r) = {it}  tree_set l  tree_set r"
setup fold add_rewrite_rule @{thms tree_set.simps}

fun tree_sorted :: "interval_tree  bool" where
  "tree_sorted Tip = True"
| "tree_sorted (Node l it m r) = ((xtree_set l. x < it)  (xtree_set r. it < x)
                                    tree_sorted l  tree_sorted r)"
setup fold add_rewrite_rule @{thms tree_sorted.simps}

lemma tree_sorted_lr [forward]:
  "tree_sorted (Node l it m r)  tree_sorted l  tree_sorted r" by auto2

lemma tree_sortedD1 [forward]:
  "tree_sorted (Node l it m r)  x  tree_set l  x < it" by auto2

lemma tree_sortedD2 [forward]:
  "tree_sorted (Node l it m r)  x  tree_set r  x > it" by auto2

lemma inorder_preserve_set [rewrite]:
  "tree_set t = set (in_traverse t)"
@proof @induct t @qed

lemma inorder_sorted [rewrite]:
  "tree_sorted t  strict_sorted (in_traverse t)"
@proof @induct t @qed

text ‹Use definition in terms of in\_traverse from now on.›
setup fold del_prfstep_thm (@{thms tree_set.simps} @ @{thms tree_sorted.simps})

subsection ‹Invariant on the maximum›

definition max3 :: "nat idx_interval  nat  nat  nat" where [rewrite]:
  "max3 it b c = max (high (int it)) (max b c)"

fun tree_max_inv :: "interval_tree  bool" where
  "tree_max_inv Tip = True"
| "tree_max_inv (Node l it m r)  (tree_max_inv l  tree_max_inv r  m = max3 it (tmax l) (tmax r))"
setup fold add_rewrite_rule @{thms tree_max_inv.simps}

lemma tree_max_is_max [resolve]:
  "tree_max_inv t  it  tree_set t  high (int it)  tmax t"
@proof @induct t @qed

lemma tmax_exists [backward]:
  "tree_max_inv t  t  Tip  ptree_set t. high (int p) = tmax t"
@proof @induct t @with
  @subgoal "t = Node l it m r"
    @case "l = Tip" @with @case "r = Tip" @end
    @case "r = Tip"
  @endgoal @end
@qed

text ‹For insertion›
lemma max3_insert [rewrite]: "max3 it 0 0 = high (int it)" by auto2

setup del_prfstep_thm @{thm max3_def}

subsection ‹Condition on the values›

definition tree_interval_inv :: "interval_tree  bool" where [rewrite]:
  "tree_interval_inv t  (ptree_set t. is_interval (int p))"

definition is_interval_tree :: "interval_tree  bool" where [rewrite]:
  "is_interval_tree t  (tree_sorted t  tree_max_inv t  tree_interval_inv t)"

lemma is_interval_tree_lr [forward]:
  "is_interval_tree (Node l x m r)  is_interval_tree l  is_interval_tree r" by auto2

subsection ‹Insertion on trees›

fun insert :: "nat idx_interval  interval_tree  interval_tree" where
  "insert x Tip = Node Tip x (high (int x)) Tip"
| "insert x (Node l y m r) =
    (if x = y then Node l y m r
     else if x < y then
       let l' = insert x l in
           Node l' y (max3 y (tmax l') (tmax r)) r
     else
       let r' = insert x r in
           Node l y (max3 y (tmax l) (tmax r')) r')"
setup fold add_rewrite_rule @{thms insert.simps}

lemma tree_insert_in_traverse [rewrite]:
  "tree_sorted t  in_traverse (insert x t) = ordered_insert x (in_traverse t)"
@proof @induct t @qed

lemma tree_insert_max_inv [forward]:
  "tree_max_inv t  tree_max_inv (insert x t)"
@proof @induct t @qed

text ‹Correctness of insertion.›
theorem tree_insert_all_inv [forward]:
  "is_interval_tree t  is_interval (int it)  is_interval_tree (insert it t)" by auto2

theorem tree_insert_on_set [rewrite]:
  "tree_sorted t  tree_set (insert it t) = {it}  tree_set t" by auto2

subsection ‹Deletion on trees›

fun del_min :: "interval_tree  nat idx_interval × interval_tree" where
  "del_min Tip = undefined"
| "del_min (Node lt v m rt) =
   (if lt = Tip then (v, rt) else
    let lt' = snd (del_min lt) in
    (fst (del_min lt), Node lt' v (max3 v (tmax lt') (tmax rt)) rt))"
setup add_rewrite_rule @{thm del_min.simps(2)}
setup register_wellform_data ("del_min t", ["t ≠ Tip"])

lemma delete_min_del_hd:
  "t  Tip  fst (del_min t) # in_traverse (snd (del_min t)) = in_traverse t"
@proof @induct t @qed
setup add_forward_prfstep_cond @{thm delete_min_del_hd} [with_term "in_traverse (snd (del_min ?t))"]

lemma delete_min_max_inv [forward_arg]:
  "tree_max_inv t  t  Tip  tree_max_inv (snd (del_min t))"
@proof @induct t @qed

lemma delete_min_on_set:
  "t  Tip  {fst (del_min t)}  tree_set (snd (del_min t)) = tree_set t" by auto2
setup add_forward_prfstep_cond @{thm delete_min_on_set} [with_term "tree_set (snd (del_min ?t))"]

lemma delete_min_interval_inv [forward_arg]:
  "tree_interval_inv t  t  Tip  tree_interval_inv (snd (del_min t))" by auto2

lemma delete_min_all_inv [forward_arg]:
  "is_interval_tree t  t  Tip  is_interval_tree (snd (del_min t))" by auto2

fun delete_elt_tree :: "interval_tree  interval_tree" where
  "delete_elt_tree Tip = undefined"
| "delete_elt_tree (Node lt x m rt) =
    (if lt = Tip then rt else if rt = Tip then lt else
     let x' = fst (del_min rt);
         rt' = snd (del_min rt);
         m' = max3 x' (tmax lt) (tmax rt') in
       Node lt (fst (del_min rt)) m' rt')"
setup add_rewrite_rule @{thm delete_elt_tree.simps(2)}

lemma delete_elt_in_traverse [rewrite]:
  "in_traverse (delete_elt_tree (Node lt x m rt)) = in_traverse lt @ in_traverse rt" by auto2

lemma delete_elt_max_inv [forward_arg]:
  "tree_max_inv t  t  Tip  tree_max_inv (delete_elt_tree t)" by auto2

lemma delete_elt_on_set [rewrite]:
  "t  Tip  tree_set (delete_elt_tree (Node lt x m rt)) = tree_set lt  tree_set rt" by auto2

lemma delete_elt_interval_inv [forward_arg]:
  "tree_interval_inv t  t  Tip  tree_interval_inv (delete_elt_tree t)" by auto2

lemma delete_elt_all_inv [forward_arg]:
  "is_interval_tree t  t  Tip  is_interval_tree (delete_elt_tree t)" by auto2

fun delete :: "nat idx_interval  interval_tree  interval_tree" where
  "delete x Tip = Tip"
| "delete x (Node l y m r) =
    (if x = y then delete_elt_tree (Node l y m r)
     else if x < y then
       let l' = delete x l;
           m' = max3 y (tmax l') (tmax r) in Node l' y m' r
     else
       let r' = delete x r;
           m' = max3 y (tmax l) (tmax r') in Node l y m' r')"
setup fold add_rewrite_rule @{thms delete.simps}

lemma tree_delete_in_traverse [rewrite]:
  "tree_sorted t  in_traverse (delete x t) = remove_elt_list x (in_traverse t)"
@proof @induct t @qed

lemma tree_delete_max_inv [forward]:
  "tree_max_inv t  tree_max_inv (delete x t)"
@proof @induct t @qed

text ‹Correctness of deletion.›
theorem tree_delete_all_inv [forward]:
  "is_interval_tree t  is_interval_tree (delete x t)"
@proof @have "tree_set (delete x t)  tree_set t" @qed

theorem tree_delete_on_set [rewrite]:
  "tree_sorted t  tree_set (delete x t) = tree_set t - {x}" by auto2

subsection ‹Search on interval trees›

fun search :: "interval_tree  nat interval  bool" where
  "search Tip x = False"
| "search (Node l y m r) x =
   (if is_overlap (int y) x then True
    else if l  Tip  tmax l  low x then search l x
    else search r x)"
setup fold add_rewrite_rule @{thms search.simps}

text ‹Correctness of search›
theorem search_correct [rewrite]:
  "is_interval_tree t  is_interval x  search t x  has_overlap (tree_set t) x"
@proof
  @induct t @with
    @subgoal "t = Node l y m r"
      @let "t = Node l y m r"
      @case "is_overlap (int y) x"
      @case "l  Tip  tmax l  low x" @with
        @obtain "ptree_set l" where "high (int p) = tmax l"
        @case "is_overlap (int p) x"
      @end
      @case "l = Tip"
    @endgoal
  @end
@qed

end