Theory List_Vector

(*
Author:  Florian Messner <florian.g.messner@uibk.ac.at>
Author:  Julian Parsert <julian.parsert@gmail.com>
Author:  Jonas Schöpf <jonas.schoepf@uibk.ac.at>
Author:  Christian Sternagel <c.sternagel@gmail.com>
License: LGPL
*)

section ‹Vectors as Lists of Naturals›

theory List_Vector
  imports Main
begin

(*TODO: move*)
lemma lex_lengthD: "(x, y)  lex P  length x = length y"
  by (auto simp: lexord_lex)

(*TODO: move*)
lemma lex_take_index:
  assumes "(xs, ys)  lex r"
  obtains i where "length ys = length xs"
    and "i < length xs" and "take i xs = take i ys"
    and "(xs ! i, ys ! i)  r"
proof -
  obtain n us x xs' y ys' where "(xs, ys)  lexn r n" and "length xs = n" and "length ys = n"
    and "xs = us @ x # xs'" and "ys = us @ y # ys'" and "(x, y)  r"
    using assms by (fastforce simp: lex_def lexn_conv)
  then show ?thesis by (intro that [of "length us"]) auto
qed

(*TODO: move*)
lemma mods_with_nats:
  assumes "(v::nat) > w"
    and "(v * b) mod a = (w * b) mod a"
  shows "((v - w) * b) mod a = 0"
  using assms by (simp add: mod_eq_dvd_iff_nat algebra_simps)

― ‹The 0-vector of length n›.›
abbreviation zeroes :: "nat  nat list"
  where
    "zeroes n  replicate n 0"

lemma rep_upd_unit:
  assumes "x = (zeroes n)[i := a]"
  shows "j < length x. (j  i  x ! j = 0)  (j = i  x ! j = a)"
  using assms by simp

definition nonzero_iff: "nonzero xs  (xset xs. x  0)"

lemma nonzero_append [simp]:
  "nonzero (xs @ ys)  nonzero xs  nonzero ys" by (auto simp: nonzero_iff)


subsection ‹The Inner Product›

definition dotprod :: "nat list  nat list  nat" (infixl  70)
  where
    "xs  ys = (i<min (length xs) (length ys). xs ! i * ys ! i)"

lemma dotprod_code [code]:
  "xs  ys = sum_list (map (λ(x, y). x * y) (zip xs ys))"
  by (auto simp: dotprod_def sum_list_sum_nth lessThan_atLeast0)

lemma dotprod_commute:
  assumes "length xs = length ys"
  shows "xs  ys = ys  xs"
  using assms by (auto simp: dotprod_def mult.commute)

lemma dotprod_Nil [simp]: "[]  [] = 0"
  by (simp add: dotprod_def)

lemma dotprod_Cons [simp]:
  "(x # xs)  (y # ys) = x * y + xs  ys"
  unfolding dotprod_def and length_Cons and min_Suc_Suc and sum.lessThan_Suc_shift by auto

lemma dotprod_1_right [simp]:
  "xs  replicate (length xs) 1 = sum_list xs"
  by (induct xs) (simp_all)

lemma dotprod_0_right [simp]:
  "xs  zeroes (length xs) = 0"
  by (induct xs) (simp_all)

lemma dotprod_unit [simp]:
  assumes "length a = n"
    and "k < n"
  shows "a  (zeroes n)[k := zk] = a ! k * zk"
  using assms by (induct a arbitrary: k n) (auto split: nat.splits)

lemma dotprod_gt0:
  assumes "length x = length y" and "i<length y. x ! i > 0  y ! i > 0"
  shows "x  y > 0"
  using assms by (induct x y rule: list_induct2) (fastforce simp: nth_Cons split: nat.splits)+

lemma dotprod_gt0D:
  assumes "length x = length y"
    and "x  y > 0"
  shows "i<length y. x ! i > 0  y ! i > 0"
  using assms by (induct x y rule: list_induct2) (auto simp: Ex_less_Suc2)

lemma dotprod_gt0_iff [iff]:
  assumes "length x = length y"
  shows "x  y > 0  (i<length y. x ! i > 0  y ! i > 0)"
  using assms and dotprod_gt0D and dotprod_gt0 by blast

lemma dotprod_append:
  assumes "length a = length b"
  shows"(a @ x)  (b @ y) = a  b + x  y"
  using assms by (induct a b rule: list_induct2) auto

lemma dotprod_le_take:
  assumes "length a = length b"
    and "k  length a"
  shows"take k a  take k b  a  b"
  using assms and append_take_drop_id [of k a] and append_take_drop_id [of k b]
  by (metis add_right_cancel leI length_append length_drop not_add_less1 dotprod_append)

lemma dotprod_le_drop:
  assumes "length a = length b"
    and "k  length a"
  shows "drop k a  drop k b  a  b"
  using assms and append_take_drop_id [of k a] and append_take_drop_id [of k b]
  by (metis dotprod_append length_take order_refl trans_le_add2)

lemma dotprod_is_0 [simp]:
  assumes "length x = length y"
  shows "x  y = 0  (i<length y. x ! i = 0  y ! i = 0)"
  using assms by (metis dotprod_gt0_iff neq0_conv)

lemma dotprod_eq_0_iff:
  assumes "length x = length a"
    and "0  set a"
  shows "x  a = 0  (e  set x. e = 0)"
  using assms by (fastforce simp: in_set_conv_nth)

lemma dotprod_eq_nonzero_iff:
  assumes "a  x = b  y" and "length x = length a" and "length y = length b"
    and "0  set a" and "0  set b"
  shows "nonzero x  nonzero y"
  using assms by (auto simp: nonzero_iff) (metis dotprod_commute dotprod_eq_0_iff neq0_conv)+

lemma eq_0_iff:
  "xs = zeroes n  length xs = n  (xset xs. x = 0)"
  using in_set_replicate [of _ n 0] and replicate_eqI [of xs n 0] by auto

lemma not_nonzero_iff: "¬ nonzero x  x = zeroes (length x)"
  by (auto simp: nonzero_iff replicate_length_same eq_0_iff)

lemma neq_0_iff':
  "xs  zeroes n  length xs  n  (xset xs. x > 0)"
  by (auto simp: eq_0_iff)

lemma dotprod_pointwise_le:
  assumes "length as = length xs"
    and "i < length as"
  shows "as ! i * xs ! i  as  xs"
proof -
  have "as  xs = (i<min (length as) (length xs). as ! i * xs ! i)"
    by (simp add: dotprod_def)
  then show ?thesis
    using assms by (auto intro: member_le_sum)
qed

lemma replicate_dotprod:
  assumes "length y = n"
  shows "replicate n x  y = x * sum_list y"
proof -
  have "x * (i<length y.  y ! i) = (i<length y. x * y ! i)"
    using sum_distrib_left by blast
  then show ?thesis
    using assms by (auto simp: dotprod_def sum_list_sum_nth atLeast0LessThan)
qed


subsection ‹The Pointwise Order on Vectors›

definition  less_eq :: "nat list  nat list  bool" (‹_/ v _› [51, 51] 50)
  where
    "xs v ys  length xs = length ys  (i<length xs. xs ! i  ys ! i)"

definition less :: "nat list  nat list  bool" (‹_/ <v _› [51, 51] 50)
  where
    "xs <v ys  xs v ys  ¬ ys v xs"

interpretation order_vec: order less_eq less
  by (standard, auto simp add: less_def less_eq_def dual_order.antisym nth_equalityI) (force)

lemma less_eqI [intro?]: "length xs = length ys  i<length xs. xs ! i  ys ! i  xs v ys"
  by (auto simp: less_eq_def)

lemma le0 [simp, intro]: "zeroes (length xs) v xs" by (simp add: less_eq_def)

lemma le_list_update [simp]:
  assumes "xs v ys" and "i < length ys" and "z  ys ! i"
  shows "xs[i := z] v ys"
  using assms by (auto simp: less_eq_def nth_list_update)

lemma le_Cons: "x # xs v y # ys  x  y  xs v ys"
  by (auto simp add: less_eq_def nth_Cons split: nat.splits)

lemma zero_less:
  assumes "nonzero x"
  shows "zeroes (length x) <v x"
  using assms and eq_0_iff order_vec.dual_order.strict_iff_order
  by (auto simp: nonzero_iff)

lemma le_append:
  assumes "length xs = length vs"
  shows "xs @ ys v vs @ ws  xs v vs  ys v ws"
  using assms
  by (auto simp: less_eq_def nth_append)
    (metis add.commute add_diff_cancel_left' nat_add_left_cancel_less not_add_less2)

lemma less_Cons:
  "(x # xs) <v (y # ys)  length xs = length ys  (x  y  xs <v ys  x < y  xs v ys)"
  by (simp add: less_def less_eq_def All_less_Suc2) (auto dest: leD)

lemma le_length [dest]:
  assumes "xs v ys"
  shows "length xs = length ys"
  using assms by (simp add: less_eq_def)

lemma less_length [dest]:
  assumes "x <v y"
  shows "length x = length y"
  using assms by (auto simp: less_def)

lemma less_append:
  assumes "xs <v vs " and "ys v ws"
  shows "xs @ ys <v vs @ ws"
proof -
  have "length xs  = length vs"
    using assms by blast
  then show ?thesis
    using assms by (induct xs vs rule: list_induct2) (auto simp: less_Cons le_append le_length)
qed

lemma less_appendD:
  assumes "xs @ ys <v vs @ ws"
    and "length xs = length vs"
  shows "xs <v vs  ys <v ws"
  by (auto) (metis (no_types, lifting) assms le_append order_vec.order.strict_iff_order)

lemma less_append_cases:
  assumes "xs @ ys <v vs @ ws" and "length xs = length vs"
  obtains "xs <v vs" and "ys v ws" | "xs v vs" and "ys <v ws"
  using assms and that
  by (metis le_append less_appendD order_vec.order.strict_implies_order)

lemma less_append_swap:
  assumes "x @ y <v u @ v"
    and "length x = length u"
  shows "y @ x <v v @ u"
  using assms(2, 1)
  by (induct x u rule: list_induct2)
    (auto simp: order_vec.order.strict_iff_order le_Cons le_append le_length)

lemma le_sum_list_less:
  assumes "xs v ys"
    and "sum_list xs < sum_list ys"
  shows "xs <v ys"
proof -
  have "length xs = length ys" and "i<length ys. xs ! i  ys ! i"
    using assms by (auto simp: less_eq_def)
  then show ?thesis
    using sum_list xs < sum_list ys
    by (induct xs ys rule: list_induct2)
      (auto simp: less_Cons All_less_Suc2 less_eq_def)
qed

lemma dotprod_le_right:
  assumes "v v w"
    and "length b = length w"
  shows "b  v  b  w"
  using assms by (auto simp: dotprod_def less_eq_def intro: sum_mono)

lemma dotprod_pointwise_le_right:
  assumes "length z = length u"
    and "length u = length v"
    and "i<length v. u ! i  v ! i"
  shows "z  u  z  v"
  using assms by (intro dotprod_le_right) (auto intro: less_eqI)

lemma dotprod_le_left:
  assumes "v v w"
    and "length b = length w"
  shows "v  b  w  b "
  using assms by (simp add: dotprod_le_right dotprod_commute le_length)

lemma dotprod_le:
  assumes "x v u" and "y v v"
    and "length y = length x" and "length v = length u"
  shows "x  y  u  v"
  using assms by (metis dotprod_le_left dotprod_le_right le_length le_trans)

lemma dotprod_less_left:
  assumes "length b = length w"
    and "0  set b"
    and "v <v w"
  shows "v  b < w  b"
proof -
  have "length v = length w" using assms
    using less_eq_def order_vec.order.strict_implies_order by blast
  then show ?thesis
    using assms
  proof (induct v w arbitrary: b rule: list_induct2)
    case (Cons x xs y ys)
    then show ?case
      by (cases b) (auto simp: less_Cons add_mono_thms_linordered_field dotprod_le_left)
  qed simp
qed

lemma le_append_swap:
  assumes "length y = length v"
    and "x @ y v w @ v"
  shows "y @ x v v @ w"
proof -
  have "length w = length x" using assms by auto
  with assms show ?thesis
    by (induct y v arbitrary: x w rule: list_induct2) (auto simp: le_Cons le_append)
qed

lemma le_append_swap_iff:
  assumes "length y = length v"
  shows "y @ x v v @ w   x @ y v w @ v"
  using assms and le_append_swap
  by (auto) (metis (no_types, lifting) add_left_imp_eq le_length length_append)

lemma unit_less:
  assumes "i < n"
    and "x <v (zeroes n)[i := b]"
  shows "x ! i < b  (j<n. j  i  x ! j = 0)"
proof
  show "x ! i < b"
    using assms less_def by fastforce
next
  have "x v (zeroes n)[i := b]" by (simp add: assms order_vec.less_imp_le)
  then show "j<n. j  i  x ! j = 0" by (auto simp: less_eq_def)
qed

lemma le_sum_list_mono:
  assumes "xs v ys"
  shows "sum_list xs  sum_list ys"
  using assms and sum_list_mono [of "[0..<length ys]" "(!) xs" "(!) ys"]
  by (auto simp: less_eq_def) (metis map_nth)

lemma sum_list_less_diff_Ex:
  assumes "u v y"
    and "sum_list u < sum_list y"
  shows "i<length y. u ! i < y ! i"
proof -
  have "length u = length y" and "i<length y. u ! i  y ! i"
    using u v y by (auto simp: less_eq_def)
  then show ?thesis
    using sum_list u < sum_list y
    by (induct u y rule: list_induct2) (force simp: Ex_less_Suc2 All_less_Suc2)+
qed

lemma less_vec_sum_list_less:
  assumes "v <v w"
  shows "sum_list v < sum_list w"
  using assms
proof -
  have "length v = length w"
    using assms less_eq_def less_imp_le by blast
  then show ?thesis
    using assms
  proof (induct v w rule: list_induct2)
    case (Cons x xs y ys)
    then show ?case
      using length_replicate less_Cons order_vec.order.strict_iff_order by force
  qed simp
qed

definition maxne0 :: "nat list  nat list  nat"
  where
    "maxne0 x a =
      (if length x = length a  (i<length a. x ! i  0)
      then Max {a ! i | i. i < length a  x ! i  0}
      else 0)"

lemma maxne0_le_Max:
  "maxne0 x a  Max (set a)"
  by (auto simp: maxne0_def nonzero_iff in_set_conv_nth) simp

lemma maxne0_Nil [simp]:
  "maxne0 [] as = 0"
  "maxne0 xs [] = 0"
  by (auto simp: maxne0_def)

lemma maxne0_Cons [simp]:
  "maxne0 (x # xs) (a # as) =
    (if length xs = length as then
      (if x = 0 then maxne0 xs as else max a (maxne0 xs as))
    else 0)"
proof -
  let ?a = "a # as" and ?x = "x # xs"
  have eq: "{?a ! i | i. i < length ?a  ?x ! i  0} =
    (if x > 0 then {a} else {})  {as ! i | i. i < length as  xs ! i  0}"
    by (auto simp: nth_Cons split: nat.splits) (metis Suc_pred)+
  show ?thesis
    unfolding maxne0_def and eq
    by (auto simp: less_Suc_eq_0_disj nth_Cons' intro: Max_insert2)
qed

lemma maxne0_times_sum_list_gt_dotprod:
  assumes "length b = length ys"
  shows "maxne0 ys b * sum_list ys  b  ys"
  using assms
  apply (induct b ys rule: list_induct2)
   apply (auto  simp: max_def ring_distribs add_mono_thms_linordered_semiring(1))
  by (meson leI le_trans mult_less_cancel2 nat_less_le)

lemma max_times_sum_list_gt_dotprod:
  assumes "length b = length ys"
  shows "Max (set b) * sum_list ys  b  ys"
proof -
  have " e  set b . Max (set b)  e" by simp
  then have "replicate (length ys) (Max (set b))  ys  b  ys" (is "?rep  _")
    by (metis assms dotprod_pointwise_le_right dotprod_commute
        length_replicate nth_mem nth_replicate)
  moreover have "Max (set b) * sum_list ys = ?rep"
    using replicate_dotprod [of ys _ "Max (set b)"] by auto
  ultimately show ?thesis
    by (simp add: assms)
qed

lemma maxne0_mono:
  assumes "y v x"
  shows "maxne0 y a  maxne0 x a"
proof (cases "length y = length a")
  case True
  have "length y = length x" using assms by (auto)
  then show ?thesis
    using assms and True
  proof (induct y x arbitrary: a rule: list_induct2)
    case (Cons x xs y ys)
    then show ?case by (cases a) (force simp: less_eq_def All_less_Suc2 le_max_iff_disj)+
  qed simp
next
  case False
  then show ?thesis
    using assms by (auto simp: maxne0_def)
qed

lemma all_leq_Max:
  assumes "x v y"
    and "x  []"
  shows "xi  set x. xi  Max (set y)"
  by (metis (no_types, lifting) List.finite_set Max_ge_iff
      assms in_set_conv_nth length_0_conv less_eq_def set_empty)

lemma le_not_less_replicate:
  "xset xs. x  b  ¬ xs <v replicate (length xs) b  xs = replicate (length xs) b"
  by (induct xs) (auto simp: less_Cons)

lemma le_replicateI: "xset xs. x  b  xs v replicate (length xs) b"
  by (induct xs) (auto simp: le_Cons)

lemma le_take:
  assumes "x v y" and "i  length x" shows "take i x v take i y"
  using assms by (auto simp: less_eq_def)

lemma wf_less:
  "wf {(x, y). x <v y}"
proof -
  have "wf (measure sum_list)" ..
  moreover have "{(x, y). x <v y}  measure sum_list"
    by (auto simp: less_vec_sum_list_less)
  ultimately show "wf {(x, y). x <v y}"
    by (rule wf_subset)
qed


subsection ‹Pointwise Subtraction›

definition vdiff :: "nat list  nat list  nat list" (infixl -v 65)
  where
    "w -v v = map (λi. w ! i - v ! i) [0 ..< length w]"

lemma vdiff_Nil [simp]: "[] -v [] = []" by (simp add: vdiff_def)

lemma upt_Cons_conv:
  assumes "j < n"
  shows "[j..<n] = j # [j+1..<n]"
  by (simp add: assms upt_eq_Cons_conv)

lemma map_upt_Suc: "map f [Suc m ..< Suc n] = map (f  Suc) [m ..< n]"
  by (fold list.map_comp [of f "Suc" "[m ..< n]"]) (simp add: map_Suc_upt)

lemma vdiff_Cons [simp]:
  "(x # xs) -v (y # ys) = (x - y) # (xs -v ys)"
  by (simp add: vdiff_def upt_Cons_conv [OF zero_less_Suc] map_upt_Suc del: upt_Suc)

lemma vdiff_alt_def:
  assumes "length w = length v"
  shows "w -v v = map (λ(x, y). x - y) (zip w v)"
  using assms by (induct rule: list_induct2) simp_all

lemma vdiff_dotprod_distr:
  assumes "length b = length w"
    and "v v w"
  shows "(w -v v)  b = w  b - v  b"
proof -
  have "length v = length w" and "i<length w. v ! i  w ! i"
    using assms less_eq_def by auto
  then show ?thesis
    using length b = length w
  proof (induct v w arbitrary: b rule: list_induct2)
    case (Cons x xs y ys)
    then show ?case
      by (cases b) (auto simp: All_less_Suc2 diff_mult_distrib
          dotprod_commute dotprod_pointwise_le_right)
  qed simp
qed

lemma sum_list_vdiff_distr [simp]:
  assumes "v v u"
  shows "sum_list (u -v v) = sum_list u - sum_list v"
  by (metis (no_types, lifting) assms diff_zero dotprod_1_right
      length_map length_replicate length_upt
      less_eq_def vdiff_def vdiff_dotprod_distr)

lemma vdiff_le:
  assumes "v v w"
    and "length v = length x"
  shows "v -v x v w"
  using assms by (auto simp add: less_eq_def vdiff_def)

lemma mods_with_vec:
  assumes "v <v w"
    and "0  set b"
    and "length b = length w"
    and "(v  b) mod a = (w  b) mod a"
  shows "((w -v v)  b) mod a = 0"
proof -
  have *: "v  b < w  b"
    using dotprod_less_left and assms by blast
  have "v v w"
    using assms by auto
  from vdiff_dotprod_distr [OF assms(3) this]
  have "((w -v v)  b) mod a = (w  b - v  b) mod a "
    by simp
  also have "... = 0 mod a"
    using mods_with_nats [of "v  b" "w  b" "1" a, OF *] assms by auto
  finally show ?thesis by simp
qed

lemma mods_with_vec_2:
  assumes "v <v w"
    and "0  set b"
    and "length b = length w"
    and "(b  v) mod a = (b  w) mod a"
  shows "(b  (w -v v)) mod a = 0"
  by (metis (no_types, lifting) assms diff_zero dotprod_commute
      length_map length_upt less_eq_def order_vec.less_imp_le
      mods_with_vec vdiff_def)


subsection ‹The Lexicographic Order on Vectors›

abbreviation lex_less_than (‹_/ <lex _› [51, 51] 50)
  where
    "xs <lex ys  (xs, ys)  lex less_than"

definition rlex (infix <rlex 50)
  where
    "xs <rlex ys  rev xs <lex rev ys"

lemma rev_le [simp]:
  "rev xs v rev ys  xs v ys"
proof -
  { fix i assume i: "i < length ys" and [simp]: "length xs = length ys"
      and "i < length ys. rev xs ! i  rev ys ! i"
    then have "rev xs ! (length ys - i - 1)  rev ys ! (length ys - i - 1)" by auto
    then have "xs ! i  ys ! i" using i by (auto simp: rev_nth) }
  then show ?thesis by (auto simp: less_eq_def rev_nth)
qed

lemma rev_less [simp]:
  "rev xs <v rev ys  xs <v ys"
  by (simp add: less_def)

lemma less_imp_lex:
  assumes "xs <v ys" shows "xs <lex ys"
proof -
  have "length ys = length xs" using assms by auto
  then show ?thesis using assms
    by (induct rule: list_induct2) (auto simp: less_Cons)
qed

lemma less_imp_rlex:
  assumes "xs <v ys" shows "xs <rlex ys"
  using assms and less_imp_lex [of "rev xs" "rev ys"]
  by (simp add: rlex_def)

lemma lex_not_sym:
  assumes "xs <lex ys"
  shows "¬ ys <lex xs"
proof
  assume "ys <lex xs"
  then obtain i where "i < length xs" and "take i xs = take i ys"
    and "ys ! i < xs ! i" by (elim lex_take_index) auto
  moreover obtain j where "j < length xs" and "length ys = length xs" and "take j xs = take j ys"
    and "xs ! j < ys ! j" using assms by (elim lex_take_index) auto
  ultimately show False by (metis le_antisym nat_less_le nat_neq_iff nth_take)
qed

lemma rlex_not_sym:
  assumes "xs <rlex ys"
  shows "¬ ys <rlex xs"
proof
  assume ass: "ys <rlex xs"
  then obtain i where "i < length xs" and "take i xs = take i ys"
    and "ys ! i > xs ! i" using assms lex_not_sym rlex_def by blast
  moreover obtain j where "j < length xs" and "length ys = length xs" and "take j xs = take j ys"
    and "xs ! j > ys ! j" using assms rlex_def ass lex_not_sym by blast
  ultimately show False
    by (metis leD nat_less_le nat_neq_iff nth_take)
qed

lemma lex_trans:
  assumes "x <lex y" and "y <lex z"
  shows "x <lex z"
  using assms by (auto simp: antisym_def intro: transD [OF lex_transI])

lemma rlex_trans:
  assumes "x <rlex y" and "y <rlex z"
  shows "x <rlex z"
  using assms lex_trans rlex_def by blast

lemma lex_append_rightD:
  assumes "xs @ us <lex ys @ vs" and "length xs = length ys"
    and "¬ xs <lex ys"
  shows "ys = xs  us <lex vs"
  using assms(2,1,3)
  by (induct xs ys rule: list_induct2) auto

lemma rlex_Cons:
  "x # xs <rlex y # ys  xs <rlex ys  ys = xs  x < y" (is "?A = ?B")
  by (cases "length ys = length xs")
    (auto simp: rlex_def intro: lex_append_rightI lex_append_leftI dest: lex_append_rightD lex_lengthD)

lemma rlex_irrefl:
  "¬ x <rlex x"
  by (induct x) (auto simp: rlex_def dest: lex_append_rightD)


subsection ‹Code Equations›

fun exists2
  where
    "exists2 d P [] []  False"
  | "exists2 d P (x#xs) (y#ys)  P x y  exists2 d P xs ys"
  | "exists2 d P _ _  d"

lemma not_le_code [code_unfold]: "¬ xs v ys  exists2 True (>) xs ys"
  by (induct "True" "(>) :: nat  nat  bool" xs ys rule: exists2.induct) (auto simp: le_Cons)

end