Theory Tensor
section ‹Tensor Products›
theory Tensor
imports
Complex_Vectors
Matrix_Tensor.Matrix_Tensor
Jordan_Normal_Form.Matrix
begin
text ‹
There is already a formalization of tensor products in the Archive of Formal Proofs,
namely Matrix\_Tensor.thy in Tensor Product of Matrices[1] by T.V.H. Prathamesh, but it does not build
on top of the formalization of vectors and matrices given in Matrices, Jordan Normal Forms, and
Spectral Radius Theory[2] by René Thiemann and Akihisa Yamada.
In the present theory our purpose consists in giving such a formalization. Of course, we will reuse
Prathamesh's code as much as possible, and in order to achieve that we formalize some lemmas that
translate back and forth between vectors (resp. matrices) seen as lists (resp. lists of lists) and
vectors (resp. matrices) as formalized in [2].
›
subsection ‹The Kronecker Product of Complex Vectors›
definition tensor_vec:: "complex Matrix.vec ⇒ complex Matrix.vec ⇒ complex Matrix.vec" (infixl ‹⊗› 63)
where "tensor_vec u v ≡ vec_of_list (mult.vec_vec_Tensor (*) (list_of_vec u) (list_of_vec v))"
subsection ‹The Tensor Product of Complex Matrices›
text ‹To see a matrix in the sense of [2] as a matrix in the sense of [1], we convert it into its list
of column vectors.›
definition mat_to_cols_list:: "complex Matrix.mat ⇒ complex list list" where
"mat_to_cols_list A = [[A $$ (i,j) . i <- [0..< dim_row A]] . j <- [0..< dim_col A]]"
lemma length_mat_to_cols_list [simp]:
"length (mat_to_cols_list A) = dim_col A"
by (simp add: mat_to_cols_list_def)
lemma length_cols_mat_to_cols_list [simp]:
assumes "j < dim_col A"
shows "length [A $$ (i,j) . i <- [0..< dim_row A]] = dim_row A"
using assms by simp
lemma length_row_mat_to_cols_list [simp]:
assumes "i < dim_row A"
shows "length (row (mat_to_cols_list A) i) = dim_col A"
using assms by (simp add: row_def)
lemma length_col_mat_to_cols_list [simp]:
assumes "j < dim_col A"
shows "length (col (mat_to_cols_list A) j) = dim_row A"
using assms by (simp add: col_def mat_to_cols_list_def)
lemma mat_to_cols_list_is_not_Nil [simp]:
assumes "dim_col A > 0"
shows "mat_to_cols_list A ≠ []"
using assms by (simp add: mat_to_cols_list_def)
text ‹Link between Matrix\_Tensor.row\_length and Matrix.dim\_row›
lemma row_length_mat_to_cols_list [simp]:
assumes "dim_col A > 0"
shows "mult.row_length (mat_to_cols_list A) = dim_row A"
proof -
have "mat_to_cols_list A ≠ []" by (simp add: assms)
then have "mult.row_length (mat_to_cols_list A) = length (hd (mat_to_cols_list A))"
using mult.row_length_def[of "1" "(*)"]
by (simp add: ‹⋀xs. Matrix_Tensor.mult 1 (*) ⟹ mult.row_length xs ≡ if xs = [] then 0 else length (hd xs)› mult.intro)
thus ?thesis by (simp add: assms mat_to_cols_list_def upt_conv_Cons)
qed
text ‹@{term mat_to_cols_list} is a matrix in the sense of @{theory Matrix.Matrix_Legacy}.›
lemma mat_to_cols_list_is_mat [simp]:
assumes "dim_col A > 0"
shows "mat (mult.row_length (mat_to_cols_list A)) (length (mat_to_cols_list A)) (mat_to_cols_list A)"
proof -
have "Ball (set (mat_to_cols_list A)) (Matrix_Legacy.vec (mult.row_length (mat_to_cols_list A)))"
using assms row_length_mat_to_cols_list mat_to_cols_list_def Ball_def set_def vec_def by fastforce
thus ?thesis by(auto simp: mat_def)
qed
definition mat_of_cols_list:: "nat ⇒ complex list list ⇒ complex Matrix.mat" where
"mat_of_cols_list nr cs = Matrix.mat nr (length cs) (λ (i,j). cs ! j ! i)"
lemma index_mat_of_cols_list [simp]:
assumes "i < nr" and "j < length cs"
shows "mat_of_cols_list nr cs $$ (i,j) = cs ! j ! i"
by (simp add: assms mat_of_cols_list_def)
lemma mat_to_cols_list_to_mat [simp]:
"mat_of_cols_list (dim_row A) (mat_to_cols_list A) = A"
proof
show f1:"dim_row (mat_of_cols_list (dim_row A) (mat_to_cols_list A)) = dim_row A"
by (simp add: mat_of_cols_list_def)
next
show f2:"dim_col (mat_of_cols_list (dim_row A) (mat_to_cols_list A)) = dim_col A"
by (simp add: Tensor.mat_of_cols_list_def)
next
show "⋀i j. i < dim_row A ⟹ j < dim_col A ⟹
(mat_of_cols_list (dim_row A) (mat_to_cols_list A)) $$ (i, j) = A $$ (i, j)"
by (simp add: mat_of_cols_list_def mat_to_cols_list_def)
qed
lemma plus_mult_cpx [simp]:
"plus_mult 1 (*) 0 (+) (a_inv cpx_rng)"
apply unfold_locales
apply (auto intro: cpx_cring_is_field simp: field_simps)
proof -
show "⋀x. x + ⊖⇘cpx_rng⇙ x = 0"
using group.r_inv[of "cpx_rng"] cpx_cring_is_field field_def domain_def cpx_rng_def
by (metis UNIV_I cring.cring_simprules(17) ordered_semiring_record_simps(1)
ordered_semiring_record_simps(11) ordered_semiring_record_simps(12))
show "⋀x. x + ⊖⇘cpx_rng⇙ x = 0"
using group.r_inv[of "cpx_rng"] cpx_cring_is_field field_def domain_def cpx_rng_def
by (metis UNIV_I cring.cring_simprules(17) ordered_semiring_record_simps(1)
ordered_semiring_record_simps(11) ordered_semiring_record_simps(12))
qed
lemma list_to_mat_to_cols_list [simp]:
fixes l::"complex list list"
assumes "mat nr nc l"
shows "mat_to_cols_list (mat_of_cols_list nr l) = l"
proof -
have "length (mat_to_cols_list (mat_of_cols_list nr l)) = length l"
by (simp add: mat_of_cols_list_def)
moreover have f1:"∀j<length l. length(l ! j) = mult.row_length l"
using assms plus_mult.row_length_constant plus_mult_cpx by fastforce
moreover have "⋀j. j<length l ⟶ mat_to_cols_list (mat_of_cols_list nr l) ! j = l ! j"
proof
fix j
assume a:"j < length l"
then have f2:"length (mat_to_cols_list (mat_of_cols_list nr l) ! j) = length (l ! j)"
by (metis col_def mat_def vec_def mat_of_cols_list_def assms dim_col_mat(1) dim_row_mat(1)
length_col_mat_to_cols_list nth_mem)
then have "∀i<mult.row_length l. mat_to_cols_list (mat_of_cols_list nr l) ! j ! i = l ! j ! i"
using a mat_to_cols_list_def mat_of_cols_list_def f1 by simp
thus "mat_to_cols_list (Tensor.mat_of_cols_list nr l) ! j = l ! j"
using f2 by(simp add: nth_equalityI a f1)
qed
ultimately show ?thesis using nth_equalityI by metis
qed
lemma col_mat_of_cols_list [simp]:
assumes "j < length l"
shows "Matrix.col (mat_of_cols_list (length (l ! j)) l) j = vec_of_list (l ! j)"
proof -
define u where "u = Matrix.col (mat_of_cols_list (length (l ! j)) l) j"
then have "dim_vec u = dim_vec (vec_of_list (l ! j))"
by (auto simp: u_def mat_of_cols_list_def Matrix.col_def vec_of_list_def)
(metis dim_vec_of_list vec_of_list.abs_eq)
moreover have "∀i<length(l ! j). u $ i = vec_of_list (l ! j) $ i"
by (simp add: u_def vec_of_list_index mat_of_cols_list_def assms)
ultimately show ?thesis by(simp add: vec_eq_iff u_def)
qed
definition tensor_mat:: "[complex Matrix.mat, complex Matrix.mat] ⇒ complex Matrix.mat" (infixl ‹⨂› 63) where
"tensor_mat A B ≡
mat_of_cols_list (dim_row A * dim_row B) (mult.Tensor (*) (mat_to_cols_list A) (mat_to_cols_list B))"
lemma dim_row_tensor_mat [simp]:
"dim_row (A ⨂ B) = dim_row A * dim_row B"
by (simp add: mat_of_cols_list_def tensor_mat_def)
lemma dim_col_tensor_mat [simp]:
"dim_col (A ⨂ B) = dim_col A * dim_col B"
using tensor_mat_def mat_of_cols_list_def mult.length_Tensor[of "1" "(*)"]
by(simp add: ‹⋀M2 M1. Matrix_Tensor.mult 1 (*) ⟹ length (mult.Tensor (*) M1 M2) = length M1 * length M2› mult.intro)
lemma mat_to_cols_list_nth_nth_eq_index_mat:
‹mat_to_cols_list A ! j ! i = A $$ (i, j)› if ‹i < dim_row A› ‹j < dim_col A›
using that by (simp add: mat_to_cols_list_def)
lemma index_tensor_mat [simp]:
assumes a1:"dim_row A = rA" and a2:"dim_col A = cA" and a3:"dim_row B = rB" and a4:"dim_col B = cB"
and a5:"i < rA * rB" and a6:"j < cA * cB" and a7:"cA > 0" and a8:"cB > 0"
shows "(A ⨂ B) $$ (i,j) = A $$ (i div rB, j div cB) * B $$ (i mod rB, j mod cB)"
proof -
from a5 have ‹rB > 0›
by (auto intro: gr0I)
have "(A ⨂ B) $$ (i,j) = (mult.Tensor (*) (mat_to_cols_list A) (mat_to_cols_list B)) ! j ! i"
using assms tensor_mat_def mat_of_cols_list_def dim_col_tensor_mat by simp
moreover have f:"i < mult.row_length (mat_to_cols_list A) * mult.row_length (mat_to_cols_list B)"
by (simp add: a1 a2 a3 a4 a5 a7 a8)
moreover have "j < length (mat_to_cols_list A) * length (mat_to_cols_list B)"
by (simp add: a2 a4 a6)
moreover have "mat (mult.row_length (mat_to_cols_list A)) (length (mat_to_cols_list A)) (mat_to_cols_list A)"
using a2 a7 mat_to_cols_list_is_mat by blast
moreover have "mat (mult.row_length (mat_to_cols_list B)) (length (mat_to_cols_list B)) (mat_to_cols_list B)"
using a4 a8 mat_to_cols_list_is_mat by blast
ultimately have "(A ⨂ B) $$ (i,j) =
(mat_to_cols_list A) ! (j div length (mat_to_cols_list B)) ! (i div mult.row_length (mat_to_cols_list B))
* (mat_to_cols_list B) ! (j mod length (mat_to_cols_list B)) ! (i mod mult.row_length (mat_to_cols_list B))"
using mult.matrix_Tensor_elements[of "1" "(*)"]
by(simp add: ‹⋀M2 M1. mult 1 (*) ⟹ ∀i j. (i<mult.row_length M1 * mult.row_length M2
∧ j<length M1 * length M2) ∧ mat (mult.row_length M1) (length M1) M1 ∧ mat (mult.row_length M2) (length M2) M2 ⟶
mult.Tensor (*) M1 M2 ! j ! i = M1 ! (j div length M2) ! (i div mult.row_length M2) * M2 ! (j mod length M2) ! (i mod mult.row_length M2)› mult.intro)
moreover have ‹i div rB < rA› ‹j div cB < cA› ‹i mod rB < rB›
using a5 a6 ‹rB > 0› by (simp_all add: less_mult_imp_div_less)
ultimately show ?thesis
using a1 a2 a3 a4 a8
by (simp add: mat_to_cols_list_nth_nth_eq_index_mat)
qed
text ‹To go from @{term Matrix.row} to @{term Matrix_Legacy.row}›
lemma Matrix_row_is_Legacy_row:
assumes "i < dim_row A"
shows "Matrix.row A i = vec_of_list (row (mat_to_cols_list A) i)"
proof
show "dim_vec (Matrix.row A i) = dim_vec (vec_of_list (row (mat_to_cols_list A) i))"
using length_mat_to_cols_list Matrix.dim_vec_of_list by (metis row_def index_row(2) length_map)
next
show "⋀j. j<dim_vec (vec_of_list (row (mat_to_cols_list A) i)) ⟹
Matrix.row A i $ j = vec_of_list (row (mat_to_cols_list A) i) $ j"
using Matrix.row_def vec_of_list_def mat_to_cols_list_def
by (smt (verit) row_def assms dim_vec_of_list index_mat_of_cols_list index_row(1)
length_mat_to_cols_list length_row_mat_to_cols_list mat_to_cols_list_to_mat nth_map vec_of_list_index)
qed
text ‹To go from @{term Matrix_Legacy.row} to @{term Matrix.row}›
lemma Legacy_row_is_Matrix_row:
assumes "i < mult.row_length A"
shows "row A i = list_of_vec (Matrix.row (mat_of_cols_list (mult.row_length A) A) i)"
proof (rule nth_equalityI)
show "length (row A i) = length (list_of_vec (Matrix.row (mat_of_cols_list (mult.row_length A) A) i))"
using row_def length_list_of_vec by(metis mat_of_cols_list_def dim_col_mat(1) index_row(2) length_map)
next
fix j:: nat
assume "j < length (row A i)"
then show "row A i ! j = list_of_vec (Matrix.row (mat_of_cols_list (mult.row_length A) A) i) ! j"
using assms index_mat_of_cols_list
by(metis row_def mat_of_cols_list_def dim_col_mat(1) dim_row_mat(1) index_row(1) length_map list_of_vec_index nth_map)
qed
text ‹To go from @{term Matrix.col} to @{term Matrix_Legacy.col}›
lemma Matrix_col_is_Legacy_col:
assumes "j < dim_col A"
shows "Matrix.col A j = vec_of_list (col (mat_to_cols_list A) j)"
proof
show "dim_vec (Matrix.col A j) = dim_vec (vec_of_list (col (mat_to_cols_list A) j))"
by (simp add: col_def assms mat_to_cols_list_def)
next
show "⋀i. i < dim_vec (vec_of_list (col (mat_to_cols_list A) j)) ⟹
Matrix.col A j $ i = vec_of_list (col (mat_to_cols_list A) j) $ i"
using mat_to_cols_list_def
by (metis col_def assms col_mat_of_cols_list length_col_mat_to_cols_list length_mat_to_cols_list
mat_to_cols_list_to_mat)
qed
text ‹To go from @{term Matrix_Legacy.col} to @{term Matrix.col}›
lemma Legacy_col_is_Matrix_col:
assumes a1:"j < length A" and a2:"length (A ! j) = mult.row_length A"
shows "col A j = list_of_vec (Matrix.col (mat_of_cols_list (mult.row_length A) A) j)"
proof (rule nth_equalityI)
have "length (list_of_vec (Matrix.col (mat_of_cols_list (mult.row_length A) A) j)) =
dim_vec (Matrix.col (mat_of_cols_list (mult.row_length A) A) j)"
using length_list_of_vec by blast
also have "… = dim_row (mat_of_cols_list (mult.row_length A) A)"
using Matrix.col_def by simp
also have f1:"… = mult.row_length A"
by (simp add: mat_of_cols_list_def)
finally show f2:"length (col A j) = length (list_of_vec (Matrix.col (mat_of_cols_list (mult.row_length A) A) j))"
using a2 by (simp add: col_def)
next
fix i:: nat
assume "i<length (col A j)"
then show "(col A j) ! i = (list_of_vec (Matrix.col (mat_of_cols_list (mult.row_length A) A) j)) ! i"
by (metis col_def a1 a2 col_mat_of_cols_list list_vec)
qed
text ‹Link between @{term plus_mult.scalar_product} and @{term Matrix.scalar_prod}›
lemma scalar_prod_is_Matrix_scalar_prod [simp]:
fixes u::"complex list" and v::"complex list"
assumes "length u = length v"
shows "plus_mult.scalar_product (*) 0 (+) u v = (vec_of_list u) ∙ (vec_of_list v)"
proof -
have f:"(vec_of_list u) ∙ (vec_of_list v) = (∑i=0..<length v. u ! i * v ! i)"
using assms scalar_prod_def[of "vec_of_list u" "vec_of_list v"] Matrix.dim_vec_of_list[of v] index_vec_of_list
by (metis (no_types, lifting) atLeastLessThan_iff sum.cong)
thus ?thesis
proof -
have "plus_mult.scalar_product (*) 0 (+) u v = semiring_0_class.scalar_prod u v"
using plus_mult.scalar_product_def[of 1 "(*)" 0 "(+)" "a_inv cpx_rng" u v] by simp
also have "… = sum_list (map (λ(x,y). x * y) (zip u v))"
by (simp add: scalar_prod)
moreover have "∀i<length v. (zip u v) ! i = (u ! i, v ! i)"
using assms zip_def by simp
then have "∀i<length v. (map (λ(x,y). x * y) (zip u v)) ! i = u ! i * v ! i"
by (simp add: assms)
ultimately have "plus_mult.scalar_product (*) 0 (+) u v = (∑i=0..<length v. u ! i * v ! i)"
by(metis (no_types, lifting) assms atLeastLessThan_iff length_map map_fst_zip sum.cong sum_list_sum_nth)
thus ?thesis by (simp add: f)
qed
qed
text ‹Link between @{term times} and @{term plus_mult.matrix_mult}›
lemma matrix_mult_to_times_mat:
assumes "dim_col A > 0" and "dim_col B > 0" and "dim_col (A::complex Matrix.mat) = dim_row B"
shows "A * B = mat_of_cols_list (dim_row A) (plus_mult.matrix_mult (*) 0 (+) (mat_to_cols_list A) (mat_to_cols_list B))"
proof
define M where "M = mat_of_cols_list (dim_row A) (plus_mult.matrix_mult (*) 0 (+) (mat_to_cols_list A) (mat_to_cols_list B))"
then show f1:"dim_row (A * B) = dim_row M"
by (simp add: mat_of_cols_list_def times_mat_def)
have "length (plus_mult.matrix_mult (*) 0 (+) (mat_to_cols_list A) (mat_to_cols_list B)) = dim_col B"
by (simp add: mat_multI_def)
then show f2:"dim_col (A * B) = dim_col M"
by (simp add: M_def times_mat_def mat_of_cols_list_def)
show "⋀i j. i < dim_row M ⟹ j < dim_col M ⟹ (A * B) $$ (i, j) = M $$ (i, j)"
proof -
fix i j
assume a1:"i < dim_row M" and a2:"j < dim_col M"
then have "(A * B) $$ (i,j) = Matrix.row A i ∙ Matrix.col B j"
using f1 f2 by simp
also have "… = vec_of_list (row (mat_to_cols_list A) i) ∙ vec_of_list (col (mat_to_cols_list B) j)"
using f1 f2 a1 a2 by (simp add: Matrix_row_is_Legacy_row Matrix_col_is_Legacy_col)
also have "… = plus_mult.scalar_product (*) 0 (+) (row (mat_to_cols_list A) i) (col (mat_to_cols_list B) j)"
using a1 a2 assms(3) f1 f2 by simp
also have "M $$ (i,j) = plus_mult.scalar_product (*) 0 (+) (row (mat_to_cols_list A) i) (col (mat_to_cols_list B) j)"
proof-
have "M $$ (i,j) = (plus_mult.matrix_mult (*) 0 (+) (mat_to_cols_list A) (mat_to_cols_list B)) ! j ! i"
using M_def f1 f2
‹length (mat_mult (mult.row_length (mat_to_cols_list A)) (mat_to_cols_list A) (mat_to_cols_list B)) = dim_col B› a1 a2 by simp
moreover have "mat (mult.row_length (mat_to_cols_list A)) (dim_col A) (mat_to_cols_list A)"
using mat_to_cols_list_is_mat assms(1) by simp
moreover have "mat (dim_col A) (dim_col B) (mat_to_cols_list B)"
using assms(2) assms(3) mat_to_cols_list_is_mat by simp
ultimately show ?thesis
using assms(1) a1 a2 row_length_mat_to_cols_list plus_mult.matrix_index[of 1 "(*)" 0 "(+)"] plus_mult_cpx
by (smt (verit) f1 f2 index_mult_mat(2) index_mult_mat(3))
qed
finally show "(A * B) $$ (i, j) = M $$ (i, j)" by simp
qed
qed
lemma mat_to_cols_list_times_mat [simp]:
assumes "dim_col A = dim_row B" and "dim_col A > 0"
shows "mat_to_cols_list (A * B) = plus_mult.matrix_mult (*) 0 (+) (mat_to_cols_list A) (mat_to_cols_list B)"
proof (rule nth_equalityI)
define M where "M = plus_mult.matrix_mult (*) 0 (+) (mat_to_cols_list A) (mat_to_cols_list B)"
then show f0:"length (mat_to_cols_list (A * B)) = length M" by (simp add: mat_multI_def)
moreover have f1:"⋀j. j<length (mat_to_cols_list (A * B)) ⟶ mat_to_cols_list (A * B) ! j = M ! j"
proof
fix j:: nat
assume a0:"j < length (mat_to_cols_list (A * B))"
then have "length (mat_to_cols_list (A * B) ! j) = dim_row A"
by (simp add: mat_to_cols_list_def)
then also have f2:"length (M ! j) = dim_row A"
using a0 M_def mat_multI_def[of 0 "(+)" "(*)" "dim_row A" "mat_to_cols_list A" "mat_to_cols_list B"]
row_length_mat_to_cols_list assms(2)
by (metis assms(1) f0 length_greater_0_conv length_map length_mat_to_cols_list
list_to_mat_to_cols_list mat_mult mat_to_cols_list_is_mat matrix_mult_to_times_mat)
ultimately have "length (mat_to_cols_list (A * B) ! j) = length (M ! j)" by simp
moreover have "⋀i. i<dim_row A ⟶ mat_to_cols_list (A * B) ! j ! i = M ! j ! i"
proof
fix i
assume a1:"i < dim_row A"
have "mat (mult.row_length (mat_to_cols_list A)) (dim_col A) (mat_to_cols_list A)"
using mat_to_cols_list_is_mat assms(2) by simp
moreover have "mat (dim_col A) (dim_col B) (mat_to_cols_list B)"
using mat_to_cols_list_is_mat assms(1) a0 by simp
ultimately have "M ! j ! i = plus_mult.scalar_product (*) 0 (+) (row (mat_to_cols_list A) i) (col (mat_to_cols_list B) j)"
using plus_mult.matrix_index a0 a1 row_length_mat_to_cols_list assms(2) plus_mult_cpx M_def
by (metis index_mult_mat(3) length_mat_to_cols_list)
also have "… = vec_of_list (row (mat_to_cols_list A) i) ∙ vec_of_list (col (mat_to_cols_list B) j)"
using a0 a1 assms(1) by simp
finally show "mat_to_cols_list (A * B) ! j ! i = M ! j ! i"
using mat_to_cols_list_def index_mult_mat(1) a0 a1
by(simp add: Matrix_row_is_Legacy_row Matrix_col_is_Legacy_col)
qed
ultimately show "mat_to_cols_list (A * B) ! j = M ! j" by(simp add: nth_equalityI f2)
qed
fix i:: nat
assume "i < length (mat_to_cols_list (A * B))"
thus "mat_to_cols_list (A * B) ! i = M ! i" by (simp add: f1)
qed
text ‹
Finally, we prove that the tensor product of complex matrices is distributive over the
multiplication of complex matrices.
›
lemma mult_distr_tensor:
assumes a1:"dim_col A = dim_row B" and a2:"dim_col C = dim_row D" and a3:"dim_col A > 0" and
a4:"dim_col B > 0" and a5:"dim_col C > 0" and a6:"dim_col D > 0"
shows "(A * B) ⨂ (C * D) = (A ⨂ C) * (B ⨂ D)"
proof -
define A' B' C' D' M N where "A' = mat_to_cols_list A" and "B' = mat_to_cols_list B" and
"C' = mat_to_cols_list C" and "D' = mat_to_cols_list D" and
"M = mat_of_cols_list (dim_row A * dim_row C) (mult.Tensor (*) (mat_to_cols_list A) (mat_to_cols_list C))" and
"N = mat_of_cols_list (dim_row B * dim_row D) (mult.Tensor (*) (mat_to_cols_list B) (mat_to_cols_list D))"
then have "(A ⨂ C) * (B ⨂ D) = M * N"
by (simp add: tensor_mat_def)
also have "… = mat_of_cols_list (dim_row A * dim_row C) (plus_mult.matrix_mult (*) 0 (+)
(mat_to_cols_list M) (mat_to_cols_list N))"
using assms M_def N_def dim_col_tensor_mat dim_row_tensor_mat tensor_mat_def
by(simp add: matrix_mult_to_times_mat)
also have f1:"… = mat_of_cols_list (dim_row A * dim_row C) (plus_mult.matrix_mult (*) 0 (+)
(mult.Tensor (*) A' C') (mult.Tensor (*) B' D'))"
proof -
define M' N' where "M' = mult.Tensor (*) (mat_to_cols_list A) (mat_to_cols_list C)" and
"N' = mult.Tensor (*) (mat_to_cols_list B) (mat_to_cols_list D)"
then have "mat (mult.row_length M') (length M') M'"
using M'_def mult.effective_well_defined_Tensor[of 1 "(*)"] mat_to_cols_list_is_mat a3 a5
by (smt (verit) mult.length_Tensor mult.row_length_mat plus_mult_cpx plus_mult_def)
moreover have "mat (mult.row_length N') (length N') N'"
using N'_def mult.effective_well_defined_Tensor[of 1 "(*)"] mat_to_cols_list_is_mat a4 a6
by (smt (verit) mult.length_Tensor mult.row_length_mat plus_mult_cpx plus_mult_def)
ultimately show ?thesis
using list_to_mat_to_cols_list M_def N_def mult.row_length_mat row_length_mat_to_cols_list
assms(3) a4 a5 a6 A'_def B'_def C'_def D'_def by(metis M'_def N'_def plus_mult_cpx plus_mult_def)
qed
also have "… = mat_of_cols_list (dim_row A * dim_row C) (mult.Tensor (*)
(plus_mult.matrix_mult (*) 0 (+) A' B')
(plus_mult.matrix_mult (*) 0 (+) C' D'))"
proof -
have f2:"mat (mult.row_length A') (length A') A'"
using A'_def a3 mat_to_cols_list_is_mat by simp
moreover have "mat (mult.row_length B') (length B') B'"
using B'_def a4 mat_to_cols_list_is_mat by simp
moreover have "mat (mult.row_length C') (length C') C'"
using C'_def a5 mat_to_cols_list_is_mat by simp
moreover have "mat (mult.row_length D') (length D') D'"
using D'_def a6 mat_to_cols_list_is_mat by simp
moreover have "length A' = mult.row_length B'"
using A'_def B'_def a1 a4 by simp
moreover have "length C' = mult.row_length D'"
using C'_def D'_def a2 a6 by simp
moreover have "A' ≠ [] ∧ B' ≠ [] ∧ C' ≠ [] ∧ D' ≠ []"
using A'_def B'_def C'_def D'_def a3 a4 a5 a6 by simp
ultimately have "plus_mult.matrix_match A' B' C' D'"
using plus_mult.matrix_match_def[of 1 "(*)" 0 "(+)" "a_inv cpx_rng"] by simp
thus ?thesis
using f1 plus_mult.distributivity plus_mult_cpx by fastforce
qed
also have "… = mat_of_cols_list (dim_row A * dim_row C) (mult.Tensor (*)
(mat_to_cols_list (A * B)) (mat_to_cols_list (C * D)))"
using A'_def B'_def C'_def D'_def a1 a2 a3 a5 by simp
finally show ?thesis by(simp add: tensor_mat_def)
qed
lemma tensor_mat_is_assoc:
fixes A B C:: "complex Matrix.mat"
shows "A ⨂ (B ⨂ C) = (A ⨂ B) ⨂ C"
proof-
define M where d:"M = mat_of_cols_list (dim_row B * dim_row C) (mult.Tensor (*) (mat_to_cols_list B) (mat_to_cols_list C))"
then have *: "B ⨂ C = M"
using tensor_mat_def by simp
then have **: "A ⨂ (B ⨂ C) = mat_of_cols_list (dim_row A * (dim_row B * dim_row C))
(mult.Tensor (*) (mat_to_cols_list A) (mat_to_cols_list M))"
using tensor_mat_def d dim_row_tensor_mat by simp
then have ***: "mat_to_cols_list M = mult.Tensor (*) (mat_to_cols_list B) (mat_to_cols_list C)"
unfolding d using list_to_mat_to_cols_list
by (smt (verit) Tensor.mat_of_cols_list_def bot_nat_0.not_eq_extremum * d dim_col_mat(1) dim_col_tensor_mat length_greater_0_conv length_mat_to_cols_list mat_to_cols_list_is_mat mult.effective_well_defined_Tensor mult_is_0 plus_mult_cpx plus_mult_def row_length_mat_to_cols_list)
with ** have "A ⨂ (B ⨂ C) = mat_of_cols_list (dim_row A * (dim_row B * dim_row C))
(mult.Tensor (*) (mat_to_cols_list A) (mult.Tensor (*) (mat_to_cols_list B) (mat_to_cols_list C)))" by simp
moreover have "… = mat_of_cols_list ((dim_row A * dim_row B) * dim_row C)
(mult.Tensor (*) (mult.Tensor (*) (mat_to_cols_list A) (mat_to_cols_list B)) (mat_to_cols_list C))"
using Matrix_Tensor.mult.associativity
by (smt (verit) ab_semigroup_mult_class.mult_ac(1) length_greater_0_conv length_mat_to_cols_list
mat_to_cols_list_is_mat mult.Tensor.simps(1) mult.Tensor_null plus_mult_cpx plus_mult_def)
ultimately show ?thesis
using tensor_mat_def
by (smt (verit) Tensor.mat_of_cols_list_def dim_col_mat(1) dim_col_tensor_mat dim_row_tensor_mat length_0_conv
list_to_mat_to_cols_list mat_to_cols_list_is_mat mult.well_defined_Tensor mult_is_0 neq0_conv
plus_mult_cpx plus_mult_def row_length_mat_to_cols_list)
qed
end