Theory HOL-Data_Structures.RBT_Set

(* Author: Tobias Nipkow *)

section ‹Red-Black Tree Implementation of Sets›

theory RBT_Set
imports
  Complex_Main
  RBT
  Cmp
  Isin2
begin

definition empty :: "'a rbt" where
"empty = Leaf"

fun ins :: "'a::linorder  'a rbt  'a rbt" where
"ins x Leaf = R Leaf x Leaf" |
"ins x (B l a r) =
  (case cmp x a of
     LT  baliL (ins x l) a r |
     GT  baliR l a (ins x r) |
     EQ  B l a r)" |
"ins x (R l a r) =
  (case cmp x a of
    LT  R (ins x l) a r |
    GT  R l a (ins x r) |
    EQ  R l a r)"

definition insert :: "'a::linorder  'a rbt  'a rbt" where
"insert x t = paint Black (ins x t)"

fun color :: "'a rbt  color" where
"color Leaf = Black" |
"color (Node _ (_, c) _) = c"

fun del :: "'a::linorder  'a rbt  'a rbt" where
"del x Leaf = Leaf" |
"del x (Node l (a, _) r) =
  (case cmp x a of
     LT  if l  Leaf  color l = Black
           then baldL (del x l) a r else R (del x l) a r |
     GT  if r  Leaf color r = Black
           then baldR l a (del x r) else R l a (del x r) |
     EQ  join l r)"

definition delete :: "'a::linorder  'a rbt  'a rbt" where
"delete x t = paint Black (del x t)"


subsection "Functional Correctness Proofs"

lemma inorder_paint: "inorder(paint c t) = inorder t"
by(cases t) (auto)

lemma inorder_baliL:
  "inorder(baliL l a r) = inorder l @ a # inorder r"
by(cases "(l,a,r)" rule: baliL.cases) (auto)

lemma inorder_baliR:
  "inorder(baliR l a r) = inorder l @ a # inorder r"
by(cases "(l,a,r)" rule: baliR.cases) (auto)

lemma inorder_ins:
  "sorted(inorder t)  inorder(ins x t) = ins_list x (inorder t)"
by(induction x t rule: ins.induct)
  (auto simp: ins_list_simps inorder_baliL inorder_baliR)

lemma inorder_insert:
  "sorted(inorder t)  inorder(insert x t) = ins_list x (inorder t)"
by (simp add: insert_def inorder_ins inorder_paint)

lemma inorder_baldL:
  "inorder(baldL l a r) = inorder l @ a # inorder r"
by(cases "(l,a,r)" rule: baldL.cases)
  (auto simp:  inorder_baliL inorder_baliR inorder_paint)

lemma inorder_baldR:
  "inorder(baldR l a r) = inorder l @ a # inorder r"
by(cases "(l,a,r)" rule: baldR.cases)
  (auto simp:  inorder_baliL inorder_baliR inorder_paint)

lemma inorder_join:
  "inorder(join l r) = inorder l @ inorder r"
by(induction l r rule: join.induct)
  (auto simp: inorder_baldL inorder_baldR split: tree.split color.split)

lemma inorder_del:
 "sorted(inorder t)   inorder(del x t) = del_list x (inorder t)"
by(induction x t rule: del.induct)
  (auto simp: del_list_simps inorder_join inorder_baldL inorder_baldR)

lemma inorder_delete:
  "sorted(inorder t)  inorder(delete x t) = del_list x (inorder t)"
by (auto simp: delete_def inorder_del inorder_paint)


subsection ‹Structural invariants›

lemma neq_Black[simp]: "(c  Black) = (c = Red)"
by (cases c) auto

text‹The proofs are due to Markus Reiter and Alexander Krauss.›

fun bheight :: "'a rbt  nat" where
"bheight Leaf = 0" |
"bheight (Node l (x, c) r) = (if c = Black then bheight l + 1 else bheight l)"

fun invc :: "'a rbt  bool" where
"invc Leaf = True" |
"invc (Node l (a,c) r) =
  ((c = Red  color l = Black  color r = Black)  invc l  invc r)"

text ‹Weaker version:›
abbreviation invc2 :: "'a rbt  bool" where
"invc2 t  invc(paint Black t)"

fun invh :: "'a rbt  bool" where
"invh Leaf = True" |
"invh (Node l (x, c) r) = (bheight l = bheight r  invh l  invh r)"

lemma invc2I: "invc t  invc2 t"
by (cases t rule: tree2_cases) simp+

definition rbt :: "'a rbt  bool" where
"rbt t = (invc t  invh t  color t = Black)"

lemma color_paint_Black: "color (paint Black t) = Black"
by (cases t) auto

lemma paint2: "paint c2 (paint c1 t) = paint c2 t"
by (cases t) auto

lemma invh_paint: "invh t  invh (paint c t)"
by (cases t) auto

lemma invc_baliL:
  "invc2 l; invc r  invc (baliL l a r)" 
by (induct l a r rule: baliL.induct) auto

lemma invc_baliR:
  "invc l; invc2 r  invc (baliR l a r)" 
by (induct l a r rule: baliR.induct) auto

lemma bheight_baliL:
  "bheight l = bheight r  bheight (baliL l a r) = Suc (bheight l)"
by (induct l a r rule: baliL.induct) auto

lemma bheight_baliR:
  "bheight l = bheight r  bheight (baliR l a r) = Suc (bheight l)"
by (induct l a r rule: baliR.induct) auto

lemma invh_baliL: 
  " invh l; invh r; bheight l = bheight r   invh (baliL l a r)"
by (induct l a r rule: baliL.induct) auto

lemma invh_baliR: 
  " invh l; invh r; bheight l = bheight r   invh (baliR l a r)"
by (induct l a r rule: baliR.induct) auto

text ‹All in one:›

lemma inv_baliR: " invh l; invh r; invc l; invc2 r; bheight l = bheight r 
  invc (baliR l a r)  invh (baliR l a r)  bheight (baliR l a r) = Suc (bheight l)"
by (induct l a r rule: baliR.induct) auto

lemma inv_baliL: " invh l; invh r; invc2 l; invc r; bheight l = bheight r 
  invc (baliL l a r)  invh (baliL l a r)  bheight (baliL l a r) = Suc (bheight l)"
by (induct l a r rule: baliL.induct) auto

subsubsection ‹Insertion›

lemma invc_ins: "invc t  invc2 (ins x t)  (color t = Black  invc (ins x t))"
by (induct x t rule: ins.induct) (auto simp: invc_baliL invc_baliR invc2I)

lemma invh_ins: "invh t  invh (ins x t)  bheight (ins x t) = bheight t"
by(induct x t rule: ins.induct)
  (auto simp: invh_baliL invh_baliR bheight_baliL bheight_baliR)

theorem rbt_insert: "rbt t  rbt (insert x t)"
by (simp add: invc_ins invh_ins color_paint_Black invh_paint rbt_def insert_def)

text ‹All in one:›

lemma inv_ins: " invc t; invh t  
  invc2 (ins x t)  (color t = Black  invc (ins x t)) 
  invh(ins x t)  bheight (ins x t) = bheight t"
by (induct x t rule: ins.induct) (auto simp: inv_baliL inv_baliR invc2I)

theorem rbt_insert2: "rbt t  rbt (insert x t)"
by (simp add: inv_ins color_paint_Black invh_paint rbt_def insert_def)


subsubsection ‹Deletion›

lemma bheight_paint_Red:
  "color t = Black  bheight (paint Red t) = bheight t - 1"
by (cases t) auto

lemma invh_baldL_invc:
  " invh l;  invh r;  bheight l + 1 = bheight r;  invc r 
    invh (baldL l a r)  bheight (baldL l a r) = bheight r"
by (induct l a r rule: baldL.induct)
   (auto simp: invh_baliR invh_paint bheight_baliR bheight_paint_Red)

lemma invh_baldL_Black: 
  " invh l;  invh r;  bheight l + 1 = bheight r;  color r = Black 
    invh (baldL l a r)  bheight (baldL l a r) = bheight r"
by (induct l a r rule: baldL.induct) (auto simp add: invh_baliR bheight_baliR) 

lemma invc_baldL: "invc2 l; invc r; color r = Black  invc (baldL l a r)"
by (induct l a r rule: baldL.induct) (simp_all add: invc_baliR)

lemma invc2_baldL: " invc2 l; invc r   invc2 (baldL l a r)"
by (induct l a r rule: baldL.induct) (auto simp: invc_baliR paint2 invc2I)

lemma invh_baldR_invc:
  " invh l;  invh r;  bheight l = bheight r + 1;  invc l 
   invh (baldR l a r)  bheight (baldR l a r) = bheight l"
by(induct l a r rule: baldR.induct)
  (auto simp: invh_baliL bheight_baliL invh_paint bheight_paint_Red)

lemma invc_baldR: "invc l; invc2 r; color l = Black  invc (baldR l a r)"
by (induct l a r rule: baldR.induct) (simp_all add: invc_baliL)

lemma invc2_baldR: " invc l; invc2 r  invc2 (baldR l a r)"
by (induct l a r rule: baldR.induct) (auto simp: invc_baliL paint2 invc2I)

lemma invh_join:
  " invh l; invh r; bheight l = bheight r 
   invh (join l r)  bheight (join l r) = bheight l"
by (induct l r rule: join.induct) 
   (auto simp: invh_baldL_Black split: tree.splits color.splits)

lemma invc_join: 
  " invc l; invc r  
  (color l = Black  color r = Black  invc (join l r))  invc2 (join l r)"
by (induct l r rule: join.induct)
   (auto simp: invc_baldL invc2I split: tree.splits color.splits)

text ‹All in one:›

lemma inv_baldL:
  " invh l;  invh r;  bheight l + 1 = bheight r; invc2 l; invc r 
    invh (baldL l a r)  bheight (baldL l a r) = bheight r
   invc2 (baldL l a r)  (color r = Black  invc (baldL l a r))"
by (induct l a r rule: baldL.induct)
   (auto simp: inv_baliR invh_paint bheight_baliR bheight_paint_Red paint2 invc2I)

lemma inv_baldR:
  " invh l;  invh r;  bheight l = bheight r + 1; invc l; invc2 r 
    invh (baldR l a r)  bheight (baldR l a r) = bheight l
   invc2 (baldR l a r)  (color l = Black  invc (baldR l a r))"
by (induct l a r rule: baldR.induct)
   (auto simp: inv_baliL invh_paint bheight_baliL bheight_paint_Red paint2 invc2I)

lemma inv_join:
  " invh l; invh r; bheight l = bheight r; invc l; invc r 
   invh (join l r)  bheight (join l r) = bheight l
   invc2 (join l r)  (color l = Black  color r = Black  invc (join l r))"
by (induct l r rule: join.induct) 
   (auto simp: invh_baldL_Black inv_baldL invc2I split: tree.splits color.splits)

lemma neq_LeafD: "t  Leaf  l x c r. t = Node l (x,c) r"
by(cases t rule: tree2_cases) auto

lemma inv_del: " invh t; invc t  
   invh (del x t) 
   (color t = Red  bheight (del x t) = bheight t  invc (del x t)) 
   (color t = Black  bheight (del x t) = bheight t - 1  invc2 (del x t))"
by(induct x t rule: del.induct)
  (auto simp: inv_baldL inv_baldR inv_join dest!: neq_LeafD)

theorem rbt_delete: "rbt t  rbt (delete x t)"
by (metis delete_def rbt_def color_paint_Black inv_del invh_paint)

text ‹Overall correctness:›

interpretation S: Set_by_Ordered
where empty = empty and isin = isin and insert = insert and delete = delete
and inorder = inorder and inv = rbt
proof (standard, goal_cases)
  case 1 show ?case by (simp add: empty_def)
next
  case 2 thus ?case by(simp add: isin_set_inorder)
next
  case 3 thus ?case by(simp add: inorder_insert)
next
  case 4 thus ?case by(simp add: inorder_delete)
next
  case 5 thus ?case by (simp add: rbt_def empty_def) 
next
  case 6 thus ?case by (simp add: rbt_insert) 
next
  case 7 thus ?case by (simp add: rbt_delete) 
qed


subsection ‹Height-Size Relation›

lemma rbt_height_bheight_if: "invc t  invh t 
  height t  2 * bheight t + (if color t = Black then 0 else 1)"
by(induction t) (auto split: if_split_asm)

lemma rbt_height_bheight: "rbt t  height t / 2  bheight t "
by(auto simp: rbt_def dest: rbt_height_bheight_if)

lemma bheight_size_bound:  "invc t  invh t  2 ^ (bheight t)  size1 t"
by (induction t) auto

lemma bheight_le_min_height:  "invh t  bheight t  min_height t"
by (induction t) auto

lemma rbt_height_le: assumes "rbt t" shows "height t  2 * log 2 (size1 t)"
proof -
  have "2 powr (height t / 2)  2 powr bheight t"
    using rbt_height_bheight[OF assms] by simp
  also have "  size1 t" using assms
    by (simp add: powr_realpow bheight_size_bound rbt_def)
  finally have "2 powr (height t / 2)  size1 t" .
  hence "height t / 2  log 2 (size1 t)"
    by (simp add: le_log_iff size1_size del: divide_le_eq_numeral1(1))
  thus ?thesis by simp
qed

lemma rbt_height_le2: assumes "rbt t" shows "height t  2 * log 2 (size1 t)"
proof -
  have "height t  2 * bheight t"
    using rbt_height_bheight_if assms[simplified rbt_def] by fastforce
  also have "  2 * min_height t"
    using bheight_le_min_height assms[simplified rbt_def] by auto
  also have "  2 * log 2 (size1 t)"
    using le_log2_of_power min_height_size1 by auto
  finally show ?thesis by simp
qed

end