Theory Commuting_Hermitian

(*
Author: 
  Mnacho Echenim, Université Grenoble Alpes
*)

theory Commuting_Hermitian imports Spectral_Theory_Complements Commuting_Hermitian_Misc
"Projective_Measurements.Linear_Algebra_Complements" 
"Projective_Measurements.Projective_Measurements" begin 

section ‹Additional results on block decompositions of matrices›

subsection ‹Split block results›

lemma split_block_diag_carrier:
  assumes "D  carrier_mat n n"
  and "a  n"
  and "split_block D a a = (D1, D2, D3, D4)"
shows "D1 carrier_mat a a" "D4 carrier_mat (n-a) (n-a)"
proof -
  show "D1 carrier_mat a a" using assms unfolding split_block_def
    by (metis Pair_inject mat_carrier)
  show "D4  carrier_mat (n-a) (n-a)" using assms unfolding split_block_def
    by (metis Pair_inject carrier_matD(1) carrier_matD(2) mat_carrier)
qed

lemma split_block_diagonal:
  assumes "diagonal_mat D"
  and "D  carrier_mat n n"
and "a  n"
and "split_block D a a = (D1, D2, D3, D4)"
shows "diagonal_mat D1  diagonal_mat D4" unfolding diagonal_mat_def
proof (intro allI conjI impI)
  have "D1  carrier_mat a a" using assms unfolding split_block_def Let_def 
    by fastforce
  fix i j
  assume "i < dim_row D1"
  and "j < dim_col D1"
  and "i  j"
  have "D1 $$ (i,j) = D $$(i,j)" using assms unfolding split_block_def Let_def
    using i < dim_row D1 j < dim_col D1 by fastforce
  also have "... = 0" using assms i  j  D1  carrier_mat a a 
    i < dim_row D1 j < dim_col D1 unfolding diagonal_mat_def by fastforce
  finally show "D1 $$(i,j) = 0" .
next
  have "D4  carrier_mat (n-a) (n-a)" using assms 
    unfolding split_block_def Let_def by fastforce
  fix i j
  assume "i < dim_row D4"
  and "j < dim_col D4"
  and "i  j"
  have "D4 $$ (i,j) = D $$(i + a,j + a)" using assms unfolding split_block_def Let_def
    using i < dim_row D4 j < dim_col D4 by fastforce
  also have "... = 0" using assms i  j  D4  carrier_mat (n-a) (n-a) 
    i < dim_row D4 j < dim_col D4 unfolding diagonal_mat_def by fastforce
  finally show "D4 $$(i,j) = 0" .
qed

lemma split_block_times_diag_index:
  fixes B::"'a::comm_ring Matrix.mat"
  assumes "diagonal_mat D"
  and "D carrier_mat n n"
  and "B carrier_mat n n"
  and "a  n"
  and "split_block B a a = (B1, B2, B3, B4)"
  and "split_block D a a = (D1, D2, D3, D4)"
  and "i < dim_row (D4 * B4)"
  and "j < dim_col (D4 * B4)"
shows "(B4 * D4) $$ (i, j) = (B*D) $$ (i+a, j+a)"
      "(D4 * B4) $$ (i, j) = (D*B) $$ (i+a, j+a)"
proof -
  have d4: "D4  carrier_mat (n-a) (n-a)" using assms  
      split_block(4)[of D] by simp
  have b4: "B4  carrier_mat (n-a) (n-a)" using assms  
      split_block(4)[of B] by simp
  have "diagonal_mat D4" using assms split_block_diagonal[of D] by blast
  have "i < n-a" using i < dim_row (D4 * B4) b4 d4 by simp
  have "j < n-a" using j < dim_col (D4 * B4) b4 d4 by simp
  have "(B4 * D4) $$ (i, j) = D4 $$ (j,j) * B4 $$ (i,j)" 
  proof (rule  diagonal_mat_mult_index') 
    show "diagonal_mat D4" using diagonal_mat D4 .
    show "B4  carrier_mat (n-a) (n-a)" using b4 .
    show "D4  carrier_mat (n - a) (n - a)" using d4 .
    show "i < n-a" using i < n-a .
    show "j < n-a" using j < n-a .
  qed
  also have "... = D $$ (j+a, j+a) * B $$ (i+a, j+a)" 
    using assms i < n-a j < n-a 
    unfolding split_block_def Let_def by fastforce 
  also have "... = (B*D) $$ (i+a, j+a)" using diagonal_mat_mult_index' assms
    by (metis i < n - a j < n - a less_diff_conv)
  finally show "(B4 * D4) $$ (i, j) = (B*D) $$ (i+a, j+a)" .
  have "(D4 * B4) $$ (i, j) = D4 $$ (i,i) * B4 $$ (i,j)" 
    using diagonal_mat_mult_index diagonal_mat D4 i < n - a j < n - a b4 d4 
    by blast
  also have "... = D $$ (i+a, i+a) * B $$ (i+a, j+a)" 
    using assms i < n-a j < n-a 
    unfolding split_block_def Let_def by fastforce
  also have "... = (D*B) $$ (i+a, j+a)" using diagonal_mat_mult_index assms
    by (metis i < n - a j < n - a less_diff_conv)
  finally show "(D4 * B4) $$ (i, j) = (D*B) $$ (i+a, j+a)" .
qed

lemma split_block_commute_subblock:
  fixes B::"'a::comm_ring Matrix.mat"
  assumes "diagonal_mat D"
  and "D carrier_mat n n"
  and "B carrier_mat n n"
  and "a  n"
  and "split_block B a a = (B1, B2, B3, B4)"
  and "split_block D a a = (D1, D2, D3, D4)"
  and "B * D = D * B"
shows "B4 * D4 = D4 * B4"
proof
  have d4: "D4  carrier_mat (n-a) (n-a)" using assms  
      split_block(4)[of D] by simp
  have b4: "B4  carrier_mat (n-a) (n-a)" using assms  
      split_block(4)[of B] by simp
  have "diagonal_mat D4" using assms split_block_diagonal[of D] by blast
  show "dim_row (B4 * D4) = dim_row (D4 * B4)" using d4 b4 by simp
  show "dim_col (B4 * D4) = dim_col (D4 * B4)" using d4 b4 by simp
  fix i j
  assume "i < dim_row (D4 * B4)"
  and "j < dim_col (D4 * B4)"
  have "(B4*D4) $$(i,j) = (B*D) $$(i+a, j+a)"
    using split_block_times_diag_index[of D n B a] assms
      i < dim_row (D4 * B4) j < dim_col (D4 * B4) by blast
  also have "... = (D*B) $$ (i+a, j+a)" using assms by simp
  also have "... = (D4*B4) $$ (i, j)"
    using split_block_times_diag_index[of D n B a] assms
    by (metis i < dim_row (D4 * B4) j < dim_col (D4 * B4))
  finally show "(B4*D4) $$(i,j) = (D4*B4) $$ (i, j)" .
qed

lemma commute_diag_mat_zero_comp:
  fixes D::"'a::{field} Matrix.mat"
  assumes "diagonal_mat D"
  and "D carrier_mat n n"
  and "B carrier_mat n n"
  and" B* D = D * B"
  and "i < n"
  and "j < n"
  and "D$$(i,i)  D$$(j,j)"
shows "B $$(i,j) = 0"
proof -
  have "B$$(i,j) * D$$(j,j) = (B*D) $$(i,j)" 
    using diagonal_mat_mult_index'[of B n D] assms by simp
  also have "... = (D*B) $$ (i,j)" using assms by simp
  also have "... = B$$(i,j) * D$$(i,i)" 
    using diagonal_mat_mult_index  assms
    by (metis Groups.mult_ac(2))
  finally have "B$$(i,j) * D$$(j,j) = B$$(i,j) * D$$(i,i)" .
  hence "B$$(i,j) * (D$$(j,j) - D$$(i,i)) = 0" by auto
  thus "B$$(i,j) = 0" using assms by simp
qed

lemma commute_diag_mat_split_block:
  fixes D::"'a::{field} Matrix.mat"
  assumes "diagonal_mat D"
  and "D carrier_mat n n"
  and "B carrier_mat n n"
  and" B* D = D * B"
  and "k  n"
  and "i j. (i < k  k  j  j < n)  D$$(i,i)  D$$(j,j)"
  and "(B1, B2, B3, B4) = split_block B k k"
shows "B2 = 0m k (n-k)" "B3 = 0m (n-k) k"
proof (intro eq_matI)
  show "dim_row B2 = dim_row (0m k (n - k))" 
    using assms unfolding split_block_def Let_def by simp
  show "dim_col B2 = dim_col (0m k (n - k))" 
    using assms unfolding split_block_def Let_def by simp
  fix i j
  assume "i < dim_row (0m k (n - k))"
  and "j < dim_col (0m k (n - k))" note ijprop = this
  have "B2 $$ (i, j) = B $$ (i, j+k)" using assms ijprop 
    unfolding split_block_def Let_def by simp
  also have "... = 0" 
  proof (rule commute_diag_mat_zero_comp[of D n], (auto simp add: assms))
    show "i < n" using ijprop assms by simp
    show "j + k < n" using ijprop assms by simp
    show "D $$ (i, i) = D $$ (j + k, j + k)  False" using ijprop assms
      by (metis j + k < n index_zero_mat(2) le_add2)
  qed
  finally show "B2 $$ (i, j) = 0m k (n - k) $$ (i, j)" using ijprop by simp
next 
  show "B3 = 0m (n-k) k"
  proof (intro eq_matI)
    show "dim_row B3 = dim_row (0m (n - k) k)" 
      using assms unfolding split_block_def Let_def by simp
    show "dim_col B3 = dim_col (0m (n - k) k)" 
      using assms unfolding split_block_def Let_def by simp
    fix i j
    assume "i < dim_row (0m (n - k) k)"
    and "j < dim_col (0m (n - k) k)" note ijprop = this
    have "B3 $$ (i, j) = B $$ (i+k, j)" using assms ijprop 
      unfolding split_block_def Let_def by simp
    also have "... = 0" 
    proof (rule commute_diag_mat_zero_comp[of D n], (auto simp add: assms))
      show "i + k < n" using ijprop assms by simp
      show "j < n" using ijprop assms by simp
      show "D $$ (i+k, i+k) = D $$ (j, j)  False" using ijprop assms
        by (metis i + k < n index_zero_mat(3) le_add2)
    qed
    finally show "B3 $$ (i, j) = 0m (n - k) k $$ (i, j)" using ijprop by simp
  qed
qed

lemma split_block_hermitian_1:
  assumes "hermitian A"
  and "n  dim_row A"
and "(A1, A2, A3, A4) = split_block A n n"
shows "hermitian A1"  unfolding hermitian_def
proof (rule eq_matI, auto)
  have "dim_row A = dim_col A" using assms
    by (metis carrier_matD(2) hermitian_square) 
  show "dim_col A1 = dim_row A1" using assms unfolding split_block_def Let_def 
    by simp
  thus "dim_row A1 = dim_col A1" by simp
  show "i j. i < dim_row A1  j < dim_col A1  
    Complex_Matrix.adjoint A1 $$ (i, j) = A1 $$ (i, j)"
  proof -
    fix i j
    assume "i < dim_row A1" and "j < dim_col A1" note ij = this
    have r: "dim_row A1 = n" using assms unfolding split_block_def Let_def 
      by simp
    have c: "dim_col A1 = n" using assms unfolding split_block_def Let_def 
      by simp
    have "Complex_Matrix.adjoint A1 $$ (i, j) = conjugate (A1 $$ (j,i))"
      using ij r c unfolding Complex_Matrix.adjoint_def by simp
    also have "... = conjugate (A $$ (j,i))" using assms ij r c
      unfolding split_block_def Let_def by simp
    also have "... = A $$ (i,j)" using assms ij r c dim_row A = dim_col A
      unfolding hermitian_def Complex_Matrix.adjoint_def
      by (metis adjoint_eval assms(1) hermitian_def order_less_le_trans)
    also have "... = A1 $$(i,j)" using assms ij r c 
      unfolding split_block_def Let_def by simp
    finally show "Complex_Matrix.adjoint A1 $$ (i, j) = A1 $$ (i, j)" .
  qed
qed

lemma split_block_hermitian_4:
  assumes "hermitian A"
  and "n  dim_row A"
and "(A1, A2, A3, A4) = split_block A n n"
shows "hermitian A4"  unfolding hermitian_def
proof (rule eq_matI, auto)
  have arc: "dim_row A = dim_col A" using assms
    by (metis carrier_matD(2) hermitian_square) 
  thus "dim_col A4 = dim_row A4" using assms unfolding split_block_def Let_def 
    by simp
  thus "dim_row A4 = dim_col A4" by simp
  show "i j. i < dim_row A4  j < dim_col A4  
    Complex_Matrix.adjoint A4 $$ (i, j) = A4 $$ (i, j)"
  proof -
    fix i j
    assume "i < dim_row A4" and "j < dim_col A4" note ij = this
    have r: "dim_row A4 = dim_row A - n" using assms 
      unfolding split_block_def Let_def by simp
    have c: "dim_col A4 = dim_col A - n" using assms 
      unfolding split_block_def Let_def by simp
    have "Complex_Matrix.adjoint A4 $$ (i, j) = conjugate (A4 $$ (j,i))"
      using ij r c arc unfolding Complex_Matrix.adjoint_def by simp
    also have "... = conjugate (A $$ (j +n ,i+n))" using assms ij r c arc
      unfolding split_block_def Let_def by simp
    also have "... = A $$ (i+n,j+n)" using assms ij r c arc 
      unfolding hermitian_def Complex_Matrix.adjoint_def
      by (metis index_mat(1) less_diff_conv split_conv)      
    also have "... = A4 $$(i,j)" using assms ij r c 
      unfolding split_block_def Let_def by simp
    finally show "Complex_Matrix.adjoint A4 $$ (i, j) = A4 $$ (i, j)" .
  qed
qed

lemma diag_block_split_block:
  assumes "B carrier_mat n n"
  and "k < n"
  and "(B1, B2, B3, B4) = split_block B k k"
  and "B2 = 0m k (n-k)" 
  and "B3 = 0m (n-k) k"
shows "B = diag_block_mat [B1,B4]"
proof -
  have dr: "dim_row B = k + (n-k)" using assms by simp
  have dc: "dim_col B = k + (n-k)" using assms by simp
  have c1: "B1  carrier_mat k k" using assms 
    split_block(1)[of B, OF _ dr dc] by metis
  have c4: "B4  carrier_mat (n-k) (n-k)" using assms 
    split_block(4)[of B, OF _ dr dc] by metis
  have d4: "diag_block_mat [B4] = B4" using diag_block_mat_singleton[of B4] 
    by simp
  have "B = four_block_mat B1 B2 B3 B4" using assms split_block(3)[of B k ]
    by (metis carrier_matD(1) carrier_matD(2) diff_is_0_eq 
        le_add_diff_inverse nat_le_linear semiring_norm(137) 
        split_block(5) zero_less_diff) 
  also have "... = four_block_mat B1 (0m k (n-k)) (0m (n-k) k) B4" 
    using assms by simp
  also have "... = four_block_mat B1 (0m k (n-k)) (0m (n-k) k) 
    (diag_block_mat [B4])" using diag_block_mat_singleton[of B4] by simp
  also have "... = diag_block_mat [B1, B4]" 
    using diag_block_mat.simps(2)[of B1 "[B4]"] c1 c4 
    unfolding Let_def by auto
  finally show ?thesis .
qed

subsection ‹Diagonal block matrices›

abbreviation four_block_diag where
"four_block_diag B1 B2 
  (four_block_mat B1 (0m (dim_row B1) (dim_col B2)) 
  (0m (dim_row B2) (dim_col B1)) B2)"

lemma four_block_diag_cong_comp:
  assumes "dim_row A1 = dim_row B1"
  and "dim_col A1 = dim_col B1"
  and "four_block_diag A1 A2 = four_block_diag B1 B2"
shows "A1 = B1"
proof (rule eq_matI, auto simp:assms)
  define A where "A = four_block_diag A1 A2"
  define B where "B = four_block_diag B1 B2"
  fix i j
  assume "i < dim_row B1" and "j<dim_col B1" note ij=this
  hence "i <dim_row A1" "j<dim_col A1" using assms by auto
  hence "A1$$(i,j) = A$$(i, j)" 
    unfolding A_def four_block_mat_def Let_def by force 
  also have "... = B$$(i, j)" using assms unfolding A_def B_def by simp
  also have "... = B1$$(i,j)" 
    using ij unfolding B_def four_block_mat_def Let_def by force 
  finally show "A1$$(i,j) = B1$$(i,j)" .
qed

lemma four_block_diag_cong_comp':
  assumes "dim_row A1 = dim_row B1"
  and "dim_col A1 = dim_col B1"
  and "four_block_diag A1 A2 = four_block_diag B1 B2"
shows "A2 = B2"
proof (rule eq_matI)
  define n where "n=dim_row A1"
  define m where "m = dim_col A1"
  define A where "A = four_block_diag A1 A2"
  define B where "B = four_block_diag B1 B2"
  show "dim_row A2 = dim_row B2" 
    using assms unfolding four_block_mat_def Let_def
    by (metis assms(3) diff_add_inverse index_mat_four_block(2)) 
  show "dim_col A2 = dim_col B2"
    using assms unfolding four_block_mat_def Let_def
    by (metis assms(3) diff_add_inverse index_mat_four_block(3))
  fix i j
  assume "i < dim_row B2" and "j<dim_col B2" note ij=this
  hence "i+n < dim_row A" 
    unfolding A_def n_def m_def four_block_mat_def Let_def
    by (simp add: dim_row A2 = dim_row B2)
  have "j+m < dim_col A"
    unfolding A_def n_def m_def four_block_mat_def Let_def
    by (simp add: dim_col A2 = dim_col B2 ij)
  {
    have "n  i+n" by simp
    have "m j+m" by simp
    have "i + n - n = i" by simp
    have "j + m - m = j" by simp
  } note ijeq = this
  have "A2$$(i,j) = A$$(i+n, j+m)" using ijeq
    using A_def i + n < dim_row A j + m < dim_col A m_def n_def by force 
  also have "... = B$$(i+n, j+m)" using assms unfolding A_def B_def by simp
  also have "... = B2$$(i,j)" using ijeq
    by (metis A_def B_def i + n < dim_row A j + m < dim_col A 
        add_implies_diff assms(1) assms(2) assms(3) index_mat_four_block(1) 
        index_mat_four_block(2) index_mat_four_block(3) m_def n_def 
        not_add_less2)
  finally show "A2$$(i,j) = B2$$(i,j)" .
qed

lemma four_block_mat_real_diag:
  assumes "i < dim_row B1. B1$$(i,i)  Reals"
  and "i < dim_row B2. B2$$(i,i)  Reals"
  and "dim_row B1 = dim_col B1"
  and "dim_row B2 = dim_col B2"
  and "i < dim_row (four_block_diag B1 B2)"
shows "(four_block_diag B1 B2) $$ (i,i)  Reals" 
proof (cases "i < dim_row B1")
  case True  
  then show ?thesis using assms  by simp
next
  case False
  then show ?thesis using assms by force
qed

lemma four_block_diagonal:
  assumes "dim_row B1 = dim_col B1"
  and "dim_row B2 = dim_col B2"
  and "diagonal_mat B1"
  and "diagonal_mat B2"
shows "diagonal_mat (four_block_diag B1 B2)" unfolding diagonal_mat_def 
proof (intro allI impI)
  fix i j
  assume "i < dim_row (four_block_diag B1 B2)"
  and "j < dim_col (four_block_diag B1 B2)"
  and "i  j" note ijprops = this
  show "(four_block_diag B1 B2) $$ (i,j) = 0" 
  proof (cases "i < dim_row B1")
    case True
    then show ?thesis 
      using assms(3) diagonal_mat_def ijprops(2) ijprops(3)
      by (metis add_less_imp_less_left  
          ijprops(1) index_mat_four_block(1) index_mat_four_block(2) 
          index_mat_four_block(3) index_zero_mat(1) 
          linordered_semidom_class.add_diff_inverse) 
  next
    case False
    then show ?thesis using ijprops 
      by (metis (no_types, lifting) add_less_cancel_left assms(1) 
          assms(4) diagonal_mat_def index_mat_four_block(1) 
          index_mat_four_block(2) index_mat_four_block(3) 
          index_zero_mat(1) linordered_semidom_class.add_diff_inverse)
  qed
qed

lemma four_block_diag_zero:
  assumes "B carrier_mat 0 0"
  shows "four_block_diag A B = A"
proof (rule eq_matI, auto)
  show "dim_row B = 0" using assms by simp
  show "dim_col B = 0" using assms by simp
qed

lemma four_block_diag_zero':
  assumes "B carrier_mat 0 0"
  shows "four_block_diag B A = A"
proof (rule eq_matI)
  show "dim_row (four_block_diag B A) = dim_row A" using assms by simp
  show "dim_col (four_block_diag B A) = dim_col A" using assms by simp
  fix i j
  assume "i < dim_row A" and "j < dim_col A"
  thus "four_block_diag B A $$ (i, j) = A $$ (i, j)"
    using dim_col (four_block_diag B A) = dim_col A 
      dim_row (four_block_diag B A) = dim_row A 
  by auto
qed

lemma mult_four_block_diag:
  assumes "A1  carrier_mat nr1 n1" "D1  carrier_mat nr2 n2" 
  and "A2  carrier_mat n1 nc1" "D2  carrier_mat n2 nc2"
shows "four_block_diag A1 D1 * 
  four_block_diag A2  D2
  = four_block_diag (A1 * A2) (D1 * D2)" 
proof -
  define fb1 where "fb1 = four_block_mat A1 (0m nr1 n2) (0m nr2 n1) D1"
  define fb2 where "fb2 = four_block_mat A2 (0m n1 nc2) (0m n2 nc1) D2"
  have "fb1 * fb2 = four_block_mat (A1 * A2 + 0m nr1 n2 * 0m n2 nc1) 
    (A1 * 0m n1 nc2 + 0m nr1 n2 * D2) (0m nr2 n1 * A2 + D1 * 0m n2 nc1)
    (0m nr2 n1 * 0m n1 nc2 + D1 * D2)" unfolding fb1_def fb2_def
  proof (rule mult_four_block_mat)
    show "A1  carrier_mat nr1 n1" using assms by simp
    show "D1  carrier_mat nr2 n2" using assms by simp
    show "A2  carrier_mat n1 nc1" "D2  carrier_mat n2 nc2" using assms by auto
  qed auto  
  also have "... = four_block_mat (A1 * A2) (0m nr1 nc2) (0m nr2 nc1) (D1 * D2)" 
    using assms by simp
  finally show ?thesis unfolding fb1_def fb2_def  
    using assms by simp
qed

lemma four_block_diag_adjoint:
  shows  "(Complex_Matrix.adjoint (four_block_diag A1 A2)) = 
    (four_block_diag (Complex_Matrix.adjoint A1) 
    (Complex_Matrix.adjoint A2))" 
    by (rule eq_matI, 
        auto simp: four_block_mat_adjoint zero_adjoint adjoint_eval)

lemma four_block_diag_unitary:
  assumes "unitary U1"
  and "unitary U2"
shows "unitary
  (four_block_diag U1 U2)"
(is "unitary ?fU")
  unfolding unitary_def
proof
  show "?fU  carrier_mat (dim_row ?fU) (dim_row ?fU)" 
    by (metis Complex_Matrix.unitary_def assms(1) assms(2) 
        four_block_carrier_mat index_mat_four_block(2))
  define n where "n = dim_row ?fU"
  show "inverts_mat ?fU (Complex_Matrix.adjoint ?fU)"
  proof -
    have "(Complex_Matrix.adjoint ?fU) = 
      (four_block_mat (Complex_Matrix.adjoint U1) 
      (0m (dim_col U1) (dim_row U2)) 
      (0m (dim_col U2) (dim_row U1)) 
      (Complex_Matrix.adjoint U2))" 
      by (rule eq_matI, 
          auto simp: four_block_mat_adjoint zero_adjoint adjoint_eval)
    hence "?fU * (Complex_Matrix.adjoint ?fU) = 
      ?fU * (four_block_diag (Complex_Matrix.adjoint U1) 
      (Complex_Matrix.adjoint U2))"  by simp
    also have "... = four_block_diag
      (U1 * (Complex_Matrix.adjoint U1))
      (U2 * (Complex_Matrix.adjoint U2))"
      by (rule mult_four_block_diag, (auto simp add: assms))
    also have "... = four_block_mat
      (1m (dim_row U1))
      (0m (dim_row U1)  (dim_row U2))
      (0m (dim_row U2)  (dim_row U1))
      (1m (dim_row U2))" using assms 
      unfolding unitary_def inverts_mat_def 
      by simp
    also have "... = 1m (dim_row U1 + dim_row U2)" by simp
    finally show ?thesis unfolding inverts_mat_def  by simp
  qed
qed

lemma four_block_diag_similar:
  assumes "unitarily_equiv A1 B1 U1"
  and "unitarily_equiv A2 B2 U2"
  and "dim_row A1 = dim_col A1"
  and "dim_row A2 = dim_col A2"
shows "similar_mat_wit 
  (four_block_diag A1 A2)
  (four_block_diag B1 B2)
  (four_block_diag U1 U2)
  (Complex_Matrix.adjoint (four_block_diag U1 U2))"
  unfolding similar_mat_wit_def
proof (simp add: Let_def, intro conjI)
  define n where "n = dim_row A1 + dim_row A2"
  show "four_block_diag A1 A2  carrier_mat n n" unfolding n_def using assms 
    by auto
  show "four_block_diag B1 B2  carrier_mat n n" unfolding n_def using assms
    by (metis carrier_matI four_block_carrier_mat unitarily_equiv_carrier(1))
  show u: "four_block_diag U1 U2  carrier_mat n n" unfolding n_def using assms
    by (metis carrier_matI four_block_carrier_mat unitarily_equiv_carrier(2))
  thus cu: "Complex_Matrix.adjoint (four_block_diag U1 U2)  carrier_mat n n" 
    unfolding n_def using adjoint_dim' by blast
  show "four_block_diag U1 U2*Complex_Matrix.adjoint (four_block_diag U1 U2) =
    1m n" unfolding n_def
    using u assms four_block_diag_unitary n_def 
      unitarily_equiv_def unitary_simps(2) by blast
  thus "Complex_Matrix.adjoint (four_block_diag U1 U2)*four_block_diag U1 U2 = 
    1m n"
    using cu mat_mult_left_right_inverse u by blast 
  have "four_block_diag A1 A2 = 
    four_block_diag (U1 * B1 * (Complex_Matrix.adjoint U1))
    (U2 * B2 * (Complex_Matrix.adjoint U2))"
    using assms unitarily_equiv_eq by blast
  also have "... = (four_block_diag (U1*B1) (U2*B2)) *
    (four_block_diag (Complex_Matrix.adjoint U1)
    (Complex_Matrix.adjoint U2))"  
  proof (rule mult_four_block_diag[symmetric])
    show "U1 * B1  carrier_mat (dim_row A1) (dim_row A1)"
      by (metis assms(1) assms(3) carrier_mat_triv mult_carrier_mat 
          unitarily_equiv_carrier(1) unitarily_equiv_carrier(2))
    show "U2 * B2  carrier_mat (dim_row A2) (dim_row A2)"
      by (metis assms(2) assms(4) carrier_mat_triv mult_carrier_mat 
          unitarily_equiv_carrier(1) unitarily_equiv_carrier(2))
    show "Complex_Matrix.adjoint U1  carrier_mat (dim_row A1) (dim_row A1)"
      by (metis Complex_Matrix.unitary_def adjoint_dim assms(1) 
          index_mult_mat(2) unitarily_equivD(1) unitarily_equiv_eq)
    show "Complex_Matrix.adjoint U2  carrier_mat (dim_row A2) (dim_row A2)"
      by (meson assms(2) carrier_mat_triv similar_mat_witD2(7) 
          unitarily_equiv_def)
  qed
  also have "... = four_block_diag U1 U2 * four_block_diag B1 B2 * 
    Complex_Matrix.adjoint (four_block_diag U1 U2)"
  proof -
    have "four_block_diag (U1*B1) (U2*B2) = 
      four_block_diag U1 U2 * four_block_diag B1 B2" 
    proof (rule mult_four_block_diag[symmetric])
      show "U1  carrier_mat (dim_row A1) (dim_row A1)"
        by (metis assms(1) assms(3) carrier_mat_triv 
            unitarily_equiv_carrier(2))
      show "B1  carrier_mat (dim_row A1) (dim_row A1)"
        by (metis assms(1) assms(3) carrier_mat_triv 
            unitarily_equiv_carrier(1))
      show "U2  carrier_mat (dim_row A2) (dim_row A2)"
        by (metis assms(2) assms(4) carrier_mat_triv 
            unitarily_equiv_carrier(2))
      show "B2  carrier_mat (dim_row A2) (dim_row A2)"
        by (metis assms(2) assms(4) carrier_mat_triv 
            unitarily_equiv_carrier(1))
    qed
    moreover have "four_block_diag (Complex_Matrix.adjoint U1)
     (Complex_Matrix.adjoint U2) = 
      Complex_Matrix.adjoint (four_block_diag U1 U2)" 
      by (rule four_block_diag_adjoint[symmetric])
    ultimately show ?thesis by simp
  qed
  finally show "four_block_diag A1 A2 = 
    four_block_diag U1 U2 * four_block_diag B1 B2 * 
    Complex_Matrix.adjoint (four_block_diag U1 U2)" .
qed

lemma four_block_unitarily_equiv:
  assumes "unitarily_equiv A1 B1 U1"
  and "unitarily_equiv A2 B2 U2"
  and "dim_row A1 = dim_col A1"
  and "dim_row A2 = dim_col A2"
shows "unitarily_equiv 
  (four_block_diag A1 A2)
  (four_block_diag B1 B2)
  (four_block_diag U1 U2)"
(is "unitarily_equiv ?fA ?fB ?fU")
  unfolding unitarily_equiv_def
proof 
  show "unitary ?fU" using four_block_diag_unitary assms unitarily_equivD(1) 
    by blast  
  show "similar_mat_wit ?fA ?fB ?fU (Complex_Matrix.adjoint ?fU)" 
    using assms four_block_diag_similar[of A1] by simp
qed

lemma four_block_unitary_diag:
  assumes "unitary_diag A1 B1 U1"
  and "unitary_diag A2 B2 U2"
  and "dim_row A1 = dim_col A1"
  and "dim_row A2 = dim_col A2"
shows "unitary_diag 
  (four_block_diag A1 A2)
  (four_block_diag B1 B2)
  (four_block_diag U1 U2)"
(is "unitary_diag ?fA ?fB ?fU")
  unfolding unitary_diag_def
proof
  show "unitarily_equiv ?fA ?fB ?fU" 
    using four_block_unitarily_equiv[of A1] assms by simp
  have "dim_row B1 = dim_col B1" unfolding unitary_diag_def
    by (metis assms(1) assms(3) carrier_matD(1) carrier_matD(2) 
          carrier_mat_triv unitary_diag_carrier(1))
  moreover have "dim_row B2 = dim_col B2"  unfolding unitary_diag_def
    by (metis assms(2) assms(4) carrier_matD(1) carrier_matD(2) 
        carrier_mat_triv unitary_diag_carrier(1))
  ultimately show "diagonal_mat ?fB" using four_block_diagonal assms 
    unfolding unitary_diag_def by blast
qed

lemma four_block_real_diag_decomp:
  assumes "real_diag_decomp A1 B1 U1"
  and "real_diag_decomp A2 B2 U2"
  and "dim_row A1 = dim_col A1"
  and "dim_row A2 = dim_col A2"
shows "real_diag_decomp 
  (four_block_diag A1 A2)
  (four_block_diag B1 B2)
  (four_block_diag U1 U2)"
(is "real_diag_decomp ?fA ?fB ?fU")
  unfolding real_diag_decomp_def
proof (intro conjI allI impI)
  show "unitary_diag ?fA ?fB ?fU" using four_block_unitary_diag assms 
    unfolding real_diag_decomp_def by blast
  fix i
  assume "i < dim_row ?fB" 
  show "?fB $$ (i,i)  Reals" 
  proof (rule four_block_mat_real_diag)
    show "i < dim_row ?fB" using i < dim_row ?fB .
    show "i<dim_row B1. B1 $$ (i, i)  " using assms 
      unfolding real_diag_decomp_def by simp
    show "i<dim_row B2. B2 $$ (i, i)  " using assms 
      unfolding real_diag_decomp_def by simp
    show "dim_row B1 = dim_col B1"  unfolding unitary_diag_def
      by (metis assms(1) assms(3) carrier_matD(1) carrier_matD(2) 
          carrier_mat_triv real_diag_decompD(1) unitary_diag_carrier(1))
    show "dim_row B2 = dim_col B2"  unfolding unitary_diag_def
      by (metis assms(2) assms(4) carrier_matD(1) carrier_matD(2) 
          carrier_mat_triv real_diag_decompD(1) unitary_diag_carrier(1))
  qed
qed

lemma diag_block_mat_mult:
  assumes "length Al = length Bl"
  and "i < length Al. dim_col (Al!i) = dim_row (Bl!i)"
shows "diag_block_mat Al * (diag_block_mat Bl) = 
  (diag_block_mat (map2 (*) Al Bl))" using assms
proof (induct Al arbitrary: Bl)
  case Nil
  then show ?case by simp
next
  case (Cons a Al)
  define A where "A = diag_block_mat Al"
  define B where "B = diag_block_mat (tl Bl)"
  have "0 < length Bl" using Cons by auto
  hence "Bl = hd Bl # (tl Bl)" by simp
  have "length (tl Bl) = length Al" using Cons by simp
  have dim: "i<length Al. dim_col (Al ! i) = dim_row (tl Bl ! i)"
  proof (intro allI impI)
    fix i
    assume "i < length Al"
    hence "dim_col (Al ! i) = dim_col ((a#Al)!(Suc i))" by simp
    also have "... = dim_row (Bl!(Suc i))" using Cons
      by (metis Suc_lessI i < length Al length_Cons less_Suc_eq)
    also have "... = dim_row (tl Bl!i)"
      by (metis Bl = hd Bl # tl Bl nth_Cons_Suc) 
    finally show "dim_col (Al ! i) = dim_row (tl Bl!i)" .
  qed
  define C where "C = map2 (*) (a # Al) Bl"
  have "hd C = a * hd Bl" using Bl = hd Bl # tl Bl unfolding C_def
    by (metis list.map(2) list.sel(1) prod.simps(2) zip_Cons_Cons)
  have "tl C = map2 (*) Al (tl Bl)"
    by (metis (no_types, lifting) C_def Bl = hd Bl # tl Bl list.sel(3) 
        map_tl zip_Cons_Cons)
  have "C = hd C # (tl C)" unfolding C_def
    by (metis Nil_eq_zip_iff Nil_is_map_conv Bl = hd Bl # tl Bl 
        list.exhaust_sel list.simps(3))
  have "dim_row B = sum_list (map dim_row (tl Bl))" unfolding B_def
    by (simp add: dim_diag_block_mat(1))
  also have "... = sum_list (map dim_col Al)" 
  proof (rule sum_list_cong)
    show "length (map dim_row (tl Bl)) = length (map dim_col Al)"  
      using  length (tl Bl) = length Al by simp
    show "i<length (map dim_row (tl Bl)). 
      map dim_row (tl Bl) ! i = map dim_col Al ! i"
      by (metis length (tl Bl) = length Al dim length_map nth_map)
  qed
  also have "... = dim_col A" unfolding A_def
    by (simp add: dim_diag_block_mat(2))
  finally have ba: "dim_row B = dim_col A" .  
  have "diag_block_mat (a#Al) * (diag_block_mat Bl) = 
    four_block_diag a A * (four_block_diag (hd Bl) B)" 
    using diag_block_mat.simps(2) Bl = hd Bl # (tl Bl) 
    unfolding Let_def A_def B_def by metis
  also have "... = four_block_diag (a * hd Bl) (A * B)"     
  proof (rule mult_four_block_diag)
    show "a carrier_mat (dim_row a) (dim_col a)" by simp 
    show "hd Bl  carrier_mat (dim_col a) (dim_col (hd Bl))"
      using Cons
      by (metis 0 < length Bl Bl = hd Bl # tl Bl carrier_mat_triv nth_Cons_0)
    show "A  carrier_mat (dim_row A) (dim_col A)" by simp
    show " B  carrier_mat (dim_col A) (dim_col B)" using ba by auto
  qed
  also have "... = four_block_diag (hd C) (diag_block_mat (tl C))" 
    unfolding A_def B_def 
    using C_def hd C = a * hd Bl length (tl Bl) = length Al 
      tl C = map2 (*) Al (tl Bl) dim local.Cons(1) 
    by presburger
  also have "... = diag_block_mat C" 
    using C = hd C#(tl C) diag_block_mat.simps(2) unfolding Let_def by metis
  finally show ?case unfolding C_def .
qed

lemma real_diag_decomp_block:
  fixes Al::"complex Matrix.mat list"
  assumes "Al  []"
  and "list_all (λA. 0 < dim_row A  hermitian A)  Al"
shows " Bl Ul. length Ul = length Al 
  (i < length Al. 
    Ul!i  carrier_mat (dim_row (Al!i)) (dim_col (Al!i))  unitary (Ul!i) 
    Bl!i  carrier_mat (dim_row (Al!i)) (dim_col (Al!i))) 
  real_diag_decomp (diag_block_mat Al) (diag_block_mat Bl) (diag_block_mat Ul)"
  using assms
proof (induct Al)
  case Nil
  then show ?case by simp
next
  case (Cons A Al)
  hence "hermitian A" "0 < dim_row A" by auto
  hence "A  carrier_mat (dim_row A) (dim_row A)"
    by (simp add: hermitian_square)
  from this obtain B U where r: "real_diag_decomp A B U" 
    using hermitian_real_diag_decomp hermitian A 0 < dim_row A by blast
  have bcar: "B  carrier_mat (dim_row A) (dim_col A)" 
      using real_diag_decompD(1)
      by (metis A  carrier_mat (dim_row A) (dim_row A) carrier_matD(2) r 
          unitary_diag_carrier(1))
    have ucar: "U  carrier_mat (dim_row A) (dim_col A)" 
      using real_diag_decompD(1)
      by (metis A  carrier_mat (dim_row A) (dim_row A) carrier_matD(2) r 
          unitary_diag_carrier(2))
    have unit: "unitary U"
      by (meson r real_diag_decompD(1) unitary_diagD(3))
  show ?case
  proof (cases "Al = []")
    case True
    hence "diag_block_mat (Cons A Al) = A" by auto
    moreover have "diag_block_mat [B] = B" by auto
    moreover have "diag_block_mat [U] = U" by auto
    moreover have "unitary U"
      using r real_diag_decompD(1) unitary_diagD(3) by blast
    ultimately have 
      "real_diag_decomp (diag_block_mat (Cons A Al)) 
        (diag_block_mat [B]) (diag_block_mat [U])"
      using real_diag_decomp A B U by auto
    moreover have "(i<length (A # Al).
      [U]!i  carrier_mat (dim_row ((A # Al) ! i)) (dim_col ((A # Al) ! i)) 
      Complex_Matrix.unitary ([U] ! i)  [B] ! i  
      carrier_mat (dim_row ((A # Al) ! i)) (dim_col ((A # Al) ! i)))" using True
      by (simp add: bcar ucar unit)
    ultimately show ?thesis 
      using True Complex_Matrix.unitary U bcar  less_one ucar
      by (metis length_list_update list_update_code(2))
  next
    case False
    have "list_all (λA. 0 < dim_row A  hermitian A) Al" using Cons by auto
    hence "Bl Ul. length Ul = length Al 
       (i<length Al.
           Ul ! i  carrier_mat (dim_row (Al ! i)) (dim_col (Al ! i))  
           unitary (Ul!i)  
            Bl ! i  carrier_mat (dim_row (Al ! i)) (dim_col (Al ! i))) 
       real_diag_decomp (diag_block_mat Al) (diag_block_mat Bl) (diag_block_mat Ul)"
      using Cons False by simp 
    from this obtain Bl Ul where "length Ul  =length Al" and  
      rl: "real_diag_decomp (diag_block_mat Al) 
      (diag_block_mat Bl) (diag_block_mat Ul)"
      and "i<length Al.
           Ul ! i  carrier_mat (dim_row (Al ! i)) (dim_col (Al ! i))  
            unitary (Ul!i)  
            Bl ! i  carrier_mat (dim_row (Al ! i)) (dim_col (Al ! i))"
      by auto note bu = this
    have "real_diag_decomp (diag_block_mat (A # Al)) 
      (diag_block_mat (B # Bl)) (diag_block_mat (U # Ul))" 
      using four_block_real_diag_decomp[OF r rl] 
      by (metis A  carrier_mat (dim_row A) (dim_row A) 
          carrier_matD(2) diag_block_mat.simps(2) hermitian_square 
          real_diag_decomp_hermitian rl)
    moreover have "length (U#Ul) = length (A#Al)" using bu by simp
    moreover have "i<length (A # Al).
           (U#Ul) ! i  carrier_mat (dim_row ((A # Al) ! i)) (dim_col ((A # Al) ! i)) 
           unitary ((U#Ul)!i) 
           (B#Bl) ! i  carrier_mat (dim_row ((A # Al) ! i)) (dim_col ((A # Al) ! i))" 
    proof (intro allI impI)
      fix i
      assume "i < length (A#Al)"
      show "(U # Ul) ! i  carrier_mat (dim_row ((A # Al) ! i)) 
        (dim_col ((A # Al) ! i))  unitary ((U#Ul)!i) 
        (B # Bl) ! i  carrier_mat (dim_row ((A # Al) ! i)) 
        (dim_col ((A # Al) ! i))"
      proof (cases "i = 0")
        case True
        then show ?thesis by (simp add: bcar ucar unit) 
      next
        case False
        hence "j. i = Suc j" by (simp add: not0_implies_Suc)
        from this obtain j where j: "i = Suc j" by auto
        hence "j < length Al" using i < length (A#Al) by simp
        have "(A#Al)!i = Al!j" "(U # Ul) ! i = Ul!j" "(B#Bl) ! i = Bl!j" 
          using j by auto
        then show ?thesis using Cons j < length Al bu(3) by presburger
      qed
    qed
    ultimately show ?thesis by blast
  qed
qed

lemma diag_block_mat_adjoint:
  shows "Complex_Matrix.adjoint (diag_block_mat Al) =
    diag_block_mat (map Complex_Matrix.adjoint Al)"
proof (induct Al)
  case Nil
  then show ?case using zero_adjoint by simp
next
  case (Cons a Al)
  have "Complex_Matrix.adjoint (diag_block_mat (a # Al)) =
    Complex_Matrix.adjoint (four_block_diag a (diag_block_mat Al))" 
    using diag_block_mat.simps(2)[of a] unfolding Let_def by simp
  also have "... = four_block_diag (Complex_Matrix.adjoint a)
    (Complex_Matrix.adjoint (diag_block_mat Al))" 
    using four_block_diag_adjoint[of a] by simp
  also have "... = four_block_diag (Complex_Matrix.adjoint a)
    (diag_block_mat (map Complex_Matrix.adjoint Al))" using Cons by simp
  also have "... = diag_block_mat (map Complex_Matrix.adjoint (a#Al))" 
    using diag_block_mat.simps(2) unfolding Let_def 
    by (metis (no_types) diag_block_mat.simps(2) list.map(2))
  finally show ?case .
qed

lemma diag_block_mat_mat_conj:
  assumes "length Al = length Bl"
  and "i < length Al. dim_col (Al!i) = dim_row (Bl!i)"
  and "i < length Al. dim_row (Bl!i) = dim_col (Bl!i)"
  shows "mat_conj (diag_block_mat Al) (diag_block_mat Bl) =
    diag_block_mat (map2 mat_conj Al Bl)"
proof -
  have "mat_conj (diag_block_mat Al) (diag_block_mat Bl) =
    diag_block_mat Al * diag_block_mat Bl * 
    diag_block_mat (map Complex_Matrix.adjoint Al)" 
    using diag_block_mat_adjoint[of Al] unfolding mat_conj_def by simp
  also have "... = diag_block_mat (map2 (*) Al Bl) * 
    diag_block_mat (map Complex_Matrix.adjoint Al)" 
    using diag_block_mat_mult[OF assms(1) assms(2)] by simp
  also have "... = diag_block_mat (map2 (*) (map2 (*) Al Bl)
    (map Complex_Matrix.adjoint Al))"
  proof (rule diag_block_mat_mult)
    show "length (map2 (*) Al Bl) = length (map Complex_Matrix.adjoint Al)"
      by (simp add: assms(1))
    show "i<length (map2 (*) Al Bl). dim_col (map2 (*) Al Bl ! i) = 
      dim_row (map Complex_Matrix.adjoint Al ! i)"
      by (simp add: assms(2) assms(3))
  qed
  also have "... = diag_block_mat (map2 mat_conj Al Bl)" 
    using map2_mat_conj_exp[OF assms(1)] by simp
  finally show ?thesis .
qed

lemma diag_block_mat_commute:
  assumes "length Al = length Bl"
  and "i < length Al. Al!i * (Bl!i) = Bl!i * (Al!i)"
  and "i<length Al. dim_col (Al ! i) = dim_row (Bl ! i)"
  and "i<length Al. dim_col (Bl ! i) = dim_row (Al ! i)"
shows "diag_block_mat Al * (diag_block_mat Bl) = 
  diag_block_mat Bl * (diag_block_mat Al)"
proof -
  have "diag_block_mat Al * diag_block_mat Bl =
    diag_block_mat (map2 (*) Al Bl)" 
    using diag_block_mat_mult[of Al Bl] assms by simp
  also have "... = diag_block_mat (map2 (*) Bl Al)" 
  proof -
    have "map2 (*) Al Bl = map2 (*) Bl Al"     
      by (rule map2_commute, auto simp add: assms)
    thus ?thesis by simp
  qed
  also have "... = diag_block_mat Bl * (diag_block_mat Al)"
    using diag_block_mat_mult[of Bl Al] assms by simp
  finally show ?thesis .
qed

lemma diag_block_mat_length_1:
  assumes "length Al = 1"
  shows "diag_block_mat Al = Al!0" 
proof -
  have "Al = [Al!0]" using assms
    by (metis One_nat_def length_0_conv length_Suc_conv nth_Cons_0)
  thus ?thesis
    by (metis diag_block_mat_singleton) 
qed

lemma diag_block_mat_cong_hd:
  assumes "0 < length Al"
  and "length Al = length Bl"
  and "dim_row (hd Al) = dim_row (hd Bl)"
  and "dim_col (hd Al) = dim_col (hd Bl)"
  and "diag_block_mat Al = diag_block_mat Bl"
shows "hd Al = hd Bl" 
proof -
  have "Al  []" using assms by blast
  hence "Al = hd Al#(tl Al)" by simp
  hence da:"diag_block_mat Al = 
    four_block_diag (hd Al) (diag_block_mat (tl Al))"
    using diag_block_mat.simps(2)[of "hd Al" "tl Al"] unfolding Let_def by simp
  have  "Bl  []" using assms by simp
  hence "Bl = hd Bl#(tl Bl)" by simp
  hence "diag_block_mat Bl = four_block_diag (hd Bl) (diag_block_mat (tl Bl))"
    using diag_block_mat.simps(2)[of "hd Bl" "tl Bl"] unfolding Let_def by simp
  hence "four_block_diag (hd Al) (diag_block_mat (tl Al)) = 
    four_block_diag (hd Bl) (diag_block_mat (tl Bl))" using da assms by simp
  thus ?thesis using four_block_diag_cong_comp assms by metis
qed

lemma diag_block_mat_cong_tl:
  assumes "0 < length Al"
  and "length Al = length Bl"
  and "dim_row (hd Al) = dim_row (hd Bl)"
  and "dim_col (hd Al) = dim_col (hd Bl)"
  and "diag_block_mat Al = diag_block_mat Bl"
shows "diag_block_mat (tl Al) = diag_block_mat (tl Bl)" 
proof -
  have "Al  []" using assms by blast
  hence "Al = hd Al#(tl Al)" by simp
  hence da:"diag_block_mat Al = 
    four_block_diag (hd Al) (diag_block_mat (tl Al))"
    using diag_block_mat.simps(2)[of "hd Al" "tl Al"] unfolding Let_def by simp
  have  "Bl  []" using assms by simp
  hence "Bl = hd Bl#(tl Bl)" by simp
  hence "diag_block_mat Bl = four_block_diag (hd Bl) (diag_block_mat (tl Bl))"
    using diag_block_mat.simps(2)[of "hd Bl" "tl Bl"] unfolding Let_def by simp
  hence "four_block_diag (hd Al) (diag_block_mat (tl Al)) = 
    four_block_diag (hd Bl) (diag_block_mat (tl Bl))" using da assms by simp
  thus ?thesis using four_block_diag_cong_comp' assms by metis
qed

lemma diag_block_mat_cong_comp:
  assumes "length Al = length Bl"
  and "i<length Al. dim_row (Al ! i) = dim_row (Bl ! i)"
  and "i<length Al. dim_col (Al ! i) = dim_col (Bl ! i)"
  and "diag_block_mat Al = diag_block_mat Bl"
and "j < length Al"
shows "Al!j = Bl!j" using assms
proof (induct Al arbitrary: Bl j)
  case Nil
  then show ?case by simp
next
  case (Cons a Al)
  hence "0 <length Bl" by linarith
  hence "Bl = hd Bl#(tl Bl)" by simp
  then show ?case 
  proof (cases "j = 0")
    case True
    hence "(a#Al)!j = hd(a#Al)" by simp
    have "Bl!j= hd Bl" using j = 0
      by (metis Bl = hd Bl # tl Bl nth_Cons_0)
    have da: "diag_block_mat (a#Al) = four_block_diag a (diag_block_mat Al)"
      using diag_block_mat.simps(2)[of a Al] unfolding Let_def by simp
    have db: "diag_block_mat (hd Bl#(tl Bl)) = 
      four_block_diag (hd Bl) (diag_block_mat (tl Bl))"
      using diag_block_mat.simps(2)[of "hd Bl" "tl Bl"] 
      unfolding Let_def by simp
    have "hd (a#Al) = hd Bl"
    proof (rule diag_block_mat_cong_hd)
      show "0 < length (a # Al)" by simp
      show "length (a # Al) = length Bl" using Cons by simp
      show "diag_block_mat (a # Al) = diag_block_mat Bl" using Cons by simp
      show "dim_row (hd (a # Al)) = dim_row (hd Bl)"
        by (metis True 0 < length Bl Bl ! j = hd Bl list.sel(1) Cons(2) 
            Cons(3) nth_Cons_0)
      show "dim_col (hd (a # Al)) = dim_col (hd Bl)"
        by (metis True 0 < length Bl Bl ! j = hd Bl list.sel(1) Cons(2) 
            Cons(4) nth_Cons_0)
    qed
    thus "(a # Al) ! j = Bl ! j" using j = 0 Bl ! j = hd Bl by fastforce
  next
    case False
    hence "k. j = Suc k" by (simp add: not0_implies_Suc) 
    from this obtain k where "j = Suc k" by auto
    hence "(a#Al)!j = Al!k" by simp
    have "Bl!j = (tl Bl)!k" using j = Suc k Bl = hd Bl#(tl Bl)
      by (metis nth_Cons_Suc)
    have "Al!k = (tl Bl)!k"
    proof (rule Cons(1))
      show "length Al = length (tl Bl)" using Cons
        by (metis diff_Suc_1 length_Cons length_tl)
      show "k < length Al"
        by (metis Cons.prems(5) Suc_less_SucD j = Suc k length_Cons) 
      show "i<length Al. dim_row (Al ! i) = dim_row (tl Bl ! i)"
        by (metis Suc_less_eq length Al = length (tl Bl) length_Cons 
            local.Cons(3) nth_Cons_Suc nth_tl)
      show "i<length Al. dim_col (Al ! i) = dim_col (tl Bl ! i)"
        by (metis Suc_mono Bl = hd Bl # tl Bl length_Cons local.Cons(4) 
            nth_Cons_Suc)
      have "diag_block_mat (tl (a#Al)) = diag_block_mat (tl Bl)"
      proof (rule diag_block_mat_cong_tl)
        show "length (a # Al) = length Bl" using Cons by simp
        show "dim_row (hd (a # Al)) = dim_row (hd Bl)"
          by (metis Bl = hd Bl # tl Bl length_Cons list.sel(1) local.Cons(3) 
              nth_Cons_0 zero_less_Suc)
        show "dim_col (hd (a # Al)) = dim_col (hd Bl)"
          by (metis 0 < length Bl Bl = hd Bl # tl Bl list.sel(1) 
              local.Cons(2) local.Cons(4) nth_Cons_0)
        show "diag_block_mat (a # Al) = diag_block_mat Bl" using Cons by simp
        show "0 < length (a#Al)" by simp
      qed
      thus "diag_block_mat Al = diag_block_mat (tl Bl)" by simp
    qed
    then show ?thesis
      by (simp add: (a # Al) ! j = Al ! k Bl ! j = tl Bl ! k) 
  qed
qed

lemma diag_block_mat_commute_comp:
  assumes "length Al = length Bl"
  and "i<length Al. dim_row (Al ! i) = dim_col (Al ! i)"
  and "i<length Al. dim_row (Al ! i) = dim_row (Bl ! i)"
  and "i<length Al. dim_col (Al ! i) = dim_col (Bl ! i)"
  and "diag_block_mat Al * (diag_block_mat Bl) = 
    diag_block_mat Bl * (diag_block_mat Al)"
  and "i < length Al"
shows "Al!i * Bl!i = Bl!i * Al!i" 
proof -
  have "diag_block_mat (map2 (*) Al Bl)=diag_block_mat Al * diag_block_mat Bl"
    using diag_block_mat_mult[of Al] assms by simp
  also have "... = diag_block_mat Bl * diag_block_mat Al" using assms by simp
  also have "... = diag_block_mat (map2 (*) Bl Al)" 
    using diag_block_mat_mult[of Bl] assms by simp
  finally have eq: "diag_block_mat (map2 (*) Al Bl) = 
    diag_block_mat (map2 (*) Bl Al)" .
  have "(map2 (*) Al Bl)!i = (map2 (*) Bl Al)!i" 
  proof (rule diag_block_mat_cong_comp) 
    show "length (map2 (*) Al Bl) = length (map2 (*) Bl Al)" 
      using map2_length assms by metis
    show "i < length (map2 (*) Al Bl)" using map2_length assms by metis 
    show "diag_block_mat (map2 (*) Al Bl) = diag_block_mat (map2 (*) Bl Al)"
      using eq .
    show "i<length (map2 (*) Al Bl). dim_row (map2 (*) Al Bl ! i) = 
      dim_row (map2 (*) Bl Al ! i)"
      by (simp add: assms(3))
    show "i<length (map2 (*) Al Bl). dim_col (map2 (*) Al Bl ! i) = 
      dim_col (map2 (*) Bl Al ! i)"
      by (simp add: assms(4))
  qed
  moreover have "(map2 (*) Al Bl)!i = Al!i * Bl!i" using assms by simp
  moreover have "(map2 (*) Bl Al)!i = Bl!i * Al!i" using assms by simp
  ultimately show ?thesis by simp
qed

lemma diag_block_mat_dim_row_cong:
  assumes "length Ul = length Bl"
  and "i < length Bl. dim_row (Bl!i) = dim_row (Ul!i)"
  shows "dim_row (diag_block_mat Ul) = dim_row (diag_block_mat Bl)"
proof -
  have "dim_row (diag_block_mat Ul) = sum_list (map dim_row Ul)" 
    by (simp add: dim_diag_block_mat(1))
  also have "... = sum_list (map dim_row Bl)" using assms 
    by (metis nth_map_conv)
  also have "... = dim_row (diag_block_mat Bl)"
    by (simp add: dim_diag_block_mat(1))
  finally show ?thesis .
qed

lemma diag_block_mat_dim_col_cong:
  assumes "length Ul = length Bl"
  and "i < length Bl. dim_col (Bl!i) = dim_col (Ul!i)"
  shows "dim_col (diag_block_mat Ul) = dim_col (diag_block_mat Bl)"
proof -
  have "dim_col (diag_block_mat Ul) = sum_list (map dim_col Ul)" 
    by (simp add: dim_diag_block_mat(2))
  also have "... = sum_list (map dim_col Bl)" using assms 
    by (metis nth_map_conv)
  also have "... = dim_col (diag_block_mat Bl)"
    by (simp add: dim_diag_block_mat(2))
  finally show ?thesis .
qed

lemma diag_block_mat_dim_row_col_eq:
  assumes "i < length Al. dim_row (Al!i) = dim_col (Al!i)"
  shows "dim_row (diag_block_mat Al) = dim_col (diag_block_mat Al)"
proof -
  have "dim_row (diag_block_mat Al) = sum_list (map dim_row Al)"
    by (simp add:dim_diag_block_mat(1))
  also have "... = sum_list (map dim_col Al)" using assms
    by (metis nth_map_conv)
  also have "... = dim_col (diag_block_mat Al)"
    by (simp add:dim_diag_block_mat(2))
  finally show ?thesis .
qed

section ‹Block matrix decomposition›

subsection ‹Subdiagonal extraction›
text ‹\verb+extract_subdiags+ returns a list of diagonal sub-blocks, the sizes of which are
specified by the list of integers provided as parameters.›

fun extract_subdiags where
  "extract_subdiags B [] = []"
| "extract_subdiags B (x#xs) = 
    (let (B1, B2, B3, B4) = (split_block B x x) in 
      B1 # (extract_subdiags B4 xs))"

lemma extract_subdiags_not_emp:
  fixes x::nat and l::"nat list"
  assumes "(B1, B2, B3, B4) = (split_block B x x)"
  shows "hd (extract_subdiags B (x#l)) = B1" 
    "tl (extract_subdiags B (x#l)) = extract_subdiags B4 l" 
proof -
  show "hd (extract_subdiags B (x#l)) = B1" unfolding  Let_def 
    by (metis (no_types) assms extract_subdiags.simps(2) list.sel(1) split_conv) 
  show "tl (extract_subdiags B (x # l)) = extract_subdiags B4 l" 
    using assms extract_subdiags.simps(2) unfolding Let_def
    by (metis (no_types, lifting) list.sel(3) split_conv)
qed

lemma extract_subdiags_neq_Nil:
  shows "extract_subdiags B (a#l)  []" 
  using extract_subdiags.simps(2)[of B] 
  unfolding Let_def split_block_def by simp

lemma extract_subdiags_length:
  shows "length (extract_subdiags B l) = length l"
proof (induct l arbitrary: B)
  case Nil
  then show ?case by simp
next
  case (Cons a l)
  define B1 where "B1 = fst (split_block B a a)"
  define B2 where "B2 = fst (snd (split_block B a a))"
  define B3 where "B3 = fst (snd (snd (split_block B a a)))"
  define B4 where "B4 = snd (snd (snd (split_block B a a)))"
  have sp: "split_block B a a = (B1, B2, B3, B4)" using fst_conv snd_conv 
    unfolding B1_def B2_def B3_def B4_def by simp
  then show ?case using Cons extract_subdiags.simps(2)[of B a l] 
    unfolding Let_def by simp 
qed

lemma extract_subdiags_carrier:
  assumes "i < length l"
  shows "(extract_subdiags B l)!i  carrier_mat (l!i) (l!i)" using assms
proof (induct i arbitrary: l B)  
  case 0
  define B1 where "B1 = fst (split_block B (hd l) (hd l))"
  define B2 where "B2 = fst (snd (split_block B (hd l) (hd l)))"
  define B3 where "B3 = fst (snd (snd (split_block B (hd l) (hd l))))"
  define B4 where "B4 = snd (snd (snd (split_block B (hd l) (hd l))))"
  have sp: "split_block B (hd l) (hd l) = (B1, B2, B3, B4)" using fst_conv snd_conv 
    unfolding B1_def B2_def B3_def B4_def by simp
  have "l = hd l # (tl l)" using 0 by auto
  have "(extract_subdiags B l)!0 = B1" 
    using extract_subdiags.simps(2)[of B "hd l" "tl l"] l = hd l # tl l sp
    unfolding Let_def by auto
  also have "...  carrier_mat (hd l) (hd l)" 
    unfolding B1_def split_block_def Let_def by simp
  finally show ?case
    by (metis l = hd l # tl l hd_conv_nth list.sel(2) not_Cons_self) 
next
  case (Suc i)  
  define B1 where "B1 = fst (split_block B (hd l) (hd l))"
  define B2 where "B2 = fst (snd (split_block B (hd l) (hd l)))"
  define B3 where "B3 = fst (snd (snd (split_block B (hd l) (hd l))))"
  define B4 where "B4 = snd (snd (snd (split_block B (hd l) (hd l))))"
  have sp: "split_block B (hd l) (hd l) = (B1, B2, B3, B4)" using fst_conv snd_conv 
    unfolding B1_def B2_def B3_def B4_def by simp
  have "l = hd l # (tl l)" using Suc
    by (metis Cons_nth_drop_Suc drop_Nil list.exhaust_sel not_Cons_self)
  hence "l! Suc i = (tl l)!i" by (metis nth_Cons_Suc)
  have "tl (extract_subdiags B l) = extract_subdiags B4 (tl l)" 
    using extract_subdiags_not_emp(2)[OF sp[symmetric]] l = hd l # (tl l) 
    by metis
  hence "extract_subdiags B l = B1 # extract_subdiags B4 (tl l)" 
    using extract_subdiags_not_emp(1)[OF sp[symmetric]]
    by (metis l = hd l # tl l extract_subdiags_neq_Nil list.exhaust_sel)
  hence "extract_subdiags B l ! Suc i = (extract_subdiags B4 (tl l))!i" 
    using nth_Cons_Suc by simp
  also have "...  carrier_mat (tl l!i) (tl l!i)" using Suc
    by (metis l = hd l # tl l length_Cons not_less_eq)
  also have "... = carrier_mat (l!Suc i) (l! Suc i)" 
    using nth_Cons_Suc[of "hd l" "tl l" i] l = hd l # tl l by simp
  finally show ?case .
qed

lemma extract_subdiags_diagonal:
  assumes "diagonal_mat B"
  and "B  carrier_mat n n"
  and "l  []"
  and "sum_list l  n"
  and "i < length l"
shows "diagonal_mat ((extract_subdiags B l)!i)" using assms
proof (induct i arbitrary: l B n)
  case 0
  define a where "a = hd l"
  have "l = a#(tl l)" unfolding a_def using 0 by simp
  have "a  n" using 0 unfolding a_def
    by (metis a_def dual_order.strict_trans2 elem_le_sum_list 
        hd_conv_nth less_le_not_le nat_le_linear)
  define B1 where "B1 = fst (split_block B a a)"
  define B2 where "B2 = fst (snd (split_block B a a))"
  define B3 where "B3 = fst (snd (snd (split_block B a a)))"
  define B4 where "B4 = snd (snd (snd (split_block B a a)))"
  have sp: "split_block B a a = (B1, B2, B3, B4)" using fst_conv snd_conv 
    unfolding B1_def B2_def B3_def B4_def by simp
  hence "extract_subdiags B l!0 = B1" unfolding a_def 
    using hd_conv_nth 0 
    by (metis l = a # tl l sp extract_subdiags_neq_Nil 
        extract_subdiags_not_emp(1))
  moreover have "diagonal_mat B1" using sp split_block_diagonal assms a  n 0
    by blast
  ultimately show ?case by simp
next
  case (Suc i)
  show ?case
  proof (cases "length l = 1")
    case True
    hence "Suc i = 0" using Suc by presburger
    then show ?thesis by simp
  next
    case False
    define a where "a = hd l"
    have "l = a#(tl l)" unfolding a_def using Suc by simp
    have "a  n" using Suc unfolding a_def
      by (metis dual_order.trans elem_le_sum_list hd_conv_nth 
          length_greater_0_conv)
    define B1 where "B1 = fst (split_block B a a)"
    define B2 where "B2 = fst (snd (split_block B a a))"
    define B3 where "B3 = fst (snd (snd (split_block B a a)))"
    define B4 where "B4 = snd (snd (snd (split_block B a a)))"
    have sp: "split_block B a a = (B1, B2, B3, B4)" using fst_conv snd_conv 
      unfolding B1_def B2_def B3_def B4_def by simp
    have "extract_subdiags B l ! Suc i = 
      extract_subdiags B4 (tl l)! i"  using sp
      by (metis Suc(6) Suc_less_SucD l = a # tl l length_Cons nth_tl 
          extract_subdiags_length extract_subdiags_not_emp(2))
    moreover have "diagonal_mat (extract_subdiags B4 (tl l)! i)"
    proof (rule Suc(1))
      show "tl l  []" using False Suc
        by (metis l = a # tl l length_Cons list.size(3) numeral_nat(7)) 
      show "i < length (tl l)" using False Suc
        by (metis Suc_lessD l = a # tl l le_neq_implies_less length_Cons 
            less_Suc_eq_le)
      show "B4  carrier_mat (n-a) (n-a)" 
        using sp split_block_diag_carrier(2) Suc(3) a  n by blast 
      show "diagonal_mat B4" 
        using split_block_diagonal sp Suc a  n by blast
      show "sum_list (tl l)  n - a" using Suc(5) a  n sum_list_tl_leq
        by (simp add: Suc(4) a_def)
    qed
    ultimately show ?thesis by simp
  qed
qed

lemma extract_subdiags_diag_elem:
  fixes B::"complex Matrix.mat"
  assumes "B carrier_mat n n"
  and "0 < n"
  and "l  []"
  and "i < length l"
  and "j< l!i"
  and "sum_list l  n"
  and "j < length l. 0 < l!j"
  shows "extract_subdiags B l!i $$ (j,j) = 
    diag_mat B!(n_sum i l + j)" using assms
proof (induct i arbitrary: l B n)
  case 0
  define a where "a = hd l"
  have "l = a#(tl l)" unfolding a_def using 0 by simp
  have "a  n" using 0 unfolding a_def
    by (metis a_def dual_order.strict_trans2 elem_le_sum_list 
        hd_conv_nth less_le_not_le nat_le_linear)
  define B1 where "B1 = fst (split_block B a a)"
  define B2 where "B2 = fst (snd (split_block B a a))"
  define B3 where "B3 = fst (snd (snd (split_block B a a)))"
  define B4 where "B4 = snd (snd