Theory More_Tensor

(* 
Authors: 

  Anthony Bordg, University of Cambridge, apdb3@cam.ac.uk;
  Yijun He, University of Cambridge, yh403@cam.ac.uk 
*)

section ‹Further Results on Tensor Products›

theory More_Tensor
imports
  Quantum
  Tensor
  Jordan_Normal_Form.Matrix
  Basics
begin


lemma tensor_prod_2 [simp]: 
"mult.vec_vec_Tensor (*) [x1::complex,x2] [x3, x4] = [x1 * x3, x1 * x4, x2 * x3, x2 * x4]"
proof -
  have "Matrix_Tensor.mult (1::complex) (*)"
    by (simp add: Matrix_Tensor.mult_def)
  thus "mult.vec_vec_Tensor (*) [x1::complex,x2] [x3,x4] = [x1*x3,x1*x4,x2*x3,x2*x4]"
    using mult.vec_vec_Tensor_def[of "(1::complex)" "(*)"] mult.times_def[of "(1::complex)" "(*)"] by simp
qed

lemma list_vec [simp]: 
  assumes "v  state_qbit 1"
  shows "list_of_vec v = [v $ 0, v $ 1]"
proof -
  have "Rep_vec v = (fst(Rep_vec v), snd(Rep_vec v))" by simp
  also have " = (2, vec_index v)"
    by (metis (mono_tags, lifting) assms dim_vec.rep_eq mem_Collect_eq power_one_right state_qbit_def vec_index.rep_eq)
  moreover have "[0..<2::nat] = [0,1]"
    by(simp add: upt_rec) 
  ultimately show ?thesis
    by(simp add: list_of_vec_def)
qed

lemma vec_tensor_prod_2 [simp]:
  assumes "v  state_qbit 1" and "w  state_qbit 1"
  shows "v  w = vec_of_list [v $ 0 * w $ 0, v $ 0 * w $ 1, v $ 1 * w $ 0, v $ 1 * w $ 1]"
proof -
  have "list_of_vec v = [v $ 0, v $ 1]"
    using assms by simp
  moreover have "list_of_vec w = [w $ 0, w $ 1]"
    using assms by simp
  ultimately show "v  w = vec_of_list [v $ 0 * w $ 0, v $ 0 * w $ 1, v $ 1 * w $ 0, v $ 1 * w $ 1]"
    by(simp add: tensor_vec_def)
qed

lemma vec_dim_of_vec_of_list [simp]:
  assumes "length l = n"
  shows "dim_vec (vec_of_list l) = n"
  using assms vec_of_list_def by simp

lemma vec_tensor_prod_2_bis [simp]:
  assumes "v  state_qbit 1" and "w  state_qbit 1"
  shows "v  w = Matrix.vec 4 (λi. if i = 0 then v $ 0 * w $ 0 else 
                                      if i = 3 then v $ 1 * w $ 1 else
                                          if i = 1 then v $ 0 * w $ 1 else v $ 1 * w $ 0)"
proof
  define u where "u = Matrix.vec 4 (λi. if i = 0 then v $ 0 * w $ 0 else 
                                          if i = 3 then v $ 1 * w $ 1 else
                                            if i = 1 then v $ 0 * w $ 1 else v $ 1 * w $ 0)"
  then show f2:"dim_vec (v  w) = dim_vec u"
    using assms by simp
  show "i. i < dim_vec u  (v  w) $ i = u $ i"
    apply (auto simp: u_def)
    using assms apply auto[3]
    apply (simp add: numeral_3_eq_3)
    using u_def vec_of_list_index vec_tensor_prod_2 index_is_2 
    by (metis (no_types, lifting) One_nat_def assms nth_Cons_0 nth_Cons_Suc numeral_2_eq_2)
qed

lemma index_col_mat_of_cols_list [simp]:
  assumes "i < n" and "j < length l"
  shows "Matrix.col (mat_of_cols_list n l) j $ i = l ! j ! i"
  apply (auto simp: Matrix.col_def mat_of_cols_list_def)
  using assms less_le_trans by fastforce

lemma multTensor2 [simp]:
  assumes a1:"A = Matrix.mat 2 1 (λ(i,j). if i = 0 then a0 else a1)" and 
          a2:"B = Matrix.mat 2 1 (λ(i,j). if i = 0 then b0 else b1)"
  shows "mult.Tensor (*) (mat_to_cols_list A) (mat_to_cols_list B) = [[a0*b0, a0*b1, a1*b0, a1*b1]]"
proof -
  have "mat_to_cols_list A = [[a0, a1]]"
    by (auto simp: a1 mat_to_cols_list_def) (simp add: numeral_2_eq_2)
  moreover have f2:"mat_to_cols_list B = [[b0, b1]]"
    by (auto simp: a2 mat_to_cols_list_def) (simp add: numeral_2_eq_2)
  ultimately have "mult.Tensor (*) (mat_to_cols_list A) (mat_to_cols_list B) = 
                   mult.Tensor (*) [[a0,a1]] [[b0,b1]]" by simp
  thus ?thesis
    using mult.Tensor_def[of "(1::complex)" "(*)"] mult.times_def[of "(1::complex)" "(*)"]
    by (metis (mono_tags, lifting) append_self_conv list.simps(6) mult.Tensor.simps(2) mult.vec_mat_Tensor.simps(1) 
mult.vec_mat_Tensor.simps(2) plus_mult_cpx plus_mult_def tensor_prod_2)
qed

lemma multTensor2_bis [simp]:
  assumes a1:"dim_row A = 2" and a2:"dim_col A = 1" and a3:"dim_row B = 2" and a4:"dim_col B = 1"
  shows "mult.Tensor (*) (mat_to_cols_list A) (mat_to_cols_list B) =  
[[A $$ (0,0) * B $$ (0,0), A $$ (0,0) *  B $$ (1,0), A $$ (1,0) * B $$ (0,0), A $$ (1,0) * B $$ (1,0)]]" 
proof -
  have "mat_to_cols_list A = [[A $$ (0,0), A $$ (1,0)]]"
    by (auto simp: a1 a2 mat_to_cols_list_def) (simp add: numeral_2_eq_2)
  moreover have f2:"mat_to_cols_list B = [[B $$ (0,0), B $$ (1,0)]]"
    by (auto simp: a3 a4 mat_to_cols_list_def) (simp add: numeral_2_eq_2)
  ultimately have "mult.Tensor (*) (mat_to_cols_list A) (mat_to_cols_list B) =
                   mult.Tensor (*) [[A $$ (0,0), A $$ (1,0)]] [[B $$ (0,0), B $$ (1,0)]]" by simp
  thus ?thesis
    using mult.Tensor_def[of "(1::complex)" "(*)"] mult.times_def[of "(1::complex)" "(*)"]
    by (smt (verit) append_self_conv list.simps(6) mult.Tensor.simps(2) mult.vec_mat_Tensor.simps(1) 
mult.vec_mat_Tensor.simps(2) plus_mult_cpx plus_mult_def tensor_prod_2)
qed

lemma mat_tensor_prod_2_prelim [simp]:
  assumes "state 1 v" and "state 1 w"
  shows "v  w = mat_of_cols_list 4 
[[v $$ (0,0) * w $$ (0,0), v $$ (0,0) * w $$ (1,0), v $$ (1,0) * w $$ (0,0), v $$ (1,0) * w $$ (1,0)]]"
proof
  define u where "u = mat_of_cols_list 4 
[[v $$ (0,0) * w $$ (0,0), v $$ (0,0) * w $$ (1,0), v $$ (1,0) * w $$ (0,0), v $$ (1,0) * w $$ (1,0)]]"
  then show f1:"dim_row (v  w) = dim_row u"
    using assms state_def mat_of_cols_list_def tensor_mat_def by simp
  show f2:"dim_col (v  w) = dim_col u"
    using assms state_def mat_of_cols_list_def tensor_mat_def u_def by simp
  show "i j. i < dim_row u  j < dim_col u   (v  w) $$ (i, j) = u $$ (i, j)"
      using u_def tensor_mat_def assms state_def by simp
qed

lemma mat_tensor_prod_2_col [simp]:
  assumes "state 1 v" and "state 1 w"
  shows "Matrix.col (v  w) 0 = Matrix.col v 0  Matrix.col w 0"
proof
  show f1:"dim_vec (Matrix.col (v  w) 0) = dim_vec (Matrix.col v 0  Matrix.col w 0)"
    using assms vec_tensor_prod_2_bis
    by (smt (verit) Tensor.mat_of_cols_list_def dim_col dim_row_mat(1) dim_vec mat_tensor_prod_2_prelim state.state_to_state_qbit)
next
  show "i. i<dim_vec (Matrix.col v 0  Matrix.col w 0)  Matrix.col (v  w) 0 $ i = (Matrix.col v 0  Matrix.col w 0) $ i"
  proof -
    have "dim_vec (Matrix.col v 0  Matrix.col w 0) = 4"
      by (metis (no_types, lifting) assms(1) assms(2) dim_vec state.state_to_state_qbit vec_tensor_prod_2_bis)
    moreover have "(Matrix.col v 0  Matrix.col w 0) $ 0 = v $$ (0,0) * w $$ (0,0)"
      using assms vec_tensor_prod_2 state.state_to_state_qbit col_index_of_mat_col
      by (smt (verit) nth_Cons_0 power_one_right state_def vec_of_list_index zero_less_numeral)
    moreover have " = Matrix.col (v  w) 0 $ 0"
      using assms by simp
    moreover have "(Matrix.col v 0  Matrix.col w 0) $ 1 = v $$ (0,0) * w $$ (1,0)"
      using assms vec_tensor_prod_2 state.state_to_state_qbit col_index_of_mat_col 
      by (smt (verit) One_nat_def Suc_1 lessI nth_Cons_0 power_one_right state_def vec_index_vCons_Suc 
vec_of_list_Cons vec_of_list_index zero_less_numeral)
    moreover have " = Matrix.col (v  w) 0 $ 1"
      using assms by simp
    moreover have "(Matrix.col v 0  Matrix.col w 0) $ 2 = v $$ (1,0) * w $$ (0,0)"
      using assms vec_tensor_prod_2 state.state_to_state_qbit col_index_of_mat_col
      by (smt (verit) One_nat_def Suc_1 lessI nth_Cons_0 power_one_right state_def vec_index_vCons_Suc 
vec_of_list_Cons vec_of_list_index zero_less_numeral)
    moreover have " = Matrix.col (v  w) 0 $ 2"
      using assms by simp
    moreover have "(Matrix.col v 0  Matrix.col w 0) $ 3 = v $$ (1,0) * w $$ (1,0)"
      using assms vec_tensor_prod_2 state.state_to_state_qbit col_index_of_mat_col numeral_3_eq_3
  by (smt (verit) One_nat_def Suc_1 lessI nth_Cons_0 power_one_right state_def vec_index_vCons_Suc 
vec_of_list_Cons vec_of_list_index zero_less_numeral)
    moreover have " = Matrix.col (v  w) 0 $ 3"
      using assms by simp
    ultimately show "i. i<dim_vec (Matrix.col v 0  Matrix.col w 0)  Matrix.col (v  w) 0 $ i = (Matrix.col v 0  Matrix.col w 0) $ i"
      using index_sl_four by auto
  qed
qed

lemma mat_tensor_prod_2 [simp]:
  assumes "state 1 v" and "state 1 w"
  shows "v  w = Matrix.mat 4 1 (λ(i,j). if i = 0 then v $$ (0,0) * w $$ (0,0) else 
                                            if i = 3 then v $$ (1,0) * w $$ (1,0) else
                                              if i = 1 then v $$ (0,0) * w $$ (1,0) else 
                                                v $$ (1,0) * w $$ (0,0))"
proof
  define u where "u = Matrix.mat 4 1 (λ(i,j). if i = 0 then v $$ (0,0) * w $$ (0,0) else 
                                            if i = 3 then v $$ (1,0) * w $$ (1,0) else
                                              if i = 1 then v $$ (0,0) * w $$ (1,0) else 
                                                v $$ (1,0) * w $$ (0,0))"
  then show "dim_row (v  w) = dim_row u"
    using assms tensor_mat_def state_def by(simp add: Tensor.mat_of_cols_list_def)
  also have " = 4" by (simp add: u_def)
  show "dim_col (v  w) = dim_col u"
    using u_def assms tensor_mat_def state_def Tensor.mat_of_cols_list_def by simp
  moreover have " = 1" by(simp add: u_def)
  ultimately show "i j. i < dim_row u  j < dim_col u  (v  w) $$ (i, j) = u $$ (i,j)"
  proof -
    fix i j::nat
    assume a1:"i < dim_row u" and a2:"j < dim_col u"
    then have "(v  w) $$ (i, j) = Matrix.col (v  w) 0 $ i"
      using Matrix.col_def u_def assms by simp
    then have f1:"(v  w) $$ (i, j) = (Matrix.col v 0  Matrix.col w 0) $ i"
      using assms mat_tensor_prod_2_col by simp
    have "(Matrix.col v 0  Matrix.col w 0) $ i = 
 Matrix.vec 4 (λi. if i = 0 then Matrix.col v 0 $ 0 * Matrix.col w 0 $ 0 else 
                                      if i = 3 then Matrix.col v 0 $ 1 * Matrix.col w 0 $ 1 else
                                          if i = 1 then Matrix.col v 0 $ 0 * Matrix.col w 0 $ 1 else 
                                            Matrix.col v 0 $ 1 * Matrix.col w 0 $ 0) $ i"
      using vec_tensor_prod_2_bis assms state.state_to_state_qbit by presburger 
    thus "(v  w) $$ (i, j) = u $$ (i,j)"
      using u_def a1 a2 assms state_def by simp
  qed
qed
                         

lemma mat_tensor_prod_2_bis:
  assumes "state 1 v" and "state 1 w"
  shows "v  w = |Matrix.vec 4 (λi. if i = 0 then v $$ (0,0) * w $$ (0,0) else 
                                          if i = 3 then v $$ (1,0) * w $$ (1,0) else
                                            if i = 1 then v $$ (0,0) * w $$ (1,0) else 
                                              v $$ (1,0) * w $$ (0,0))"
  using assms ket_vec_def mat_tensor_prod_2 by(simp add: mat_eq_iff)

lemma eq_ket_vec:
  fixes u v:: "complex Matrix.vec"
  assumes "u = v"
  shows "|u = |v"
  using assms by simp

lemma mat_tensor_ket_vec:
  assumes "state 1 v" and "state 1 w"
  shows "v  w = |(Matrix.col v 0)  (Matrix.col w 0)"
proof -
  have "v  w = |Matrix.col v 0  |Matrix.col w 0"
    using assms state_def by simp
  also have " = 
|Matrix.vec 4 (λi. if i = 0 then |Matrix.col v 0 $$ (0,0) * |Matrix.col w 0 $$ (0,0) else 
                                          if i = 3 then |Matrix.col v 0 $$ (1,0) * |Matrix.col w 0 $$ (1,0) else
                                            if i = 1 then |Matrix.col v 0 $$ (0,0) * |Matrix.col w 0 $$ (1,0) else 
                                              |Matrix.col v 0 $$ (1,0) * |Matrix.col w 0 $$ (0,0))"
    using assms mat_tensor_prod_2_bis state_col_ket_vec by simp
  also have " = 
|Matrix.vec 4 (λi. if i = 0 then v $$ (0,0) * w $$ (0,0) else 
                                          if i = 3 then v $$ (1,0) * w $$ (1,0) else
                                            if i = 1 then v $$ (0,0) * w $$ (1,0) else 
                                              v $$ (1,0) * w $$ (0,0))"
    using assms mat_tensor_prod_2_bis calculation by auto
  also have " = 
|Matrix.vec 4 (λi. if i = 0 then Matrix.col v 0 $ 0 * Matrix.col w 0 $ 0 else 
                                      if i = 3 then Matrix.col v 0 $ 1 * Matrix.col w 0 $ 1 else
                                          if i = 1 then Matrix.col v 0 $ 0 * Matrix.col w 0 $ 1 else 
                                            Matrix.col v 0 $ 1 * Matrix.col w 0 $ 0)"
    apply(rule eq_ket_vec)
    apply (rule eq_vecI)
    using col_index_of_mat_col assms state_def by auto
  finally show ?thesis
    using vec_tensor_prod_2_bis assms state.state_to_state_qbit by presburger
qed

text ‹The property of being a state (resp. a gate) is preserved by tensor product.›

lemma tensor_state2 [simp]:
  assumes "state 1 u" and "state 1 v"
  shows "state 2 (u  v)"
proof
  show "dim_col (u  v) = 1"
    using assms dim_col_tensor_mat state.is_column by presburger
  show "dim_row (u  v) = 22" 
    using assms dim_row_tensor_mat state.dim_row
    by (metis (mono_tags, lifting) power2_eq_square power_one_right)
  show "Matrix.col (u  v) 0 = 1"
  proof -
    define l where d0:"l = [u $$ (0,0) * v $$ (0,0), u $$ (0,0) * v $$ (1,0), u $$ (1,0) * v $$ (0,0), u $$ (1,0) * v $$ (1,0)]"
    then have f4:"length l = 4" by simp
    also have "u  v = mat_of_cols_list 4 
[[u $$ (0,0) * v $$ (0,0), u $$ (0,0) * v $$ (1,0), u $$ (1,0) * v $$ (0,0), u $$ (1,0) * v $$ (1,0)]]"
      using assms by simp
    then have "Matrix.col (u  v) 0 = vec_of_list [u $$ (0,0) * v $$ (0,0), u $$ (0,0) * v $$ (1,0), 
u $$ (1,0) * v $$ (0,0), u $$ (1,0) * v $$ (1,0)]"
      by (metis One_nat_def Suc_eq_plus1 add_Suc col_mat_of_cols_list list.size(3) list.size(4) 
nth_Cons_0 numeral_2_eq_2 numeral_Bit0 plus_1_eq_Suc vec_of_list_Cons zero_less_one_class.zero_less_one)    
    then have f5:"Matrix.col (u  v) 0 = sqrt(i<4. (cmod (l ! i))2)"
      by (metis d0 f4 One_nat_def cpx_length_of_vec_of_list d0 vec_of_list_Cons)
    also have " = sqrt ((cmod (u $$ (0,0) * v $$ (0,0)))2 + (cmod(u $$ (0,0) * v $$ (1,0)))2 + 
(cmod(u $$ (1,0) * v $$ (0,0)))2 + (cmod(u $$ (1,0) * v $$ (1,0)))2)"
    proof -
      have "(i<4. (cmod (l ! i))2) = (cmod (l ! 0))2 + (cmod (l ! 1))2 + (cmod (l ! 2))2 + 
(cmod (l ! 3))2"
        by (simp add: numeral_eq_Suc)
      also have " = (cmod (u $$ (0,0) * v $$ (0,0)))2 + (cmod(u $$ (0,0) * v $$ (1,0)))2 + 
(cmod(u $$ (1,0) * v $$ (0,0)))2 + (cmod(u $$ (1,0) * v $$ (1,0)))2"
        using d0 by simp
      finally show ?thesis by(simp add: f5)
    qed
    moreover have " = 
sqrt ((cmod (u $$ (0,0)))2 * (cmod (v $$ (0,0)))2 + (cmod(u $$ (0,0)))2 * (cmod (v $$ (1,0)))2 + 
(cmod(u $$ (1,0)))2 * (cmod (v $$ (0,0)))2 + (cmod(u $$ (1,0)))2 * (cmod(v $$ (1,0)))2)"
      by (simp add: norm_mult power_mult_distrib)
    moreover have " = sqrt ((((cmod(u $$ (0,0)))2 + (cmod(u $$ (1,0)))2)) * 
(((cmod(v $$ (0,0)))2 + (cmod(v $$ (1,0)))2)))"
      by (simp add: distrib_left mult.commute)
    ultimately have f6:"Matrix.col (u  v) 02 = (((cmod(u $$ (0,0)))2 + (cmod(u $$ (1,0)))2)) * 
(((cmod(v $$ (0,0)))2 + (cmod(v $$ (1,0)))2))"
      by (simp add: f4)
    also have f7:" = (i< 2. (cmod (u $$ (i,0)))2) * (i<2. (cmod (v $$ (i,0)))2)"
      by (simp add: numeral_2_eq_2)
    also have f8:" = (i< 2.(cmod (Matrix.col u 0 $ i))2) * (i<2.(cmod (Matrix.col v 0 $ i))2)"
      using assms index_col state_def by simp
    finally show ?thesis
    proof -
      have f1:"(i< 2.(cmod (Matrix.col u 0 $ i))2) = 1"
        using assms(1) state_def cpx_vec_length_def by auto
      have f2:"(i< 2.(cmod (Matrix.col v 0 $ i))2) = 1"
        using assms(2) state_def cpx_vec_length_def by auto
      thus ?thesis
        using f1 f8 f5 f6 f7
        by (simp add: sqrt (i<4. (cmod (l ! i))2) = sqrt ((cmod (u $$ (0, 0) * v $$ (0, 0)))2 + 
(cmod (u $$ (0, 0) * v $$ (1, 0)))2 + (cmod (u $$ (1, 0) * v $$ (0, 0)))2 + (cmod (u $$ (1, 0) * v $$ (1, 0)))2))
    qed
  qed
qed

lemma sum_prod:
  fixes f::"nat  complex" and g::"nat  complex"
  shows "(i<a*b. f(i div b) * g(i mod b)) = (i<a. f(i)) * (j<b. g(j))"
proof (induction a)
  case 0
  then show ?case by simp
next
  case (Suc a)
  have "(i<(a+1)*b. f (i div b) * g (i mod b)) = (i<a*b. f (i div b) * g (i mod b)) + 
(i{a*b..<(a+1)*b}. f (i div b) * g (i mod b))"
    apply (auto simp: algebra_simps)
    by (smt (verit) add.commute mult_Suc sum.lessThan_Suc sum.nat_group)
  also have " = (i<a. f(i)) * (j<b. g(j)) + (i{a*b..<(a+1)*b}. f (i div b) * g (i mod b))"
    by (simp add: Suc.IH)
  also have " = (i<a. f(i)) * (j<b. g(j)) + (i{a*b..<(a+1)*b}. f (a) * g(i-a*b))" by simp
  also have " = (i<a. f(i)) * (j<b. g(j)) + f(a) * (i{a*b..<(a+1)*b}. g(i-a*b))" 
    by(simp add: sum_distrib_left)
  also have " = (i<a. f(i)) * (j<b. g(j)) + f(a) * (i{..<b}. g(i))"
    using sum_of_index_diff[of "g" "(a*b)" "b"] by (simp add: algebra_simps)
  ultimately show ?case  by (simp add: semiring_normalization_rules(1))
qed

lemma tensor_state [simp]:
  assumes "state m u" and "state n v"
  shows "state (m + n) (u  v)"
proof
  show c1:"dim_col (u  v) = 1"
    using assms dim_col_tensor_mat state.is_column by presburger
  show c2:"dim_row (u  v) = 2^(m + n)" 
    using assms dim_row_tensor_mat state.dim_row by (metis power_add)
  have "(i<2^(m + n). (cmod (u $$ (i div 2 ^ n, 0) * v $$ (i mod 2 ^ n, 0)))2) = 
(i<2^(m + n). (cmod (u $$ (i div 2 ^ n, 0)))2 * (cmod (v $$ (i mod 2 ^ n, 0)))2)"
    by (simp add: sqr_of_cmod_of_prod)
  also have " = (i<2^m. (cmod (u $$ (i, 0)))2) * (i<2^n. (cmod (v $$ (i, 0)))2)"
  proof-
    have " = (i<2^(m + n).complex_of_real((cmod (u $$ (i div 2^n,0)))2) * complex_of_real((cmod (v $$ (i mod 2^n,0)))2))"
      by simp
    moreover have "(i<2^m. (cmod (u $$ (i, 0)))2) = (i<2^m. complex_of_real ((cmod (u $$ (i,0)))2))" by simp
    moreover have "(i<2^n. (cmod (v $$ (i, 0)))2) = (i<2^n. complex_of_real ((cmod (v $$ (i, 0)))2))" by simp
    ultimately show ?thesis
      using sum_prod[of "λi. (cmod (u $$ (i , 0)))2" "2^n" "λi. (cmod (v $$ (i , 0)))2" "2^m"]
      by (smt (verit) of_real_eq_iff of_real_mult power_add)
  qed
  also have " = 1"
  proof-
    have "(i<2^m. (cmod (u $$ (i, 0)))2) = 1"
      using assms(1) state_def cpx_vec_length_def by auto
    moreover have "(i<2^n. (cmod (v $$ (i, 0)))2) = 1"
      using assms(2) state_def cpx_vec_length_def by auto
    ultimately show ?thesis by simp
  qed
  ultimately show "Matrix.col (u  v) 0 = 1"
    using c1 c2 assms state_def by (auto simp add: cpx_vec_length_def)
qed

lemma dim_row_of_tensor_gate:
  assumes "gate m G1" and "gate n G2"
  shows "dim_row (G1  G2) = 2^(m+n)" 
  using assms dim_row_tensor_mat gate.dim_row by (simp add: power_add)

lemma tensor_gate_sqr_mat:
  assumes "gate m G1" and "gate n G2"
  shows "square_mat (G1  G2)" 
  using assms gate.square_mat by simp

lemma dim_row_of_one_mat_less_pow:
  assumes "gate m G1" and "gate n G2" and "i < dim_row (1m(dim_col G1 * dim_col G2))"
  shows "i < 2^(m+n)" 
  using assms gate_def by (simp add: power_add)

lemma dim_col_of_one_mat_less_pow:
  assumes "gate m G1" and "gate n G2" and "j < dim_col (1m(dim_col G1 * dim_col G2))"
  shows "j < 2^(m+n)"
  using assms gate_def by (simp add: power_add)

lemma index_tensor_gate_unitary1:
  assumes "gate m G1" and "gate n G2" and "i < dim_row (1m(dim_col G1 * dim_col G2))" and 
"j < dim_col (1m(dim_col G1 * dim_col G2))"
  shows "((G1  G2) * (G1  G2)) $$ (i, j) = 1m(dim_col G1 * dim_col G2) $$ (i, j)"
proof-
  have "k. k<2^(m+n)  cnj((G1  G2) $$ (k,i)) = 
cnj(G1 $$ (k div 2^n, i div 2^n)) * cnj(G2 $$ (k mod 2^n, i mod 2^n))"
    using assms(1-3) by (simp add: gate_def power_add)
  moreover have "k. k<2^(m+n)  (G1  G2) $$ (k,j) = 
                                   G1 $$ (k div 2^n, j div 2^n) * G2 $$ (k mod 2^n, j mod 2^n)"
    using assms(1,2,4) by (simp add: gate_def power_add)
  ultimately have "k. k<2^(m+n)  cnj((G1  G2) $$ (k,i)) * ((G1  G2) $$ (k,j)) = 
      cnj(G1 $$ (k div 2^n, i div 2^n)) * G1 $$ (k div 2^n, j div 2^n) * 
      cnj(G2 $$ (k mod 2^n, i mod 2^n)) * G2 $$ (k mod 2^n, j mod 2^n)" by simp
  then have "(k<2^(m+n). cnj((G1  G2) $$ (k,i)) * ((G1  G2) $$ (k,j))) = 
      (k<2^(m+n). cnj(G1 $$ (k div 2^n, i div 2^n)) * G1 $$ (k div 2^n, j div 2^n) * 
                    cnj(G2 $$ (k mod 2^n, i mod 2^n)) * G2 $$ (k mod 2^n, j mod 2^n))" by simp
  also have " = 
             (k<2^m. cnj(G1 $$ (k, i div 2^n)) * G1 $$ (k, j div 2^n)) * 
             (k<2^n. cnj(G2 $$ (k, i mod 2^n)) * G2 $$ (k, j mod 2^n))"
    using sum_prod[of "λx. cnj(G1 $$ (x, i div 2^n)) * G1 $$ (x, j div 2^n)" "2^n" 
"λx. cnj(G2 $$ (x, i mod 2^n)) * G2 $$ (x, j mod 2^n)" "2^m"]
    by (metis (no_types, lifting) power_add semigroup_mult_class.mult.assoc sum.cong)
  also have "((G1  G2) * (G1  G2)) $$ (i, j) = (k<2^(m+n).cnj((G1  G2) $$ (k,i)) * ((G1  G2) $$ (k,j)))"
    using assms index_matrix_prod[of "i" "(G1  G2)" "j" "(G1  G2)"] dagger_def 
dim_row_of_tensor_gate tensor_gate_sqr_mat by simp
  ultimately have "((G1  G2) * (G1  G2)) $$ (i,j) = 
              (k1<2^m. cnj(G1 $$ (k1, i div 2^n)) * G1 $$ (k1, j div 2^n)) * 
              (k2<2^n. cnj(G2 $$ (k2, i mod 2^n)) * G2 $$ (k2, j mod 2^n))" by simp
  moreover have "(k<2^m. cnj(G1 $$ (k, i div 2^n))* G1 $$ (k, j div 2^n)) = (G1 * G1) $$ (i div 2^n, j div 2^n)"
      using assms gate_def dagger_def index_matrix_prod[of "i div 2^n" "G1" "j div 2^n" "G1"]
      by (simp add: less_mult_imp_div_less power_add)
  moreover have " = 1m(2^m) $$ (i div 2^n, j div 2^n)"
    using assms(1,2) gate_def unitary_def by simp
  moreover have "(k<2^n. cnj(G2 $$ (k, i mod 2^n))* G2 $$ (k, j mod 2^n)) = (G2 * G2) $$ (i mod 2^n, j mod 2^n)"
    using assms(1,2) gate_def dagger_def index_matrix_prod[of "i mod 2^n" "G2" "j mod 2^n" "G2"] by simp
  moreover have " = 1m(2^n) $$ (i mod 2^n, j mod 2^n)"
    using assms(1,2) gate_def unitary_def by simp
  ultimately have "((G1  G2) * (G1  G2)) $$ (i,j) = 1m (2^m) $$ (i div 2^n, j div 2^n) * 1m (2^n) $$ (i mod 2^n, j mod 2^n)"
    by simp
  thus ?thesis
    using assms assms(3,4) gate_def index_one_mat_div_mod[of "i" "m" "n" "j"] by(simp add: power_add)
qed

lemma tensor_gate_unitary1 [simp]:
  assumes "gate m G1" and "gate n G2"
  shows "(G1  G2) * (G1  G2) = 1m(dim_col G1 * dim_col G2)"
proof
  show "dim_row ((G1  G2) * (G1  G2)) = dim_row (1m(dim_col G1 * dim_col G2))" by simp
  show "dim_col ((G1  G2) * (G1  G2)) = dim_col (1m(dim_col G1 * dim_col G2))" by simp
  fix i j assume "i < dim_row (1m(dim_col G1 * dim_col G2))" and "j < dim_col (1m(dim_col G1 * dim_col G2))"
  thus "((G1  G2) * (G1  G2)) $$ (i, j) = 1m(dim_col G1 * dim_col G2) $$ (i, j)"
    using assms index_tensor_gate_unitary1 by simp
qed

lemma index_tensor_gate_unitary2 [simp]:
  assumes "gate m G1" and "gate n G2" and "i < dim_row (1m (dim_col G1 * dim_col G2))" and
"j < dim_col (1m (dim_col G1 * dim_col G2))"
  shows "((G1  G2) * ((G1  G2))) $$ (i, j) = 1m(dim_col G1 * dim_col G2) $$ (i, j)"
proof-
  have "k. k<2^(m+n)  (G1  G2) $$ (i,k) = 
G1 $$ (i div 2^n, k div 2^n) * G2 $$ (i mod 2^n, k mod 2^n)"
    using assms(1-3) by (simp add: gate_def power_add)
  moreover have "k. k<2^(m+n)  cnj((G1  G2) $$ (j,k)) = 
cnj(G1 $$ (j div 2^n, k div 2^n)) * cnj(G2 $$ (j mod 2^n, k mod 2^n))"
    using assms(1,2,4) by (simp add: gate_def power_add)
  ultimately have "k. k{..<2^(m+n)}  (G1  G2) $$ (i,k) * cnj((G1  G2) $$ (j,k)) = 
                         G1 $$ (i div 2^n, k div 2^n) * cnj(G1 $$ (j div 2^n, k div 2^n)) * 
                         G2 $$ (i mod 2^n, k mod 2^n) * cnj(G2 $$ (j mod 2^n, k mod 2^n))" by simp
  then have "(k<2^(m+n). (G1  G2) $$ (i,k) * cnj((G1  G2) $$ (j,k))) = 
      (k<2^(m+n). G1 $$ (i div 2^n, k div 2^n) * cnj(G1 $$ (j div 2^n, k div 2^n)) * 
      G2 $$ (i mod 2^n, k mod 2^n) * cnj(G2 $$ (j mod 2^n, k mod 2^n)))" by simp
  also have " = 
             (k<2^m. G1 $$ (i div 2^n, k) * cnj(G1 $$ (j div 2^n, k))) * 
             (k<2^n. G2 $$ (i mod 2^n, k) * cnj(G2 $$ (j mod 2^n, k)))"
    using sum_prod[of "λk. G1 $$ (i div 2^n, k) * cnj(G1 $$ (j div 2^n, k))" "2^n" 
                        "λk. G2 $$ (i mod 2^n, k) * cnj(G2 $$ (j mod 2^n, k))" "2^m"]
    by (metis (no_types, lifting) power_add semigroup_mult_class.mult.assoc sum.cong)
  also have "((G1  G2) * ((G1  G2))) $$ (i, j) = (k<2^(m+n). (G1  G2) $$ (i,k) * cnj((G1  G2) $$ (j,k)))"
    using assms index_matrix_prod[of "i" "(G1  G2)" "j" "(G1  G2)"] dagger_def
dim_row_of_tensor_gate tensor_gate_sqr_mat by simp
  ultimately have "((G1  G2) * ((G1  G2))) $$ (i,j) = 
             (k<2^m. G1 $$ (i div 2^n, k) * cnj(G1 $$ (j div 2^n, k))) * 
             (k<2^n. G2 $$ (i mod 2^n, k) * cnj(G2 $$ (j mod 2^n, k)))" by simp
  moreover have "(k<2^m. G1 $$ (i div 2^n, k) * cnj(G1 $$ (j div 2^n, k))) = (G1 * (G1)) $$ (i div 2^n, j div 2^n)"
    using assms gate_def dagger_def index_matrix_prod[of "i div 2^n" "G1" "j div 2^n" "G1"]
    by (simp add: less_mult_imp_div_less power_add)
  moreover have " = 1m(2^m) $$ (i div 2^n, j div 2^n)"
    using assms(1,2) gate_def unitary_def by simp
  moreover have "(k<2^n. G2 $$ (i mod 2^n, k) * cnj(G2 $$ (j mod 2^n, k))) = (G2 * (G2)) $$ (i mod 2^n, j mod 2^n)"
    using assms(1,2) gate_def dagger_def index_matrix_prod[of "i mod 2^n" "G2" "j mod 2^n" "G2"] by simp
  moreover have " = 1m(2^n) $$ (i mod 2^n, j mod 2^n)"
    using assms(1,2) gate_def unitary_def by simp
  ultimately have "((G1  G2) * ((G1  G2))) $$ (i,j) = 1m(2^m) $$ (i div 2^n, j div 2^n) * 1m(2^n) $$ (i mod 2^n, j mod 2^n)"
    by simp
  thus ?thesis
    using assms gate_def index_one_mat_div_mod[of "i" "m" "n" "j"] by(simp add: power_add)
qed

lemma tensor_gate_unitary2 [simp]:
  assumes "gate m G1" and "gate n G2"
  shows "(G1  G2) * ((G1  G2)) = 1m(dim_col G1 * dim_col G2)"
proof
  show "dim_row ((G1  G2) * ((G1  G2))) = dim_row(1m (dim_col G1 * dim_col G2))"
    using assms gate_def by simp
  show "dim_col ((G1  G2) * ((G1  G2))) = dim_col (1m(dim_col G1 * dim_col G2))"
    using assms gate_def by simp
  fix i j assume "i < dim_row (1m (dim_col G1 * dim_col G2))" and "j < dim_col (1m (dim_col G1 * dim_col G2))"
  thus "((G1  G2) * ((G1  G2))) $$ (i, j) = 1m(dim_col G1 * dim_col G2) $$ (i, j)"
    using assms index_tensor_gate_unitary2 by simp
qed

lemma tensor_gate [simp]:
  assumes "gate m G1" and "gate n G2"
  shows "gate (m + n) (G1  G2)" 
proof
  show "dim_row (G1  G2) = 2^(m+n)"
    using assms dim_row_tensor_mat gate.dim_row by (simp add: power_add)
  show "square_mat (G1  G2)"
    using assms gate.square_mat by simp
  thus "unitary (G1  G2)"
    using assms unitary_def by simp
qed

end