Theory Tensor_Product

(* Author: Alexander Bentkamp, Universität des Saarlandes
*)
section ‹Tensor Product›

theory Tensor_Product
imports Tensor_Scalar_Mult Tensor_Subtensor
begin

instantiation tensor:: (ring) semigroup_mult
begin
  definition tensor_prod_def:"A * B = tensor_from_vec (dims A @ dims B) (concat (map (λa. vec_smult a (vec B)) (vec A)))"

  abbreviation tensor_prod_otimes :: "'a tensor  'a tensor  'a tensor" (infixl  70)
    where "A  B  A * B"


  lemma vec_tensor_prod[simp]: "vec (A  B) = concat (map (λa. vec_smult a (vec B)) (vec A))" (is ?V)
  and dims_tensor_prod[simp]: "dims (A  B) = dims A @ dims B" (is ?D)
  proof -
    have "length (concat (map (λa. vec_smult a (vec B)) (vec A))) = prod_list (dims A @ dims B)"
    proof -
      have "xs. xs  set (map (λa. vec_smult a (vec B)) (vec A))  length xs = length (vec B)"
        using length_vec_smult by force
      then show ?thesis using concat_equal_length by (metis length_map length_vec prod_list.append)
    qed
    then show ?V ?D by (simp add: tensor_prod_def)+
  qed

  lemma tensorprod_subtensor_base:
  shows "concat (map f (concat xss)) = concat (map (λxs. concat (map f xs)) xss)"
  by (induction xss; auto)

  lemma subtensor_combine_tensor_prod:
  assumes "A. Aset As  dims A = ds"
  shows "subtensor_combine ds As  B = subtensor_combine (ds @ dims B) (map (λA. A  B) As)"
  proof -
    let ?f = "λa. vec_smult a (Tensor.vec B)"
    let ?xss = "map Tensor.vec As"
    have 1:"prod_list (length As # ds) = length (concat ?xss)" by (metis assms length_vec subtensor_combine_dims subtensor_combine_vec)
    have 2:"A. Aset As  prod_list (dims A @ dims B) = length (concat (map ?f (Tensor.vec A)))"
      by (metis dims_tensor_prod length_vec vec_tensor_prod)
    have 3: "length As # ds @ dims B = (length (map (λA. tensor_from_vec (dims A @ dims B)
      (concat (map (λa. vec_smult a (vec B)) (vec A)))) As) # ds @ dims B)" by simp
    have 4:"(concat (map (λxs. concat (map (λa. vec_smult a (vec B)) xs)) (map vec As)))
         = (concat (map vec (map (λA. tensor_from_vec (dims A @ dims B) (concat (map (λa. vec_smult a (vec B)) (vec A))))  As)))"
      unfolding map_map[unfolded comp_def]  using vec_tensor by (metis (no_types, lifting) "2" map_eq_conv)
    have "subtensor_combine ds As  B = tensor_from_vec (length As # ds @ dims B) (concat (map ?f (concat (?xss))))"
      unfolding subtensor_combine_def tensor_prod_def using 1 by auto
    also have "... = tensor_from_vec (length As # ds @ dims B) (concat (map (λxs. concat (map ?f xs)) ?xss))"
      using tensorprod_subtensor_base[of ?f ?xss] by auto
    also have "... =subtensor_combine (ds @ dims B) (map (λA. A  B) As)"
      unfolding subtensor_combine_def tensor_prod_def using 3 4 by metis
    finally show ?thesis by metis
  qed

  (* evtl. besser ohne Induktion beweisen? Dann wäre das obige Lemma unnötig. Vllt aber auch schwierig. *)
  lemma subtensor_tensor_prod:
  assumes "dims A  []" and "i < hd (dims A)"
  shows "subtensor (A  B) i = subtensor A i  B"
  using assms proof (induction A rule:subtensor_combine_induct)
    case order_0
    then show ?case by auto
  next
    case (order_step As ds)
    have 1:"i < length (map (λA. A  B) As)" using order_step by (simp add: order_step.hyps order_step.prems(1))
    have 2:"(A. A  set (map (λA. A  B) As)  dims A = ds @ dims B)" using order_step by auto
    have "subtensor (subtensor_combine ds As  B) i = subtensor (subtensor_combine (ds @ dims B) (map (λA. A  B) As)) i"
      using subtensor_combine_tensor_prod order_step by metis
    also have "... = As ! i  B"
      using order_step subtensor_subtensor_combine[of "(map (λA. A  B) As)" "ds @ dims B" i] 1 2 by auto
    also have "... = subtensor (subtensor_combine ds As) i  B"
      by (metis "1" length_map order_step.hyps subtensor_subtensor_combine)
    finally show ?case by auto
  qed

  lemma lookup_tensor_prod[simp]:
  assumes is1_valid:"is1  dims A" and is2_valid:"is2  dims B"
  shows "lookup (A  B) (is1 @ is2) = lookup A is1 * lookup B is2"
  using assms proof (induction A arbitrary:is1 rule:subtensor_induct)
    case (order_0 A is1)
    then obtain a where "vec A = [a]"
      using Suc_length_conv Tensor.tensor_vec_from_lookup_Nil length_0_conv length_tensor_vec_from_lookup length_vec by metis
    then have "A  B = a  B" unfolding tensor_prod_def smult_def using order_0 by simp
    moreover have "lookup A [] = a" by (simp add: Tensor.vec A = [a] lookup_def order_0.hyps)
    ultimately have "lookup (A  B) (is2) = a * lookup B is2" by (simp add: lookup_smult is2_valid)
    then show ?case using lookup A [] = a null_rec(1) order_0.hyps order_0.prems(1) by auto
  next
    case (order_step A is1)
    then obtain i is1' where "i # is1' = is1" by blast
    have "lookup (subtensor A i  B) (is1' @ is2) = lookup (subtensor A i) is1' * lookup B is2" using order_step
      by (metis i # is1' = is1 dims_subtensor list.sel(1) list.sel(3) valid_index_dimsE)
    then show "lookup (A  B) (is1 @ is2) = lookup A is1 * lookup B is2"
      using lookup_subtensor1[of i is1' A] lookup_subtensor1[of i "is1' @ is2" "AB"] subtensor_tensor_prod[of A i B]
      Cons_eq_appendI i # is1' = is1 dims_tensor_prod is2_valid list.sel(1) order_step.hyps order_step.prems(1) valid_index_append valid_index_dimsE
      by metis
  qed

  lemma valid_index_split:
  assumes "is  ds1 @ ds2"
  obtains is1 is2 where "is1 @ is2 = is" "is1  ds1" "is2  ds2"
  proof
    assume a: "is1 is2. is1 @ is2 = is  is1  ds1  is2  ds2  thesis"
    have length_is:"length is = length ds1 + length ds2" using valid_index_length using assms by auto
    show "take (length ds1) is  ds1"
      apply (rule valid_indexI)
      using valid_index_length using assms apply auto[1]
      by (metis add_leD1 assms length_append not_less nth_append nth_take valid_index_lt)
    show "drop (length ds1) is  ds2"
      apply (rule valid_indexI)
      using valid_index_length using assms apply auto[1]
      using nth_drop[of "length ds1" "is"] valid_index_lt[OF assms(1)] nth_append[of ds1 ds2] length_is
      by (metis length_append nat_add_left_cancel_less nat_le_iff_add nth_append_length_plus)
    show "take (length ds1) is @ drop (length ds1) is = is" using length_is by auto
  qed

  instance proof
    fix A B C::"'a::ring tensor"
    show "(A  B)  C = A  (B  C)"
    proof (rule tensor_lookup_eqI, simp)
      fix "is" assume "is  dims ((A  B)  C)"
      obtain is1 is23 where "is1  dims A" "is23  dims (B  C)" "is1 @ is23 = is"
        by (metis (mono_tags, lifting) is  dims ((A  B)  C) Tensor_Product.dims_tensor_prod append_assoc valid_index_split)
      obtain is2 is3 where "is2  dims B" "is3  dims C" "is2 @ is3 = is23"
        by (metis is23  dims (local.tensor_prod_otimes B C) dims_tensor_prod valid_index_split)
      define is12 where "is12 = is1 @ is2"
      have "is12  dims (A  B)" by (simp add: is1  dims A is2  dims B is12_def valid_index_append)
      have "is12 @ is3 = is" by (simp add: is1 @ is23 = is is2 @ is3 = is23 is12_def)
      show "lookup ((A  B)  C) is = lookup (A  (B  C)) is"
        unfolding lookup_tensor_prod[OF is1  dims A is23  dims (B  C), unfolded is1 @ is23 = is]
        lookup_tensor_prod[OF is12  dims (A  B) is3  dims C, unfolded is12 @ is3 = is]
        using is1  dims A is2 @ is3 = is23 is2  dims B is3  dims C is12_def mult.assoc by fastforce
    qed
  qed

end

lemma tensor_prod_distr_left:
assumes "dims A = dims B"
shows "(A + B)  C = (A  C) + (B  C)"
proof -
  have "is. is  dims A @ dims C  lookup ((A + B)  C) is = lookup (A  C + B  C) is"
  proof -
    fix "is" assume "is  dims A @ dims C"
    obtain is1 is2 where "is = is1 @ is2" "is1  dims A" "is2  dims C" using valid_index_split using is  dims A @ dims C by blast
    then show "lookup ((A + B)  C) is = lookup ((A  C) + (B  C)) is"
     using lookup_plus
     is1  dims A is2  dims C assms plus_dim1 dims_tensor_prod lookup_tensor_prod ring_class.ring_distribs(2) valid_index_append
     by fastforce
  qed
  moreover have "tensor_from_lookup (dims A @ dims C) (lookup ((A + B)  C)) = (A + B)  C"
       "tensor_from_lookup (dims A @ dims C) (lookup ((A  C) + (B  C))) = (A  C) + (B  C)"
    by (metis (no_types, lifting) assms plus_dim1 dims_tensor_prod tensor_lookup)+
  ultimately show ?thesis using tensor_from_lookup_eqI
    by (metis is. is  dims A @ dims C  lookup ((A + B)  C) is = lookup (A  C + B  C) is)
qed

lemma tensor_prod_distr_right:
assumes "dims A = dims B"
shows "C  (A + B) = (C  A) + (C  B)"
proof -
  have "is. is  dims C @ dims A  lookup (C  (A + B)) is = lookup (C  A + C  B) is"
  proof -
    fix "is" assume "is  dims C @ dims A"
    obtain is1 is2 where "is = is1 @ is2" "is1  dims C" "is2  dims A" using valid_index_split using is  dims C @ dims A by blast
    then show "lookup (C  (A + B)) is = lookup ((C  A) + (C  B)) is"
     using lookup_plus
     using is2  dims A is1  dims C assms plus_dim1 dims_tensor_prod lookup_tensor_prod ring_class.ring_distribs(1) valid_index_append
     by fastforce
  qed
  moreover have "tensor_from_lookup (dims C @ dims A) (lookup (C  (A + B))) = C  (A + B)"
       "tensor_from_lookup (dims C @ dims A) (lookup ((C  A) + (C  B))) = (C  A) + (C  B)"
    by (metis (no_types, lifting) assms plus_dim1 dims_tensor_prod tensor_lookup)+
  ultimately show ?thesis using tensor_from_lookup_eqI
    by (metis is. is  dims C @ dims A  lookup (C  (A + B)) is = lookup (C  A + C  B) is)
qed

instantiation tensor :: (ring_1) monoid_mult
begin
  definition tensor_one_def:"1 = tensor_from_vec [] [1]"

  lemma tensor_one_from_lookup: "1 = tensor_from_lookup [] (λ_. 1)"
    unfolding tensor_one_def by (rule tensor_eqI; simp_all add: tensor_from_lookup_def )

  instance proof
    fix A::"'a::ring_1 tensor"
    show "A * 1 = A" unfolding tensor_one_from_lookup
      by (rule tensor_lookup_eqI;metis lookup_tensor_prod[of _ "A" "[]" "tensor_from_lookup [] (λ_. 1)"]
          lookup_tensor_from_lookup valid_index.Nil append_Nil2 dims_tensor dims_tensor_prod
          length_tensor_vec_from_lookup mult.right_neutral tensor_from_lookup_def)
  next
    fix A::"'a::ring_1 tensor"
    show "1 * A = A" unfolding tensor_one_from_lookup
      by (rule tensor_lookup_eqI; metis lookup_tensor_prod[of "[]" "tensor_from_lookup [] (λ_. 1)" _ "A"]
          lookup_tensor_from_lookup valid_index.Nil List.append.append_Nil dims_tensor dims_tensor_prod
          length_tensor_vec_from_lookup mult.left_neutral tensor_from_lookup_def)
  qed
end

lemma order_tensor_one: "order 1 = 0" unfolding tensor_one_def by simp

lemma smult_prod_extract1:
fixes a::"'a::comm_ring_1"
shows "a  (A  B) = (a  A)  B"
proof (rule tensor_lookup_eqI)
  show "dims (a  (A  B)) = dims ((a  A)  B)" by simp
  fix "is" assume "is  dims (a  (A  B))"
  then have "is  dims (A  B)" by auto
  then obtain is1 is2 where "is1  dims A" "is2  dims B" "is = is1 @ is2" by (metis dims_tensor_prod valid_index_split)
  then have "is1  dims (a  A)" by auto
  show "lookup (a  (A  B)) is = lookup (a  A  B) is"
  using lookup_tensor_prod[OF is1  dims A is2  dims B] lookup_tensor_prod[OF is1  dims (a  A) is2  dims B]
        lookup_smult[OF is  dims (A  B)] lookup_smult[OF is1  dims A] is = is1 @ is2 by simp
qed

lemma smult_prod_extract2:
fixes a::"'a::comm_ring_1"
shows "a  (A  B) = A  (a  B)"
proof (rule tensor_lookup_eqI)
  show "dims (a  (A  B)) = dims (A  (a  B))" by simp
  fix "is" assume "is  dims (a  (A  B))"
  then have "is  dims (A  B)" by auto
  then obtain is1 is2 where "is1  dims A" "is2  dims B" "is = is1 @ is2" by (metis dims_tensor_prod valid_index_split)
  then have "is2  dims (a  B)" by auto
  show "lookup (a  (A  B)) is = lookup (A  (a  B)) is"
  using lookup_tensor_prod[OF is1  dims A is2  dims B] lookup_tensor_prod[OF is1  dims A is2  dims (a  B)]
        lookup_smult[OF is  dims (A  B)] lookup_smult[OF is2  dims B] is = is1 @ is2 by simp
qed


lemma order_0_multiple_of_one:
assumes "order A = 0"
obtains a where "A = a  1"
proof
  assume "(a. A = a  1  thesis)"
  have "length (vec A) = 1" using assms by (simp add:length_vec)
  then obtain a where "vec A = [a]" by (metis One_nat_def Suc_length_conv length_0_conv)
  moreover have "vec (a  1) = [a]" unfolding smult_def tensor_one_def by (simp add: vec_smult_def)
  ultimately have "A = a  1" using tensor_eqI by (metis assms dims_smult length_0_conv order_tensor_one)
  then show "A = hd (vec A)  1" using vec A = [a] by auto
qed

lemma smult_1:
fixes A::"'a::ring_1 tensor"
shows "A = 1  A" unfolding smult_def tensor_one_def
apply (rule tensor_eqI)
apply (simp add: length_vec length_vec_smult)
by (metis dims_tensor length_vec length_vec_smult lookup_smult mult.left_neutral smult_def tensor_lookup_eqI)


lemma tensor0_prod_right[simp]: "A  tensor0 ds = tensor0 (dims A @ ds)"
proof (rule tensor_lookup_eqI,simp)
  fix "is" assume "is  dims (A  tensor0 ds)"
  then obtain is1 is2 where "is1  dims A" "is2  dims (tensor0 ds)" "is = is1 @ is2"
    by (metis dims_tensor0 dims_tensor_prod valid_index_split)
  then show "lookup (A  tensor0 ds) is = lookup (tensor0 (dims A @ ds)) is"
    by (metis (no_types, lifting) is  dims (A  tensor0 ds) dims_tensor0 dims_tensor_prod lookup_tensor0 lookup_tensor_prod mult_zero_right)
qed

lemma tensor0_prod_left[simp]: "tensor0 ds  A = tensor0 (ds @ dims A)"
proof (rule tensor_lookup_eqI,simp)
  fix "is" assume "is  dims (tensor0 ds  A)"
  then obtain is1 is2 where "is1  dims (tensor0 ds)" "is2  dims A" "is = is1 @ is2"
    by (metis dims_tensor0 dims_tensor_prod valid_index_split)
  then show "lookup (tensor0 ds  A) is = lookup (tensor0 (ds @ dims A)) is"
    by (metis (no_types, lifting) is  dims (tensor0 ds  A) dims_tensor0 dims_tensor_prod lookup_tensor0 lookup_tensor_prod mult_zero_left)
qed

lemma subtensor_prod_with_vec:
assumes "order A = 1" "i < hd (dims A)"
shows "subtensor (A  B) i = lookup A [i]  B"
proof (rule tensor_lookup_eqI)
  have "dims (A  B)  []" using assms(1) by auto
  have "hd (dims A) =  hd (dims (A  B))"
    by (metis One_nat_def Suc_length_conv append_Cons assms(1) dims_tensor_prod list.sel(1))
  show "dims (subtensor (A  B) i) = dims (lookup A [i]  B)"
    unfolding dims_smult dims_subtensor[OF dims (A  B)  [] i < hd (dims A)[unfolded hd (dims A) =  hd (dims (A  B))] ]
    by (metis One_nat_def Suc_length_conv append.simps(2) append_self_conv2 assms(1) dims_tensor_prod length_0_conv list.sel(3))
next
  fix "is" assume "is  dims (subtensor (A  B) i)"
  have "dims (A  B)  []" using assms(1) by auto
  have "hd (dims A) =  hd (dims (A  B))"
    by (metis One_nat_def Suc_length_conv append_Cons assms(1) dims_tensor_prod list.sel(1))
  then have "is  dims B"
    using is  dims (subtensor (A  B) i)[unfolded dims_subtensor[OF dims (A  B)  [] i < hd (dims A)[unfolded hd (dims A) =  hd (dims (A  B))] ]]
    by (metis One_nat_def Suc_length_conv append_self_conv2 assms(1) dims_tensor_prod length_0_conv list.sel(3) list.simps(3) tl_append2)
  have "[i]  dims A" using assms by (metis One_nat_def Suc_length_conv length_0_conv list.sel(1) valid_index.Nil valid_index.simps)
  then have "i # is  dims (A  B)" using is  dims (subtensor (A  B) i) dims_subtensor valid_index.Cons by auto
  then show "lookup (subtensor (A  B) i) is = lookup (lookup A [i]  B) is"
  unfolding lookup_subtensor1[OF i # is  dims (A  B)]
    using lookup_tensor_prod[OF [i]  dims A is  dims B] lookup_smult
    is  dims B using append_Cons by fastforce
qed

end