Theory Splay_Tree

section "Splay Tree"

theory Splay_Tree
imports
  "HOL-Library.Tree"
  "HOL-Data_Structures.Set_Specs"
  "HOL-Data_Structures.Cmp"
begin

declare sorted_wrt.simps(2)[simp del]

text‹Splay trees were invented by Sleator and Tarjan~cite"SleatorT-JACM85".›

subsection "Function splay›"

function splay :: "'a::linorder  'a tree  'a tree" where
"splay x Leaf = Leaf" |
"splay x (Node AB x CD) = Node AB x CD" |
"x<b  splay x (Node (Node A x B) b CD) = Node A x (Node B b CD)" |
"x<b  splay x (Node Leaf b CD) = Node Leaf b CD" |
"x<a  x<b  splay x (Node (Node Leaf a B) b CD) = Node Leaf a (Node B b CD)" |
"x<a  x<b  A  Leaf 
 splay x (Node (Node A a B) b CD) =
 (case splay x A of Node A1 a' A2  Node A1 a' (Node A2 a (Node B b CD)))" |
"a<x  x<b  splay x (Node (Node A a Leaf) b CD) = Node A a (Node Leaf b CD)" |
"a<x  x<b  B  Leaf 
 splay x (Node (Node A a B) b CD) =
 (case splay x B of Node B1 b' B2  Node (Node A a B1) b' (Node B2 b CD))" |
"b<x  splay x (Node AB b (Node C x D)) = Node (Node AB b C) x D" |
"b<x  splay x (Node AB b Leaf) = Node AB b Leaf" |
"b<x  x<c  C  Leaf 
 splay x (Node AB b (Node C c D)) =
 (case splay x C of Node C1 c' C  Node (Node AB b C1) c' (Node C c D))" |
"b<x  x<c  splay x (Node AB b (Node Leaf c D)) = Node (Node AB b Leaf) c D" |
"b<x  c<x  splay x (Node AB b (Node C c Leaf)) = Node (Node AB b C) c Leaf" |
"a<x  c<x  D  Leaf 
 splay x (Node AB a (Node C c D)) =
 (case splay x D of Node D1 d' D2  Node (Node (Node AB a C) c D1) d' D2)"
apply(atomize_elim)
apply(auto)
(* 1 subgoal *)
apply (subst (asm) neq_Leaf_iff)
apply(auto)
apply (metis tree.exhaust le_less_linear less_linear)+
done

termination splay
by lexicographic_order

lemma splay_code: "splay x (Node AB b CD) =
  (case cmp x b of
   EQ  Node AB b CD |
   LT  (case AB of
          Leaf  Node AB b CD |
          Node A a B 
            (case cmp x a of EQ  Node A a (Node B b CD) |
             LT   if A = Leaf then Node A a (Node B b CD)
                       else case splay x A of
                         Node A1 x' A2  Node A1 x' (Node A2 a (Node B b CD)) |
             GT  if B = Leaf then Node A a (Node B b CD)
                       else case splay x B of
                         Node B1 x' B2  Node (Node A a B1) x' (Node B2 b CD))) |
   GT  (case CD of
          Leaf  Node AB b CD |
          Node C c D 
            (case cmp x c of EQ  Node (Node AB b C) c D |
             LT  if C = Leaf then Node (Node AB b C) c D
                       else case splay x C of
                         Node C1 x' C2  Node (Node AB b C1) x' (Node C2 c D) |
             GT  if D=Leaf then Node (Node AB b C) c D
                       else case splay x D of
                         Node D1 x' D2  Node (Node (Node AB b C) c D1) x' D2)))"
by(auto split!: tree.split)

definition is_root :: "'a  'a tree  bool" where
"is_root x t = (case t of Leaf  False | Node l a r  x = a)"

definition "isin t x = is_root x (splay x t)"

definition empty :: "'a tree" where
"empty = Leaf"

hide_const (open) insert

fun insert :: "'a::linorder  'a tree  'a tree" where
"insert x t =
  (if t = Leaf then Node Leaf x Leaf
   else case splay x t of
     Node l a r 
      case cmp x a of
        EQ  Node l a r |
        LT  Node l x (Node Leaf a r) |
        GT  Node (Node l a Leaf) x r)"


fun splay_max :: "'a tree  'a tree" where
"splay_max Leaf = Leaf" |
"splay_max (Node A a Leaf) = Node A a Leaf" |
"splay_max (Node A a (Node B b CD)) =
  (if CD = Leaf then Node (Node A a B) b Leaf
   else case splay_max CD of
     Node C c D  Node (Node (Node A a B) b C) c D)"

lemma splay_max_code: "splay_max t = (case t of
  Leaf  t |
  Node la a ra  (case ra of
    Leaf  t |
    Node lb b rb 
      (if rb=Leaf then Node (Node la a lb) b rb
       else case splay_max rb of
              Node lc c rc  Node (Node (Node la a lb) b lc) c rc)))"
by(auto simp: neq_Leaf_iff split: tree.split)

definition delete :: "'a::linorder  'a tree  'a tree" where
"delete x t =
  (if t = Leaf then Leaf
   else case splay x t of Node l a r 
     if x  a then Node l a r
     else if l = Leaf then r else case splay_max l of Node l' m r'  Node l' m r)"



subsection "Functional Correctness Proofs I"

text ‹This subsection follows the automated method by Nipkow cite"Nipkow-ITP16".›

lemma splay_Leaf_iff[simp]: "(splay a t = Leaf) = (t = Leaf)"
by(induction a t rule: splay.induct) (auto split: tree.splits)

lemma splay_max_Leaf_iff[simp]: "(splay_max t = Leaf) = (t = Leaf)"
by(induction t rule: splay_max.induct)(auto split: tree.splits)


subsubsection "Verification of @{const isin}"

lemma splay_elemsD:
  "splay x t = Node l a r  sorted(inorder t) 
  x  set (inorder t)  x=a"
by(induction x t arbitrary: l a r rule: splay.induct)
  (auto simp: isin_simps ball_Un split: tree.splits)

lemma isin_set: "sorted(inorder t)  isin t x = (x  set (inorder t))"
by (auto simp: isin_def is_root_def dest: splay_elemsD split: tree.splits)


subsubsection "Verification of @{const insert}"

lemma inorder_splay: "inorder(splay x t) = inorder t"
by(induction x t rule: splay.induct)
  (auto simp: neq_Leaf_iff split: tree.split)

lemma sorted_splay:
  "sorted(inorder t)  splay x t = Node l a r 
  sorted(inorder l @ x # inorder r)"
unfolding inorder_splay[of x t, symmetric]
by(induction x t arbitrary: l a r rule: splay.induct)
  (auto simp: sorted_lems sorted_Cons_le sorted_snoc_le split: tree.splits)

lemma inorder_insert:
  "sorted(inorder t)  inorder(insert x t) = ins_list x (inorder t)"
using inorder_splay[of x t, symmetric] sorted_splay[of t x]
by(auto simp: ins_list_simps ins_list_Cons ins_list_snoc neq_Leaf_iff split: tree.split)


subsubsection "Verification of @{const delete}"

lemma inorder_splay_maxD:
  "splay_max t = Node l a r  sorted(inorder t) 
  inorder l @ [a] = inorder t  r = Leaf"
by(induction t arbitrary: l a r rule: splay_max.induct)
  (auto simp: sorted_lems split: tree.splits if_splits)

lemma inorder_delete:
  "sorted(inorder t)  inorder(delete x t) = del_list x (inorder t)"
using inorder_splay[of x t, symmetric] sorted_splay[of t x]
by (auto simp: del_list_simps del_list_sorted_app delete_def
  del_list_notin_Cons inorder_splay_maxD split: tree.splits)


subsubsection "Overall Correctness"

interpretation splay: Set_by_Ordered
where empty = empty and isin = isin and insert = insert
and delete = delete and inorder = inorder and inv = "λ_. True"
proof (standard, goal_cases)
  case 2 thus ?case by(simp add: isin_set)
next
  case 3 thus ?case by(simp add: inorder_insert del: insert.simps)
next
  case 4 thus ?case by(simp add: inorder_delete)
qed (auto simp: empty_def)

text ‹Corollaries:›

lemma bst_splay: "bst t  bst (splay x t)"
by (simp add: bst_iff_sorted_wrt_less inorder_splay)

lemma bst_insert: "bst t  bst(insert x t)"
using splay.invar_insert[of t x] by (simp add: bst_iff_sorted_wrt_less splay.invar_def)

lemma bst_delete: "bst t  bst(delete x t)"
using splay.invar_delete[of t x] by (simp add: bst_iff_sorted_wrt_less splay.invar_def)

lemma splay_bstL: "bst t  splay a t = Node l e r  x  set_tree l  x < a"
by (metis bst_iff_sorted_wrt_less list.set_intros(1) set_inorder sorted_splay sorted_wrt_append)

lemma splay_bstR: "bst t  splay a t = Node l e r  x  set_tree r  a < x"
by (metis bst_iff_sorted_wrt_less sorted_Cons_iff set_inorder sorted_splay sorted_wrt_append)


subsubsection "Size lemmas"

lemma size_splay[simp]: "size (splay a t) = size t"
apply(induction a t rule: splay.induct)
apply auto
 apply(force split: tree.split)+
done

lemma size_if_splay: "splay a t = Node l u r  size t = size l + size r + 1"
by (metis One_nat_def size_splay tree.size(4))

lemma splay_not_Leaf: "t  Leaf  l x r. splay a t = Node l x r"
by (metis neq_Leaf_iff splay_Leaf_iff)

lemma size_splay_max: "size(splay_max t) = size t"
apply(induction t rule: splay_max.induct)
  apply(simp)
 apply(simp)
apply(clarsimp split: tree.split)
done

lemma size_if_splay_max: "splay_max t = Node l u r  size t = size l + size r + 1"
by (metis One_nat_def size_splay_max tree.size(4))

(*
subsection "Functional Correctness Proofs II"

text ‹This subsection follows the traditional approach, is less automated
and is retained more for historic reasons.›

lemma set_splay: "set_tree(splay a t) = set_tree t"
proof(induction a t rule: splay.induct)
  case (6 a)
  with splay_not_Leaf[OF 6(3), of a] show ?case by(fastforce)
next
  case (8 _ a)
  with splay_not_Leaf[OF 8(3), of a] show ?case by(fastforce)
next
  case (11 _ a)
  with splay_not_Leaf[OF 11(3), of a] show ?case by(fastforce)
next
  case (14 _ a)
  with splay_not_Leaf[OF 14(3), of a] show ?case by(fastforce)
qed auto

lemma splay_bstL: "bst t ⟹ splay a t = Node l e r ⟹ x ∈ set_tree l ⟹ x < a"
apply(induction a t arbitrary: l x r rule: splay.induct)
apply (auto split: tree.splits)
apply auto
done

lemma splay_bstR: "bst t ⟹ splay a t = Node l e r ⟹ x ∈ set_tree r ⟹ a < x"
apply(induction a t arbitrary: l e x r rule: splay.induct)
apply auto
apply (fastforce split!: tree.splits)+
done

lemma bst_splay: "bst t ⟹ bst(splay a t)"
proof(induction a t rule: splay.induct)
  case (6 a _ _ ll)
  with splay_not_Leaf[OF 6(3), of a] set_splay[of a ll,symmetric]
  show ?case by (fastforce)
next
  case (8 _ a _ t)
  with splay_not_Leaf[OF 8(3), of a] set_splay[of a t,symmetric]
  show ?case by fastforce
next
  case (11 _ a _ t)
  with splay_not_Leaf[OF 11(3), of a] set_splay[of a t,symmetric]
  show ?case by fastforce
next
  case (14 _ a _ t)
  with splay_not_Leaf[OF 14(3), of a] set_splay[of a t,symmetric]
  show ?case by fastforce
qed auto

lemma splay_to_root: "⟦ bst t;  splay a t = t' ⟧ ⟹
  a ∈ set_tree t ⟷ (∃l r. t' = Node l a r)"
proof(induction a t arbitrary: t' rule: splay.induct)
  case (6 a)
  with splay_not_Leaf[OF 6(3), of a] show ?case by auto
next
  case (8 _ a)
  with splay_not_Leaf[OF 8(3), of a] show ?case by auto
next
  case (11 _ a)
  with splay_not_Leaf[OF 11(3), of a] show ?case by auto
next
  case (14 _ a)
  with splay_not_Leaf[OF 14(3), of a] show ?case by auto
qed fastforce+


subsubsection "Verification of Is-in Test"

text‹To test if an element ‹a› is in ‹t›, first perform
@{term"splay a t"}, then check if the root is ‹a›. One could
put this into one function that returns both a new tree and the test result.›

lemma is_root_splay: "bst t ⟹ is_root a (splay a t) ⟷ a ∈ set_tree t"
by(auto simp add: is_root_def splay_to_root split: tree.split)


subsubsection "Verification of @{const insert}"

lemma set_insert: "set_tree(insert a t) = Set.insert a (set_tree t)"
apply(cases t)
 apply simp
using set_splay[of a t]
by(simp split: tree.split) fastforce

lemma bst_insert: "bst t ⟹ bst(insert a t)"
apply(cases t)
 apply simp
using bst_splay[of t a] splay_bstL[of t a] splay_bstR[of t a]
by(auto simp: ball_Un split: tree.split)


subsubsection "Verification of ‹splay_max›"

lemma set_splay_max: "set_tree(splay_max t) = set_tree t"
apply(induction t rule: splay_max.induct)
   apply(simp)
  apply(simp)
apply(force split: tree.split)
done

lemma bst_splay_max: "bst t ⟹ bst (splay_max t)"
proof(induction t rule: splay_max.induct)
  case (3 l b rl c rr)
  { fix rrl' d' rrr'
    have "splay_max rr = Node rrl' d' rrr'
       ⟹ ∀x ∈ set_tree(Node rrl' d' rrr'). c < x" 
      using "3.prems" set_splay_max[of rr]
      by (clarsimp split: tree.split simp: ball_Un)
  }
  with 3 show ?case by (fastforce split: tree.split simp: ball_Un)
qed auto

lemma splay_max_Leaf: "splay_max t = Node l a r ⟹ r = Leaf"
by(induction t arbitrary: l rule: splay_max.induct)
  (auto split: tree.splits if_splits)

text‹For sanity purposes only:›

lemma splay_max_eq_splay:
  "bst t ⟹ ∀x ∈ set_tree t. x ≤ a ⟹ splay_max t = splay a t"
proof(induction a t rule: splay.induct)
  case (2 a l r)
  show ?case
  proof (cases r)
    case Leaf with 2 show ?thesis by simp
  next
    case Node with 2 show ?thesis by(auto)
  qed
qed (auto simp: neq_Leaf_iff)

lemma splay_max_eq_splay_ex: assumes "bst t" shows "∃a. splay_max t = splay a t"
proof(cases t)
  case Leaf thus ?thesis by simp
next
  case Node
  hence "splay_max t = splay (Max(set_tree t)) t"
    using assms by (auto simp: splay_max_eq_splay)
  thus ?thesis by auto
qed


subsubsection "Verification of @{const delete}"

lemma set_delete: assumes "bst t"
shows "set_tree (delete a t) = set_tree t - {a}"
proof(cases t)
  case Leaf thus ?thesis by(simp add: delete_def)
next
  case (Node l x r)
  obtain l' x' r' where sp[simp]: "splay a (Node l x r) = Node l' x' r'"
    by (metis neq_Leaf_iff splay_Leaf_iff)
  show ?thesis
  proof cases
    assume [simp]: "x' = a"
    show ?thesis
    proof cases
      assume "l' = Leaf"
      thus ?thesis
        using Node assms set_splay[of a "Node l x r"] bst_splay[of "Node l x r" a]
        by(simp add: delete_def split: tree.split prod.split)(fastforce)
    next
      assume "l' ≠ Leaf"
      moreover then obtain l'' m r'' where "splay_max l' = Node l'' m r''"
        using splay_max_Leaf_iff tree.exhaust by blast
      moreover have "a ∉ set_tree l'"
        by (metis (no_types) Node assms less_irrefl sp splay_bstL)
      ultimately show ?thesis
        using Node assms set_splay[of a "Node l x r"] bst_splay[of "Node l x r" a]
          splay_max_Leaf[of l' l'' m r''] set_splay_max[of l']
        by(clarsimp simp: delete_def split: tree.split) auto
    qed
  next
    assume "x' ≠ a"
    thus ?thesis using Node assms set_splay[of a "Node l x r"] splay_to_root[OF _ sp]
      by (simp add: delete_def)
  qed
qed

lemma bst_delete: assumes "bst t" shows "bst (delete a t)"
proof(cases t)
  case Leaf thus ?thesis by(simp add: delete_def)
next
  case (Node l x r)
  obtain l' x' r' where sp[simp]: "splay a (Node l x r) = Node l' x' r'"
    by (metis neq_Leaf_iff splay_Leaf_iff)
  show ?thesis
  proof cases
    assume [simp]: "x' = a"
    show ?thesis
    proof cases
      assume "l' = Leaf"
      thus ?thesis using Node assms bst_splay[of "Node l x r" a]
        by(simp add: delete_def split: tree.split prod.split)
    next
      assume "l' ≠ Leaf"
      thus ?thesis
        using Node assms set_splay[of a "Node l x r"] bst_splay[of "Node l x r" a]
          bst_splay_max[of l'] set_splay_max[of l']
        by(clarsimp simp: delete_def split: tree.split)
          (metis (no_types) insertI1 less_trans)
    qed
  next
    assume "x' ≠ a"
    thus ?thesis using Node assms bst_splay[of "Node l x r" a]
      by(auto simp: delete_def split: tree.split prod.split)
  qed
qed
*)

end