Theory Tensor_Product

theory Tensor_Product
imports Tensor_Scalar_Mult
(* Author: Alexander Bentkamp, Universität des Saarlandes
*)
section ‹Tensor Product›

theory Tensor_Product
imports Tensor_Scalar_Mult Tensor_Subtensor
begin

instantiation tensor:: (ring) semigroup_mult
begin
  definition tensor_prod_def:"A * B = tensor_from_vec (dims A @ dims B) (concat (map (λa. vec_smult a (vec B)) (vec A)))"

  abbreviation tensor_prod_otimes :: "'a tensor ⇒ 'a tensor ⇒ 'a tensor" (infixl "⊗" 70)
    where "A ⊗ B ≡ A * B"


  lemma vec_tensor_prod[simp]: "vec (A ⊗ B) = concat (map (λa. vec_smult a (vec B)) (vec A))" (is ?V)
  and dims_tensor_prod[simp]: "dims (A ⊗ B) = dims A @ dims B" (is ?D)
  proof -
    have "length (concat (map (λa. vec_smult a (vec B)) (vec A))) = prod_list (dims A @ dims B)"
    proof -
      have "⋀xs. xs ∈ set (map (λa. vec_smult a (vec B)) (vec A)) ⟹ length xs = length (vec B)"
        using length_vec_smult by force
      then show ?thesis using concat_equal_length by (metis length_map length_vec prod_list.append)
    qed
    then show ?V ?D by (simp add: tensor_prod_def)+
  qed

  lemma tensorprod_subtensor_base:
  shows "concat (map f (concat xss)) = concat (map (λxs. concat (map f xs)) xss)"
  by (induction xss; auto)

  lemma subtensor_combine_tensor_prod:
  assumes "⋀A. A∈set As ⟹ dims A = ds"
  shows "subtensor_combine ds As ⊗ B = subtensor_combine (ds @ dims B) (map (λA. A ⊗ B) As)"
  proof -
    let ?f = "λa. vec_smult a (Tensor.vec B)"
    let ?xss = "map Tensor.vec As"
    have 1:"prod_list (length As # ds) = length (concat ?xss)" by (metis assms length_vec subtensor_combine_dims subtensor_combine_vec)
    have 2:"⋀A. A∈set As ⟹ prod_list (dims A @ dims B) = length (concat (map ?f (Tensor.vec A)))"
      by (metis dims_tensor_prod length_vec vec_tensor_prod)
    have 3: "length As # ds @ dims B = (length (map (λA. tensor_from_vec (dims A @ dims B)
      (concat (map (λa. vec_smult a (vec B)) (vec A)))) As) # ds @ dims B)" by simp
    have 4:"(concat (map (λxs. concat (map (λa. vec_smult a (vec B)) xs)) (map vec As)))
         = (concat (map vec (map (λA. tensor_from_vec (dims A @ dims B) (concat (map (λa. vec_smult a (vec B)) (vec A))))  As)))"
      unfolding map_map[unfolded comp_def]  using vec_tensor by (metis (no_types, lifting) "2" map_eq_conv)
    have "subtensor_combine ds As ⊗ B = tensor_from_vec (length As # ds @ dims B) (concat (map ?f (concat (?xss))))"
      unfolding subtensor_combine_def tensor_prod_def using 1 by auto
    also have "... = tensor_from_vec (length As # ds @ dims B) (concat (map (λxs. concat (map ?f xs)) ?xss))"
      using tensorprod_subtensor_base[of ?f ?xss] by auto
    also have "... =subtensor_combine (ds @ dims B) (map (λA. A ⊗ B) As)"
      unfolding subtensor_combine_def tensor_prod_def using 3 4 by metis
    finally show ?thesis by metis
  qed

  (* evtl. besser ohne Induktion beweisen? Dann wäre das obige Lemma unnötig. Vllt aber auch schwierig. *)
  lemma subtensor_tensor_prod:
  assumes "dims A ≠ []" and "i < hd (dims A)"
  shows "subtensor (A ⊗ B) i = subtensor A i ⊗ B"
  using assms proof (induction A rule:subtensor_combine_induct)
    case order_0
    then show ?case by auto
  next
    case (order_step As ds)
    have 1:"i < length (map (λA. A ⊗ B) As)" using order_step by (simp add: order_step.hyps order_step.prems(1))
    have 2:"(⋀A. A ∈ set (map (λA. A ⊗ B) As) ⟹ dims A = ds @ dims B)" using order_step by auto
    have "subtensor (subtensor_combine ds As ⊗ B) i = subtensor (subtensor_combine (ds @ dims B) (map (λA. A ⊗ B) As)) i"
      using subtensor_combine_tensor_prod order_step by metis
    also have "... = As ! i ⊗ B"
      using order_step subtensor_subtensor_combine[of "(map (λA. A ⊗ B) As)" "ds @ dims B" i] 1 2 by auto
    also have "... = subtensor (subtensor_combine ds As) i ⊗ B"
      by (metis "1" length_map order_step.hyps subtensor_subtensor_combine)
    finally show ?case by auto
  qed

  lemma lookup_tensor_prod[simp]:
  assumes is1_valid:"is1 ⊲ dims A" and is2_valid:"is2 ⊲ dims B"
  shows "lookup (A ⊗ B) (is1 @ is2) = lookup A is1 * lookup B is2"
  using assms proof (induction A arbitrary:is1 rule:subtensor_induct)
    case (order_0 A is1)
    then obtain a where "vec A = [a]"
      using Suc_length_conv Tensor.tensor_vec_from_lookup_Nil length_0_conv length_tensor_vec_from_lookup length_vec by metis
    then have "A ⊗ B = a ⋅ B" unfolding tensor_prod_def smult_def using order_0 by simp
    moreover have "lookup A [] = a" by (simp add: ‹Tensor.vec A = [a]› lookup_def order_0.hyps)
    ultimately have "lookup (A ⊗ B) (is2) = a * lookup B is2" by (simp add: lookup_smult is2_valid)
    then show ?case using ‹lookup A [] = a› null_rec(1) order_0.hyps order_0.prems(1) by auto
  next
    case (order_step A is1)
    then obtain i is1' where "i # is1' = is1" by blast
    have "lookup (subtensor A i ⊗ B) (is1' @ is2) = lookup (subtensor A i) is1' * lookup B is2" using order_step
      by (metis ‹i # is1' = is1› dims_subtensor list.sel(1) list.sel(3) valid_index_dimsE)
    then show "lookup (A ⊗ B) (is1 @ is2) = lookup A is1 * lookup B is2"
      using lookup_subtensor1[of i is1' A] lookup_subtensor1[of i "is1' @ is2" "A⊗B"] subtensor_tensor_prod[of A i B]
      Cons_eq_appendI ‹i # is1' = is1› dims_tensor_prod is2_valid list.sel(1) order_step.hyps order_step.prems(1) valid_index_append valid_index_dimsE
      by metis
  qed

  lemma valid_index_split:
  assumes "is ⊲ ds1 @ ds2"
  obtains is1 is2 where "is1 @ is2 = is" "is1 ⊲ ds1" "is2 ⊲ ds2"
  proof
    assume a: "⋀is1 is2. is1 @ is2 = is ⟹ is1 ⊲ ds1 ⟹ is2 ⊲ ds2 ⟹ thesis"
    have length_is:"length is = length ds1 + length ds2" using valid_index_length using assms by auto
    show "take (length ds1) is ⊲ ds1"
      apply (rule valid_indexI)
      using valid_index_length using assms apply auto[1]
      by (metis add_leD1 assms length_append not_less nth_append nth_take valid_index_lt)
    show "drop (length ds1) is ⊲ ds2"
      apply (rule valid_indexI)
      using valid_index_length using assms apply auto[1]
      using nth_drop[of "length ds1" "is"] valid_index_lt[OF assms(1)] nth_append[of ds1 ds2] length_is
      by (metis length_append nat_add_left_cancel_less nat_le_iff_add nth_append_length_plus)
    show "take (length ds1) is @ drop (length ds1) is = is" using length_is by auto
  qed

  instance proof
    fix A B C::"'a::ring tensor"
    show "(A ⊗ B) ⊗ C = A ⊗ (B ⊗ C)"
    proof (rule tensor_lookup_eqI, simp)
      fix "is" assume "is ⊲ dims ((A ⊗ B) ⊗ C)"
      obtain is1 is23 where "is1 ⊲ dims A" "is23 ⊲ dims (B ⊗ C)" "is1 @ is23 = is"
        by (metis (mono_tags, lifting) ‹is ⊲ dims ((A ⊗ B) ⊗ C)› Tensor_Product.dims_tensor_prod append_assoc valid_index_split)
      obtain is2 is3 where "is2 ⊲ dims B" "is3 ⊲ dims C" "is2 @ is3 = is23"
        by (metis ‹is23 ⊲ dims (local.tensor_prod_otimes B C)› dims_tensor_prod valid_index_split)
      define is12 where "is12 = is1 @ is2"
      have "is12 ⊲ dims (A ⊗ B)" by (simp add: ‹is1 ⊲ dims A› ‹is2 ⊲ dims B› is12_def valid_index_append)
      have "is12 @ is3 = is" by (simp add: ‹is1 @ is23 = is› ‹is2 @ is3 = is23› is12_def)
      show "lookup ((A ⊗ B) ⊗ C) is = lookup (A ⊗ (B ⊗ C)) is"
        unfolding lookup_tensor_prod[OF ‹is1 ⊲ dims A› ‹is23 ⊲ dims (B ⊗ C)›, unfolded ‹is1 @ is23 = is›]
        lookup_tensor_prod[OF ‹is12 ⊲ dims (A ⊗ B)› ‹is3 ⊲ dims C›, unfolded ‹is12 @ is3 = is›]
        using ‹is1 ⊲ dims A› ‹is2 @ is3 = is23› ‹is2 ⊲ dims B› ‹is3 ⊲ dims C› is12_def mult.assoc by fastforce
    qed
  qed

end

lemma tensor_prod_distr_left:
assumes "dims A = dims B"
shows "(A + B) ⊗ C = (A ⊗ C) + (B ⊗ C)"
proof -
  have "⋀is. is ⊲ dims A @ dims C ⟹ lookup ((A + B) ⊗ C) is = lookup (A ⊗ C + B ⊗ C) is"
  proof -
    fix "is" assume "is ⊲ dims A @ dims C"
    obtain is1 is2 where "is = is1 @ is2" "is1 ⊲ dims A" "is2 ⊲ dims C" using valid_index_split using ‹is ⊲ dims A @ dims C› by blast
    then show "lookup ((A + B) ⊗ C) is = lookup ((A ⊗ C) + (B ⊗ C)) is"
     using lookup_plus
     ‹is1 ⊲ dims A› ‹is2 ⊲ dims C› assms plus_dim1 dims_tensor_prod lookup_tensor_prod ring_class.ring_distribs(2) valid_index_append
     by fastforce
  qed
  moreover have "tensor_from_lookup (dims A @ dims C) (lookup ((A + B) ⊗ C)) = (A + B) ⊗ C"
       "tensor_from_lookup (dims A @ dims C) (lookup ((A ⊗ C) + (B ⊗ C))) = (A ⊗ C) + (B ⊗ C)"
    by (metis (no_types, lifting) assms plus_dim1 dims_tensor_prod tensor_lookup)+
  ultimately show ?thesis using tensor_from_lookup_eqI
    by (metis ‹⋀is. is ⊲ dims A @ dims C ⟹ lookup ((A + B) ⊗ C) is = lookup (A ⊗ C + B ⊗ C) is›)
qed

lemma tensor_prod_distr_right:
assumes "dims A = dims B"
shows "C ⊗ (A + B) = (C ⊗ A) + (C ⊗ B)"
proof -
  have "⋀is. is ⊲ dims C @ dims A ⟹ lookup (C ⊗ (A + B)) is = lookup (C ⊗ A + C ⊗ B) is"
  proof -
    fix "is" assume "is ⊲ dims C @ dims A"
    obtain is1 is2 where "is = is1 @ is2" "is1 ⊲ dims C" "is2 ⊲ dims A" using valid_index_split using ‹is ⊲ dims C @ dims A› by blast
    then show "lookup (C ⊗ (A + B)) is = lookup ((C ⊗ A) + (C ⊗ B)) is"
     using lookup_plus
     using ‹is2 ⊲ dims A› ‹is1 ⊲ dims C› assms plus_dim1 dims_tensor_prod lookup_tensor_prod ring_class.ring_distribs(1) valid_index_append
     by fastforce
  qed
  moreover have "tensor_from_lookup (dims C @ dims A) (lookup (C ⊗ (A + B))) = C ⊗ (A + B)"
       "tensor_from_lookup (dims C @ dims A) (lookup ((C ⊗ A) + (C ⊗ B))) = (C ⊗ A) + (C ⊗ B)"
    by (metis (no_types, lifting) assms plus_dim1 dims_tensor_prod tensor_lookup)+
  ultimately show ?thesis using tensor_from_lookup_eqI
    by (metis ‹⋀is. is ⊲ dims C @ dims A ⟹ lookup (C ⊗ (A + B)) is = lookup (C ⊗ A + C ⊗ B) is›)
qed

instantiation tensor :: (ring_1) monoid_mult
begin
  definition tensor_one_def:"1 = tensor_from_vec [] [1]"

  lemma tensor_one_from_lookup: "1 = tensor_from_lookup [] (λ_. 1)"
    unfolding tensor_one_def by (rule tensor_eqI; simp_all add: tensor_from_lookup_def )

  instance proof
    fix A::"'a::ring_1 tensor"
    show "A * 1 = A" unfolding tensor_one_from_lookup
      by (rule tensor_lookup_eqI;metis lookup_tensor_prod[of _ "A" "[]" "tensor_from_lookup [] (λ_. 1)"]
          lookup_tensor_from_lookup valid_index.Nil append_Nil2 dims_tensor dims_tensor_prod
          length_tensor_vec_from_lookup mult.right_neutral tensor_from_lookup_def)
  next
    fix A::"'a::ring_1 tensor"
    show "1 * A = A" unfolding tensor_one_from_lookup
      by (rule tensor_lookup_eqI; metis lookup_tensor_prod[of "[]" "tensor_from_lookup [] (λ_. 1)" _ "A"]
          lookup_tensor_from_lookup valid_index.Nil List.append.append_Nil dims_tensor dims_tensor_prod
          length_tensor_vec_from_lookup mult.left_neutral tensor_from_lookup_def)
  qed
end

lemma order_tensor_one: "order 1 = 0" unfolding tensor_one_def by simp

lemma smult_prod_extract1:
fixes a::"'a::comm_ring_1"
shows "a ⋅ (A ⊗ B) = (a ⋅ A) ⊗ B"
proof (rule tensor_lookup_eqI)
  show "dims (a ⋅ (A ⊗ B)) = dims ((a ⋅ A) ⊗ B)" by simp
  fix "is" assume "is ⊲ dims (a ⋅ (A ⊗ B))"
  then have "is ⊲ dims (A ⊗ B)" by auto
  then obtain is1 is2 where "is1 ⊲ dims A" "is2 ⊲ dims B" "is = is1 @ is2" by (metis dims_tensor_prod valid_index_split)
  then have "is1 ⊲ dims (a ⋅ A)" by auto
  show "lookup (a ⋅ (A ⊗ B)) is = lookup (a ⋅ A ⊗ B) is"
  using lookup_tensor_prod[OF ‹is1 ⊲ dims A› ‹is2 ⊲ dims B›] lookup_tensor_prod[OF ‹is1 ⊲ dims (a ⋅ A)› ‹is2 ⊲ dims B›]
        lookup_smult[OF ‹is ⊲ dims (A ⊗ B)›] lookup_smult[OF ‹is1 ⊲ dims A›] ‹is = is1 @ is2› by simp
qed

lemma smult_prod_extract2:
fixes a::"'a::comm_ring_1"
shows "a ⋅ (A ⊗ B) = A ⊗ (a ⋅ B)"
proof (rule tensor_lookup_eqI)
  show "dims (a ⋅ (A ⊗ B)) = dims (A ⊗ (a ⋅ B))" by simp
  fix "is" assume "is ⊲ dims (a ⋅ (A ⊗ B))"
  then have "is ⊲ dims (A ⊗ B)" by auto
  then obtain is1 is2 where "is1 ⊲ dims A" "is2 ⊲ dims B" "is = is1 @ is2" by (metis dims_tensor_prod valid_index_split)
  then have "is2 ⊲ dims (a ⋅ B)" by auto
  show "lookup (a ⋅ (A ⊗ B)) is = lookup (A ⊗ (a ⋅ B)) is"
  using lookup_tensor_prod[OF ‹is1 ⊲ dims A› ‹is2 ⊲ dims B›] lookup_tensor_prod[OF ‹is1 ⊲ dims A› ‹is2 ⊲ dims (a ⋅ B)›]
        lookup_smult[OF ‹is ⊲ dims (A ⊗ B)›] lookup_smult[OF ‹is2 ⊲ dims B›] ‹is = is1 @ is2› by simp
qed


lemma order_0_multiple_of_one:
assumes "order A = 0"
obtains a where "A = a ⋅ 1"
proof
  assume "(⋀a. A = a ⋅ 1 ⟹ thesis)"
  have "length (vec A) = 1" using assms by (simp add:length_vec)
  then obtain a where "vec A = [a]" by (metis One_nat_def Suc_length_conv length_0_conv)
  moreover have "vec (a ⋅ 1) = [a]" unfolding smult_def tensor_one_def by (simp add: vec_smult_def)
  ultimately have "A = a ⋅ 1" using tensor_eqI by (metis assms dims_smult length_0_conv order_tensor_one)
  then show "A = hd (vec A) ⋅ 1" using ‹vec A = [a]› by auto
qed

lemma smult_1:
fixes A::"'a::ring_1 tensor"
shows "A = 1 ⋅ A" unfolding smult_def tensor_one_def
apply (rule tensor_eqI)
apply (simp add: length_vec length_vec_smult)
by (metis dims_tensor length_vec length_vec_smult lookup_smult mult.left_neutral smult_def tensor_lookup_eqI)


lemma tensor0_prod_right[simp]: "A ⊗ tensor0 ds = tensor0 (dims A @ ds)"
proof (rule tensor_lookup_eqI,simp)
  fix "is" assume "is ⊲ dims (A ⊗ tensor0 ds)"
  then obtain is1 is2 where "is1 ⊲ dims A" "is2 ⊲ dims (tensor0 ds)" "is = is1 @ is2"
    by (metis dims_tensor0 dims_tensor_prod valid_index_split)
  then show "lookup (A ⊗ tensor0 ds) is = lookup (tensor0 (dims A @ ds)) is"
    by (metis (no_types, lifting) ‹is ⊲ dims (A ⊗ tensor0 ds)› dims_tensor0 dims_tensor_prod lookup_tensor0 lookup_tensor_prod mult_zero_right)
qed

lemma tensor0_prod_left[simp]: "tensor0 ds ⊗ A = tensor0 (ds @ dims A)"
proof (rule tensor_lookup_eqI,simp)
  fix "is" assume "is ⊲ dims (tensor0 ds ⊗ A)"
  then obtain is1 is2 where "is1 ⊲ dims (tensor0 ds)" "is2 ⊲ dims A" "is = is1 @ is2"
    by (metis dims_tensor0 dims_tensor_prod valid_index_split)
  then show "lookup (tensor0 ds ⊗ A) is = lookup (tensor0 (ds @ dims A)) is"
    by (metis (no_types, lifting) ‹is ⊲ dims (tensor0 ds ⊗ A)› dims_tensor0 dims_tensor_prod lookup_tensor0 lookup_tensor_prod mult_zero_left)
qed

lemma subtensor_prod_with_vec:
assumes "order A = 1" "i < hd (dims A)"
shows "subtensor (A ⊗ B) i = lookup A [i] ⋅ B"
proof (rule tensor_lookup_eqI)
  have "dims (A ⊗ B) ≠ []" using assms(1) by auto
  have "hd (dims A) =  hd (dims (A ⊗ B))"
    by (metis One_nat_def Suc_length_conv append_Cons assms(1) dims_tensor_prod list.sel(1))
  show "dims (subtensor (A ⊗ B) i) = dims (lookup A [i] ⋅ B)"
    unfolding dims_smult dims_subtensor[OF ‹dims (A ⊗ B) ≠ []› ‹i < hd (dims A)›[unfolded ‹hd (dims A) =  hd (dims (A ⊗ B))›] ]
    by (metis One_nat_def Suc_length_conv append.simps(2) append_self_conv2 assms(1) dims_tensor_prod length_0_conv list.sel(3))
next
  fix "is" assume "is ⊲ dims (subtensor (A ⊗ B) i)"
  have "dims (A ⊗ B) ≠ []" using assms(1) by auto
  have "hd (dims A) =  hd (dims (A ⊗ B))"
    by (metis One_nat_def Suc_length_conv append_Cons assms(1) dims_tensor_prod list.sel(1))
  then have "is ⊲ dims B"
    using ‹is ⊲ dims (subtensor (A ⊗ B) i)›[unfolded dims_subtensor[OF ‹dims (A ⊗ B) ≠ []› ‹i < hd (dims A)›[unfolded ‹hd (dims A) =  hd (dims (A ⊗ B))›] ]]
    by (metis One_nat_def Suc_length_conv append_self_conv2 assms(1) dims_tensor_prod length_0_conv list.sel(3) list.simps(3) tl_append2)
  have "[i] ⊲ dims A" using assms by (metis One_nat_def Suc_length_conv length_0_conv list.sel(1) valid_index.Nil valid_index.simps)
  then have "i # is ⊲ dims (A ⊗ B)" using ‹is ⊲ dims (subtensor (A ⊗ B) i)› dims_subtensor valid_index.Cons by auto
  then show "lookup (subtensor (A ⊗ B) i) is = lookup (lookup A [i] ⋅ B) is"
  unfolding lookup_subtensor1[OF ‹i # is ⊲ dims (A ⊗ B)›]
    using lookup_tensor_prod[OF ‹[i] ⊲ dims A› ‹is ⊲ dims B›] lookup_smult
    ‹is ⊲ dims B› using append_Cons by fastforce
qed

end