Theory Tensor_Scalar_Mult

theory Tensor_Scalar_Mult
imports Tensor_Plus
(* Author: Alexander Bentkamp, Universit├Ąt des Saarlandes
*)
section ‹Tensor Scalar Multiplication›

theory Tensor_Scalar_Mult
imports Tensor_Plus Tensor_Subtensor
begin

definition vec_smult::"'a::ring ⇒ 'a list ⇒ 'a list" where
"vec_smult α β = map ((*) α) β"

lemma vec_smult0: "vec_smult 0 as = vec0 (length as)"
  by (induction as; auto simp add:vec0_def vec_smult_def)

lemma vec_smult_distr_right:
shows "vec_smult (α + β) as = vec_plus (vec_smult α as) (vec_smult β as)"
  unfolding vec_smult_def vec_plus_def
  by (induction as; simp add: distrib_right)

lemma vec_smult_Cons:
shows "vec_smult α (a # as) = (α * a) # vec_smult α as" by (simp add: vec_smult_def)

lemma vec_plus_Cons:
shows "vec_plus (a # as) (b # bs) = (a+b) # vec_plus as bs" by (simp add: vec_plus_def)

lemma vec_smult_distr_left:
assumes "length as = length bs"
shows "vec_smult α (vec_plus as bs) = vec_plus (vec_smult α as) (vec_smult α bs)"
using assms proof (induction as arbitrary:bs)
  case Nil
  then show ?case unfolding vec_smult_def vec_plus_def by simp
next
  case (Cons a as')
  then obtain b bs' where "bs = b # bs'" by (metis Suc_length_conv)
  then have 0:"vec_smult α (vec_plus (a # as') bs) = (α*(a+b)) # vec_smult α (vec_plus as' bs')"
    unfolding vec_smult_def vec_plus_def using Cons.IH[of bs'] by simp
  have "length bs' = length as'" using Cons.prems ‹bs = b # bs'› by auto
  then show ?case unfolding 0 unfolding  ‹bs = b # bs'› vec_smult_Cons vec_plus_Cons
    by (simp add: Cons.IH distrib_left)
qed

lemma length_vec_smult: "length (vec_smult α v) = length v" unfolding vec_smult_def by simp

definition smult::"'a::ring ⇒ 'a tensor ⇒ 'a tensor" (infixl "⋅" 70) where
"smult α A = (tensor_from_vec (dims A) (vec_smult α (vec A)))"


lemma tensor_smult0: fixes A::"'a::ring tensor"
shows "0 ⋅ A = tensor0 (dims A)"
  unfolding smult_def tensor0_def vec_smult_def using vec_smult0 length_vec
  by (metis (no_types) vec_smult_def)

lemma dims_smult[simp]:"dims (α ⋅ A) = dims A"
and   vec_smult[simp]: "vec  (α ⋅ A) = map ((*) α) (vec A)"
  unfolding smult_def vec_smult_def by (simp add: length_vec)+

lemma tensor_smult_distr_right: "(α + β) ⋅ A = α ⋅ A  + β ⋅ A"
  unfolding plus_def plus_base_def
  by (auto; metis smult_def vec_smult_def vec_smult_distr_right)

lemma tensor_smult_distr_left: "dims A = dims B ⟹ α ⋅ (A + B) = α ⋅ A  + α ⋅ B"
proof -
  assume a1: "dims A = dims B"
  then have f2: "length (vec_plus (vec A) (vec B)) = length (vec A)"
    by (simp add: length_vec vec_plus_def)
  have f3: "dims (tensor_from_vec (dims B) (vec_smult α (vec A))) = dims B"
    using a1 by (simp add: length_vec vec_smult_def)
  have f4: "vec (α ⋅ A) = vec_smult α (vec A)"
    by (simp add: vec_smult_def)
  have "length (vec_smult α (vec B)) = length (vec B)"
    by (simp add: vec_smult_def)
  then show ?thesis
    unfolding plus_def plus_base_def using f4 f3 f2 a1
    by (simp add: length_vec smult_def vec_smult_distr_left)
qed

lemma smult_fixed_length_sublist:
assumes "length xs = l * c" "i<c"
shows "fixed_length_sublist (vec_smult α xs) l i = vec_smult α (fixed_length_sublist xs l i)"
unfolding fixed_length_sublist_def vec_smult_def by (simp add: drop_map take_map)

lemma smult_subtensor:
assumes "dims A ≠ []" "i < hd (dims A)"
shows "α ⋅ subtensor A i = subtensor (α ⋅ A) i"
proof (rule tensor_eqI)
  show "dims (α ⋅ subtensor A i) = dims (subtensor (α ⋅ A) i)"
    using dims_smult dims_subtensor assms(1) assms(2) by simp
  show "vec (α ⋅ subtensor A i) = vec (subtensor (α ⋅ A) i)"
    unfolding vec_smult
    unfolding vec_subtensor[OF ‹dims A ≠ []› ‹i < hd (dims A)›]
    using vec_subtensor[of "α ⋅ A" i]
    by (simp add: assms(1) assms(2) drop_map fixed_length_sublist_def take_map)
qed

lemma lookup_smult:
assumes "is ⊲ dims A"
shows "lookup (α ⋅ A) is = α * lookup A is"
using assms proof (induction A arbitrary:"is" rule:subtensor_induct)
  case (order_0 A "is")
  then have "length (vec A) = 1" by (simp add: length_vec)
  then have "hd (vec_smult α (vec A)) = α * hd (vec A)" unfolding vec_smult_def by (metis list.map_sel(1) list.size(3) zero_neq_one)
  moreover have "is = []" using order_0 by auto
  ultimately show ?case unfolding smult_def by (auto simp add: ‹length (Tensor.vec A) = 1› lookup_def length_vec_smult order_0.hyps)
next
  case (order_step A "is")
  then obtain i is' where "is = i # is'" by blast
  then have "lookup (α ⋅ subtensor A i) is' = α * lookup (subtensor A i) is'"
    by (metis (no_types, lifting) dims_subtensor list.sel(1) list.sel(3) order_step.IH order_step.hyps order_step.prems valid_index_dimsE)
  then show ?case using smult_subtensor ‹is = i # is'› dims_smult lookup_subtensor1
    list.sel(1) order_step.hyps order_step.prems valid_index_dimsE
    by metis
qed

lemma tensor_smult_assoc:
fixes A::"'a::ring tensor"
shows "α ⋅ (β ⋅ A) = (α * β) ⋅ A"
by (rule tensor_lookup_eqI, simp, metis lookup_smult dims_smult mult.assoc)

end