Theory Tensor_Product
section ‹Tensor Product›
theory Tensor_Product
imports Tensor_Scalar_Mult Tensor_Subtensor
begin
instantiation tensor:: (ring) semigroup_mult
begin
definition tensor_prod_def:"A * B = tensor_from_vec (dims A @ dims B) (concat (map (λa. vec_smult a (vec B)) (vec A)))"
abbreviation tensor_prod_otimes :: "'a tensor ⇒ 'a tensor ⇒ 'a tensor" (infixl ‹⊗› 70)
where "A ⊗ B ≡ A * B"
lemma vec_tensor_prod[simp]: "vec (A ⊗ B) = concat (map (λa. vec_smult a (vec B)) (vec A))" (is ?V)
and dims_tensor_prod[simp]: "dims (A ⊗ B) = dims A @ dims B" (is ?D)
proof -
have "length (concat (map (λa. vec_smult a (vec B)) (vec A))) = prod_list (dims A @ dims B)"
proof -
have "⋀xs. xs ∈ set (map (λa. vec_smult a (vec B)) (vec A)) ⟹ length xs = length (vec B)"
using length_vec_smult by force
then show ?thesis using concat_equal_length by (metis length_map length_vec prod_list.append)
qed
then show ?V ?D by (simp add: tensor_prod_def)+
qed
lemma tensorprod_subtensor_base:
shows "concat (map f (concat xss)) = concat (map (λxs. concat (map f xs)) xss)"
by (induction xss; auto)
lemma subtensor_combine_tensor_prod:
assumes "⋀A. A∈set As ⟹ dims A = ds"
shows "subtensor_combine ds As ⊗ B = subtensor_combine (ds @ dims B) (map (λA. A ⊗ B) As)"
proof -
let ?f = "λa. vec_smult a (Tensor.vec B)"
let ?xss = "map Tensor.vec As"
have 1:"prod_list (length As # ds) = length (concat ?xss)" by (metis assms length_vec subtensor_combine_dims subtensor_combine_vec)
have 2:"⋀A. A∈set As ⟹ prod_list (dims A @ dims B) = length (concat (map ?f (Tensor.vec A)))"
by (metis dims_tensor_prod length_vec vec_tensor_prod)
have 3: "length As # ds @ dims B = (length (map (λA. tensor_from_vec (dims A @ dims B)
(concat (map (λa. vec_smult a (vec B)) (vec A)))) As) # ds @ dims B)" by simp
have 4:"(concat (map (λxs. concat (map (λa. vec_smult a (vec B)) xs)) (map vec As)))
= (concat (map vec (map (λA. tensor_from_vec (dims A @ dims B) (concat (map (λa. vec_smult a (vec B)) (vec A)))) As)))"
unfolding map_map[unfolded comp_def] using vec_tensor by (metis (no_types, lifting) "2" map_eq_conv)
have "subtensor_combine ds As ⊗ B = tensor_from_vec (length As # ds @ dims B) (concat (map ?f (concat (?xss))))"
unfolding subtensor_combine_def tensor_prod_def using 1 by auto
also have "... = tensor_from_vec (length As # ds @ dims B) (concat (map (λxs. concat (map ?f xs)) ?xss))"
using tensorprod_subtensor_base[of ?f ?xss] by auto
also have "... =subtensor_combine (ds @ dims B) (map (λA. A ⊗ B) As)"
unfolding subtensor_combine_def tensor_prod_def using 3 4 by metis
finally show ?thesis by metis
qed
lemma subtensor_tensor_prod:
assumes "dims A ≠ []" and "i < hd (dims A)"
shows "subtensor (A ⊗ B) i = subtensor A i ⊗ B"
using assms proof (induction A rule:subtensor_combine_induct)
case order_0
then show ?case by auto
next
case (order_step As ds)
have 1:"i < length (map (λA. A ⊗ B) As)" using order_step by (simp add: order_step.hyps order_step.prems(1))
have 2:"(⋀A. A ∈ set (map (λA. A ⊗ B) As) ⟹ dims A = ds @ dims B)" using order_step by auto
have "subtensor (subtensor_combine ds As ⊗ B) i = subtensor (subtensor_combine (ds @ dims B) (map (λA. A ⊗ B) As)) i"
using subtensor_combine_tensor_prod order_step by metis
also have "... = As ! i ⊗ B"
using order_step subtensor_subtensor_combine[of "(map (λA. A ⊗ B) As)" "ds @ dims B" i] 1 2 by auto
also have "... = subtensor (subtensor_combine ds As) i ⊗ B"
by (metis "1" length_map order_step.hyps subtensor_subtensor_combine)
finally show ?case by auto
qed
lemma lookup_tensor_prod[simp]:
assumes is1_valid:"is1 ⊲ dims A" and is2_valid:"is2 ⊲ dims B"
shows "lookup (A ⊗ B) (is1 @ is2) = lookup A is1 * lookup B is2"
using assms proof (induction A arbitrary:is1 rule:subtensor_induct)
case (order_0 A is1)
then obtain a where "vec A = [a]"
using Suc_length_conv Tensor.tensor_vec_from_lookup_Nil length_0_conv length_tensor_vec_from_lookup length_vec by metis
then have "A ⊗ B = a ⋅ B" unfolding tensor_prod_def smult_def using order_0 by simp
moreover have "lookup A [] = a" by (simp add: ‹Tensor.vec A = [a]› lookup_def order_0.hyps)
ultimately have "lookup (A ⊗ B) (is2) = a * lookup B is2" by (simp add: lookup_smult is2_valid)
then show ?case using ‹lookup A [] = a› null_rec(1) order_0.hyps order_0.prems(1) by auto
next
case (order_step A is1)
then obtain i is1' where "i # is1' = is1" by blast
have "lookup (subtensor A i ⊗ B) (is1' @ is2) = lookup (subtensor A i) is1' * lookup B is2" using order_step
by (metis ‹i # is1' = is1› dims_subtensor list.sel(1) list.sel(3) valid_index_dimsE)
then show "lookup (A ⊗ B) (is1 @ is2) = lookup A is1 * lookup B is2"
using lookup_subtensor1[of i is1' A] lookup_subtensor1[of i "is1' @ is2" "A⊗B"] subtensor_tensor_prod[of A i B]
Cons_eq_appendI ‹i # is1' = is1› dims_tensor_prod is2_valid list.sel(1) order_step.hyps order_step.prems(1) valid_index_append valid_index_dimsE
by metis
qed
lemma valid_index_split:
assumes "is ⊲ ds1 @ ds2"
obtains is1 is2 where "is1 @ is2 = is" "is1 ⊲ ds1" "is2 ⊲ ds2"
proof
assume a: "⋀is1 is2. is1 @ is2 = is ⟹ is1 ⊲ ds1 ⟹ is2 ⊲ ds2 ⟹ thesis"
have length_is:"length is = length ds1 + length ds2" using valid_index_length using assms by auto
show "take (length ds1) is ⊲ ds1"
apply (rule valid_indexI)
using valid_index_length using assms apply auto[1]
by (metis add_leD1 assms length_append not_less nth_append nth_take valid_index_lt)
show "drop (length ds1) is ⊲ ds2"
apply (rule valid_indexI)
using valid_index_length using assms apply auto[1]
using nth_drop[of "length ds1" "is"] valid_index_lt[OF assms(1)] nth_append[of ds1 ds2] length_is
by (metis length_append nat_add_left_cancel_less nat_le_iff_add nth_append_length_plus)
show "take (length ds1) is @ drop (length ds1) is = is" using length_is by auto
qed
instance proof
fix A B C::"'a::ring tensor"
show "(A ⊗ B) ⊗ C = A ⊗ (B ⊗ C)"
proof (rule tensor_lookup_eqI, simp)
fix "is" assume "is ⊲ dims ((A ⊗ B) ⊗ C)"
obtain is1 is23 where "is1 ⊲ dims A" "is23 ⊲ dims (B ⊗ C)" "is1 @ is23 = is"
by (metis (mono_tags, lifting) ‹is ⊲ dims ((A ⊗ B) ⊗ C)› Tensor_Product.dims_tensor_prod append_assoc valid_index_split)
obtain is2 is3 where "is2 ⊲ dims B" "is3 ⊲ dims C" "is2 @ is3 = is23"
by (metis ‹is23 ⊲ dims (local.tensor_prod_otimes B C)› dims_tensor_prod valid_index_split)
define is12 where "is12 = is1 @ is2"
have "is12 ⊲ dims (A ⊗ B)" by (simp add: ‹is1 ⊲ dims A› ‹is2 ⊲ dims B› is12_def valid_index_append)
have "is12 @ is3 = is" by (simp add: ‹is1 @ is23 = is› ‹is2 @ is3 = is23› is12_def)
show "lookup ((A ⊗ B) ⊗ C) is = lookup (A ⊗ (B ⊗ C)) is"
unfolding lookup_tensor_prod[OF ‹is1 ⊲ dims A› ‹is23 ⊲ dims (B ⊗ C)›, unfolded ‹is1 @ is23 = is›]
lookup_tensor_prod[OF ‹is12 ⊲ dims (A ⊗ B)› ‹is3 ⊲ dims C›, unfolded ‹is12 @ is3 = is›]
using ‹is1 ⊲ dims A› ‹is2 @ is3 = is23› ‹is2 ⊲ dims B› ‹is3 ⊲ dims C› is12_def mult.assoc by fastforce
qed
qed
end
lemma tensor_prod_distr_left:
assumes "dims A = dims B"
shows "(A + B) ⊗ C = (A ⊗ C) + (B ⊗ C)"
proof -
have "⋀is. is ⊲ dims A @ dims C ⟹ lookup ((A + B) ⊗ C) is = lookup (A ⊗ C + B ⊗ C) is"
proof -
fix "is" assume "is ⊲ dims A @ dims C"
obtain is1 is2 where "is = is1 @ is2" "is1 ⊲ dims A" "is2 ⊲ dims C" using valid_index_split using ‹is ⊲ dims A @ dims C› by blast
then show "lookup ((A + B) ⊗ C) is = lookup ((A ⊗ C) + (B ⊗ C)) is"
using lookup_plus
‹is1 ⊲ dims A› ‹is2 ⊲ dims C› assms plus_dim1 dims_tensor_prod lookup_tensor_prod ring_class.ring_distribs(2) valid_index_append
by fastforce
qed
moreover have "tensor_from_lookup (dims A @ dims C) (lookup ((A + B) ⊗ C)) = (A + B) ⊗ C"
"tensor_from_lookup (dims A @ dims C) (lookup ((A ⊗ C) + (B ⊗ C))) = (A ⊗ C) + (B ⊗ C)"
by (metis (no_types, lifting) assms plus_dim1 dims_tensor_prod tensor_lookup)+
ultimately show ?thesis using tensor_from_lookup_eqI
by (metis ‹⋀is. is ⊲ dims A @ dims C ⟹ lookup ((A + B) ⊗ C) is = lookup (A ⊗ C + B ⊗ C) is›)
qed
lemma tensor_prod_distr_right:
assumes "dims A = dims B"
shows "C ⊗ (A + B) = (C ⊗ A) + (C ⊗ B)"
proof -
have "⋀is. is ⊲ dims C @ dims A ⟹ lookup (C ⊗ (A + B)) is = lookup (C ⊗ A + C ⊗ B) is"
proof -
fix "is" assume "is ⊲ dims C @ dims A"
obtain is1 is2 where "is = is1 @ is2" "is1 ⊲ dims C" "is2 ⊲ dims A" using valid_index_split using ‹is ⊲ dims C @ dims A› by blast
then show "lookup (C ⊗ (A + B)) is = lookup ((C ⊗ A) + (C ⊗ B)) is"
using lookup_plus
using ‹is2 ⊲ dims A› ‹is1 ⊲ dims C› assms plus_dim1 dims_tensor_prod lookup_tensor_prod ring_class.ring_distribs(1) valid_index_append
by fastforce
qed
moreover have "tensor_from_lookup (dims C @ dims A) (lookup (C ⊗ (A + B))) = C ⊗ (A + B)"
"tensor_from_lookup (dims C @ dims A) (lookup ((C ⊗ A) + (C ⊗ B))) = (C ⊗ A) + (C ⊗ B)"
by (metis (no_types, lifting) assms plus_dim1 dims_tensor_prod tensor_lookup)+
ultimately show ?thesis using tensor_from_lookup_eqI
by (metis ‹⋀is. is ⊲ dims C @ dims A ⟹ lookup (C ⊗ (A + B)) is = lookup (C ⊗ A + C ⊗ B) is›)
qed
instantiation tensor :: (ring_1) monoid_mult
begin
definition tensor_one_def:"1 = tensor_from_vec [] [1]"
lemma tensor_one_from_lookup: "1 = tensor_from_lookup [] (λ_. 1)"
unfolding tensor_one_def by (rule tensor_eqI; simp_all add: tensor_from_lookup_def )
instance proof
fix A::"'a::ring_1 tensor"
show "A * 1 = A" unfolding tensor_one_from_lookup
by (rule tensor_lookup_eqI;metis lookup_tensor_prod[of _ "A" "[]" "tensor_from_lookup [] (λ_. 1)"]
lookup_tensor_from_lookup valid_index.Nil append_Nil2 dims_tensor dims_tensor_prod
length_tensor_vec_from_lookup mult.right_neutral tensor_from_lookup_def)
next
fix A::"'a::ring_1 tensor"
show "1 * A = A" unfolding tensor_one_from_lookup
by (rule tensor_lookup_eqI; metis lookup_tensor_prod[of "[]" "tensor_from_lookup [] (λ_. 1)" _ "A"]
lookup_tensor_from_lookup valid_index.Nil List.append.append_Nil dims_tensor dims_tensor_prod
length_tensor_vec_from_lookup mult.left_neutral tensor_from_lookup_def)
qed
end
lemma order_tensor_one: "order 1 = 0" unfolding tensor_one_def by simp
lemma :
fixes a::"'a::comm_ring_1"
shows "a ⋅ (A ⊗ B) = (a ⋅ A) ⊗ B"
proof (rule tensor_lookup_eqI)
show "dims (a ⋅ (A ⊗ B)) = dims ((a ⋅ A) ⊗ B)" by simp
fix "is" assume "is ⊲ dims (a ⋅ (A ⊗ B))"
then have "is ⊲ dims (A ⊗ B)" by auto
then obtain is1 is2 where "is1 ⊲ dims A" "is2 ⊲ dims B" "is = is1 @ is2" by (metis dims_tensor_prod valid_index_split)
then have "is1 ⊲ dims (a ⋅ A)" by auto
show "lookup (a ⋅ (A ⊗ B)) is = lookup (a ⋅ A ⊗ B) is"
using lookup_tensor_prod[OF ‹is1 ⊲ dims A› ‹is2 ⊲ dims B›] lookup_tensor_prod[OF ‹is1 ⊲ dims (a ⋅ A)› ‹is2 ⊲ dims B›]
lookup_smult[OF ‹is ⊲ dims (A ⊗ B)›] lookup_smult[OF ‹is1 ⊲ dims A›] ‹is = is1 @ is2› by simp
qed
lemma :
fixes a::"'a::comm_ring_1"
shows "a ⋅ (A ⊗ B) = A ⊗ (a ⋅ B)"
proof (rule tensor_lookup_eqI)
show "dims (a ⋅ (A ⊗ B)) = dims (A ⊗ (a ⋅ B))" by simp
fix "is" assume "is ⊲ dims (a ⋅ (A ⊗ B))"
then have "is ⊲ dims (A ⊗ B)" by auto
then obtain is1 is2 where "is1 ⊲ dims A" "is2 ⊲ dims B" "is = is1 @ is2" by (metis dims_tensor_prod valid_index_split)
then have "is2 ⊲ dims (a ⋅ B)" by auto
show "lookup (a ⋅ (A ⊗ B)) is = lookup (A ⊗ (a ⋅ B)) is"
using lookup_tensor_prod[OF ‹is1 ⊲ dims A› ‹is2 ⊲ dims B›] lookup_tensor_prod[OF ‹is1 ⊲ dims A› ‹is2 ⊲ dims (a ⋅ B)›]
lookup_smult[OF ‹is ⊲ dims (A ⊗ B)›] lookup_smult[OF ‹is2 ⊲ dims B›] ‹is = is1 @ is2› by simp
qed
lemma order_0_multiple_of_one:
assumes "order A = 0"
obtains a where "A = a ⋅ 1"
proof
assume "(⋀a. A = a ⋅ 1 ⟹ thesis)"
have "length (vec A) = 1" using assms by (simp add:length_vec)
then obtain a where "vec A = [a]" by (metis One_nat_def Suc_length_conv length_0_conv)
moreover have "vec (a ⋅ 1) = [a]" unfolding smult_def tensor_one_def by (simp add: vec_smult_def)
ultimately have "A = a ⋅ 1" using tensor_eqI by (metis assms dims_smult length_0_conv order_tensor_one)
then show "A = hd (vec A) ⋅ 1" using ‹vec A = [a]› by auto
qed
lemma smult_1:
fixes A::"'a::ring_1 tensor"
shows "A = 1 ⋅ A" unfolding smult_def tensor_one_def
apply (rule tensor_eqI)
apply (simp add: length_vec length_vec_smult)
by (metis dims_tensor length_vec length_vec_smult lookup_smult mult.left_neutral smult_def tensor_lookup_eqI)
lemma tensor0_prod_right[simp]: "A ⊗ tensor0 ds = tensor0 (dims A @ ds)"
proof (rule tensor_lookup_eqI,simp)
fix "is" assume "is ⊲ dims (A ⊗ tensor0 ds)"
then obtain is1 is2 where "is1 ⊲ dims A" "is2 ⊲ dims (tensor0 ds)" "is = is1 @ is2"
by (metis dims_tensor0 dims_tensor_prod valid_index_split)
then show "lookup (A ⊗ tensor0 ds) is = lookup (tensor0 (dims A @ ds)) is"
by (metis (no_types, lifting) ‹is ⊲ dims (A ⊗ tensor0 ds)› dims_tensor0 dims_tensor_prod lookup_tensor0 lookup_tensor_prod mult_zero_right)
qed
lemma tensor0_prod_left[simp]: "tensor0 ds ⊗ A = tensor0 (ds @ dims A)"
proof (rule tensor_lookup_eqI,simp)
fix "is" assume "is ⊲ dims (tensor0 ds ⊗ A)"
then obtain is1 is2 where "is1 ⊲ dims (tensor0 ds)" "is2 ⊲ dims A" "is = is1 @ is2"
by (metis dims_tensor0 dims_tensor_prod valid_index_split)
then show "lookup (tensor0 ds ⊗ A) is = lookup (tensor0 (ds @ dims A)) is"
by (metis (no_types, lifting) ‹is ⊲ dims (tensor0 ds ⊗ A)› dims_tensor0 dims_tensor_prod lookup_tensor0 lookup_tensor_prod mult_zero_left)
qed
lemma subtensor_prod_with_vec:
assumes "order A = 1" "i < hd (dims A)"
shows "subtensor (A ⊗ B) i = lookup A [i] ⋅ B"
proof (rule tensor_lookup_eqI)
have "dims (A ⊗ B) ≠ []" using assms(1) by auto
have "hd (dims A) = hd (dims (A ⊗ B))"
by (metis One_nat_def Suc_length_conv append_Cons assms(1) dims_tensor_prod list.sel(1))
show "dims (subtensor (A ⊗ B) i) = dims (lookup A [i] ⋅ B)"
unfolding dims_smult dims_subtensor[OF ‹dims (A ⊗ B) ≠ []› ‹i < hd (dims A)›[unfolded ‹hd (dims A) = hd (dims (A ⊗ B))›] ]
by (metis One_nat_def Suc_length_conv append.simps(2) append_self_conv2 assms(1) dims_tensor_prod length_0_conv list.sel(3))
next
fix "is" assume "is ⊲ dims (subtensor (A ⊗ B) i)"
have "dims (A ⊗ B) ≠ []" using assms(1) by auto
have "hd (dims A) = hd (dims (A ⊗ B))"
by (metis One_nat_def Suc_length_conv append_Cons assms(1) dims_tensor_prod list.sel(1))
then have "is ⊲ dims B"
using ‹is ⊲ dims (subtensor (A ⊗ B) i)›[unfolded dims_subtensor[OF ‹dims (A ⊗ B) ≠ []› ‹i < hd (dims A)›[unfolded ‹hd (dims A) = hd (dims (A ⊗ B))›] ]]
by (metis One_nat_def Suc_length_conv append_self_conv2 assms(1) dims_tensor_prod length_0_conv list.sel(3) list.simps(3) tl_append2)
have "[i] ⊲ dims A" using assms by (metis One_nat_def Suc_length_conv length_0_conv list.sel(1) valid_index.Nil valid_index.simps)
then have "i # is ⊲ dims (A ⊗ B)" using ‹is ⊲ dims (subtensor (A ⊗ B) i)› dims_subtensor valid_index.Cons by auto
then show "lookup (subtensor (A ⊗ B) i) is = lookup (lookup A [i] ⋅ B) is"
unfolding lookup_subtensor1[OF ‹i # is ⊲ dims (A ⊗ B)›]
using lookup_tensor_prod[OF ‹[i] ⊲ dims A› ‹is ⊲ dims B›] lookup_smult
‹is ⊲ dims B› using append_Cons by fastforce
qed
end