# Theory Tensor_Product

theory Tensor_Product
imports Tensor_Scalar_Mult

section

theory Tensor_Product
imports Tensor_Scalar_Mult Tensor_Subtensor
begin

instantiation tensor:: (ring) semigroup_mult
begin
definition tensor_prod_def:

abbreviation tensor_prod_otimes ::  (infixl  70)
where

lemma vec_tensor_prod[simp]:  (is ?V)
and dims_tensor_prod[simp]:  (is ?D)
proof -
have
proof -
have
using length_vec_smult by force
then show ?thesis using concat_equal_length by (metis length_map length_vec prod_list.append)
qed
then show ?V ?D by (simp add: tensor_prod_def)+
qed

lemma tensorprod_subtensor_base:
shows
by (induction xss; auto)

lemma subtensor_combine_tensor_prod:
assumes
shows
proof -
let ?f =
let ?xss =
have 1: by (metis assms length_vec subtensor_combine_dims subtensor_combine_vec)
have 2:
by (metis dims_tensor_prod length_vec vec_tensor_prod)
have 3:  by simp
have 4:
unfolding map_map[unfolded comp_def]  using vec_tensor by (metis (no_types, lifting)  map_eq_conv)
have
unfolding subtensor_combine_def tensor_prod_def using 1 by auto
also have
using tensorprod_subtensor_base[of ?f ?xss] by auto
also have
unfolding subtensor_combine_def tensor_prod_def using 3 4 by metis
finally show ?thesis by metis
qed

lemma subtensor_tensor_prod:
assumes  and
shows
using assms proof (induction A rule:subtensor_combine_induct)
case order_0
then show ?case by auto
next
case (order_step As ds)
have 1: using order_step by (simp add: order_step.hyps order_step.prems(1))
have 2: using order_step by auto
have
using subtensor_combine_tensor_prod order_step by metis
also have
using order_step subtensor_subtensor_combine[of   i] 1 2 by auto
also have
by (metis  length_map order_step.hyps subtensor_subtensor_combine)
finally show ?case by auto
qed

lemma lookup_tensor_prod[simp]:
assumes is1_valid: and is2_valid:
shows
using assms proof (induction A arbitrary:is1 rule:subtensor_induct)
case (order_0 A is1)
then obtain a where
using Suc_length_conv Tensor.tensor_vec_from_lookup_Nil length_0_conv length_tensor_vec_from_lookup length_vec by metis
then have  unfolding tensor_prod_def smult_def using order_0 by simp
moreover have  by (simp add:  lookup_def order_0.hyps)
ultimately have  by (simp add: lookup_smult is2_valid)
then show ?case using  null_rec(1) order_0.hyps order_0.prems(1) by auto
next
case (order_step A is1)
then obtain i is1' where  by blast
have  using order_step
by (metis  dims_subtensor list.sel(1) list.sel(3) valid_index_dimsE)
then show
using lookup_subtensor1[of i is1' A] lookup_subtensor1[of i  ] subtensor_tensor_prod[of A i B]
Cons_eq_appendI  dims_tensor_prod is2_valid list.sel(1) order_step.hyps order_step.prems(1) valid_index_append valid_index_dimsE
by metis
qed

lemma valid_index_split:
assumes
obtains is1 is2 where
proof
assume a:
have length_is: using valid_index_length using assms by auto
show
apply (rule valid_indexI)
using valid_index_length using assms apply auto[1]
by (metis add_leD1 assms length_append not_less nth_append nth_take valid_index_lt)
show
apply (rule valid_indexI)
using valid_index_length using assms apply auto[1]
using nth_drop[of  ] valid_index_lt[OF assms(1)] nth_append[of ds1 ds2] length_is
by (metis length_append nat_add_left_cancel_less nat_le_iff_add nth_append_length_plus)
show  using length_is by auto
qed

instance proof
fix A B C::
show
proof (rule tensor_lookup_eqI, simp)
fix  assume
obtain is1 is23 where
by (metis (mono_tags, lifting)  Tensor_Product.dims_tensor_prod append_assoc valid_index_split)
obtain is2 is3 where
by (metis  dims_tensor_prod valid_index_split)
define is12 where
have  by (simp add:   is12_def valid_index_append)
have  by (simp add:   is12_def)
show
unfolding lookup_tensor_prod[OF  , unfolded ]
lookup_tensor_prod[OF  , unfolded ]
using     is12_def mult.assoc by fastforce
qed
qed

end

lemma tensor_prod_distr_left:
assumes
shows
proof -
have
proof -
fix  assume
obtain is1 is2 where    using valid_index_split using  by blast
then show
using lookup_plus
assms plus_dim1 dims_tensor_prod lookup_tensor_prod ring_class.ring_distribs(2) valid_index_append
by fastforce
qed
moreover have

by (metis (no_types, lifting) assms plus_dim1 dims_tensor_prod tensor_lookup)+
ultimately show ?thesis using tensor_from_lookup_eqI
by (metis )
qed

lemma tensor_prod_distr_right:
assumes
shows
proof -
have
proof -
fix  assume
obtain is1 is2 where    using valid_index_split using  by blast
then show
using lookup_plus
using   assms plus_dim1 dims_tensor_prod lookup_tensor_prod ring_class.ring_distribs(1) valid_index_append
by fastforce
qed
moreover have

by (metis (no_types, lifting) assms plus_dim1 dims_tensor_prod tensor_lookup)+
ultimately show ?thesis using tensor_from_lookup_eqI
by (metis )
qed

instantiation tensor :: (ring_1) monoid_mult
begin
definition tensor_one_def:

lemma tensor_one_from_lookup:
unfolding tensor_one_def by (rule tensor_eqI; simp_all add: tensor_from_lookup_def )

instance proof
fix A::
show  unfolding tensor_one_from_lookup
by (rule tensor_lookup_eqI;metis lookup_tensor_prod[of _   ]
lookup_tensor_from_lookup valid_index.Nil append_Nil2 dims_tensor dims_tensor_prod
length_tensor_vec_from_lookup mult.right_neutral tensor_from_lookup_def)
next
fix A::
show  unfolding tensor_one_from_lookup
by (rule tensor_lookup_eqI; metis lookup_tensor_prod[of   _ ]
lookup_tensor_from_lookup valid_index.Nil List.append.append_Nil dims_tensor dims_tensor_prod
length_tensor_vec_from_lookup mult.left_neutral tensor_from_lookup_def)
qed
end

lemma order_tensor_one:  unfolding tensor_one_def by simp

lemma smult_prod_extract1:
fixes a::
shows
proof (rule tensor_lookup_eqI)
show  by simp
fix  assume
then have  by auto
then obtain is1 is2 where    by (metis dims_tensor_prod valid_index_split)
then have  by auto
show
using lookup_tensor_prod[OF  ] lookup_tensor_prod[OF  ]
lookup_smult[OF ] lookup_smult[OF ]  by simp
qed

lemma smult_prod_extract2:
fixes a::
shows
proof (rule tensor_lookup_eqI)
show  by simp
fix  assume
then have  by auto
then obtain is1 is2 where    by (metis dims_tensor_prod valid_index_split)
then have  by auto
show
using lookup_tensor_prod[OF  ] lookup_tensor_prod[OF  ]
lookup_smult[OF ] lookup_smult[OF ]  by simp
qed

lemma order_0_multiple_of_one:
assumes
obtains a where
proof
assume
have  using assms by (simp add:length_vec)
then obtain a where  by (metis One_nat_def Suc_length_conv length_0_conv)
moreover have  unfolding smult_def tensor_one_def by (simp add: vec_smult_def)
ultimately have  using tensor_eqI by (metis assms dims_smult length_0_conv order_tensor_one)
then show  using  by auto
qed

lemma smult_1:
fixes A::
shows  unfolding smult_def tensor_one_def
apply (rule tensor_eqI)
apply (simp add: length_vec length_vec_smult)
by (metis dims_tensor length_vec length_vec_smult lookup_smult mult.left_neutral smult_def tensor_lookup_eqI)

lemma tensor0_prod_right[simp]:
proof (rule tensor_lookup_eqI,simp)
fix  assume
then obtain is1 is2 where
by (metis dims_tensor0 dims_tensor_prod valid_index_split)
then show
by (metis (no_types, lifting)  dims_tensor0 dims_tensor_prod lookup_tensor0 lookup_tensor_prod mult_zero_right)
qed

lemma tensor0_prod_left[simp]:
proof (rule tensor_lookup_eqI,simp)
fix  assume
then obtain is1 is2 where
by (metis dims_tensor0 dims_tensor_prod valid_index_split)
then show
by (metis (no_types, lifting)  dims_tensor0 dims_tensor_prod lookup_tensor0 lookup_tensor_prod mult_zero_left)
qed

lemma subtensor_prod_with_vec:
assumes
shows
proof (rule tensor_lookup_eqI)
have  using assms(1) by auto
have
by (metis One_nat_def Suc_length_conv append_Cons assms(1) dims_tensor_prod list.sel(1))
show
unfolding dims_smult dims_subtensor[OF  [unfolded ] ]
by (metis One_nat_def Suc_length_conv append.simps(2) append_self_conv2 assms(1) dims_tensor_prod length_0_conv list.sel(3))
next
fix  assume
have  using assms(1) by auto
have
by (metis One_nat_def Suc_length_conv append_Cons assms(1) dims_tensor_prod list.sel(1))
then have
using [unfolded dims_subtensor[OF  [unfolded ] ]]
by (metis One_nat_def Suc_length_conv append_self_conv2 assms(1) dims_tensor_prod length_0_conv list.sel(3) list.simps(3) tl_append2)
have  using assms by (metis One_nat_def Suc_length_conv length_0_conv list.sel(1) valid_index.Nil valid_index.simps)
then have  using  dims_subtensor valid_index.Cons by auto
then show
unfolding lookup_subtensor1[OF ]
using lookup_tensor_prod[OF  ] lookup_smult
using append_Cons by fastforce
qed

end