Theory Isabelle_Marries_Dirac.Quantum

(*
Authors: 

  Anthony Bordg, University of Cambridge, apdb3@cam.ac.uk
  Yijun He, University of Cambridge, yh403@cam.ac.uk
  with contributions by Hanna Lachnitt
*)

section ‹Qubits and Quantum Gates›

theory Quantum
imports
  Jordan_Normal_Form.Matrix
  "HOL-Library.Nonpos_Ints"
  Basics
  Binary_Nat
begin


subsection ‹Qubits›

text‹In this theory @{text cpx} stands for @{text complex}.›

definition cpx_vec_length :: "complex vec  real" (_) where
"cpx_vec_length v  sqrt(i<dim_vec v. (cmod (v $ i))2)"

lemma cpx_length_of_vec_of_list [simp]:
  "vec_of_list l = sqrt(i<length l. (cmod (l ! i))2)"
  by (auto simp: cpx_vec_length_def vec_of_list_def vec_of_list_index)
    (metis (no_types, lifting) dim_vec_of_list sum.cong vec_of_list.abs_eq vec_of_list_index)

lemma norm_vec_index_unit_vec_is_0 [simp]:
  assumes "j < n" and "j  i"
  shows "cmod ((unit_vec n i) $ j) = 0"
  using assms by (simp add: unit_vec_def)

lemma norm_vec_index_unit_vec_is_1 [simp]:
  assumes "j < n" and "j = i"
  shows "cmod ((unit_vec n i) $ j) = 1"
proof -
  have f:"(unit_vec n i) $ j = 1"
    using assms by simp
  thus ?thesis
    by (simp add: f cmod_def) 
qed

lemma unit_cpx_vec_length [simp]:
  assumes "i < n"
  shows "unit_vec n i = 1"
proof -
  have "(j<n. (cmod((unit_vec n i) $ j))2) = (j<n. if j = i then 1 else 0)"
    using norm_vec_index_unit_vec_is_0 norm_vec_index_unit_vec_is_1
    by (smt (verit) lessThan_iff one_power2 sum.cong zero_power2) 
  also have " = 1"
    using assms by simp
  finally have "sqrt (j<n. (cmod((unit_vec n i) $ j))2) = 1" 
    by simp
  thus ?thesis
    using cpx_vec_length_def by simp
qed

lemma smult_vec_length [simp]:
  assumes "x  0"
  shows "complex_of_real(x) v v = x * v"
proof-
  have "(λi::nat.(cmod (complex_of_real x * v $ i))2) = (λi::nat. (cmod (v $ i))2 * x2)" 
    by (auto simp: norm_mult power_mult_distrib)
  then have "(i<dim_vec v. (cmod (complex_of_real x * v $ i))2) = 
             (i<dim_vec v. (cmod (v $ i))2 * x2)" by meson
  moreover have "(i<dim_vec v. (cmod (v $ i))2 * x2) = x2 * (i<dim_vec v. (cmod (v $ i))2)"
    by (metis (no_types) mult.commute sum_distrib_right)
  moreover have "sqrt(x2 * (i<dim_vec v. (cmod (v $ i))2)) = 
                 sqrt(x2) * sqrt (i<dim_vec v. (cmod (v $ i))2)" 
    using real_sqrt_mult by blast
  ultimately show ?thesis
    by(simp add: cpx_vec_length_def assms)
qed

locale state =
  fixes n:: nat and v:: "complex mat"
  assumes is_column [simp]: "dim_col v = 1"
    and dim_row [simp]: "dim_row v = 2^n"
    and is_normal [simp]: "col v 0 = 1"

text‹ 
Below the natural number n codes for the dimension of the complex vector space whose elements of norm
1 we call states. 
›

lemma unit_vec_of_right_length_is_state [simp]:
  assumes "i < 2^n"
  shows "unit_vec (2^n) i  {v| n v::complex vec. dim_vec v = 2^n  v = 1}"
proof-
  have "dim_vec (unit_vec (2^n) i) = 2^n" 
    by simp
  moreover have "unit_vec (2^n) i = 1"
    using assms by simp
  ultimately show ?thesis 
    by simp
qed

definition state_qbit :: "nat  complex vec set" where
"state_qbit n  {v| v:: complex vec. dim_vec v = 2^n  v = 1}"

lemma (in state) state_to_state_qbit [simp]:
  shows "col v 0  state_qbit n"
  using state_def state_qbit_def by simp

subsection "The Hermitian Conjugation"

text ‹The Hermitian conjugate of a complex matrix is the complex conjugate of its transpose. ›

definition dagger :: "complex mat  complex mat" (‹_) where
  "M  mat (dim_col M) (dim_row M) (λ(i,j). cnj(M $$ (j,i)))"

text ‹We introduce the type of complex square matrices.›

typedef cpx_sqr_mat = "{M | M::complex mat. square_mat M}"
proof-
  have "square_mat (1m n)" for n
    using one_mat_def by simp
  thus ?thesis by blast
qed

definition cpx_sqr_mat_to_cpx_mat :: "cpx_sqr_mat => complex mat" where
"cpx_sqr_mat_to_cpx_mat M  Rep_cpx_sqr_mat M"

text ‹
We introduce a coercion from the type of complex square matrices to the type of complex 
matrices.
›

declare [[coercion cpx_sqr_mat_to_cpx_mat]]

lemma dim_row_of_dagger [simp]:
  "dim_row (M) = dim_col M"
  using dagger_def by simp

lemma dim_col_of_dagger [simp]:
  "dim_col (M) = dim_row M"
  using dagger_def by simp

lemma col_of_dagger [simp]:
  assumes "j < dim_row M"
  shows "col (M) j = vec (dim_col M) (λi. cnj (M $$ (j,i)))"
  using assms col_def dagger_def by simp

lemma row_of_dagger [simp]:
  assumes "i < dim_col M"
  shows "row (M) i = vec (dim_row M) (λj. cnj (M $$ (j,i)))"
  using assms row_def dagger_def by simp

lemma dagger_of_dagger_is_id:
  fixes M :: "complex Matrix.mat"
  shows "(M) = M"
proof
  show "dim_row ((M)) = dim_row M" by simp
  show "dim_col ((M)) = dim_col M" by simp
  fix i j assume a0:"i < dim_row M" and a1:"j < dim_col M"
  then show "(M) $$ (i,j) = M $$ (i,j)"
  proof-
    show ?thesis
      using dagger_def a0 a1 by auto
  qed
qed

lemma dagger_of_sqr_is_sqr [simp]:
  "square_mat ((M::cpx_sqr_mat))"
proof-
  have "square_mat M"
    using cpx_sqr_mat_to_cpx_mat_def Rep_cpx_sqr_mat by simp
  then have "dim_row M = dim_col M" by simp
  then have "dim_col (M) = dim_row (M)" by simp
  thus "square_mat (M)" by simp
qed

lemma dagger_of_id_is_id [simp]:
  "(1m n) = 1m n"
  using dagger_def one_mat_def by auto

subsection "Unitary Matrices and Quantum Gates"

definition unitary :: "complex mat  bool" where
"unitary M  (M) * M = 1m (dim_col M)  M * (M) = 1m (dim_row M)"

lemma id_is_unitary [simp]:
  "unitary (1m n)"
  by (simp add: unitary_def)

locale gate =
  fixes n:: nat and A:: "complex mat"
  assumes dim_row [simp]: "dim_row A = 2^n"
    and square_mat [simp]: "square_mat A"
    and unitary [simp]: "unitary A"

text ‹
We prove that a quantum gate is invertible and its inverse is given by its Hermitian conjugate.
›

lemma mat_unitary_mat [intro]:
  assumes "unitary M"
  shows "inverts_mat M (M)"
  using assms by (simp add: unitary_def inverts_mat_def)

lemma unitary_mat_mat [intro]:
  assumes "unitary M"
  shows "inverts_mat (M) M"
  using assms by (simp add: unitary_def inverts_mat_def)

lemma (in gate) gate_is_inv:
  "invertible_mat A"
  using square_mat unitary invertible_mat_def by blast

subsection "Relations Between Complex Conjugation, Hermitian Conjugation, Transposition and Unitarity"

notation transpose_mat ((_t))

lemma col_tranpose [simp]:
  assumes "dim_row M = n" and "i < n"
  shows "col (Mt) i = row M i"
proof
  show "dim_vec (col (Mt) i) = dim_vec (row M i)"
    by (simp add: row_def col_def transpose_mat_def)
next
  show "j. j < dim_vec (row M i)  col Mt i $ j = row M i $ j"
    using assms by (simp add: transpose_mat_def)
qed

lemma row_transpose [simp]:
  assumes "dim_col M = n" and "i < n"
  shows "row (Mt) i = col M i"
  using assms by simp

definition cpx_mat_cnj :: "complex mat  complex mat" ((_)) where
"cpx_mat_cnj M  mat (dim_row M) (dim_col M) (λ(i,j). cnj (M $$ (i,j)))"

lemma cpx_mat_cnj_id [simp]:
  "(1m n) = 1m n" 
  by (auto simp: cpx_mat_cnj_def)

lemma cpx_mat_cnj_cnj [simp]:
  "(M) = M"
  by (auto simp: cpx_mat_cnj_def)

lemma dim_row_of_cjn_prod [simp]: 
  "dim_row ((M) * (N)) = dim_row M"
  by (simp add: cpx_mat_cnj_def)

lemma dim_col_of_cjn_prod [simp]: 
  "dim_col ((M) * (N)) = dim_col N"
  by (simp add: cpx_mat_cnj_def)

lemma cpx_mat_cnj_prod:
  assumes "dim_col M = dim_row N"
  shows "(M * N) = (M) * (N)"
proof
  show "dim_row (M * N) = dim_row ((M) * (N))" 
    by (simp add: cpx_mat_cnj_def)
next
  show "dim_col ((M * N)) = dim_col ((M) * (N))" 
    by (simp add: cpx_mat_cnj_def)
next 
  fix i j::nat
  assume a1:"i < dim_row ((M) * (N))" and a2:"j < dim_col ((M) * (N))"
  then have "(M * N) $$ (i,j) = cnj (k<(dim_row N). M $$ (i,k) * N $$ (k,j))"
    using assms cpx_mat_cnj_def index_mat times_mat_def scalar_prod_def row_def col_def 
dim_row_of_cjn_prod dim_col_of_cjn_prod
    by simp
  also have " = (k<(dim_row N). cnj(M $$ (i,k)) * cnj(N $$ (k,j)))" by simp
  also have "((M) * (N)) $$ (i,j) = 
    (k<(dim_row N). cnj(M $$ (i,k)) * cnj(N $$ (k,j)))"
    using assms a1 a2 cpx_mat_cnj_def index_mat times_mat_def scalar_prod_def row_def col_def
    by (smt case_prod_conv dim_col dim_col_mat(1) dim_row_mat(1) index_vec lessThan_atLeast0 
        lessThan_iff sum.cong)
  finally show "(M * N) $$ (i, j) = ((M) * (N)) $$ (i, j)" by simp
qed

lemma transpose_of_prod:
  fixes M N::"complex Matrix.mat"
  assumes "dim_col M = dim_row N"
  shows "(M * N)t = Nt * (Mt)"
proof
  fix i j::nat
  assume a0: "i < dim_row (Nt * (Mt))" and a1: "j < dim_col (Nt * (Mt))"  
  then have "(M * N)t $$ (i,j) = (M * N) $$ (j,i)" by auto
  also have "... = (k<dim_row Mt.  M $$ (j,k) * N $$ (k,i))"
    using assms a0 a1 by auto
  also have "... = (k<dim_row Mt. N $$ (k,i) * M $$ (j,k))"
   by (simp add: semiring_normalization_rules(7))
  also have "... = (k<dim_row Mt. ((Nt) $$ (i,k)) * (Mt) $$ (k,j))" 
    using assms a0 a1 by auto
  finally show "((M * N)t) $$ (i,j) = (Nt * (Mt)) $$ (i,j)" 
    using assms a0 a1 by auto
next
  show "dim_row ((M * N)t) = dim_row (Nt * (Mt))" by auto
next
  show "dim_col ((M * N)t) = dim_col (Nt * (Mt))" by auto
qed

lemma transpose_cnj_is_dagger [simp]:
  "(Mt) = (M)"
proof
  show f1:"dim_row ((Mt)) = dim_row (M)"
    by (simp add: cpx_mat_cnj_def transpose_mat_def dagger_def)
next
  show f2:"dim_col ((Mt)) = dim_col (M)" 
    by (simp add: cpx_mat_cnj_def transpose_mat_def dagger_def)
next
  fix i j::nat
  assume "i < dim_row M" and "j < dim_col M"
  then show "Mt $$ (i, j) = M $$ (i, j)" 
    by (simp add: cpx_mat_cnj_def transpose_mat_def dagger_def)
qed

lemma cnj_transpose_is_dagger [simp]:
  "(M)t = (M)"
proof
  show "dim_row ((M)t) = dim_row (M)" 
    by (simp add: transpose_mat_def cpx_mat_cnj_def dagger_def)
next
  show "dim_col ((M)t) = dim_col (M)" 
    by (simp add: transpose_mat_def cpx_mat_cnj_def dagger_def)
next
  fix i j::nat
  assume "i < dim_row M" and "j < dim_col M"
  then show "Mt $$ (i, j) = M $$ (i, j)" 
    by (simp add: transpose_mat_def cpx_mat_cnj_def dagger_def)
qed

lemma dagger_of_transpose_is_cnj [simp]:
  "(Mt) = (M)"
  by (metis transpose_transpose transpose_cnj_is_dagger)

lemma dagger_of_prod:
  fixes M N::"complex Matrix.mat"
  assumes "dim_col M = dim_row N"
  shows "(M * N) = N * (M)"
proof-
  have "(M * N) = ((M * N))t" by auto
  also have "... = ((M) * (N))t" using assms cpx_mat_cnj_prod by auto
  also have "... = (N)t * ((M)t)" using assms transpose_of_prod 
    by (metis cnj_transpose_is_dagger dim_col_of_dagger dim_row_of_dagger index_transpose_mat(2) index_transpose_mat(3))
  finally show "(M * N) = N * (M)" by auto
qed

text ‹The product of two quantum gates is a quantum gate.›

lemma prod_of_gate_is_gate: 
  assumes "gate n G1" and "gate n G2"
  shows "gate n (G1 * G2)"
proof
  show "dim_row (G1 * G2) = 2^n" using assms by (simp add: gate_def)
next
  show "square_mat (G1 * G2)" 
    using assms gate.dim_row gate.square_mat by simp
next
  show "unitary (G1 * G2)" 
  proof-
    have "((G1 * G2)) * (G1 * G2) = 1m (dim_col (G1 * G2))" 
    proof-
      have f0: "G1  carrier_mat (2^n) (2^n)  G2  carrier_mat (2^n) (2^n)
               G1  carrier_mat (2^n) (2^n)  G2  carrier_mat (2^n) (2^n)
               G1 * G2  carrier_mat (2^n) (2^n)" 
        using assms gate.dim_row gate.square_mat by auto
      have "((G1 * G2)) * (G1 * G2) = ((G2) * (G1)) * (G1 * G2)" 
        using assms dagger_of_prod gate.dim_row gate.square_mat by simp
      also have "... = (G2) * ((G1) * (G1 * G2))" 
        using assms f0 by auto
      also have "... = (G2) * (((G1) * G1) * G2)" 
        using assms f0 f0 by auto
      also have "... = (G2) * ((1m (dim_col G1)) * G2)" 
        using gate.unitary[of n G1] assms unitary_def[of G1] by simp
      also have "... = (G2) * ((1m (dim_col G2)) * G2)" 
        using assms f0 by (metis carrier_matD(2))
      also have "... = (G2) * G2" 
        using f0 by (metis carrier_matD(2) left_mult_one_mat)
      finally show "((G1 * G2)) * (G1 * G2) = 1m (dim_col (G1 * G2))" 
        using assms gate.unitary unitary_def by simp
    qed
    moreover have "(G1 * G2) * ((G1 * G2)) = 1m (dim_row (G1 * G2))"
      using assms calculation
      by (smt (verit) carrier_matI dim_col_of_dagger dim_row_of_dagger gate.dim_row gate.square_mat index_mult_mat(2) index_mult_mat(3) 
          mat_mult_left_right_inverse square_mat.elims(2))
    ultimately show ?thesis using unitary_def by simp
  qed
qed

lemma left_inv_of_unitary_transpose [simp]:
  assumes "unitary U"
  shows "(Ut) * (Ut) =  1m(dim_row U)"
proof -
  have "dim_col U = dim_row ((Ut))" by simp
  then have "(U * ((Ut))) = (U) * (Ut)"
    using cpx_mat_cnj_prod cpx_mat_cnj_cnj by presburger
  also have " = (Ut) * (Ut)" by simp
  finally show ?thesis 
    using assms by (metis transpose_cnj_is_dagger cpx_mat_cnj_id unitary_def)
qed

lemma right_inv_of_unitary_transpose [simp]:
  assumes "unitary U"
  shows "Ut * ((Ut)) = 1m(dim_col U)"
proof -
  have "dim_col ((Ut)) = dim_row U" by simp
  then have "Ut * ((Ut)) = (((Ut) * U))"
    using cpx_mat_cnj_cnj cpx_mat_cnj_prod dagger_of_transpose_is_cnj by presburger
  also have " = (U * U)" by simp
  finally show ?thesis
    using assms by (metis cpx_mat_cnj_id unitary_def)
qed

lemma transpose_of_unitary_is_unitary [simp]:
  assumes "unitary U"
  shows "unitary (Ut)" 
  using unitary_def assms left_inv_of_unitary_transpose right_inv_of_unitary_transpose by simp


subsection "The Inner Product"

text ‹We introduce a coercion between complex vectors and (column) complex matrices.›

definition ket_vec :: "complex vec  complex mat" (|_) where
"|v  mat (dim_vec v) 1 (λ(i,j). v $ i)"

lemma ket_vec_index [simp]:
  assumes "i < dim_vec v"
  shows "|v $$ (i,0) = v $ i"
  using assms ket_vec_def by simp

lemma ket_vec_col [simp]:
  "col |v 0 = v"
  by (auto simp: col_def ket_vec_def)

lemma smult_ket_vec [simp]:
  "|x v v = x m |v"
  by (auto simp: ket_vec_def)

lemma smult_vec_length_bis [simp]:
  assumes "x  0"
  shows "col (complex_of_real(x) m |v) 0 = x * v"
  using assms smult_ket_vec smult_vec_length ket_vec_col by metis

declare [[coercion ket_vec]]

definition row_vec :: "complex vec  complex mat" where
"row_vec v  mat 1 (dim_vec v) (λ(i,j). v $ j)" 

definition bra_vec :: "complex vec  complex mat" where
"bra_vec v  (row_vec v)"

lemma row_bra_vec [simp]:
  "row (bra_vec v) 0 = vec (dim_vec v) (λi. cnj(v $ i))"
  by (auto simp: row_def bra_vec_def cpx_mat_cnj_def row_vec_def)

text ‹We introduce a definition called @{term "bra"} to see a vector as a column matrix.›

definition bra :: "complex mat  complex mat" (_|) where
"v|  mat 1 (dim_row v) (λ(i,j). cnj(v $$ (j,i)))"

text ‹The relation between @{term "bra"}, @{term "bra_vec"} and @{term "ket_vec"} is given as follows.›

lemma bra_bra_vec [simp]:
  "bra (ket_vec v) = bra_vec v"
  by (auto simp: bra_def ket_vec_def bra_vec_def cpx_mat_cnj_def row_vec_def)

lemma row_bra [simp]:
  fixes v::"complex vec"
  shows "row v| 0 = vec (dim_vec v) (λi. cnj (v $ i))" by simp

text ‹We introduce the inner product of two complex vectors in @{text "ℂn"}.›

definition inner_prod :: "complex vec  complex vec  complex" (_|_) where
"inner_prod u v   i  {0..< dim_vec v}. cnj(u $ i) * (v $ i)"

lemma inner_prod_with_row_bra_vec [simp]:
  assumes "dim_vec u = dim_vec v"
  shows "u|v = row (bra_vec u) 0  v"
  using assms inner_prod_def scalar_prod_def row_bra_vec index_vec
  by (smt (verit) lessThan_atLeast0 lessThan_iff sum.cong)

lemma inner_prod_with_row_bra_vec_col_ket_vec [simp]:
  assumes "dim_vec u = dim_vec v"
  shows "u|v = (row u| 0)  (col |v 0)"
  using assms by (simp add: inner_prod_def scalar_prod_def)

lemma inner_prod_with_times_mat [simp]:
  assumes "dim_vec u = dim_vec v"
  shows "u|v = (u| * |v) $$ (0,0)"
  using assms inner_prod_with_row_bra_vec_col_ket_vec 
  by (simp add: inner_prod_def times_mat_def ket_vec_def bra_def)

lemma orthogonal_unit_vec [simp]:
  assumes "i < n" and "j < n" and "i  j"
  shows "unit_vec n i|unit_vec n j = 0"
proof-
  have "unit_vec n i|unit_vec n j = unit_vec n i  unit_vec n j"
    using assms unit_vec_def inner_prod_def scalar_prod_def
    by (smt (verit) complex_cnj_zero index_unit_vec(3) index_vec inner_prod_with_row_bra_vec row_bra_vec 
        scalar_prod_right_unit)
  thus ?thesis
    using assms scalar_prod_def unit_vec_def by simp 
qed

text ‹We prove that our inner product is linear in its second argument.›

lemma vec_index_is_linear [simp]:
  assumes "dim_vec u = dim_vec v" and "j < dim_vec u"
  shows "(k v u + l v v) $ j = k * (u $ j) + l * (v $ j)"
  using assms vec_index_def smult_vec_def plus_vec_def by simp

lemma inner_prod_is_linear [simp]:
  fixes u::"complex vec" and v::"nat  complex vec" and l::"nat  complex"
  assumes "i{0, 1}. dim_vec u = dim_vec (v i)"
  shows "u|l 0 v v 0 + l 1 v v 1 = (i1. l i * u|v i)"
proof -
  have f1:"dim_vec (l 0 v v 0 + l 1 v v 1) = dim_vec u"
    using assms by simp
  then have "u|l 0 v v 0 + l 1 v v 1 = (i{0 ..< dim_vec u}. cnj (u $ i) * ((l 0 v v 0 + l 1 v v 1) $ i))"
    by (simp add: inner_prod_def)
  also have " = (i{0 ..< dim_vec u}. cnj (u $ i) * (l 0 * v 0 $ i + l 1 * v 1 $ i))"
    using assms by simp
  also have " = l 0 * (i{0 ..< dim_vec u}. cnj(u $ i) * (v 0 $ i)) + l 1 * (i{0 ..< dim_vec u}. cnj(u $ i) * (v 1 $ i))"
    by (auto simp: algebra_simps)
      (simp add: sum.distrib sum_distrib_left)
  also have " = l 0 * u|v 0 + l 1 * u|v 1"
    using assms inner_prod_def by auto
  finally show ?thesis by simp
qed

lemma inner_prod_cnj:
  assumes "dim_vec u = dim_vec v"
  shows "v|u = cnj (u|v)"
  by (simp add: assms inner_prod_def algebra_simps)

lemma inner_prod_with_itself_Im [simp]:
  "Im (u|u) = 0"
  using inner_prod_cnj by (metis Reals_cnj_iff complex_is_Real_iff)

lemma inner_prod_with_itself_real [simp]:
  "u|u  "
  using inner_prod_with_itself_Im by (simp add: complex_is_Real_iff)

lemma inner_prod_with_itself_eq0 [simp]:
  assumes "u = 0v (dim_vec u)"
  shows "u|u = 0"
  using assms inner_prod_def zero_vec_def
  by (smt (verit) atLeastLessThan_iff complex_cnj_zero index_zero_vec(1) mult_zero_left sum.neutral)

lemma inner_prod_with_itself_Re:
  "Re (u|u)  0"
proof -
  have "Re (u|u) = (i<dim_vec u. Re (cnj(u $ i) * (u $ i)))"
    by (simp add: inner_prod_def lessThan_atLeast0)
  moreover have " = (i<dim_vec u. (Re (u $ i))2 + (Im (u $ i))2)"
    using complex_mult_cnj
    by (metis (no_types, lifting) Re_complex_of_real semiring_normalization_rules(7))
  ultimately show "Re (u|u)  0" by (simp add: sum_nonneg)
qed

lemma inner_prod_with_itself_nonneg_reals:
  fixes u::"complex vec"
  shows "u|u  nonneg_Reals"
  using inner_prod_with_itself_real inner_prod_with_itself_Re complex_nonneg_Reals_iff 
inner_prod_with_itself_Im by auto

lemma inner_prod_with_itself_Re_non0:
  assumes "u  0v (dim_vec u)"
  shows "Re (u|u) > 0"
proof -
  obtain i where a1:"i < dim_vec u" and "u $ i  0"
    using assms zero_vec_def by (metis dim_vec eq_vecI index_zero_vec(1))
  then have f1:"Re (cnj (u $ i) * (u $ i)) > 0"
    by (metis Re_complex_of_real complex_mult_cnj complex_neq_0 mult.commute)
  moreover have f2:"Re (u|u) = (i<dim_vec u. Re (cnj(u $ i) * (u $ i)))"
    using inner_prod_def by (simp add: lessThan_atLeast0)
  moreover have f3:"i<dim_vec u. Re (cnj(u $ i) * (u $ i))  0"
    using complex_mult_cnj by simp
  ultimately show ?thesis
    using a1 inner_prod_def lessThan_iff
    by (metis (no_types, lifting) finite_lessThan sum_pos2)
qed

lemma inner_prod_with_itself_nonneg_reals_non0:
  assumes "u  0v (dim_vec u)"
  shows "u|u  0"
  using assms inner_prod_with_itself_Re_non0 by fastforce

lemma cpx_vec_length_inner_prod [simp]:
  "v2 = v|v"
proof -
  have "v2 = (i<dim_vec v. (cmod (v $ i))2)"
    using cpx_vec_length_def complex_of_real_def
    by (metis (no_types, lifting) real_sqrt_power real_sqrt_unique sum_nonneg zero_le_power2)
  also have " = (i<dim_vec v. cnj (v $ i) * (v $ i))"
    using complex_norm_square mult.commute by (smt (verit) of_real_sum sum.cong)
  finally show ?thesis
    using inner_prod_def by (simp add: lessThan_atLeast0)
qed

lemma inner_prod_csqrt [simp]:
  "csqrt v|v = v"
  using inner_prod_with_itself_Re inner_prod_with_itself_Im csqrt_of_real_nonneg cpx_vec_length_def
  by (metis (no_types, lifting) Re_complex_of_real cpx_vec_length_inner_prod real_sqrt_ge_0_iff 
      real_sqrt_unique sum_nonneg zero_le_power2)


subsection "Unitary Matrices and Length-Preservation"

subsubsection "Unitary Matrices are Length-Preserving"

text ‹The bra-vector @{text "⟨A * v|"} is given by @{text "⟨v| * A"}

lemma dagger_of_ket_is_bra:
  fixes v:: "complex vec"
  shows "( |v ) = v|"
  by (simp add: bra_def dagger_def ket_vec_def)

lemma bra_mat_on_vec:
  fixes v::"complex vec" and A::"complex mat"
  assumes "dim_col A = dim_vec v"
  shows "A * v| = v| * (A)"
proof
  show "dim_row A * v| = dim_row (v| * (A))"
    by (simp add: bra_def times_mat_def)
next
  show "dim_col A * v| = dim_col (v| * (A))"
    by (simp add: bra_def times_mat_def)
next
  fix i j::nat
  assume a1:"i < dim_row (v| * (A))" and a2:"j < dim_col (v| * (A))" 
  then have "cnj((A * v) $$ (j,0)) = cnj (row A j  v)"
    using bra_def times_mat_def ket_vec_col ket_vec_def by simp
  also have f7:"= (i{0 ..< dim_vec v}. cnj(v $ i) * cnj(A $$ (j,i)))"
    using row_def scalar_prod_def cnj_sum complex_cnj_mult mult.commute
    by (smt assms index_vec lessThan_atLeast0 lessThan_iff sum.cong)
  moreover have f8:"(row v| 0)  (col (A) j) = 
    vec (dim_vec v) (λi. cnj (v $ i))  vec (dim_col A) (λi. cnj (A $$ (j,i)))"
    using a2 by simp 
  ultimately have "cnj((A * v) $$ (j,0)) = (row v| 0)  (col (A) j)"
    using assms scalar_prod_def
    by (smt (verit) dim_vec index_vec lessThan_atLeast0 lessThan_iff sum.cong)
  then have "A * v| $$ (0,j) = (v| * (A)) $$ (0,j)"
    using bra_def times_mat_def a2 by simp
  thus "A * |v| $$ (i, j) = (v| * (A)) $$ (i, j)" 
    using a1 by (simp add: times_mat_def bra_def)
qed

lemma mat_on_ket:
  fixes v:: "complex vec" and A:: "complex mat"
  assumes "dim_col A = dim_vec v"
  shows "A * |v = |col (A * v) 0"
  using assms ket_vec_def by auto

lemma dagger_of_mat_on_ket:
  fixes v:: "complex vec" and A :: "complex mat"
  assumes "dim_col A = dim_vec v"
  shows "(A * |v ) = v| * (A)"
  using assms by (metis bra_mat_on_vec dagger_of_ket_is_bra mat_on_ket)

definition col_fst :: "'a mat  'a vec" where 
  "col_fst A = vec (dim_row A) (λ i. A $$ (i,0))"

lemma col_fst_is_col [simp]:
  "col_fst M = col M 0"
  by (simp add: col_def col_fst_def)

text ‹
We need to declare @{term "col_fst"} as a coercion from matrices to vectors in order to see a column 
matrix as a vector. 
›

declare 
  [[coercion_delete ket_vec]]
  [[coercion col_fst]]

lemma unit_vec_to_col:
  assumes "dim_col A = n" and "i < n"
  shows "col A i = A * |unit_vec n i"
proof
  show "dim_vec (col A i) = dim_vec (A * |unit_vec n i)"
    using col_def times_mat_def by simp
next
  fix j::nat
  assume "j < dim_vec (col_fst (A * |unit_vec n i))"
  then show "col A i $ j = (A * |unit_vec n i) $ j"
    using assms times_mat_def ket_vec_def
    by (smt (verit) col_fst_is_col dim_col dim_col_mat(1) index_col index_mult_mat(1) index_mult_mat(2) 
index_row(1) ket_vec_col less_numeral_extra(1) scalar_prod_right_unit)
qed

lemma mult_ket_vec_is_ket_vec_of_mult:
  fixes A::"complex mat" and v::"complex vec"
  assumes "dim_col A = dim_vec v"
  shows "|A * |v  = A * |v"
  using assms ket_vec_def
  by (metis One_nat_def col_fst_is_col dim_col dim_col_mat(1) index_mult_mat(3) ket_vec_col less_Suc0 
mat_col_eqI)

lemma unitary_is_sq_length_preserving [simp]:
  assumes "unitary U" and "dim_vec v = dim_col U"
  shows "U * |v2 = v2"
proof -
  have "U * |v|U * |v  = (|v| * (U) * |U * |v) $$ (0,0)"
    using assms(2) bra_mat_on_vec
    by (metis inner_prod_with_times_mat mult_ket_vec_is_ket_vec_of_mult)
  then have "U * |v|U * |v  = (|v| * (U) * (U * |v)) $$ (0,0)"
    using assms(2) mult_ket_vec_is_ket_vec_of_mult by simp
  moreover have f1:"dim_col |v| = dim_vec v"
    using ket_vec_def bra_def by simp
  moreover have "dim_row (U) = dim_vec v"
    using assms(2) by simp
  ultimately have "U * |v|U * |v  = (|v| * ((U) * U) * |v) $$ (0,0)"
    using assoc_mult_mat
    by(smt (verit, ccfv_threshold) carrier_mat_triv dim_row_mat(1) dagger_def ket_vec_def mat_carrier times_mat_def)
  then have "U * |v|U * |v  = (|v| * |v) $$ (0,0)"
    using assms f1 unitary_def by simp
  thus ?thesis
    using cpx_vec_length_inner_prod by(metis Re_complex_of_real inner_prod_with_times_mat)
qed

lemma col_ket_vec [simp]:
  assumes "dim_col M = 1"
  shows "|col M 0 = M"
  using eq_matI assms ket_vec_def by auto

lemma state_col_ket_vec:
  assumes "state 1 v"
  shows "state 1 |col v 0"
  using assms by (simp add: state_def)

lemma col_ket_vec_index [simp]:
  assumes "i < dim_row v"
  shows "|col v 0 $$ (i,0) = v $$ (i,0)"
  using assms ket_vec_def by (simp add: col_def)

lemma col_index_of_mat_col [simp]:
  assumes "dim_col v = 1" and "i < dim_row v"
  shows "col v 0 $ i = v $$ (i,0)"
  using assms by simp

lemma unitary_is_sq_length_preserving_bis [simp]:
  assumes "unitary U" and "dim_row v = dim_col U" and "dim_col v = 1"
  shows "col (U * v) 02 = col v 02"
proof -
  have "dim_vec (col v 0) = dim_col U"
    using assms(2) by simp
  then have "col_fst (U * |col v 0)2 = col v 02"
    using unitary_is_sq_length_preserving[of "U" "col v 0"] assms(1) by simp
  thus ?thesis
    using assms(3) by simp
qed

text ‹ 
A unitary matrix is length-preserving, i.e. it acts on a vector to produce another vector of the 
same length. 
›

lemma unitary_is_length_preserving_bis [simp]:
  fixes U::"complex mat" and v::"complex mat"
  assumes "unitary U" and "dim_row v = dim_col U" and "dim_col v = 1"
  shows "col (U * v) 0 = col v 0"
  using assms unitary_is_sq_length_preserving_bis
  by (metis cpx_vec_length_inner_prod inner_prod_csqrt of_real_hom.injectivity)

lemma unitary_is_length_preserving [simp]:
  fixes U:: "complex mat" and v:: "complex vec"
  assumes "unitary U" and "dim_vec v = dim_col U"
  shows "U * |v = v"
  using assms unitary_is_sq_length_preserving
  by (metis cpx_vec_length_inner_prod inner_prod_csqrt of_real_hom.injectivity)


subsubsection "Length-Preserving Matrices are Unitary"

lemma inverts_mat_sym:
  fixes A B:: "complex mat"
  assumes "inverts_mat A B" and "dim_row B = dim_col A" and "square_mat B"
  shows "inverts_mat B A"
proof-
  define n where d0:"n = dim_row B"
  have "A * B = 1m (dim_row A)" using assms(1) inverts_mat_def by auto
  moreover have "dim_col B = dim_col (A * B)" using times_mat_def by simp
  ultimately have "dim_col B = dim_row A" by simp
  then have c0:"A  carrier_mat n n" using assms(2,3) d0 by auto
  have c1:"B  carrier_mat n n" using assms(3) d0 by auto
  have f0:"A * B = 1m n" using inverts_mat_def c0 c1 assms(1) by auto
  have f1:"det B  0"
  proof
    assume "det B = 0"
    then have "v. v  carrier_vec n  v  0v n  B *v v = 0v n"
      using det_0_iff_vec_prod_zero assms(3) c1 by blast
    then obtain v where d1:"v  carrier_vec n  v  0v n  B *v v = 0v n" by auto
    then have d2:"dim_vec v = n" by simp
    have "B * |v = |0v n"
    proof
      show "dim_row (B * |v) = dim_row |0v n" using ket_vec_def d0 by simp
    next
      show "dim_col (B * |v) = dim_col |0v n" using ket_vec_def d0 by simp
    next
      fix i j assume "i < dim_row |0v n" and "j < dim_col |0v n"
      then have f2:"i < n  j = 0" using ket_vec_def by simp
      moreover have "vec (dim_row B) (($) v) = v" using d0 d1 by auto
      moreover have "(B *v v) $ i = (ia = 0..<dim_row B. row B i $ ia * v $ ia)"
        using d0 d2 f2 by (auto simp add: scalar_prod_def)
      ultimately show "(B * |v) $$ (i, j) = |0v n $$ (i, j)"
        using ket_vec_def d0 d1 times_mat_def mult_mat_vec_def by (auto simp add: scalar_prod_def)
    qed
    moreover have "|v  carrier_mat n 1" using d2 ket_vec_def by simp
    ultimately have "(A * B) * |v = A * |0v n" using c0 c1 by simp
    then have f3:"|v = A * |0v n" using d2 f0 ket_vec_def by auto
    have "v = 0v n"
    proof
      show "dim_vec v = dim_vec (0v n)" using d2 by simp
    next
      fix i assume f4:"i < dim_vec (0v n)"
      then have "|v $$ (i,0) = v $ i" using d2 ket_vec_def by simp
      moreover have "(A * |0v n) $$ (i, 0) = 0"
        using ket_vec_def times_mat_def scalar_prod_def f4 c0 by auto
      ultimately show "v $ i = 0v n $ i" using f3 f4 by simp
    qed
    then show False using d1 by simp
  qed
  have f5:"adj_mat B  carrier_mat n n  B * adj_mat B = det B m 1m n" using c1 adj_mat by auto
  then have c2:"((1/det B) m adj_mat B)  carrier_mat n n" by simp
  have f6:"B * ((1/det B) m adj_mat B) = 1m n" using c1 f1 f5 mult_smult_distrib[of "B"] by auto
  then have "A = (A * B) * ((1/det B) m adj_mat B)" using c0 c1 c2 by simp
  then have "A = (1/det B) m adj_mat B" using f0 c2 by auto
  then show ?thesis using c0 c1 f6 inverts_mat_def by auto
qed

lemma sum_of_unit_vec_length:
  fixes i j n:: nat and c:: complex
  assumes "i < n" and "j < n" and "i  j"
  shows "unit_vec n i + c v unit_vec n j2 = 1 + cnj(c) * c"
proof-
  define v where d0:"v = unit_vec n i + c v unit_vec n j"
  have "k<n. v $ k = (if k = i then 1 else (if k = j then c else 0))"
    using d0 assms(1,2,3) by auto
  then have "k<n. cnj (v $ k) * v $ k = (if k = i then 1 else 0) + (if k = j then cnj(c) * c else 0)"
    using assms(3) by auto
  moreover have "v2 = (k = 0..<n. cnj (v $ k) * v $ k)"
    using d0 assms cpx_vec_length_inner_prod inner_prod_def by simp
  ultimately show ?thesis
    using d0 assms by (auto simp add: sum.distrib)
qed

lemma sum_of_unit_vec_to_col:
  assumes "dim_col A = n" and "i < n" and "j < n"
  shows "col A i + c v col A j = A * |unit_vec n i + c v unit_vec n j"
proof
  show "dim_vec (col A i + c v col A j) = dim_vec (col_fst (A * |unit_vec n i + c v unit_vec n j))"
    using assms(1) by auto
next
  fix k assume "k < dim_vec (col_fst (A * |unit_vec n i + c v unit_vec n j))"
  then have f0:"k < dim_row A" using assms(1) by auto
  have "(col A i + c v col A j) $ k = A $$ (k, i) + c * A $$ (k, j)"
    using f0 assms(1-3) by auto
  moreover have "(x<n. A $$ (k, x) * ((if x = i then 1 else 0) + c * (if x = j then 1 else 0))) = 
                 (x<n. A $$ (k, x) * (if x = i then 1 else 0)) + 
                 (x<n. A $$ (k, x) * c * (if x = j then 1 else 0))"
    by (auto simp add: sum.distrib algebra_simps)
  moreover have "x<n. A $$ (k, x) * (if x = i then 1 else 0) = (if x = i then A $$ (k, x) else 0)"
    by simp
  moreover have "x<n. A $$ (k, x) * c * (if x = j then 1 else 0) = (if x = j then A $$ (k, x) * c else 0)"
    by simp
  ultimately show "(col A i + c v col A j) $ k = col_fst (A * |unit_vec n i + c v unit_vec n j) $ k"
    using f0 assms(1-3) times_mat_def scalar_prod_def ket_vec_def by auto
qed

lemma inner_prod_is_sesquilinear:
  fixes u1 u2 v1 v2:: "complex vec" and c1 c2 c3 c4:: complex and n:: nat
  assumes "dim_vec u1 = n" and "dim_vec u2 = n" and "dim_vec v1 = n" and "dim_vec v2 = n"
  shows "c1 v u1 + c2 v u2|c3 v v1 + c4 v v2 = cnj (c1) * c3 * u1|v1 + cnj (c2) * c3 * u2|v1 + 
                                                 cnj (c1) * c4 * u1|v2 + cnj (c2) * c4 * u2|v2"
proof-
  have "c1 v u1 + c2 v u2|c3 v v1 + c4 v v2 = c3 * c1 v u1 + c2 v u2|v1 + c4 * c1 v u1 + c2 v u2|v2"
    using inner_prod_is_linear[of "c1 v u1 + c2 v u2" "λi. if i = 0 then v1 else v2" 
                                  "λi. if i = 0 then c3 else c4"] assms
    by simp
  also have "... = c3 * cnj(v1|c1 v u1 + c2 v u2) + c4 * cnj(v2|c1 v u1 + c2 v u2)"
    using assms inner_prod_cnj[of "v1" "c1 v u1 + c2 v u2"] inner_prod_cnj[of "v2" "c1 v u1 + c2 v u2"] 
    by simp
  also have "... = c3 * cnj(c1 * v1|u1 + c2 * v1|u2) + c4 * cnj(c1 * v2|u1 + c2 * v2|u2)"
    using inner_prod_is_linear[of "v1" "λi. if i = 0 then u1 else u2" "λi. if i = 0 then c1 else c2"] 
          inner_prod_is_linear[of "v2" "λi. if i = 0 then u1 else u2" "λi. if i = 0 then c1 else c2"] assms
    by simp
  also have "... = c3 * (cnj(c1) * u1|v1 + cnj(c2) * u2|v1) + 
                   c4 * (cnj(c1) * u1|v2 + cnj(c2) * u2|v2)"
    using inner_prod_cnj[of "v1" "u1"] inner_prod_cnj[of "v1" "u2"] 
          inner_prod_cnj[of "v2" "u1"] inner_prod_cnj[of "v2" "u2"] assms
    by simp
  finally show ?thesis
    by (auto simp add: algebra_simps)
qed

text ‹
A length-preserving matrix is unitary. So, unitary matrices are exactly the length-preserving
matrices.
›

lemma length_preserving_is_unitary:
  fixes U:: "complex mat"
  assumes "square_mat U" and "v::complex vec. dim_vec v = dim_col U  U * |v = v"
  shows "unitary U"
proof-
  define n where "n = dim_col U"
  then have c0:"U  carrier_mat n n" using assms(1) by auto
  then have c1:"U  carrier_mat n n" using assms(1) dagger_def by auto
  have f0:"(U) * U = 1m (dim_col U)"
  proof
    show "dim_row (U * U) = dim_row (1m (dim_col U))" using c0 by simp
  next
    show "dim_col (U * U) = dim_col (1m (dim_col U))" using c0 by simp
  next
    fix i j assume "i < dim_row (1m (dim_col U))" and "j < dim_col (1m (dim_col U))"
    then have a0:"i < n  j < n" using c0 by simp
    have f1:"l. l<n  (k<n. cnj (U $$ (k, l)) * U $$ (k, l)) = 1"
    proof
      fix l assume a1:"l<n"
      define v::"complex vec" where d1:"v = unit_vec n l"
      have "col U l2 = (k<n. cnj (U $$ (k, l)) * U $$ (k, l))"
        using c0 a1 cpx_vec_length_inner_prod inner_prod_def lessThan_atLeast0 by simp
      moreover have "col U l2 = v2" using c0 d1 a1 assms(2) unit_vec_to_col by simp
      moreover have "v2 = 1" using d1 a1 cpx_vec_length_inner_prod by simp
      ultimately show "(k<n. cnj (U $$ (k, l)) * U $$ (k, l)) = 1" by simp
    qed
    moreover have "i  j  (k<n. cnj (U $$ (k, i)) * U $$ (k, j)) = 0"
    proof
      assume a2:"i  j"
      define v1::"complex vec" where d1:"v1 = unit_vec n i + 1 v unit_vec n j"
      define v2::"complex vec" where d2:"v2 = unit_vec n i + 𝗂 v unit_vec n j"
      have "v12 = 1 + cnj 1 * 1" using d1 a0 a2 sum_of_unit_vec_length by blast
      then have "v12 = 2"
        by (metis complex_cnj_one cpx_vec_length_inner_prod mult.left_neutral of_real_eq_iff 
            of_real_numeral one_add_one)
      then have "U * |v12 = 2" using c0 d1 assms(2) unit_vec_to_col by simp
      moreover have "col U i + 1 v col U j = U * |v1"
        using c0 d1 a0 sum_of_unit_vec_to_col by blast
      moreover have "col U i + 1 v col U j = col U i + col U j" by simp
      ultimately have "col U i + col U j|col U i + col U j = 2"
        using cpx_vec_length_inner_prod by (metis of_real_numeral)
      moreover have "col U i + col U j|col U i + col U j = 
               col U i|col U i + col U j|col U i + col U i|col U j + col U j|col U j"
        using inner_prod_is_sesquilinear[of "col U i" "dim_row U" "col U j" "col U i" "col U j" "1" "1" "1" "1"]
        by simp
      ultimately have f2:"col U j|col U i + col U i|col U j = 0"
        using c0 a0 f1 inner_prod_def lessThan_atLeast0 by simp

      have "v22 = 1 + cnj 𝗂 * 𝗂" using a0 a2 d2 sum_of_unit_vec_length by simp
      then have "v22 = 2"
        by (metis Re_complex_of_real complex_norm_square mult.commute norm_ii numeral_Bit0 
            numeral_One numeral_eq_one_iff of_real_numeral one_power2)
      moreover have "U * |v22 = v22" using c0 d2 assms(2) unit_vec_to_col by simp
      moreover have "col U i + 𝗂 v col U j|col U i + 𝗂 v col U j = U * |v22"
        using c0 a0 d2 sum_of_unit_vec_to_col cpx_vec_length_inner_prod by auto
      moreover have "col U i + 𝗂 v col U j|col U i + 𝗂 v col U j = 
                     col U i|col U i + (-𝗂) * col U j|col U i + 𝗂 * col U i|col U j + col U j|col U j"
        using inner_prod_is_sesquilinear[of "col U i" "dim_row U" "col U j" "col U i" "col U j" "1" "𝗂" "1" "𝗂"]
        by simp
      ultimately have "col U j|col U i - col U i|col U j = 0"
        using c0 a0 f1 inner_prod_def lessThan_atLeast0 by auto
      then show "(k<n. cnj (U $$ (k, i)) * U $$ (k, j)) = 0"
        using c0 a0 f2 lessThan_atLeast0 inner_prod_def by auto
    qed
    ultimately show "(U * U) $$ (i, j) = 1m (dim_col U) $$ (i, j)"
      using c0 assms(1) a0 one_mat_def dagger_def by auto
qed
  then have "(U) * U = 1m n" using c0 by simp
  then have "inverts_mat (U) U" using c1 inverts_mat_def by auto
  then have "inverts_mat U (U)" using c0 c1 inverts_mat_sym by simp
  then have "U * (U) = 1m (dim_row U)" using c0 inverts_mat_def by auto
  then show ?thesis using f0 unitary_def by simp
qed

lemma inner_prod_with_unitary_mat [simp]:
  assumes "unitary U" and "dim_vec u = dim_col U" and "dim_vec v = dim_col U"
  shows "U * |u|U * |v = u|v"
proof -
  have f1:"U * |u|U * |v = (|u| * (U) * U * |v) $$ (0,0)"
    using assms(2-3) bra_mat_on_vec mult_ket_vec_is_ket_vec_of_mult
    by (smt (verit, ccfv_threshold) assoc_mult_mat carrier_mat_triv col_fst_def dim_vec dim_col_of_dagger index_mult_mat(2) 
        index_mult_mat(3) inner_prod_with_times_mat ket_vec_def mat_carrier)
  moreover have f2:"|u|  carrier_mat 1 (dim_vec v)"
    using bra_def ket_vec_def assms(2-3) by simp
  moreover have f3:"U  carrier_mat (dim_col U) (dim_row U)"
    using dagger_def by simp
  ultimately have "U * |u|U * |v = (|u| * (U * U) * |v) $$ (0,0)"
    using assms(3) assoc_mult_mat by (metis carrier_mat_triv)
  also have " = (|u| * |v) $$ (0,0)"
    using assms(1) unitary_def
    by (simp add: assms(2) bra_def ket_vec_def)
  finally show ?thesis
    using assms(2-3) inner_prod_with_times_mat by presburger
qed

text ‹As a consequence we prove that columns and rows of a unitary matrix are orthonormal vectors.›

lemma unitary_unit_col [simp]:
  assumes "unitary U" and "dim_col U = n" and "i < n"
  shows "col U i = 1"
  using assms unit_vec_to_col unitary_is_length_preserving by simp

lemma unitary_unit_row [simp]:
  assumes "unitary U" and "dim_row U = n" and "i < n"
  shows "row U i = 1"
proof -
  have "row U i = col (Ut) i"
    using  assms(2-3) by simp
  thus ?thesis
    using assms transpose_of_unitary_is_unitary unitary_unit_col
    by (metis index_transpose_mat(3))
qed

lemma orthogonal_col_of_unitary [simp]:
  assumes "unitary U" and "dim_col U = n" and "i < n" and "j < n" and "i  j"
  shows "col U i|col U j = 0"
proof -
  have "col U i|col U j = U * |unit_vec n i| U * |unit_vec n j"
    using assms(2-4) unit_vec_to_col by simp
  also have " = unit_vec n i |unit_vec n j"
    using assms(1-2) inner_prod_with_unitary_mat index_unit_vec(3) by simp
  finally show ?thesis
    using assms(3-5) by simp
qed

lemma orthogonal_row_of_unitary [simp]:
  fixes U::"complex mat"
  assumes "unitary U" and "dim_row U = n" and "i < n" and "j < n" and "i  j"
  shows "row U i|row U j = 0"
  using assms orthogonal_col_of_unitary transpose_of_unitary_is_unitary col_transpose
  by (metis index_transpose_mat(3))


text‹
As a consequence, we prove that a quantum gate acting on a state of a system of n qubits give 
another state of that same system.
›

lemma gate_on_state_is_state [intro, simp]:
  assumes a1:"gate n A" and a2:"state n v"
  shows "state n (A * v)"
proof
  show "dim_row (A * v) = 2^n"
    using gate_def state_def a1 by simp
next
  show "dim_col (A * v) = 1"
    using state_def a2 by simp
next
  have "square_mat A"
    using a1 gate_def by simp
  then have "dim_col A = 2^n"
    using a1 gate.dim_row by simp
  then have "dim_col A = dim_row v"
    using a2 state.dim_row by simp
  then have "col (A * v) 0 = col v 0"
    using unitary_is_length_preserving_bis assms gate_def state_def by simp
  thus"col (A * v) 0 = 1"
    using a2 state.is_normal by simp
qed


subsection ‹A Few Well-known Quantum Gates›

text ‹
Any unitary operation on n qubits can be implemented exactly by composing single qubits and
CNOT-gates (controlled-NOT gates). However, no straightforward method is known to implement these 
gates in a fashion which is resistant to errors. But, the Hadamard gate, the phase gate, the 
CNOT-gate and the @{text "π/8"} gate are also universal for quantum computations, i.e. any quantum circuit on 
n qubits can be approximated to an arbitrary accuracy by using only these gates, and these gates can 
be implemented in a fault-tolerant way. 
›

text ‹We introduce a coercion from real matrices to complex matrices.›

definition real_to_cpx_mat:: "real mat  complex mat" where
"real_to_cpx_mat A  mat (dim_row A) (dim_col A) (λ(i,j). A $$ (i,j))"

text ‹Our first quantum gate: the identity matrix! Arguably, not a very interesting one though!›

definition Id :: "nat  complex mat" where
"Id n  1m (2^n)"

lemma id_is_gate [simp]:
  "gate n (Id n)"
proof
  show "dim_row (Id n) = 2^n"
    using Id_def by simp
next
  show "square_mat (Id n)"
    using Id_def by simp
next
  show "unitary (Id n)" 
    by (simp add: Id_def)
qed

text ‹More interesting: the Pauli matrices.›

definition X ::"complex mat" where
"X  mat 2 2 (λ(i,j). if i=j then 0 else 1)"

text‹ 
Be aware that @{text "gate n A"} means that the matrix A has dimension @{text "2^n * 2^n"}. 
For instance, with this convention a 2 X 2 matrix A which is unitary satisfies @{text "gate 1 A"}
 but not @{text "gate 2 A"} as one might have been expected.
›

lemma dagger_of_X [simp]:
  "X = X"
  using dagger_def by (simp add: X_def cong_mat)

lemma X_inv [simp]:
  "X * X = 1m 2"
  apply(simp add: X_def times_mat_def one_mat_def)
  apply(rule cong_mat)
  by(auto simp: scalar_prod_def)

lemma X_is_gate [simp]:
  "gate 1 X"
  by (simp add: gate_def unitary_def)
    (simp add: X_def)

definition Y ::"complex mat" where
"Y  mat 2 2 (λ(i,j). if i=j then 0 else (if i=0 then -𝗂 else 𝗂))"

lemma dagger_of_Y [simp]:
  "Y = Y"
  using dagger_def by (simp add: Y_def cong_mat)

lemma Y_inv [simp]:
  "Y * Y = 1m 2"
  apply(simp add: Y_def times_mat_def one_mat_def)
  apply(rule cong_mat)
  by(auto simp: scalar_prod_def)

lemma Y_is_gate [simp]:
  "gate 1 Y"
  by (simp add: gate_def unitary_def)
    (simp add: Y_def)

definition Z ::"complex mat" where
"Z  mat 2 2 (λ(i,j). if ij then 0 else (if i=0 then 1 else -1))"

lemma dagger_of_Z [simp]:
  "Z = Z"
  using dagger_def by (simp add: Z_def cong_mat)

lemma Z_inv [simp]:
  "Z * Z = 1m 2"
  apply(simp add: Z_def times_mat_def one_mat_def)
  apply(rule cong_mat)
  by(auto simp: scalar_prod_def)

lemma Z_is_gate [simp]:
  "gate 1 Z"
  by (simp add: gate_def unitary_def)
    (simp add: Z_def)

text ‹The Hadamard gate›

definition H ::"complex mat" where
"H  1/sqrt(2) m (mat 2 2 (λ(i,j). if ij then 1 else (if i=0 then 1 else -1)))"

lemma H_without_scalar_prod:
  "H = mat 2 2 (λ(i,j). if ij then 1/sqrt(2) else (if i=0 then 1/sqrt(2) else -(1/sqrt(2))))"
  using cong_mat by (auto simp: H_def)

lemma dagger_of_H [simp]:
  "H = H"
  using dagger_def by (auto simp: H_def cong_mat)

lemma H_inv [simp]:
  "H * H = 1m 2"
  apply(simp add: H_def times_mat_def one_mat_def)
  apply(rule cong_mat)
  by(auto simp: scalar_prod_def complex_eqI)

lemma H_is_gate [simp]:
  "gate 1 H"
  by (simp add: gate_def unitary_def)
    (simp add: H_def)

lemma H_values:
  fixes i j:: nat
  assumes "i < dim_row H" and "j < dim_col H" and "i  1  j  1" 
  shows "H $$ (i,j) = 1/sqrt 2"
proof-
  have "i < 2"
    using assms(1) by (simp add: H_without_scalar_prod less_2_cases)
  moreover have "j < 2"
    using assms(2) by (simp add: H_without_scalar_prod less_2_cases)
  ultimately show ?thesis 
    using assms(3) H_without_scalar_prod by (smt (verit) One_nat_def index_mat(1) less_2_cases old.prod.case)
qed

lemma H_values_right_bottom:
  fixes i j:: nat
  assumes "i = 1  j = 1"
  shows "H $$ (i,j) = - 1/sqrt 2"     
  using assms by (simp add: H_without_scalar_prod)

text ‹The controlled-NOT gate›

definition CNOT ::"complex mat" where
"CNOT  mat 4 4 
  (λ(i,j). if i=0  j=0 then 1 else 
    (if i=1  j=1 then 1 else 
      (if i=2  j=3 then 1 else 
        (if i=3  j=2 then 1 else 0))))"

lemma dagger_of_CNOT [simp]:
  "CNOT = CNOT"
  using dagger_def by (simp add: CNOT_def cong_mat)

lemma CNOT_inv [simp]:
  "CNOT * CNOT = 1m 4"
  apply(simp add: CNOT_def times_mat_def one_mat_def)
  apply(rule cong_mat)
  by(auto simp: scalar_prod_def)

lemma CNOT_is_gate [simp]:
  "gate 2 CNOT"
  by (simp add: gate_def unitary_def)
    (simp add: CNOT_def)

text ‹The phase gate, also known as the S-gate›

definition S ::"complex mat" where
"S  mat 2 2 (λ(i,j). if i=0  j=0 then 1 else (if i=1  j=1 then 𝗂 else 0))"

text ‹The @{text "π/8"} gate, also known as the T-gate›

definition T ::"complex mat" where
"T  mat 2 2 (λ(i,j). if i=0  j=0 then 1 else (if i=1  j=1 then exp(𝗂*(pi/4)) else 0))"

text ‹A few relations between the Hadamard gate and the Pauli matrices›

lemma HXH_is_Z [simp]:
  "H * X * H = Z" 
  apply(simp add: X_def Z_def H_def times_mat_def)
  apply(rule cong_mat)
  by(auto simp add: scalar_prod_def complex_eqI)

lemma HYH_is_minusY [simp]:
  "H * Y * H = - Y" 
  apply(simp add: Y_def H_def times_mat_def)
  apply(rule eq_matI)
  by(auto simp add: scalar_prod_def complex_eqI)

lemma HZH_is_X [simp]:
  shows "H * Z * H = X"  
  apply(simp add: X_def Z_def H_def times_mat_def)
  apply(rule cong_mat)
  by(auto simp add: scalar_prod_def complex_eqI)


subsection ‹The Bell States›

text ‹
We introduce below the so-called Bell states, also known as EPR pairs (EPR stands for Einstein,
Podolsky and Rosen).
›

definition bell00 ::"complex mat" (00) where
"bell00  1/sqrt(2) m |vec 4 (λi. if i=0  i=3 then 1 else 0)"

definition bell01 ::"complex mat" (01) where
"bell01  1/sqrt(2) m |vec 4 (λi. if i=1  i=2 then 1 else 0)