Theory SQ_MTX

(*  Title:       Square Matrices
    Author:      Jonathan Julián Huerta y Munive, 2020
    Maintainer:  Jonathan Julián Huerta y Munive <jonjulian23@gmail.com>
*)

section ‹ Square Matrices ›

text ‹ The general solution for affine systems of ODEs involves the exponential function. 
Unfortunately, this operation is only available in Isabelle for the type class ``banach''. 
Hence, we define a type of square matrices and prove that it is an instance of this class. ›

theory SQ_MTX
  imports MTX_Norms

begin

subsection ‹ Definition ›

typedef 'm sq_mtx = "UNIV::(real^'m^'m) set"
  morphisms to_vec to_mtx by simp

declare to_mtx_inverse [simp]
    and to_vec_inverse [simp]

setup_lifting type_definition_sq_mtx

lift_definition sq_mtx_ith :: "'m sq_mtx  'm  (real^'m)" (infixl "$$" 90) is "($)" .

lift_definition sq_mtx_vec_mult :: "'m sq_mtx  (real^'m)  (real^'m)" (infixl "*V" 90) is "(*v)" .

lift_definition vec_sq_mtx_prod :: "(real^'m)  'm sq_mtx  (real^'m)" is "(v*)" .

lift_definition sq_mtx_diag :: "(('m::finite)  real)  ('m::finite) sq_mtx" (binder "𝖽𝗂𝖺𝗀 " 10) 
  is diag_mat .

lift_definition sq_mtx_transpose :: "('m::finite) sq_mtx  'm sq_mtx" ("_") is transpose .

lift_definition sq_mtx_inv :: "('m::finite) sq_mtx  'm sq_mtx" ("_-1" [90]) is matrix_inv .

lift_definition sq_mtx_row :: "'m  ('m::finite) sq_mtx  real^'m" ("𝗋𝗈𝗐") is row .

lift_definition sq_mtx_col :: "'m  ('m::finite) sq_mtx  real^'m" ("𝖼𝗈𝗅")  is column .

lemma to_vec_eq_ith: "(to_vec A) $ i = A $$ i"
  by transfer simp

lemma to_mtx_ith[simp]: 
  "(to_mtx A) $$ i1 = A $ i1"
  "(to_mtx A) $$ i1 $ i2 = A $ i1 $ i2"
  by (transfer, simp)+

lemma to_mtx_vec_lambda_ith[simp]: "to_mtx (χ i j. x i j) $$ i1 $ i2 = x i1 i2"
  by (simp add: sq_mtx_ith_def)

lemma sq_mtx_eq_iff:
  shows "A = B = (i j. A $$ i $ j = B $$ i $ j)"
    and "A = B = (i. A $$ i = B $$ i)"
  by (transfer, simp add: vec_eq_iff)+

lemma sq_mtx_diag_simps[simp]:
  "i = j  sq_mtx_diag f $$ i $ j = f i"
  "i  j  sq_mtx_diag f $$ i $ j = 0"
  "sq_mtx_diag f $$ i = axis i (f i)"
  unfolding sq_mtx_diag_def by (simp_all add: axis_def vec_eq_iff)

lemma sq_mtx_diag_vec_mult: "(𝖽𝗂𝖺𝗀 i. f i) *V s = (χ i. f i * s$i)"
  by (simp add: matrix_vector_mul_diag_mat sq_mtx_diag.abs_eq sq_mtx_vec_mult.abs_eq)

lemma sq_mtx_vec_mult_diag_axis: "(𝖽𝗂𝖺𝗀 i. f i) *V (axis i k) = axis i (f i * k)"
  unfolding sq_mtx_diag_vec_mult axis_def by auto

lemma sq_mtx_vec_mult_eq: "m *V x = (χ i. sum (λj. (m $$ i $ j) * (x $ j)) UNIV)"
  by (transfer, simp add: matrix_vector_mult_def)

lemma sq_mtx_transpose_transpose[simp]: "(A) = A"
  by (transfer, simp)

lemma transpose_mult_vec_canon_row[simp]: "(A) *V (𝖾 i) = 𝗋𝗈𝗐 i A"
  by transfer (simp add: row_def transpose_def axis_def matrix_vector_mult_def)

lemma row_ith[simp]: "𝗋𝗈𝗐 i A = A $$ i"
  by transfer (simp add: row_def)

lemma mtx_vec_mult_canon: "A *V (𝖾 i) = 𝖼𝗈𝗅 i A" 
  by (transfer, simp add: matrix_vector_mult_basis)


subsection ‹ Ring of square matrices ›

instantiation sq_mtx :: (finite) ring 
begin

lift_definition plus_sq_mtx :: "'a sq_mtx  'a sq_mtx  'a sq_mtx" is "(+)" .

lift_definition zero_sq_mtx :: "'a sq_mtx" is "0" .

lift_definition uminus_sq_mtx :: "'a sq_mtx  'a sq_mtx" is "uminus" .

lift_definition minus_sq_mtx :: "'a sq_mtx  'a sq_mtx  'a sq_mtx" is "(-)" .

lift_definition times_sq_mtx :: "'a sq_mtx  'a sq_mtx  'a sq_mtx" is "(**)" .

declare plus_sq_mtx.rep_eq [simp]
    and minus_sq_mtx.rep_eq [simp]

instance apply intro_classes
  by(transfer, simp add: algebra_simps matrix_mul_assoc matrix_add_rdistrib matrix_add_ldistrib)+

end

lemma sq_mtx_zero_ith[simp]: "0 $$ i = 0"
  by (transfer, simp)

lemma sq_mtx_zero_nth[simp]: "0 $$ i $ j = 0"
  by transfer simp

lemma sq_mtx_plus_eq: "A + B = to_mtx (χ i j. A$$i$j + B$$i$j)"
  by transfer (simp add: vec_eq_iff)

lemma sq_mtx_plus_ith[simp]:"(A + B) $$ i = A $$ i + B $$ i"
  unfolding sq_mtx_plus_eq by (simp add: vec_eq_iff)

lemma sq_mtx_uminus_eq: "- A = to_mtx (χ i j. - A$$i$j)"
  by transfer (simp add: vec_eq_iff)

lemma sq_mtx_minus_eq: "A - B = to_mtx (χ i j. A$$i$j - B$$i$j)"
  by transfer (simp add: vec_eq_iff)

lemma sq_mtx_minus_ith[simp]:"(A - B) $$ i = A $$ i - B $$ i"
  unfolding sq_mtx_minus_eq by (simp add: vec_eq_iff)

lemma sq_mtx_times_eq: "A * B = to_mtx (χ i j. sum (λk. A$$i$k * B$$k$j) UNIV)"
  by transfer (simp add: matrix_matrix_mult_def)

lemma sq_mtx_plus_diag_diag[simp]: "sq_mtx_diag f + sq_mtx_diag g = (𝖽𝗂𝖺𝗀 i. f i + g i)"
  by (subst sq_mtx_eq_iff) (simp add: axis_def)

lemma sq_mtx_minus_diag_diag[simp]: "sq_mtx_diag f - sq_mtx_diag g = (𝖽𝗂𝖺𝗀 i. f i - g i)"
  by (subst sq_mtx_eq_iff) (simp add: axis_def)

lemma sum_sq_mtx_diag[simp]: "(n<m. sq_mtx_diag (g n)) = (𝖽𝗂𝖺𝗀 i. n<m. (g n i))" for m::nat
  by (induct m, simp, subst sq_mtx_eq_iff, simp_all)

lemma sq_mtx_mult_diag_diag[simp]: "sq_mtx_diag f * sq_mtx_diag g = (𝖽𝗂𝖺𝗀 i. f i * g i)"
  by (simp add: matrix_mul_diag_diag sq_mtx_diag.abs_eq times_sq_mtx.abs_eq)

lemma sq_mtx_mult_diagl: "(𝖽𝗂𝖺𝗀 i. f i) * A = to_mtx (χ i j. f i * A $$ i $ j)"
  by transfer (simp add: matrix_mul_diag_matl)

lemma sq_mtx_mult_diagr: "A * (𝖽𝗂𝖺𝗀 i. f i) = to_mtx (χ i j. A $$ i $ j * f j)"
  by transfer (simp add: matrix_matrix_mul_diag_matr)

lemma mtx_vec_mult_0l[simp]: "0 *V x = 0"
  by (simp add: sq_mtx_vec_mult.abs_eq zero_sq_mtx_def)

lemma mtx_vec_mult_0r[simp]: "A *V 0 = 0"
  by (transfer, simp)

lemma mtx_vec_mult_add_rdistr: "(A + B) *V x = A *V x + B *V x"
  unfolding plus_sq_mtx_def 
  apply(transfer)
  by (simp add: matrix_vector_mult_add_rdistrib)

lemma mtx_vec_mult_add_rdistl: "A *V (x + y) = A *V x + A *V y"
  unfolding plus_sq_mtx_def 
  apply transfer
  by (simp add: matrix_vector_right_distrib)

lemma mtx_vec_mult_minus_rdistrib: "(A - B) *V x = A *V x - B *V x"
  unfolding minus_sq_mtx_def by(transfer, simp add: matrix_vector_mult_diff_rdistrib)

lemma mtx_vec_mult_minus_ldistrib: "A *V (x - y) =  A *V x -  A *V y"
  by (metis (no_types, lifting) add_diff_cancel diff_add_cancel 
      matrix_vector_right_distrib sq_mtx_vec_mult.rep_eq)

lemma sq_mtx_times_vec_assoc: "(A * B) *V x = A *V (B *V x)"
  by (transfer, simp add: matrix_vector_mul_assoc)

lemma sq_mtx_vec_mult_sum_cols: "A *V x = sum (λi. x $ i *R 𝖼𝗈𝗅 i A) UNIV"
  by(transfer) (simp add: matrix_mult_sum scalar_mult_eq_scaleR)


subsection ‹ Real normed vector space of square matrices ›

instantiation sq_mtx :: (finite) real_normed_vector 
begin

definition norm_sq_mtx :: "'a sq_mtx  real" where "A = to_vec Aop"

lift_definition scaleR_sq_mtx :: "real  'a sq_mtx  'a sq_mtx" is scaleR .

definition sgn_sq_mtx :: "'a sq_mtx  'a sq_mtx" 
  where "sgn_sq_mtx A = (inverse (A)) *R A"

definition dist_sq_mtx :: "'a sq_mtx  'a sq_mtx  real" 
  where "dist_sq_mtx A B = A - B" 

definition uniformity_sq_mtx :: "('a sq_mtx × 'a sq_mtx) filter" 
  where "uniformity_sq_mtx = (INF e{0<..}. principal {(x, y). dist x y < e})"

definition open_sq_mtx :: "'a sq_mtx set  bool" 
  where "open_sq_mtx U = (xU. F (x', y) in uniformity. x' = x  y  U)"

instance apply intro_classes 
  unfolding sgn_sq_mtx_def open_sq_mtx_def dist_sq_mtx_def uniformity_sq_mtx_def
            prefer 10 
            apply(transfer, simp add: norm_sq_mtx_def op_norm_triangle)
           prefer 9 
           apply(simp_all add: norm_sq_mtx_def zero_sq_mtx_def op_norm_eq_0)
  by (transfer, simp add: norm_sq_mtx_def op_norm_scaleR algebra_simps)+

end

lemma sq_mtx_scaleR_eq: "c *R A = to_mtx (χ i j. c *R A $$ i $ j)"
  by transfer (simp add: vec_eq_iff)

lemma scaleR_to_mtx_ith[simp]: "c *R (to_mtx A) $$ i1 $ i2 = c * A $ i1 $ i2"
  by transfer (simp add: scaleR_vec_def)

lemma sq_mtx_scaleR_ith[simp]: "(c *R A) $$ i = (c  *R (A $$ i))"
  by (unfold scaleR_sq_mtx_def, transfer, simp)

lemma scaleR_sq_mtx_diag: "c *R sq_mtx_diag f = (𝖽𝗂𝖺𝗀 i. c * f i)"
  by (subst sq_mtx_eq_iff, simp add: axis_def)

lemma scaleR_mtx_vec_assoc: "(c *R A) *V x = c *R (A *V x)"
  unfolding scaleR_sq_mtx_def sq_mtx_vec_mult_def apply simp
  by (simp add: scaleR_matrix_vector_assoc)

lemma mtx_vec_scaleR_commute: "A *V (c *R x) = c *R (A *V x)"
  unfolding scaleR_sq_mtx_def sq_mtx_vec_mult_def apply(simp, transfer)
  by (simp add: vector_scaleR_commute)

lemma mtx_times_scaleR_commute: "A * (c *R B) = c *R (A * B)" for A::"('n::finite) sq_mtx"
  unfolding sq_mtx_scaleR_eq sq_mtx_times_eq 
  apply(simp add: to_mtx_inject)
  apply(simp add: vec_eq_iff fun_eq_iff)
  by (simp add: semiring_normalization_rules(19) vector_space_over_itself.scale_sum_right)

lemma le_mtx_norm: "m  {A *V x |x. x = 1}  m  A"
  using cSup_upper[of _ "{(to_vec A) *v x | x. x = 1}"]
  by (simp add: op_norm_set_proptys(2) op_norm_def norm_sq_mtx_def sq_mtx_vec_mult.rep_eq)

lemma norm_vec_mult_le: "A *V x  (A) * (x)"
  by (simp add: norm_matrix_le_mult_op_norm norm_sq_mtx_def sq_mtx_vec_mult.rep_eq)

lemma bounded_bilinear_sq_mtx_vec_mult: "bounded_bilinear (λA s. A *V s)"
  apply (rule bounded_bilinear.intro, simp_all add: mtx_vec_mult_add_rdistr 
      mtx_vec_mult_add_rdistl scaleR_mtx_vec_assoc mtx_vec_scaleR_commute)
  by (rule_tac x=1 in exI, auto intro!: norm_vec_mult_le)

lemma norm_sq_mtx_def2: "A = Sup {A *V x |x. x = 1}"
  unfolding norm_sq_mtx_def op_norm_def sq_mtx_vec_mult_def by simp

lemma norm_sq_mtx_def3: "A = (SUP x. (A *V x) / (x))"
  unfolding norm_sq_mtx_def onorm_def sq_mtx_vec_mult_def by simp

lemma norm_sq_mtx_diag: "sq_mtx_diag f = Max {¦f i¦ |i. i  UNIV}"
  unfolding norm_sq_mtx_def apply transfer
  by (rule op_norm_diag_mat_eq)

lemma sq_mtx_norm_le_sum_col: "A  (iUNIV. 𝖼𝗈𝗅 i A)"
  using op_norm_le_sum_column[of "to_vec A"] 
  apply(simp add: norm_sq_mtx_def)
  by(transfer, simp add: op_norm_le_sum_column)

lemma norm_le_transpose: "A  A"
  unfolding norm_sq_mtx_def by transfer (rule op_norm_le_transpose)

lemma norm_eq_norm_transpose[simp]: "A = A"
  using norm_le_transpose[of A] and norm_le_transpose[of "A"] by simp

lemma norm_column_le_norm: "A $$ i  A"
  using norm_vec_mult_le[of "A" "𝖾 i"] by simp


subsection ‹ Real normed algebra of square matrices ›

instantiation sq_mtx :: (finite) real_normed_algebra_1
begin

lift_definition one_sq_mtx :: "'a sq_mtx" is "to_mtx (mat 1)" .

lemma sq_mtx_one_idty: "1 * A = A" "A * 1 = A" for A :: "'a sq_mtx"
  by(transfer, transfer, unfold mat_def matrix_matrix_mult_def, simp add: vec_eq_iff)+

lemma sq_mtx_norm_1: "(1::'a sq_mtx) = 1"
  unfolding one_sq_mtx_def norm_sq_mtx_def 
  apply(simp add: op_norm_def)
  apply(subst cSup_eq[of _ 1])
  using ex_norm_eq_1 by auto

lemma sq_mtx_norm_times: "A * B  (A) * (B)" for A :: "'a sq_mtx"
  unfolding norm_sq_mtx_def times_sq_mtx_def by(simp add: op_norm_matrix_matrix_mult_le)

instance 
  apply intro_classes 
  apply(simp_all add: sq_mtx_one_idty sq_mtx_norm_1 sq_mtx_norm_times)
  apply(simp_all add: to_mtx_inject vec_eq_iff one_sq_mtx_def zero_sq_mtx_def mat_def)
  by(transfer, simp add: scalar_matrix_assoc matrix_scalar_ac)+

end

lemma sq_mtx_one_ith_simps[simp]: "1 $$ i $ i = 1" "i  j  1 $$ i $ j = 0"
  unfolding one_sq_mtx_def mat_def by simp_all

lemma of_nat_eq_sq_mtx_diag[simp]: "of_nat m = (𝖽𝗂𝖺𝗀 i. m)"
  by (induct m) (simp, subst sq_mtx_eq_iff, simp add: axis_def)+

lemma mtx_vec_mult_1[simp]: "1 *V s = s"
  by (auto simp: sq_mtx_vec_mult_def one_sq_mtx_def 
      mat_def vec_eq_iff matrix_vector_mult_def)

lemma sq_mtx_diag_one[simp]: "(𝖽𝗂𝖺𝗀 i. 1) = 1"
  by (subst sq_mtx_eq_iff, simp add: one_sq_mtx_def mat_def axis_def)

abbreviation "mtx_invertible A  invertible (to_vec A)"

lemma mtx_invertible_def: "mtx_invertible A  (A'. A' * A = 1  A * A' = 1)"
  apply (unfold sq_mtx_inv_def times_sq_mtx_def one_sq_mtx_def invertible_def, clarsimp, safe)
   apply(rule_tac x="to_mtx A'" in exI, simp)
  by (rule_tac x="to_vec A'" in exI, simp add: to_mtx_inject)

lemma mtx_invertibleI:
  assumes "A * B = 1" and "B * A = 1"
  shows "mtx_invertible A"
  using assms unfolding mtx_invertible_def by auto

lemma mtx_invertibleD[simp]:
  assumes "mtx_invertible A" 
  shows "A-1 * A = 1" and "A * A-1 = 1"
  apply (unfold sq_mtx_inv_def times_sq_mtx_def one_sq_mtx_def)
  using assms by simp_all

lemma mtx_invertible_inv[simp]: "mtx_invertible A  mtx_invertible (A-1)"
  using mtx_invertibleD mtx_invertibleI by blast

lemma mtx_invertible_one[simp]: "mtx_invertible 1"
  by (simp add: one_sq_mtx.rep_eq)

lemma sq_mtx_inv_unique:
  assumes "A * B = 1" and "B * A = 1"
  shows "A-1 = B"
  by (metis (no_types, lifting) assms mtx_invertibleD(2) 
      mtx_invertibleI mult.assoc sq_mtx_one_idty(1))

lemma sq_mtx_inv_idempotent[simp]: "mtx_invertible A  A-1-1 = A"
  using mtx_invertibleD sq_mtx_inv_unique by blast

lemma sq_mtx_inv_mult:
  assumes "mtx_invertible A" and "mtx_invertible B"
  shows "(A * B)-1 = B-1 * A-1"
  by (simp add: assms matrix_inv_matrix_mul sq_mtx_inv_def times_sq_mtx_def)

lemma sq_mtx_inv_one[simp]: "1-1 = 1"
  by (simp add: sq_mtx_inv_unique)

definition similar_sq_mtx :: "('n::finite) sq_mtx  'n sq_mtx  bool" (infixr "" 25)
  where "(A  B)  ( P. mtx_invertible P  A = P-1 * B * P)"

lemma similar_sq_mtx_matrix: "(A  B) = similar_matrix (to_vec A) (to_vec B)"
  apply(unfold similar_matrix_def similar_sq_mtx_def, safe)
   apply (metis sq_mtx_inv.rep_eq times_sq_mtx.rep_eq)
  by (metis UNIV_I sq_mtx_inv.abs_eq times_sq_mtx.abs_eq to_mtx_inverse to_vec_inverse)

lemma similar_sq_mtx_refl[simp]: "A  A"
  by (unfold similar_sq_mtx_def, rule_tac x="1" in exI, simp)

lemma similar_sq_mtx_simm: "A  B  B  A"
  apply(unfold similar_sq_mtx_def, clarsimp)
  apply(rule_tac x="P-1" in exI, simp add: mult.assoc)
  by (metis mtx_invertibleD(2) mult.assoc mult.left_neutral)

lemma similar_sq_mtx_trans: "A  B  B  C  A  C"
  unfolding similar_sq_mtx_matrix using similar_matrix_trans by blast

lemma power_sq_mtx_diag: "(sq_mtx_diag f)^n = (𝖽𝗂𝖺𝗀 i. f i^n)"
  by (induct n, simp_all)

lemma power_similiar_sq_mtx_diag_eq:
  assumes "mtx_invertible P"
      and "A = P-1 * (sq_mtx_diag f) * P"
    shows "A^n = P-1 * (𝖽𝗂𝖺𝗀 i. f i^n) * P"
proof(induct n, simp_all add: assms)
  fix n::nat
  have "P-1 * sq_mtx_diag f * P * (P-1 * (𝖽𝗂𝖺𝗀 i. f i ^ n) * P) = 
  P-1 * sq_mtx_diag f * (𝖽𝗂𝖺𝗀 i. f i ^ n) * P"
    by (metis (no_types, lifting) assms(1) mtx_invertibleD(2) mult.assoc mult.right_neutral)
  also have "... = P-1 * (𝖽𝗂𝖺𝗀 i. f i * f i ^ n) * P"
    by (simp add: mult.assoc) 
  finally show "P-1 * sq_mtx_diag f * P * (P-1 * (𝖽𝗂𝖺𝗀 i. f i ^ n) * P) = 
  P-1 * (𝖽𝗂𝖺𝗀 i. f i * f i ^ n) * P" .
qed

lemma power_similar_sq_mtx_diag:
  assumes "A  (sq_mtx_diag f)"
  shows "A^n  (𝖽𝗂𝖺𝗀 i. f i^n)"
  using assms power_similiar_sq_mtx_diag_eq 
  unfolding similar_sq_mtx_def by blast


subsection ‹ Banach space of square matrices ›

lemma Cauchy_cols:
  fixes X :: "nat  ('a::finite) sq_mtx" 
  assumes "Cauchy X"
  shows "Cauchy (λn. 𝖼𝗈𝗅 i (X n))" 
proof(unfold Cauchy_def dist_norm, clarsimp)
  fix ε::real assume "ε > 0"
  then obtain M where M_def:"mM. nM. X m - X n < ε"
    using Cauchy X unfolding Cauchy_def by(simp add: dist_sq_mtx_def) metis
  {fix m n assume "m  M" and "n  M"
    hence "ε > X m - X n" 
      using M_def by blast
    moreover have "X m - X n  (X m - X n) *V 𝖾 i"
      by(rule le_mtx_norm[of _ "X m - X n"], force)
    moreover have "(X m - X n) *V 𝖾 i = X m *V 𝖾 i - X n *V 𝖾 i"
      by (simp add: mtx_vec_mult_minus_rdistrib)
    moreover have "... = 𝖼𝗈𝗅 i (X m) - 𝖼𝗈𝗅 i (X n)"
      by (simp add: mtx_vec_mult_minus_rdistrib mtx_vec_mult_canon)
    ultimately have "𝖼𝗈𝗅 i (X m) - 𝖼𝗈𝗅 i (X n) < ε" 
      by linarith}
  thus "M. mM. nM. 𝖼𝗈𝗅 i (X m) - 𝖼𝗈𝗅 i (X n) < ε" 
    by blast
qed

lemma col_convergence:
  assumes "i. (λn. 𝖼𝗈𝗅 i (X n))  L $ i" 
  shows "X  to_mtx (transpose L)"
proof(unfold LIMSEQ_def dist_norm, clarsimp)
  let ?L = "to_mtx (transpose L)"
  let ?a = "CARD('a)" fix ε::real assume "ε > 0"
  hence "ε / ?a > 0" by simp
  hence "i.  N. nN. 𝖼𝗈𝗅 i (X n) - L $ i < ε/?a"
    using assms unfolding LIMSEQ_def dist_norm convergent_def by blast
  then obtain N where "i. nN. 𝖼𝗈𝗅 i (X n) - L $ i < ε/?a"
    using finite_nat_minimal_witness[of "λ i n. 𝖼𝗈𝗅 i (X n) - L $ i < ε/?a"] by blast
  also have "i n. (𝖼𝗈𝗅 i (X n) - L $ i) = (𝖼𝗈𝗅 i (X n - ?L))"
    unfolding minus_sq_mtx_def by(transfer, simp add: transpose_def vec_eq_iff column_def)
  ultimately have N_def:"i. nN. 𝖼𝗈𝗅 i (X n - ?L) < ε/?a" 
    by auto
  have "nN. X n - ?L < ε"
  proof(rule allI, rule impI)
    fix n::nat assume "N  n"
    hence " i. 𝖼𝗈𝗅 i (X n - ?L) < ε/?a"
      using N_def by blast
    hence "(iUNIV. 𝖼𝗈𝗅 i (X n - ?L)) < ((i::'a)UNIV. ε/?a)"
      using sum_strict_mono[of _ "λi. 𝖼𝗈𝗅 i (X n - ?L)"] by force
    moreover have "X n - ?L  (iUNIV. 𝖼𝗈𝗅 i (X n - ?L))"
      using sq_mtx_norm_le_sum_col by blast
    moreover have "((i::'a)UNIV. ε/?a) = ε" 
      by force
    ultimately show "X n - ?L < ε" 
      by linarith
  qed
  thus "no. nno. X n - ?L < ε" 
    by blast
qed

instance sq_mtx :: (finite) banach
proof(standard)
  fix X :: "nat  'a sq_mtx"
  assume "Cauchy X"
  hence "i. Cauchy (λn. 𝖼𝗈𝗅 i (X n))"
    using Cauchy_cols by blast
  hence obs: "i. ∃! L. (λn. 𝖼𝗈𝗅 i (X n))  L"
    using Cauchy_convergent convergent_def LIMSEQ_unique by fastforce
  define L where "L = (χ i. lim (λn. 𝖼𝗈𝗅 i (X n)))"
  hence "i. (λn. 𝖼𝗈𝗅 i (X n))  L $ i" 
    using obs theI_unique[of "λL. (λn. 𝖼𝗈𝗅 _ (X n))  L" "L $ _"] by (simp add: lim_def)
  thus "convergent X"
    using col_convergence unfolding convergent_def by blast
qed

lemma exp_similiar_sq_mtx_diag_eq:
  assumes "mtx_invertible P"
      and "A = P-1 * (𝖽𝗂𝖺𝗀 i. f i) * P"
    shows "exp A = P-1 * exp (𝖽𝗂𝖺𝗀 i. f i) * P"
proof(unfold exp_def power_similiar_sq_mtx_diag_eq[OF assms])
  have "(n. P-1 * (𝖽𝗂𝖺𝗀 i. f i ^ n) * P /R fact n) = 
  (n. P-1 * ((𝖽𝗂𝖺𝗀 i. f i ^ n) /R fact n) * P)"
    by simp
  also have "... = (n. P-1 * ((𝖽𝗂𝖺𝗀 i. f i ^ n) /R fact n)) * P"
    apply(subst suminf_multr[OF bounded_linear.summable[OF bounded_linear_mult_right]])
    unfolding power_sq_mtx_diag[symmetric] by (simp_all add: summable_exp_generic)
  also have "... = P-1 * (n. (𝖽𝗂𝖺𝗀 i. f i ^ n) /R fact n) * P"
    apply(subst suminf_mult[of _ "P-1"])
    unfolding power_sq_mtx_diag[symmetric] 
    by (simp_all add: summable_exp_generic)
  finally show "(n. P-1 * (𝖽𝗂𝖺𝗀 i. f i ^ n) * P /R fact n) = 
  P-1 * (n. sq_mtx_diag f ^ n /R fact n) * P"
    unfolding power_sq_mtx_diag by simp
qed

lemma exp_similiar_sq_mtx_diag:
  assumes "A  sq_mtx_diag f"
  shows "exp A  exp (sq_mtx_diag f)"
  using assms exp_similiar_sq_mtx_diag_eq 
  unfolding similar_sq_mtx_def by blast

lemma suminf_sq_mtx_diag:
  assumes "i. (λn. f n i) sums (suminf (λn. f n i))"
  shows "(n. (𝖽𝗂𝖺𝗀 i. f n i)) = (𝖽𝗂𝖺𝗀 i. n. f n i)"
proof(rule suminfI, unfold sums_def LIMSEQ_iff, clarsimp simp: norm_sq_mtx_diag)
  let ?g = "λn i. ¦(n<n. f n i) - (n. f n i)¦"
  fix r::real assume "r > 0"
  have "i. no. nno. ?g n i < r"
    using assms r > 0 unfolding sums_def LIMSEQ_iff by clarsimp 
  then obtain N where key: "i. nN. ?g n i < r"
    using finite_nat_minimal_witness[of "λi n. ?g n i < r"] by blast
  {fix n::nat
    assume "n  N"
    obtain i where i_def: "Max {x. i. x = ?g n i} = ?g n i"
      using cMax_finite_ex[of "{x. i. x = ?g n i}"] by auto
    hence "?g n i < r"
      using key n  N by blast
    hence "Max {x. i. x = ?g n i} < r"
      unfolding i_def[symmetric] .}
  thus "N. nN. Max {x. i. x = ?g n i} < r"
    by blast
qed

lemma exp_sq_mtx_diag: "exp (sq_mtx_diag f) = (𝖽𝗂𝖺𝗀 i. exp (f i))"
  apply(unfold exp_def, simp add: power_sq_mtx_diag scaleR_sq_mtx_diag)
  apply(rule suminf_sq_mtx_diag)
  using exp_converges[of "f _"] 
  unfolding sums_def LIMSEQ_iff exp_def by force

lemma exp_scaleR_diagonal1:
  assumes "mtx_invertible P" and "A = P-1 * (𝖽𝗂𝖺𝗀 i. f i) * P"
    shows "exp (t *R A) = P-1 * (𝖽𝗂𝖺𝗀 i. exp (t * f i)) * P"
proof-
  have "exp (t *R A) = exp (P-1 * (t *R sq_mtx_diag f) * P)"
    using assms by simp
  also have "... = P-1 * (𝖽𝗂𝖺𝗀 i. exp (t * f i)) * P"
    by (metis assms(1) exp_similiar_sq_mtx_diag_eq exp_sq_mtx_diag scaleR_sq_mtx_diag)
  finally show "exp (t *R A) = P-1 * (𝖽𝗂𝖺𝗀 i. exp (t * f i)) * P" .
qed

lemma exp_scaleR_diagonal2:
  assumes "mtx_invertible P" and "A = P * (𝖽𝗂𝖺𝗀 i. f i) * P-1"
    shows "exp (t *R A) = P * (𝖽𝗂𝖺𝗀 i. exp (t * f i)) * P-1"
  apply(subst sq_mtx_inv_idempotent[OF assms(1), symmetric])
  apply(rule exp_scaleR_diagonal1)
  by (simp_all add: assms)


subsection ‹ Examples ›

definition "mtx A = to_mtx (vector (map vector A))"

lemma vector_nth_eq: "(vector A) $ i = foldr (λx f n. (f (n + 1))(n := x)) A (λn x. 0) 1 i"
  unfolding vector_def by simp

lemma mtx_ith_eq[simp]: "mtx A $$ i $ j = foldr (λx f n. (f (n + 1))(n := x))
  (map (λl. vec_lambda (foldr (λx f n. (f (n + 1))(n := x)) l (λn x. 0) 1)) A) (λn x. 0) 1 i $ j"
  unfolding mtx_def vector_def by (simp add: vector_nth_eq)

subsubsection ‹ 2x2 matrices ›

lemma mtx2_eq_iff: "(mtx 
  ([a1, b1] # 
   [c1, d1] # []) :: 2 sq_mtx) = mtx 
  ([a2, b2] # 
   [c2, d2] # [])  a1 = a2  b1 = b2  c1 = c2  d1 = d2"
  apply(simp add: sq_mtx_eq_iff, safe)
  using exhaust_2 by force+

lemma mtx2_to_mtx: "mtx 
  ([a, b] # 
   [c, d] # []) = 
  to_mtx (χ i j::2. if i=1  j=1 then a 
  else (if i=1  j=2 then b 
  else (if i=2  j=1 then c 
  else d)))"
  apply(subst sq_mtx_eq_iff)
  using exhaust_2 by force

abbreviation diag2 :: "real  real  2 sq_mtx" 
  where "diag2 ι1 ι2  mtx 
   ([ι1, 0] # 
    [0, ι2] # [])"

lemma diag2_eq: "diag2 (ι 1) (ι 2) = (𝖽𝗂𝖺𝗀 i. ι i)"
  apply(simp add: sq_mtx_eq_iff)
  using exhaust_2 by (force simp: axis_def)

lemma one_mtx2: "(1::2 sq_mtx) = diag2 1 1"
  apply(subst sq_mtx_eq_iff)
  using exhaust_2 by force

lemma zero_mtx2: "(0::2 sq_mtx) = diag2 0 0"
  by (simp add: sq_mtx_eq_iff)

lemma scaleR_mtx2: "k *R mtx 
  ([a, b] # 
   [c, d] # []) = mtx 
  ([k*a, k*b] # 
   [k*c, k*d] # [])"
  by (simp add: sq_mtx_eq_iff)

lemma uminus_mtx2: "-mtx 
  ([a, b] # 
   [c, d] # []) = (mtx 
  ([-a, -b] # 
   [-c, -d] # [])::2 sq_mtx)"
  by (simp add: sq_mtx_uminus_eq sq_mtx_eq_iff)

lemma plus_mtx2: "mtx 
  ([a1, b1] # 
   [c1, d1] # []) + mtx 
  ([a2, b2] # 
   [c2, d2] # []) = ((mtx 
  ([a1+a2, b1+b2] # 
   [c1+c2, d1+d2] # []))::2 sq_mtx)"
  by (simp add: sq_mtx_eq_iff)

lemma minus_mtx2: "mtx 
  ([a1, b1] # 
   [c1, d1] # []) - mtx 
  ([a2, b2] # 
   [c2, d2] # []) = ((mtx 
  ([a1-a2, b1-b2] # 
   [c1-c2, d1-d2] # []))::2 sq_mtx)"
  by (simp add: sq_mtx_eq_iff)

lemma times_mtx2: "mtx 
  ([a1, b1] # 
   [c1, d1] # []) * mtx 
  ([a2, b2] # 
   [c2, d2] # []) = ((mtx 
  ([a1*a2+b1*c2, a1*b2+b1*d2] # 
   [c1*a2+d1*c2, c1*b2+d1*d2] # []))::2 sq_mtx)"
  unfolding sq_mtx_times_eq UNIV_2
  by (simp add: sq_mtx_eq_iff)

subsubsection ‹ 3x3 matrices ›

lemma mtx3_to_mtx: "mtx 
  ([a11, a12, a13] # 
   [a21, a22, a23] # 
   [a31, a32, a33] # []) = 
  to_mtx (χ i j::3. if i=1  j=1 then a11
  else (if i=1  j=2 then a12 
  else (if i=1  j=3 then a13 
  else (if i=2  j=1 then a21
  else (if i=2  j=2 then a22 
  else (if i=2  j=3 then a23 
  else (if i=3  j=1 then a31 
  else (if i=3  j=2 then a32 
  else a33))))))))"
  apply(simp add: sq_mtx_eq_iff)
  using exhaust_3 by force

abbreviation diag3 :: "real  real  real  3 sq_mtx" 
  where "diag3 ι1 ι2 ι3  mtx 
  ([ι1, 0, 0] # 
   [0, ι2, 0] # 
   [0, 0, ι3] # [])"

lemma diag3_eq: "diag3 (ι 1) (ι 2) (ι 3) = (𝖽𝗂𝖺𝗀 i. ι i)"
  apply(simp add: sq_mtx_eq_iff)
  using exhaust_3 by (force simp: axis_def)

lemma one_mtx3: "(1::3 sq_mtx) = diag3 1 1 1"
  apply(subst sq_mtx_eq_iff)
  using exhaust_3 by force

lemma zero_mtx3: "(0::3 sq_mtx) = diag3 0 0 0"
  by (simp add: sq_mtx_eq_iff)

lemma scaleR_mtx3: "k *R mtx 
  ([a11, a12, a13] # 
   [a21, a22, a23] # 
   [a31, a32, a33] # []) = mtx 
  ([k*a11, k*a12, k*a13] # 
   [k*a21, k*a22, k*a23] # 
   [k*a31, k*a32, k*a33] # [])"
  by (simp add: sq_mtx_eq_iff)

lemma plus_mtx3: "mtx 
  ([a11, a12, a13] # 
   [a21, a22, a23] # 
   [a31, a32, a33] # []) + mtx 
  ([b11, b12, b13] # 
   [b21, b22, b23] # 
   [b31, b32, b33] # []) = (mtx 
  ([a11+b11, a12+b12, a13+b13] # 
   [a21+b21, a22+b22, a23+b23] # 
   [a31+b31, a32+b32, a33+b33] # [])::3 sq_mtx)"
  by (subst sq_mtx_eq_iff) simp

lemma minus_mtx3: "mtx 
  ([a11, a12, a13] # 
   [a21, a22, a23] # 
   [a31, a32, a33] # []) - mtx 
  ([b11, b12, b13] # 
   [b21, b22, b23] # 
   [b31, b32, b33] # []) = (mtx 
  ([a11-b11, a12-b12, a13-b13] # 
   [a21-b21, a22-b22, a23-b23] # 
   [a31-b31, a32-b32, a33-b33] # [])::3 sq_mtx)"
  by (simp add: sq_mtx_eq_iff)

lemma times_mtx3: "mtx 
  ([a11, a12, a13] # 
   [a21, a22, a23] # 
   [a31, a32, a33] # []) * mtx 
  ([b11, b12, b13] # 
   [b21, b22, b23] # 
   [b31, b32, b33] # []) = (mtx 
  ([a11*b11+a12*b21+a13*b31, a11*b12+a12*b22+a13*b32, a11*b13+a12*b23+a13*b33] # 
   [a21*b11+a22*b21+a23*b31, a21*b12+a22*b22+a23*b32, a21*b13+a22*b23+a23*b33] # 
   [a31*b11+a32*b21+a33*b31, a31*b12+a32*b22+a33*b32, a31*b13+a32*b23+a33*b33] # [])::3 sq_mtx)"
  unfolding sq_mtx_times_eq
  unfolding UNIV_3 by (simp add: sq_mtx_eq_iff)

end