Theory BTree_Height

theory BTree_Height
  imports BTree
begin

section "Maximum and minimum height"

text "Textbooks usually provide some proofs relating the maxmimum and minimum height of the BTree
for a given number of nodes. We therefore introduce this counting and show the respective proofs."

subsection "Definition of node/size"

thm BTree.btree.size
  (* the automatically derived size is a bit weird for our purposes *)
value "size (Node [(Leaf, (0::nat)), (Node [(Leaf, 1), (Leaf, 10)] Leaf, 12), (Leaf, 30), (Leaf, 100)] Leaf)"


text "The default size function does not suit our needs as it regards the length of the list in each node.
 We would like to count the number of nodes in the tree only, not regarding the number of keys."

(* we want a different counting method,
  namely only the number of nodes in a tree *)

(* TODO what if we count Leafs as nodes? *)

fun nodes::"'a btree  nat" where
  "nodes Leaf = 0" |
  "nodes (Node ts t) = 1 + (tsubtrees ts. nodes t) + (nodes t)"

value "nodes (Node [(Leaf, (0::nat)), (Node [(Leaf, 1), (Leaf, 10)] Leaf, 12), (Leaf, 30), (Leaf, 100)] Leaf)"


(* maximum number of nodes for given height *)
subsection "Maximum number of nodes for a given height"


lemma sum_list_replicate: "sum_list (replicate n c) = n*c"
  apply(induction n)
   apply(auto simp add: ring_class.ring_distribs(2))
  done

abbreviation "bound k h  ((k+1)^h - 1)"

lemma nodes_height_upper_bound:
  "order k t; bal t  nodes t * (2*k)  bound (2*k) (height t)"
proof(induction t rule: nodes.induct)
  case (2 ts t)
  let ?sub_height = "((2 * k + 1) ^ height t - 1)"
  have "sum_list (map nodes (subtrees ts)) * (2*k) =
        sum_list (map (λt. nodes t * (2 * k)) (subtrees ts))"
    using sum_list_mult_const by metis
  also have "  sum_list (map (λx.?sub_height) (subtrees ts))"
    using 2
    using sum_list_mono[of "subtrees ts" "λt. nodes t * (2 * k)" "λx. bound (2 * k) (height t)"]
    by (metis bal.simps(2) order.simps(2))
  also have " = sum_list (replicate (length ts) ?sub_height)"
    using map_replicate_const[of ?sub_height "subtrees ts"] length_map
    by simp
  also have " = (length ts)*(?sub_height)"
    using sum_list_replicate by simp
  also have "  (2*k)*(?sub_height)"
    using "2.prems"(1)
    by simp
  finally have "sum_list (map nodes (subtrees ts))*(2*k)  ?sub_height*(2*k)"
    by simp
  moreover have "(nodes t)*(2*k)  ?sub_height"
    using 2 by simp
  ultimately have "(nodes (Node ts t))*(2*k) 
         2*k
        + ?sub_height * (2*k)
        + ?sub_height"
    unfolding nodes.simps add_mult_distrib
    by linarith
  also have " =  2*k + (2*k)*((2 * k + 1) ^ height t) - 2*k + (2 * k + 1) ^ height t - 1"
    by (simp add: diff_mult_distrib2 mult.assoc mult.commute)
  also have " = (2*k)*((2 * k + 1) ^ height t) + (2 * k + 1) ^ height t - 1"
    by simp
  also have " = (2*k+1)^(Suc(height t)) - 1"
    by simp
  finally show ?case
    by (metis "2.prems"(2) height_bal_tree)
qed simp

text "To verify our lower bound is sharp, we compare it to the height of artificially constructed
full trees."

fun full_node::"nat  'a  nat  'a btree" where
  "full_node k c 0 = Leaf"|
  "full_node k c (Suc n) = (Node (replicate (2*k) ((full_node k c n),c)) (full_node k c n))"

value "let k = (2::nat) in map (λx. nodes x * 2*k) (map (full_node k (1::nat)) [0,1,2,3,4])"
value "let k = (2::nat) in map (λx. ((2*k+(1::nat))^(x)-1)) [0,1,2,3,4]"

lemma compow_comp_id: "c > 0  f  f = f  (f ^^ c) = f"
  apply(induction c)
   apply auto
  by fastforce

(* required only for the fold definition of height *)
lemma compow_id_point: "f x = x  (f ^^ c) x = x"
  apply(induction c)
   apply auto
  done

lemma height_full_node: "height (full_node k a h) = h"
  apply(induction k a h rule: full_node.induct)
   apply (auto simp add: set_replicate_conv_if)
  done

lemma bal_full_node: "bal (full_node k a h)"
  apply(induction k a h rule: full_node.induct)
   apply auto
  done

lemma order_full_node: "order k (full_node k a h)"
  apply(induction k a h rule: full_node.induct)
   apply auto
  done

lemma full_btrees_sharp: "nodes (full_node k a h) * (2*k) = bound (2*k) h"
  apply(induction k a h rule: full_node.induct)
   apply (auto simp add: height_full_node algebra_simps sum_list_replicate)
  done

lemma upper_bound_sharp_node:
  "t = full_node k a h  height t = h  order k t  bal t  bound (2*k) h = nodes t * (2*k)"
  by (simp add: bal_full_node height_full_node order_full_node full_btrees_sharp)


(* maximum number of nodes *)
subsection "Maximum height for a given number of nodes"


lemma nodes_height_lower_bound:
  "order k t; bal t  bound k (height t)  nodes t * k"
proof(induction t rule: nodes.induct)
  case (2 ts t)
  let ?sub_height = "((k + 1) ^ height t - 1)"
  have "k*(?sub_height)  (length ts)*(?sub_height)"
    using "2.prems"(1)
    by simp
  also have " = sum_list (replicate (length ts) ?sub_height)"
    using sum_list_replicate by simp
  also have " = sum_list (map (λx.?sub_height) (subtrees ts))"
    using map_replicate_const[of ?sub_height "subtrees ts"] length_map
    by simp
  also have "  sum_list (map (λt. nodes t * k) (subtrees ts))"
    using 2
    using sum_list_mono[of "subtrees ts" "λx. bound k (height t)" "λt. nodes t * k"]
    by (metis bal.simps(2) order.simps(2))
  also have " = sum_list (map nodes (subtrees ts)) * k"
    using sum_list_mult_const[of nodes k "subtrees ts"] by auto
  finally have "sum_list (map nodes (subtrees ts))*k  ?sub_height*k"
    by simp
  moreover have "(nodes t)*k  ?sub_height"
    using 2 by simp
  ultimately have "(nodes (Node ts t))*k 
        k
        + ?sub_height * k
        + ?sub_height"
    unfolding nodes.simps add_mult_distrib
    by linarith
  also have
    "k + ?sub_height * k + ?sub_height =
     k + k*((k + 1) ^ height t) - k + (k + 1) ^ height t - 1"
    by (simp add: diff_mult_distrib2 mult.assoc mult.commute)
  also have " = k*((k + 1) ^ height t) + (k + 1) ^ height t - 1"
    by simp
  also have " = (k+1)^(Suc(height t)) - 1"
    by simp
  finally show ?case
    by (metis "2.prems"(2) height_bal_tree)
qed simp

text "To verify our upper bound is sharp, we compare it to the height of artificially constructed
minimally filled (=slim) trees."

fun slim_node::"nat  'a  nat  'a btree" where
  "slim_node k c 0 = Leaf"|
  "slim_node k c (Suc n) = (Node (replicate k ((slim_node k c n),c)) (slim_node k c n))"

value "let k = (2::nat) in map (λx. nodes x * k) (map (slim_node k (1::nat)) [0,1,2,3,4])"
value "let k = (2::nat) in map (λx. ((k+1::nat)^(x)-1)) [0,1,2,3,4]"


lemma height_slim_node: "height (slim_node k a h) = h"
  apply(induction k a h rule: full_node.induct)
   apply (auto simp add: set_replicate_conv_if)
  done

lemma bal_slim_node: "bal (slim_node k a h)"
  apply(induction k a h rule: full_node.induct)
   apply auto
  done

lemma order_slim_node: "order k (slim_node k a h)"
  apply(induction k a h rule: full_node.induct)
   apply auto
  done

lemma slim_nodes_sharp: "nodes (slim_node k a h) * k = bound k h"
  apply(induction k a h rule: slim_node.induct)
   apply (auto simp add: height_slim_node algebra_simps sum_list_replicate compow_id_point)
  done

lemma lower_bound_sharp_node:
  "t = slim_node k a h  height t = h  order k t  bal t  bound k h = nodes t * k"
  by (simp add: bal_slim_node height_slim_node order_slim_node slim_nodes_sharp)

(* TODO results for root_order/bal *)
text "Since BTrees have special roots, we need to show the overall nodes seperately"

lemma nodes_root_height_lower_bound:
  assumes "root_order k t"
    and "bal t"
  shows "2*((k+1)^(height t - 1) - 1) + (of_bool (t  Leaf))*k   nodes t * k"
proof (cases t)
  case (Node ts t)
  let ?sub_height = "((k + 1) ^ height t - 1)"
  from Node have "?sub_height  length ts * ?sub_height"
    using assms
    by (simp add: Suc_leI)
  also have " = sum_list (replicate (length ts) ?sub_height)"
    using sum_list_replicate
    by simp
  also have " = sum_list (map (λx. ?sub_height) (subtrees ts))"
    using map_replicate_const[of ?sub_height "subtrees ts"] length_map
    by simp
  also have "  sum_list (map (λt. nodes t * k) (subtrees ts))"
    using Node
      sum_list_mono[of "subtrees ts" "λx. (k+1)^(height t) - 1" "λx. nodes x * k"]
      nodes_height_lower_bound assms
    by fastforce
  also have " = sum_list (map nodes (subtrees ts)) * k"
    using sum_list_mult_const[of nodes k "subtrees ts"] by simp
  finally have "sum_list (map nodes (subtrees ts))*k  ?sub_height"
    by simp

  moreover have "(nodes t)*k  ?sub_height"
    using Node assms nodes_height_lower_bound
    by auto
  ultimately have "(nodes (Node ts t))*k 
        ?sub_height
        + ?sub_height + k"
    unfolding nodes.simps add_mult_distrib
    by linarith
  then show ?thesis
    using Node assms(2) height_bal_tree by fastforce
qed simp

lemma nodes_root_height_upper_bound:
  assumes "root_order k t"
    and "bal t"
  shows "nodes t * (2*k)  (2*k+1)^(height t) - 1"
proof(cases t)
  case (Node ts t)
  let ?sub_height = "((2 * k + 1) ^ height t - 1)"
  have "sum_list (map nodes (subtrees ts)) * (2*k) =
        sum_list (map (λt. nodes t * (2 * k)) (subtrees ts))"
    using sum_list_mult_const by metis
  also have "  sum_list (map (λx.?sub_height) (subtrees ts))"
    using Node
      sum_list_mono[of "subtrees ts" "λx. nodes x * (2*k)"  "λx. (2*k+1)^(height t) - 1"]
      nodes_height_upper_bound assms
    by fastforce
  also have " = sum_list (replicate (length ts) ?sub_height)"
    using map_replicate_const[of ?sub_height "subtrees ts"] length_map
    by simp
  also have " = (length ts)*(?sub_height)"
    using sum_list_replicate by simp
  also have "  (2*k)*?sub_height"
    using assms Node
    by simp
  finally have "sum_list (map nodes (subtrees ts))*(2*k)  ?sub_height*(2*k)"
    by simp
  moreover have "(nodes t)*(2*k)  ?sub_height"
    using Node assms nodes_height_upper_bound
    by auto
  ultimately have "(nodes (Node ts t))*(2*k) 
         2*k
        + ?sub_height * (2*k)
        + ?sub_height"
    unfolding nodes.simps add_mult_distrib
    by linarith
  also have " =  2*k + (2*k)*((2 * k + 1) ^ height t) - 2*k + (2 * k + 1) ^ height t - 1"
    by (simp add: diff_mult_distrib2 mult.assoc mult.commute)
  also have " = (2*k)*((2 * k + 1) ^ height t) + (2 * k + 1) ^ height t - 1"
    by simp
  also have " = (2*k+1)^(Suc(height t)) - 1"
    by simp
  finally show ?thesis
    by (metis Node assms(2) height_bal_tree)
qed simp

lemma root_order_imp_divmuleq: "root_order k t  (nodes t * k) div k = nodes t"
  using root_order.elims(2) by fastforce

lemma nodes_root_height_lower_bound_simp:
  assumes "root_order k t"
    and "bal t"
    and "k > 0"
  shows "(2*((k+1)^(height t - 1) - 1)) div k + (of_bool (t  Leaf))  nodes t"
proof (cases t)
  case Node
  have "(2*((k+1)^(height t - 1) - 1)) div k + (of_bool (t  Leaf)) =
(2*((k+1)^(height t - 1) - 1) + (of_bool (t  Leaf))*k) div k"
    using Node assms
    using div_plus_div_distrib_dvd_left[of k k "(2 * Suc k ^ (height t - Suc 0) - Suc (Suc 0))"]
    by (auto simp add: algebra_simps simp del: height_btree.simps)
  also have "  (nodes t * k) div k"
    using nodes_root_height_lower_bound[OF assms(1,2)] div_le_mono
    by blast
  also have " = nodes t"
    using root_order_imp_divmuleq[OF assms(1)]
    by simp
  finally show ?thesis .
qed simp

lemma nodes_root_height_upper_bound_simp:
  assumes "root_order k t"
    and "bal t"
  shows "nodes t  ((2*k+1)^(height t) - 1) div (2*k)"
proof -
  have "nodes t = (nodes t * (2*k)) div (2*k)"
    using root_order_imp_divmuleq[OF assms(1)]
    by simp
  also have "  ((2*k+1)^(height t) - 1) div (2*k)"
    using div_le_mono nodes_root_height_upper_bound[OF assms] by blast
  finally show ?thesis .
qed

definition "full_tree = full_node"

fun slim_tree where
  "slim_tree k c 0 = Leaf" |
  "slim_tree k c (Suc h) = Node [(slim_node k c h, c)] (slim_node k c h)"

lemma lower_bound_sharp:
  "k > 0  t = slim_tree k a h  height t = h  root_order k t  bal t  nodes t * k = 2*((k+1)^(height t - 1) - 1) + (of_bool (t  Leaf))*k"
  apply (cases h)
  using slim_nodes_sharp[of k a]
   apply (auto simp add: algebra_simps bal_slim_node height_slim_node order_slim_node)
  done

lemma upper_bound_sharp:
  "k > 0  t = full_tree k a h  height t = h  root_order k t  bal t  ((2*k+1)^(height t) - 1) = nodes t * (2*k)"
  unfolding full_tree_def
  using order_impl_root_order[of k t]
  by (simp add: bal_full_node height_full_node order_full_node full_btrees_sharp)


end