Theory More_Matrix

theory More_Matrix
  imports "Jordan_Normal_Form.Matrix"
    "Jordan_Normal_Form.DL_Rank"
    "Jordan_Normal_Form.VS_Connect"
    "Jordan_Normal_Form.Gauss_Jordan_Elimination"
begin

section "Kronecker Product"

definition kronecker_product :: "'a :: ring mat  'a mat  'a mat" where
  "kronecker_product A B =
  (let ra = dim_row A; ca = dim_col A;
       rb = dim_row B; cb = dim_col B
  in
    mat (ra*rb) (ca*cb)
    (λ(i,j).
      A $$ (i div rb, j div cb) *
      B $$ (i mod rb, j mod cb)
  ))"

lemma arith:
  assumes "d < a"
  assumes "c < b"
  shows "b*d+c < a*(b::nat)"
proof -
  have "b*d+c < b*(d+1)"
    by (simp add: assms(2))
  thus ?thesis
    by (metis One_nat_def Suc_leI add.right_neutral add_Suc_right assms(1) less_le_trans mult.commute mult_le_cancel2)
qed

lemma dim_kronecker[simp]:
  "dim_row (kronecker_product A B) = dim_row A * dim_row B"
  "dim_col (kronecker_product A B) = dim_col A * dim_col B"
  unfolding kronecker_product_def Let_def by auto

lemma kronecker_inverse_index:
  assumes "r < dim_row A" "s < dim_col A"
  assumes "v < dim_row B" "w < dim_col B"
  shows "kronecker_product A B $$ (dim_row B*r+v, dim_col B*s+w) = A $$ (r,s) * B $$ (v,w)"
proof -
  from arith[OF assms(1) assms(3)]
  have "dim_row B*r+v < dim_row A * dim_row B" .
  moreover from arith[OF assms(2) assms(4)]
  have "dim_col B * s + w < dim_col A * dim_col B" .
  ultimately show ?thesis
    unfolding kronecker_product_def Let_def
    using assms by auto
qed

lemma kronecker_distr_left:
  assumes "dim_row B = dim_row C" "dim_col B = dim_col C"
  shows "kronecker_product A (B+C) = kronecker_product A B + kronecker_product A C"
  unfolding kronecker_product_def Let_def
  using assms apply (auto simp add: mat_eq_iff) 
  by (metis (no_types, lifting) distrib_left index_add_mat(1) mod_less_divisor mult_eq_0_iff neq0_conv not_less_zero)

lemma kronecker_distr_right:
  assumes "dim_row B = dim_row C" "dim_col B = dim_col C"
  shows "kronecker_product (B+C) A = kronecker_product B A + kronecker_product C A"
  unfolding kronecker_product_def Let_def
  using assms by (auto simp add: mat_eq_iff less_mult_imp_div_less distrib_right)

lemma index_mat_mod[simp]: "nr > 0 & nc > 0  mat nr nc f $$ (i mod nr,j mod nc) = f (i mod nr,j mod nc)"
  by auto

lemma kronecker_assoc:
  shows "kronecker_product A (kronecker_product B C) = kronecker_product (kronecker_product A B) C"
  unfolding kronecker_product_def Let_def
  apply (case_tac "dim_row B * dim_row C > 0 & dim_col B * dim_col C > 0")
   apply (auto simp add: mat_eq_iff less_mult_imp_div_less)
  by (smt (verit, best) div_less_iff_less_mult div_mult2_eq kronecker_inverse_index linordered_semiring_strict_class.mult_pos_pos mod_less_divisor mod_mult2_eq mult.assoc mult.commute mult_div_mod_eq)

lemma sum_sum_mod_div:
  "(ia = 0::nat..<x. ja = 0..<y. f ia ja) =
   (ia = 0..<x*y. f (ia div y) (ia mod y))"
proof -
  have 1: "inj_on (λia. (ia div y, ia mod y)) {0..<x * y}"
    by (smt (verit, best) Pair_inject div_mod_decomp inj_onI)
  have 21: "{0..<x} × {0..<y}  (λia. (ia div y, ia mod y)) ` {0..<x * y}"
  proof clarsimp
    fix a b
    assume *:"a < x" "b < y"
    have "a * y +  b  {0..<x*y}"
      by (metis arith * atLeastLessThan_iff le0 mult.commute)
    thus "(a, b)  (λia. (ia div y, ia mod y)) ` {0..<x * y}"
      using * by (auto simp add: image_iff)
        (metis a * y + b  {0..<x * y} add.commute add.right_neutral div_less div_mult_self1 less_zeroE mod_eq_self_iff_div_eq_0 mod_mult_self1)
  qed
  have 22:"(λia. (ia div y, ia mod y)) ` {0..<x * y}  {0..<x} × {0..<y}"
    using less_mult_imp_div_less apply auto
    by (metis mod_less_divisor mult.commute neq0_conv not_less_zero)
  have 2: "{0..<x} × {0..<y} = (λia. (ia div y, ia mod y)) ` {0..<x * y}"
    using 21 22 by auto
  have *: "(ia = 0::nat..<x. ja = 0..<y. f ia ja) =
        ((x, y){0..<x} × {0..<y}. f x y)"
    by (auto simp add: sum.cartesian_product)
  show ?thesis unfolding *
    apply (intro sum.reindex_cong[of "λia. (ia div y, ia mod y)"])
    using 1 2 by auto
qed

(* Kronecker product distributes over matrix multiplication *)
lemma kronecker_of_mult:
  assumes "dim_col (A :: 'a :: comm_ring mat) = dim_row C"
  assumes "dim_col B = dim_row D"
  shows "kronecker_product A B * kronecker_product C D = kronecker_product (A * C) (B * D)"
  unfolding kronecker_product_def Let_def mat_eq_iff
proof clarsimp
  fix i j
  assume ij: "i < dim_row A * dim_row B" "j < dim_col C * dim_col D"
  have 1: "(A * C) $$ (i div dim_row B, j div dim_col D) =
    row A (i div dim_row B)  col C (j div dim_col D)"
    using ij less_mult_imp_div_less by (auto intro!: index_mult_mat)
  have 2: "(B * D) $$ (i mod dim_row B, j mod dim_col D) =
    row B (i mod dim_row B)  col D (j mod dim_col D)"
    using ij apply (auto intro!: index_mult_mat)
    using gr_implies_not0 apply fastforce
    using gr_implies_not0 by fastforce
  have 3: "x. x < dim_row C * dim_row D 
         A $$ (i div dim_row B, x div dim_row D) *
         B $$ (i mod dim_row B, x mod dim_row D) *
         (C $$ (x div dim_row D, j div dim_col D) *
          D $$ (x mod dim_row D, j mod dim_col D)) =
         row A (i div dim_row B) $ (x div dim_row D) *
         col C (j div dim_col D) $ (x div dim_row D) *
         (row B (i mod dim_row B) $ (x mod dim_row D) *
          col D (j mod dim_col D) $ (x mod dim_row D))"
  proof -
    fix x
    assume *:"x < dim_row C * dim_row D"
    have 1: "row A (i div dim_row B) $ (x div dim_row D) = A $$ (i div dim_row B, x div dim_row D)"
      by (simp add: * assms(1) less_mult_imp_div_less row_def)
    have 2: "row B (i mod dim_row B) $ (x mod dim_row D) = B $$ (i mod dim_row B, x mod dim_row D)"
      by (metis "*" assms(2) ij(1) index_row(1) mod_less_divisor nat_0_less_mult_iff neq0_conv not_less_zero)
    have 3: "col C (j div dim_col D) $ (x div dim_row D) = C $$ (x div dim_row D, j div dim_col D)"
      by (simp add: "*" ij(2) less_mult_imp_div_less)
    have 4: "col D (j mod dim_col D) $ (x mod dim_row D) = D $$ (x mod dim_row D, j mod dim_col D)"
      by (metis "*" bot_nat_0.not_eq_extremum ij(2) index_col mod_less_divisor mult_zero_right not_less_zero)
    show "A $$ (i div dim_row B, x div dim_row D) *
         B $$ (i mod dim_row B, x mod dim_row D) *
         (C $$ (x div dim_row D, j div dim_col D) *
          D $$ (x mod dim_row D, j mod dim_col D)) =
         row A (i div dim_row B) $ (x div dim_row D) *
         col C (j div dim_col D) $ (x div dim_row D) *
         (row B (i mod dim_row B) $ (x mod dim_row D) *
          col D (j mod dim_col D) $ (x mod dim_row D))" unfolding 1 2 3 4
      by (simp add: mult.assoc mult.left_commute)
  qed
  have *: "(A * C) $$ (i div dim_row B, j div dim_col D) *
        (B * D) $$ (i mod dim_row B, j mod dim_col D) =
    (ia = 0..<dim_row C * dim_row D.
               A $$ (i div dim_row B, ia div dim_row D) *
               B $$ (i mod dim_row B, ia mod dim_row D) *
               (C $$ (ia div dim_row D, j div dim_col D) *
                D $$ (ia mod dim_row D, j mod dim_col D)))"
    unfolding 1 2 scalar_prod_def sum_product sum_sum_mod_div
    apply (auto simp add: sum_product sum_sum_mod_div intro!: sum.cong)
    using 3 by presburger
  show "vec (dim_col A * dim_col B)
          (λj. A $$ (i div dim_row B, j div dim_col B) *
               B $$ (i mod dim_row B, j mod dim_col B)) 
       vec (dim_row C * dim_row D)
          (λi. C $$ (i div dim_row D, j div dim_col D) *
               D $$ (i mod dim_row D, j mod dim_col D)) =
        (A * C) $$ (i div dim_row B, j div dim_col D) *
        (B * D) $$ (i mod dim_row B, j mod dim_col D)"
    unfolding * scalar_prod_def
    by (auto simp add: assms sum_product sum_sum_mod_div intro!: sum.cong)
qed

lemma inverts_mat_length:
  assumes "square_mat A" "inverts_mat A B" "inverts_mat B A"
  shows "dim_row B = dim_row A" "dim_col B = dim_col A"
   apply (metis assms(1) assms(3) index_mult_mat(3) index_one_mat(3) inverts_mat_def square_mat.simps)
  by (metis assms(1) assms(2) index_mult_mat(3) index_one_mat(3) inverts_mat_def square_mat.simps)

lemma less_mult_imp_mod_less:
  "m mod i < i" if "m < n * i" for m n i :: nat
  using gr_implies_not_zero that by fastforce

lemma kronecker_one:
  shows "kronecker_product ((1m x)::'a :: ring_1 mat) (1m y) = 1m (x*y)"
  unfolding kronecker_product_def Let_def
  apply  (auto simp add:mat_eq_iff less_mult_imp_div_less less_mult_imp_mod_less)
  by (metis div_mult_mod_eq)

lemma kronecker_invertible:
  assumes "invertible_mat (A :: 'a :: comm_ring_1 mat)" "invertible_mat B"
  shows "invertible_mat (kronecker_product A B)"
proof -
  obtain Ai where Ai: "inverts_mat A Ai" "inverts_mat Ai A" using assms invertible_mat_def by blast
  obtain Bi where Bi: "inverts_mat B Bi" "inverts_mat Bi B" using assms invertible_mat_def by blast
  have "square_mat (kronecker_product A B)"
    by (metis (no_types, lifting) assms(1) assms(2) dim_col_mat(1) dim_row_mat(1) invertible_mat_def kronecker_product_def square_mat.simps)
  moreover have "inverts_mat (kronecker_product A B) (kronecker_product Ai Bi)"
    using Ai Bi unfolding inverts_mat_def
    by (metis (no_types, lifting) dim_kronecker(1) index_mult_mat(3) index_one_mat(3) kronecker_of_mult kronecker_one)
  moreover have "inverts_mat (kronecker_product Ai Bi) (kronecker_product A B)"
    using Ai Bi unfolding inverts_mat_def
    by (metis (no_types, lifting) dim_kronecker(1) index_mult_mat(3) index_one_mat(3) kronecker_of_mult kronecker_one)
  ultimately show ?thesis unfolding invertible_mat_def by blast
qed

section "More DL Rank"

(* conjugate matrices *)
instantiation mat :: (conjugate) conjugate
begin

definition conjugate_mat :: "'a :: conjugate mat  'a mat"
  where "conjugate m = mat (dim_row m) (dim_col m) (λ(i,j). conjugate (m $$ (i,j)))"

lemma dim_row_conjugate[simp]: "dim_row (conjugate m) = dim_row m"
  unfolding conjugate_mat_def by auto

lemma dim_col_conjugate[simp]: "dim_col (conjugate m) = dim_col m"
  unfolding conjugate_mat_def by auto

lemma carrier_vec_conjugate[simp]: "m  carrier_mat nr nc  conjugate m  carrier_mat nr nc"
  by (auto)

lemma mat_index_conjugate[simp]:
  shows "i < dim_row m  j < dim_col m  conjugate m  $$ (i,j) = conjugate (m $$ (i,j))"
  unfolding conjugate_mat_def by auto

lemma row_conjugate[simp]: "i < dim_row m  row (conjugate m) i = conjugate (row m i)"
  by (auto)

lemma col_conjugate[simp]: "i < dim_col m  col (conjugate m) i = conjugate (col m i)"
  by (auto)

lemma rows_conjugate: "rows (conjugate m) = map conjugate (rows m)"
  by (simp add: list_eq_iff_nth_eq)

lemma cols_conjugate: "cols (conjugate m) = map conjugate (cols m)"
  by (simp add: list_eq_iff_nth_eq)

instance
proof
  fix a b :: "'a mat"
  show "conjugate (conjugate a) = a"
    unfolding mat_eq_iff by auto
  let ?a = "conjugate a"
  let ?b = "conjugate b"
  show "conjugate a = conjugate b  a = b"
    by (metis dim_col_conjugate dim_row_conjugate mat_index_conjugate conjugate_cancel_iff mat_eq_iff)
qed

end

abbreviation conjugate_transpose :: "'a::conjugate mat   'a mat"
  where "conjugate_transpose A  conjugate (AT)"

notation conjugate_transpose ("(_H)" [1000])

lemma transpose_conjugate:
  shows "(conjugate A)T = AH"
  unfolding conjugate_mat_def
  by auto

lemma vec_module_col_helper:
  fixes A:: "('a :: field) mat"
  shows "(0v (dim_row A)  LinearCombinations.module.span class_ring carrier = carrier_vec (dim_row A), mult = undefined, one = undefined, zero = 0v (dim_row A), add = (+), smult = (⋅v) (set (cols A)))"
proof -
  have "v. (0::'a) v v + v = v"
    by auto
  then show "0v (dim_row A)  LinearCombinations.module.span class_ring carrier = carrier_vec (dim_row A), mult = undefined, one = undefined, zero = 0v (dim_row A), add = (+), smult = (⋅v) (set (cols A))"
    by (metis cols_dim module_vec_def right_zero_vec smult_carrier_vec vec_space.prod_in_span zero_carrier_vec)
qed

lemma vec_module_col_helper2:
  fixes A:: "('a :: field) mat"
  shows "a x. x  LinearCombinations.module.span class_ring
                carrier = carrier_vec (dim_row A), mult = undefined, one = undefined,
                   zero = 0v (dim_row A), add = (+), smult = (⋅v)
                (set (cols A)) 
           (a b v. (a + b) v v = a v v + b v v) 
           a v x
            LinearCombinations.module.span class_ring
               carrier = carrier_vec (dim_row A), mult = undefined, one = undefined,
                  zero = 0v (dim_row A), add = (+), smult = (⋅v)
               (set (cols A))"
proof -
  fix a :: 'a and x :: "'a vec"
  assume "x  LinearCombinations.module.span class_ring carrier = carrier_vec (dim_row A), mult = undefined, one = undefined, zero = 0v (dim_row A), add = (+), smult = (⋅v) (set (cols A))"
  then show "a v x  LinearCombinations.module.span class_ring carrier = carrier_vec (dim_row A), mult = undefined, one = undefined, zero = 0v (dim_row A), add = (+), smult = (⋅v) (set (cols A))"
    by (metis (full_types) cols_dim idom_vec.smult_in_span module_vec_def)
qed

lemma vec_module_col: "module (class_ring :: 'a :: field ring)
  (module_vec TYPE('a) 
    (dim_row A)
      carrier :=
         LinearCombinations.module.span
          class_ring (module_vec TYPE('a) (dim_row A)) (set (cols A)))"
proof -
  interpret abelian_group "module_vec TYPE('a) (dim_row A)
      carrier :=
         LinearCombinations.module.span
          class_ring (module_vec TYPE('a) (dim_row A)) (set (cols A))"
    apply (unfold_locales)
          apply (auto simp add:module_vec_def)
          apply (metis cols_dim module_vec_def partial_object.select_convs(1) ring.simps(2) vec_vs vectorspace.span_add1)
         apply (metis assoc_add_vec cols_dim module_vec_def vec_space.cV vec_vs vectorspace.span_closed)
    using vec_module_col_helper[of A] apply (auto)    
       apply (metis cols_dim left_zero_vec module_vec_def partial_object.select_convs(1) vec_vs vectorspace.span_closed)
      apply (metis cols_dim module_vec_def partial_object.select_convs(1) right_zero_vec vec_vs vectorspace.span_closed)
     apply (metis cols_dim comm_add_vec module_vec_def vec_space.cV vec_vs vectorspace.span_closed)
    unfolding Units_def apply auto
    by (metis (no_types, opaque_lifting) cols_dim comm_add_vec module_vec_def partial_object.select_convs(1) uminus_l_inv_vec vec_space.vec_neg vec_vs vectorspace.span_closed vectorspace.span_neg)
  show ?thesis
    apply (unfold_locales)
    unfolding class_ring_simps apply auto
    unfolding module_vec_simps using add_smult_distrib_vec apply auto
     apply (auto simp add:module_vec_def)
    using vec_module_col_helper2
     apply blast
    by (smt (verit) cols_dim module_vec_def smult_add_distrib_vec vec_space.cV vec_vs vectorspace.span_closed)
qed

(* The columns of a matrix form a vectorspace *)
lemma vec_vs_col: "vectorspace (class_ring :: 'a :: field ring)
  (module_vec TYPE('a) (dim_row A)
      carrier :=
         LinearCombinations.module.span
          class_ring
          (module_vec TYPE('a)
            (dim_row A))
          (set (cols A)))"
  unfolding vectorspace_def
  using vec_module_col class_field 
  by (auto simp: class_field_def)

lemma cols_mat_mul_map:
  shows "cols (A * B) = map ((*v) A) (cols B)"
  unfolding list_eq_iff_nth_eq
  by auto

lemma cols_mat_mul:
  shows "set (cols (A * B)) = (*v) A ` set (cols B)"
  by (simp add: cols_mat_mul_map)

lemma set_obtain_sublist:
  assumes "S  set ls"
  obtains ss where "distinct ss" "S = set ss"
  using assms finite_distinct_list infinite_super by blast

lemma mul_mat_of_cols:
  assumes "A  carrier_mat nr n"
  assumes "j. j < length cs  cs ! j  carrier_vec n"
  shows "A * (mat_of_cols n cs) = mat_of_cols nr (map ((*v) A) cs)"
  unfolding mat_eq_iff
  using assms apply auto
  apply (subst mat_of_cols_index)
  by auto

lemma helper:
  fixes x y z ::"'a :: {conjugatable_ring, comm_ring}"
  shows "x * (y * z) = y * x * z"
  by (simp add: mult.assoc mult.left_commute)

lemma cscalar_prod_conjugate_transpose:
  fixes x y ::"'a :: {conjugatable_ring, comm_ring} vec"
  assumes "A  carrier_mat nr nc"
  assumes "x  carrier_vec nr"
  assumes "y  carrier_vec nc"
  shows "x ∙c (A *v y) = (AH *v x) ∙c y"
  unfolding mult_mat_vec_def scalar_prod_def
  using assms apply (auto simp add: sum_distrib_left sum_distrib_right sum_conjugate conjugate_dist_mul)
  apply (subst sum.swap)
  by (meson helper mult.assoc mult.left_commute sum.cong)

lemma mat_mul_conjugate_transpose_vec_eq_0:                        
  fixes v ::"'a :: {conjugatable_ordered_ring,semiring_no_zero_divisors,comm_ring} vec"
  assumes "A  carrier_mat nr nc"
  assumes "v  carrier_vec nr"
  assumes "A *v (AH *v v) = 0v nr"
  shows "AH *v v = 0v nc"
proof -
  have "(AH *v v) ∙c (AH *v v) = (A *v (AH *v v)) ∙c v"
    by (metis (mono_tags, lifting) Matrix.carrier_vec_conjugate assms(1) assms(2) assms(3) carrier_matD(2) conjugate_zero_vec cscalar_prod_conjugate_transpose dim_row_conjugate index_transpose_mat(2) mult_mat_vec_def scalar_prod_left_zero scalar_prod_right_zero vec_carrier)
  also have "... = 0"
    by (simp add: assms(2) assms(3))
      (* this step requires real entries *)
  ultimately have "(AH *v v) ∙c (AH *v v) = 0" by auto
  thus ?thesis
    apply (subst conjugate_square_eq_0_vec[symmetric])
    using assms(1) carrier_dim_vec apply fastforce
    by auto
qed

lemma row_mat_of_cols:
  assumes "i < nr"
  shows "row (mat_of_cols nr ls) i = vec (length ls) (λj. (ls ! j) $i)"
  by (simp add: assms mat_of_cols_index vec_eq_iff)

lemma mat_of_cols_cons_mat_vec:
  fixes v ::"'a::comm_ring vec"
  assumes "v  carrier_vec (length ls)"
  assumes "dim_vec a = nr"
  shows
    "mat_of_cols nr (a # ls) *v (vCons m v) =
   m v a + mat_of_cols nr ls *v v"
  unfolding mult_mat_vec_def vec_eq_iff
  using assms by
    (auto simp add: row_mat_of_cols vec_Suc o_def mult.commute)

lemma smult_vec_zero:
  fixes v ::"'a::ring vec"
  shows "0 v v = 0v (dim_vec v)"
  unfolding smult_vec_def vec_eq_iff
  by (auto)

lemma helper2:
  fixes A ::"'a::comm_ring mat"
  fixes v ::"'a vec"
  assumes "v  carrier_vec (length ss)"
  assumes "x. x  set ls  dim_vec x = nr"
  shows
    "mat_of_cols nr ss *v v =
   mat_of_cols nr (ls @ ss) *v (0v (length ls) @v v)"
  using assms(2)
proof (induction ls)
  case Nil
  then show ?case by auto
next
  case (Cons a ls)
  then show ?case apply (auto simp add:zero_vec_Suc)
    apply (subst mat_of_cols_cons_mat_vec)
    by (auto simp add:assms smult_vec_zero)
qed

lemma mat_of_cols_mult_mat_vec_permute_list:
  fixes v ::"'a::comm_ring list"
  assumes "f permutes {..<length ss}"
  assumes "length ss = length v"
  shows
    "mat_of_cols nr (permute_list f ss) *v vec_of_list (permute_list f v) =
     mat_of_cols nr ss *v vec_of_list v"
  unfolding mat_of_cols_def mult_mat_vec_def vec_eq_iff scalar_prod_def
proof clarsimp
  fix i
  assume "i < nr"
  from sum.permute[OF assms(1)]
  have "(ia<length ss. ss ! f ia $ i * v ! f ia) =
  sum ((λia. ss ! f ia $ i * v ! f ia)  f) {..<length ss}" .
  also have "... = (ia = 0..<length v. ss ! f ia $ i * v ! f ia)"
    using assms(2) calculation lessThan_atLeast0 by auto
  ultimately have *: "(ia = 0..<length v.
             ss ! f ia $ i * v ! f ia) =
         (ia = 0..<length v.
             ss ! ia $ i * v ! ia)"
    by (metis (mono_tags, lifting) g. sum g {..<length ss} = sum (g  f) {..<length ss} assms(2) comp_apply lessThan_atLeast0 sum.cong)
  show "(ia = 0..<length v.
         vec (length ss) (λj. permute_list f ss ! j $ i) $ ia *
         vec_of_list (permute_list f v) $ ia) =
         (ia = 0..<length v. vec (length ss) (λj. ss ! j $ i) $ ia * vec_of_list v $ ia)"
    using assms * by (auto simp add: permute_list_nth vec_of_list_index)
qed

(* permute everything in a subset of the indices to the back *)
lemma subindex_permutation:
  assumes "distinct ss" "set ss  {..<length ls}"
  obtains f where "f permutes {..<length ls}"
    "permute_list f ls = map ((!) ls) (filter (λi. i  set ss) [0..<length ls]) @ map ((!) ls) ss"
proof -
  have "set [0..<length ls] = set (filter (λi. i  set ss) [0..<length ls] @ ss)"
    using assms unfolding multiset_eq_iff by auto
  then have "mset [0..<length ls] = mset (filter (λi. i  set ss) [0..<length ls] @ ss)"
    apply (subst set_eq_iff_mset_eq_distinct[symmetric])
    using assms by auto  
  then have "mset ls = mset (map ((!) ls)
           (filter (λi. i  set ss)
             [0..<length ls]) @ map ((!) ls) ss)"
    by (metis map_append map_nth mset_map)
  thus ?thesis
    by (metis mset_eq_permutation that)
qed

lemma subindex_permutation2:
  assumes "distinct ss" "set ss  {..<length ls}"
  obtains f where "f permutes {..<length ls}"
    "ls = permute_list f (map ((!) ls) (filter (λi. i  set ss) [0..<length ls]) @ map ((!) ls) ss)"
  using subindex_permutation
  by (metis assms(1) assms(2) length_permute_list mset_eq_permutation mset_permute_list)

lemma distinct_list_subset_nths:
  assumes "distinct ss" "set ss  set ls"
  obtains ids where "distinct ids" "set ids  {..<length ls}" "ss = map ((!) ls) ids"
proof -
  let ?ids = "map (λi. @j. j < length ls  ls!j = i ) ss"
  have 1: "distinct ?ids" unfolding distinct_map
    using assms apply (auto simp add: inj_on_def)
    by (smt (verit) in_set_conv_nth someI subset_eq)
  have 2: "set ?ids  {..<length ls}"
    using assms apply (auto)
    by (metis (mono_tags, lifting) in_mono in_set_conv_nth tfl_some)
  have 3: "ss = map ((!) ls) ?ids"
    using assms apply (auto simp add: list_eq_iff_nth_eq)
    by (smt (verit, best) in_set_conv_nth someI subsetD)
  show "(ids. distinct ids 
            set ids  {..<length ls} 
            ss = map ((!) ls) ids  thesis) 
    thesis" using 1 2 3 by blast
qed

lemma helper3: 
  fixes A ::"'a::comm_ring mat"
  assumes A: "A  carrier_mat nr nc"
  assumes ss:"distinct ss" "set ss  set (cols A)"
  assumes "v  carrier_vec (length ss)"
  obtains c where "mat_of_cols nr ss *v v = A *v c" "dim_vec c = nc"
proof -
  from distinct_list_subset_nths[OF ss]
  obtain ids where ids: "distinct ids" "set ids  {..<length (cols A)}"
    and ss: "ss = map ((!) (cols A)) ids" by blast
  let ?ls = " map ((!) (cols A)) (filter (λi. i  set ids) [0..<length (cols A)])"
  from subindex_permutation2[OF ids] obtain f where
    f: "f permutes {..<length (cols A)}"
    "cols A = permute_list f (?ls @ ss)" using ss by blast
  have *: "x. x  set ?ls  dim_vec x = nr"
    using A by auto
  let ?cs1 = "(list_of_vec (0v (length ?ls) @v v))"
  from helper2[OF assms(4) ]
  have "mat_of_cols nr ss *v v = mat_of_cols nr (?ls @ ss) *v vec_of_list (?cs1)"
    using *
    by (metis vec_list)
  also have "... = mat_of_cols nr (permute_list f (?ls @ ss)) *v vec_of_list (permute_list f ?cs1)"
    apply (auto intro!: mat_of_cols_mult_mat_vec_permute_list[symmetric])
     apply (metis cols_length f(1) f(2) length_append length_map length_permute_list)
    using assms(4) by auto
  also have "... =  A *v vec_of_list (permute_list f ?cs1)" using f(2) assms by auto
  ultimately show
    "(c. mat_of_cols nr ss *v v = A *v c  dim_vec c = nc  thesis)  thesis"
    by (metis A assms(4) carrier_matD(2) carrier_vecD cols_length dim_vec_of_list f(2) index_append_vec(2) index_zero_vec(2) length_append length_list_of_vec length_permute_list)
qed

lemma mat_mul_conjugate_transpose_sub_vec_eq_0:                        
  fixes A ::"'a :: {conjugatable_ordered_ring,semiring_no_zero_divisors,comm_ring} mat"
  assumes "A  carrier_mat nr nc"
  assumes "distinct ss" "set ss  set (cols (AH))"
  assumes "v  carrier_vec (length ss)"
  assumes "A *v (mat_of_cols nc ss *v v) = 0v nr"
  shows "(mat_of_cols nc ss *v v) = 0v nc"
proof -
  have "AH  carrier_mat nc nr" using assms(1) by auto
  from  helper3[OF this assms(2-4)]
  obtain c where c: "mat_of_cols nc ss *v v = AH *v c" "dim_vec c = nr" by blast
  have 1: "c  carrier_vec nr"
    using c carrier_vec_dim_vec by blast
  have 2: "A *v (AH *v c) = 0v nr" using c assms(5) by auto
  from mat_mul_conjugate_transpose_vec_eq_0[OF assms(1) 1 2]
  have "AH *v c = 0v nc" .
  thus ?thesis unfolding c(1)[symmetric] .
qed

lemma Units_invertible:
  fixes A:: "'a::semiring_1 mat"
  assumes "A  Units (ring_mat TYPE('a) n b)"
  shows "invertible_mat A"
  using assms unfolding Units_def invertible_mat_def
  apply (auto simp add: ring_mat_def)
  using inverts_mat_def by blast

lemma invertible_Units:
  fixes A:: "'a::semiring_1 mat"
  assumes "invertible_mat A"
  shows "A  Units (ring_mat TYPE('a) (dim_row A) b)"
  using assms unfolding Units_def invertible_mat_def
  apply (auto simp add: ring_mat_def)
  by (metis assms carrier_mat_triv invertible_mat_def inverts_mat_def inverts_mat_length(1) inverts_mat_length(2))

lemma invertible_det:
  fixes A:: "'a::field mat"
  assumes "A  carrier_mat n n"
  shows "invertible_mat A  det A  0"
  apply auto
  using invertible_Units unit_imp_det_non_zero apply fastforce
  using assms by (auto intro!: Units_invertible det_non_zero_imp_unit)

context vec_space begin

lemma find_indices_distinct:
  assumes "distinct ss"
  assumes "i < length ss"
  shows "find_indices (ss ! i) ss = [i]"
proof -
  have "set (find_indices (ss ! i) ss) = {i}"
    using assms apply auto by (simp add: assms(1) assms(2) nth_eq_iff_index_eq)
  thus ?thesis
    by (metis distinct.simps(2) distinct_find_indices empty_iff empty_set insert_iff list.exhaust list.simps(15)) 
qed

lemma lin_indpt_lin_comb_list:
  assumes "distinct ss"
  assumes "lin_indpt (set ss)"
  assumes "set ss  carrier_vec n"
  assumes "lincomb_list f ss = 0v n"
  assumes "i < length ss"
  shows "f i = 0"
proof -
  from lincomb_list_as_lincomb[OF assms(3)]
  have "lincomb_list f ss = lincomb (mk_coeff ss f) (set ss)" .
  also have "... = lincomb  (λv. sum f (set (find_indices v ss))) (set ss)"
    unfolding mk_coeff_def
    apply (subst R.sumlist_map_as_finsum)
    by (auto simp add: distinct_find_indices)
  ultimately have "lincomb_list f ss = lincomb  (λv. sum f (set (find_indices v ss))) (set ss)" by auto
  then have *:"lincomb (λv. sum f (set (find_indices v ss))) (set ss) = 0v n"
    using assms(4) by auto
  have "finite (set ss)" by simp
  from not_lindepD[OF assms(2) this _ _ *]
  have "(λv. sum f (set (find_indices v ss)))  set ss  {0}"
    by auto
  from funcset_mem[OF this]
  have "sum f (set (find_indices (nth ss i) ss))  {0}"
    using assms(5) by auto
  thus ?thesis unfolding find_indices_distinct[OF assms(1) assms(5)]
    by auto
qed

(* Note: in this locale dim_row A = n, e.g.:
lemma foo:
  assumes "dim_row A = n"
  shows "rank A = vec_space.rank (dim_row A) A"
  by (simp add: assms) *)

lemma span_mat_mul_subset:
  assumes "A  carrier_mat n d"
  assumes "B  carrier_mat d nc"
  shows "span (set (cols (A * B)))  span (set (cols A))"
proof -
  have *: "v. ca. lincomb_list v (cols (A * B)) =
              lincomb_list ca  (cols A)"
  proof -
    fix v
    have "lincomb_list v (cols (A * B)) = (A * B) *v vec nc v"
      apply (subst lincomb_list_as_mat_mult)
       apply (metis assms(1) carrier_dim_vec carrier_matD(1) cols_dim index_mult_mat(2) subset_code(1))
      by (metis assms(1) assms(2) carrier_matD(1) carrier_matD(2) cols_length index_mult_mat(2) index_mult_mat(3) mat_of_cols_cols)
    also have "... = A *v (B *v vec nc v)"
      using assms(1) assms(2) by auto
    also have "... = lincomb_list (λi. (B *v vec nc v) $ i) (cols A)"
      apply (subst lincomb_list_as_mat_mult)
      using assms(1) carrier_dim_vec cols_dim apply blast
      by (metis assms(1) assms(2) carrier_matD(1) carrier_matD(2) cols_length dim_mult_mat_vec dim_vec eq_vecI index_vec mat_of_cols_cols)
    ultimately have "lincomb_list v (cols (A * B)) =
              lincomb_list (λi. (B *v vec nc v) $ i) (cols A)" by auto
    thus "ca. lincomb_list v (cols (A * B)) = lincomb_list ca (cols A)" by auto
  qed
  show ?thesis
    apply (subst span_list_as_span[symmetric])
     apply (metis assms(1) carrier_matD(1) cols_dim index_mult_mat(2))
    apply (subst span_list_as_span[symmetric])
    using assms(1) cols_dim apply blast
    by (auto simp add:span_list_def *)
qed

lemma rank_mat_mul_right:
  assumes "A  carrier_mat n d"
  assumes "B  carrier_mat d nc"
  shows "rank (A * B)  rank A"
proof -
  have "subspace class_ring (local.span (set (cols (A*B))))
        (vs (local.span (set (cols A))))"
    unfolding subspace_def
    by (metis assms(1) assms(2) carrier_matD(1) cols_dim index_mult_mat(2) nested_submodules span_is_submodule vec_space.span_mat_mul_subset vec_vs_col)
  from vectorspace.subspace_dim[OF _ this]
  have "vectorspace.dim class_ring
   (vs (local.span (set (cols A)))
    carrier := local.span (set (cols (A * B)))) 
  vectorspace.dim class_ring
      (vs (local.span (set (cols A))))"
    apply auto
    by (metis (no_types) assms(1) carrier_matD(1) fin_dim_span_cols index_mult_mat(2) mat_of_cols_carrier(1) mat_of_cols_cols vec_vs_col)
  thus ?thesis unfolding rank_def
    by auto
qed

lemma sumlist_drop:
  assumes "v. v  set ls  dim_vec v = n"
  shows "sumlist ls = sumlist (filter (λv. v  0v n) ls)"
  using assms
proof (induction ls)
  case Nil
  then show ?case by auto
next
  case (Cons a ls)
  then show ?case using dim_sumlist by auto
qed

lemma lincomb_list_alt:
  shows "lincomb_list c s =
    sumlist (map2 (λi j. i v s ! j) (map (λi. c i) [0..<length s]) [0..<length s])"
  unfolding lincomb_list_def
  by (smt (verit, ccfv_SIG) length_map map2_map_map map_nth nth_equalityI nth_map)

lemma lincomb_list_alt2:
  assumes "v. v  set s  dim_vec v = n"
  assumes "i. i  set ls  i < length s"
  shows "
    sumlist (map2 (λi j. i v s ! j) (map (λi. c i) ls) ls) =
    sumlist (map2 (λi j. i v s ! j) (map (λi. c i) (filter (λi. c i  0) ls)) (filter (λi. c i  0) ls))"
  using assms(2)
proof (induction ls)
  case Nil
  then show ?case by auto
next
  case (Cons a s)
  then show ?case
    apply auto
    apply (subst smult_l_null)
     apply (simp add: assms(1) carrier_vecI)
    apply (subst left_zero_vec)
     apply (subst sumlist_carrier)
      apply auto
    by (metis (no_types, lifting) assms(1) carrier_dim_vec mem_Collect_eq nth_mem set_filter set_zip_rightD)
qed 

lemma two_set:
  assumes "distinct ls"
  assumes "set ls = set [a,b]"
  assumes "a  b"
  shows "ls = [a,b]  ls = [b,a]"
  apply (cases ls)
  using assms(2) apply auto[1]
proof -
  fix x xs
  assume ls:"ls = x # xs"
  obtain y ys where xs:"xs = y # ys"
    by (metis (no_types) ls = x # xs assms(2) assms(3) list.set_cases list.set_intros(1) list.set_intros(2) set_ConsD)
  have 1:"x = a  x = b"
    using ls = x # xs assms(2) by auto
  have 2:"y = a  y = b"
    using ls = x # xs xs = y # ys assms(2) by auto
  have 3:"ys = []"
    by (metis (no_types) ls = x # xs xs = y # ys assms(1) assms(2) distinct.simps(2) distinct_length_2_or_more in_set_member member_rec(2) neq_Nil_conv set_ConsD)
  show "ls = [a, b]  ls = [b, a]" using ls xs 1 2 3 assms
    by auto
qed

lemma filter_disj_inds:
  assumes "i < length ls" "j < length ls" "i  j"
  shows "filter (λia. ia  j  ia = i) [0..<length ls] = [i, j] 
  filter (λia. ia  j  ia = i) [0..<length ls] = [j,i]"
proof -
  have 1: "distinct (filter (λia. ia = i  ia = j) [0..<length ls])"
    using distinct_filter distinct_upt by blast
  have 2:"set (filter (λia. ia = i  ia = j) [0..<length ls]) = {i, j}"
    using assms by auto
  show ?thesis using two_set[OF 1]
    using assms(3) empty_set filter_cong list.simps(15)
    by (smt(verit, ccfv_SIG) "2" assms(3) empty_set filter_cong list.simps(15))
qed

lemma lincomb_list_indpt_distinct:
  assumes "v. v  set ls  dim_vec v = n"
  assumes
    "c. lincomb_list c ls = 0v n  (i. i < (length ls)  c i = 0)"
  shows "distinct ls"
  unfolding distinct_conv_nth
proof clarsimp
  fix i j
  assume ij: "i < length ls" "j < length ls" "i  j" 
  assume lsij: "ls ! i = ls ! j"
  have "lincomb_list (λk. if k = i then 1 else if k = j then -1 else 0) ls =
     (ls ! i) - (ls ! j)"
    unfolding lincomb_list_alt
    apply (subst lincomb_list_alt2[OF assms(1)])
      apply auto
    using  filter_disj_inds[OF ij]
    apply auto
    using ij(3) apply force
    using assms(1) ij(2) apply auto[1]
    using ij(3) apply blast
    using assms(1) ij(2) by auto
  also have "...  = 0v n" unfolding lsij
    apply (rule minus_cancel_vec)
    using j < length ls assms(1)
    using carrier_vec_dim_vec nth_mem by blast
  ultimately have "lincomb_list (λk. if k = i then 1 else if k = j then -1 else 0) ls = 0v n" by auto
  from assms(2)[OF this]
  show False
    using i < length ls by auto
qed

end

locale conjugatable_vec_space = vec_space f_ty n for
  f_ty::"'a::conjugatable_ordered_field itself"
  and n
begin                                                           

lemma transpose_rank_mul_conjugate_transpose:
  fixes A :: "'a mat"
  assumes "A  carrier_mat n nc"
  shows "vec_space.rank nc AH  rank (A * AH)"
proof -
  have 1: "AH  carrier_mat nc n" using assms by auto
  have 2: "A * AH  carrier_mat n n" using assms by auto
      (* S is a maximal linearly independent set of rows A (or cols AT) *)
  let ?P = "(λT. T  set (cols AH)  module.lin_indpt class_ring (module_vec TYPE('a) nc) T)"
  have *:"A. ?P A  finite A  card A  n"
  proof clarsimp
    fix S
    assume S: "S  set (cols AH)"
    have "card S  card (set (cols AH))" using S
      using card_mono by blast
    also have "...  length (cols AH)" using card_length by blast
    also have "...  n" using assms by auto
    ultimately show "finite S  card S  n"
      by (meson List.finite_set S dual_order.trans finite_subset)
  qed
  have **:"?P {}"
    apply (subst module.lin_dep_def)
    by (auto simp add: vec_module)
  from maximal_exists[OF *]
  obtain S where S: "maximal S ?P" using **
    by (metis (no_types, lifting)) 
      (* Some properties of S *)
  from vec_space.rank_card_indpt[OF 1 S]
  have rankeq: "vec_space.rank nc AH = card S" .

  have s_hyp: "S  set (cols AH)"
    using S unfolding maximal_def by simp
  have modhyp: "module.lin_indpt class_ring (module_vec TYPE('a) nc) S" 
    using S unfolding maximal_def by simp

(* switch to a list representation *)
  obtain ss where ss: "set ss = S" "distinct ss"
    by (metis (mono_tags) S maximal_def set_obtain_sublist)
  have ss2: "set (map ((*v) A) ss) = (*v) A ` S"
    by (simp add: ss(1))
  have rw_hyp: "cols (mat_of_cols n (map ((*v) A) ss)) = cols  (A * mat_of_cols nc ss)" 
    unfolding cols_def apply (auto)
    using mat_vec_as_mat_mat_mult[of A n nc]
    by (metis (no_types, lifting) "1" assms carrier_matD(1) cols_dim mul_mat_of_cols nth_mem s_hyp ss(1) subset_code(1))
  then have rw: "mat_of_cols n (map ((*v) A) ss) = A * mat_of_cols nc ss"
    by (metis assms carrier_matD(1) index_mult_mat(2) mat_of_cols_carrier(2) mat_of_cols_cols) 
  have indpt: "c. lincomb_list c (map ((*v) A) ss) = 0v n 
      i. (i < (length ss)  c i = 0)"
  proof clarsimp
    fix c i
    assume *: "lincomb_list c (map ((*v) A) ss) = 0v n"
    assume i: "i < length ss"
    have "wset (map ((*v) A) ss). dim_vec w = n"
      using assms by auto
    from lincomb_list_as_mat_mult[OF this]
    have "A * mat_of_cols nc ss *v  vec (length ss) c = 0v n"
      using * rw by auto
    then have hq: "A *v (mat_of_cols nc ss *v vec (length ss) c) =  0v n"
      by (metis assms assoc_mult_mat_vec mat_of_cols_carrier(1) vec_carrier)

    then have eq1: "(mat_of_cols nc ss *v vec (length ss) c) =  0v nc"
      apply (intro mat_mul_conjugate_transpose_sub_vec_eq_0)
      using assms ss s_hyp by auto

(* Rewrite the inner vector back to a lincomb_list *)
    have dim_hyp2: "wset ss. dim_vec w = nc"
      using ss(1) s_hyp
      by (metis "1" carrier_matD(1) carrier_vecD cols_dim subsetD) 
    from vec_module.lincomb_list_as_mat_mult[OF this, symmetric]
    have "mat_of_cols nc ss *v vec (length ss) c = module.lincomb_list (module_vec TYPE('a) nc) c ss" .
    then have "module.lincomb_list (module_vec TYPE('a) nc) c ss = 0v nc" using eq1 by auto
    from vec_space.lin_indpt_lin_comb_list[OF ss(2) _ _ this i]
    show "c i = 0" using modhyp ss s_hyp
      using "1" cols_dim by blast
  qed
  have distinct: "distinct (map ((*v) A) ss)"
    by (metis (no_types, lifting) assms carrier_matD(1) dim_mult_mat_vec imageE indpt length_map lincomb_list_indpt_distinct ss2)
  then have 3: "card S = card ((*v) A ` S)"
    by (metis ss distinct_card image_set length_map)
  then have 4: "(*v) A ` S  set (cols (A * AH))"
    using cols_mat_mul S  set (cols AH) by blast
  have 5: "lin_indpt ((*v) A ` S)"
  proof clarsimp
    assume ld:"lin_dep ((*v) A ` S)"
    have *: "finite ((*v) A ` S)"
      by (metis List.finite_set ss2)
    have **: "(*v) A ` S  carrier_vec n"
      using "2" "4" cols_dim by blast
    from finite_lin_dep[OF * ld **]
    obtain a v where
      a: "lincomb a ((*v) A ` S) = 0v n" and
      v: "v  (*v) A ` S" "a v  0" by blast
    obtain i where i:"v = map ((*v) A) ss ! i" "i < length ss"
      using v unfolding ss2[symmetric]
      using find_first_le nth_find_first by force
    from ss2[symmetric]
    have "set (map ((*v) A) ss) carrier_vec n" using ** ss2 by auto
    from lincomb_as_lincomb_list_distinct[OF this distinct] have
      "lincomb_list
     (λi. a (map ((*v) A) ss ! i))  (map ((*v) A) ss) = 0v n"
      using a ss2 by auto
    from indpt[OF this]
    show False using v i by simp
  qed
  from rank_ge_card_indpt[OF 2 4 5]
  have "card ((*v) A ` S)  rank (A * AH)" .
  thus ?thesis using rankeq 3 by linarith
qed

lemma conjugate_transpose_rank_le:
  assumes "A  carrier_mat n nc"
  shows "vec_space.rank nc (AH)  rank A"
  by (metis assms carrier_matD(2) carrier_mat_triv dim_row_conjugate dual_order.trans index_transpose_mat(2) rank_mat_mul_right transpose_rank_mul_conjugate_transpose)

lemma conjugate_finsum:
  assumes f: "f : U  carrier_vec n"
  shows "conjugate (finsum V f U) = finsum V (conjugate  f) U"
  using f
proof (induct U rule: infinite_finite_induct)
  case (infinite A)
  then show ?case by auto
next
  case empty
  then show ?case by auto
next
  case (insert u U)
  hence f: "f : U  carrier_vec n" "f u : carrier_vec n"  by auto
  then have cf: "conjugate  f : U  carrier_vec n"
    "(conjugate  f) u : carrier_vec n"
     apply (simp add: Pi_iff)
    by (simp add: f(2))
  then show ?case
    unfolding finsum_insert[OF insert(1) insert(2) f]
    unfolding finsum_insert[OF insert(1) insert(2) cf ]
    apply (subst conjugate_add_vec[of _ n])
    using f(2) apply blast
    using M.finsum_closed f(1) apply blast
    by (simp add: comp_def f(1) insert.hyps(3))
qed

lemma rank_conjugate_le:
  assumes A:"A  carrier_mat n d"
  shows "rank (conjugate (A))  rank A"
proof -
  (* S is a maximal linearly independent set of (conjugate A) *)
  let ?P = "(λT. T  set (cols (conjugate A))  lin_indpt T)"
  have *:"A. ?P A  finite A  card A  d"
    by (metis List.finite_set assms card_length card_mono carrier_matD(2) cols_length dim_col_conjugate dual_order.trans rev_finite_subset)
  have **:"?P {}"
    by (simp add: finite_lin_indpt2)
  from maximal_exists[OF *]
  obtain S where S: "maximal S ?P" using **
    by (metis (no_types, lifting))
  have s_hyp: "S  set (cols (conjugate A))" "lin_indpt S"
    using S unfolding maximal_def
     apply blast
    by (metis (no_types, lifting) S maximal_def)
  from rank_card_indpt[OF _ S, of d]
  have rankeq: "rank (conjugate A) = card S" using assms by auto 
  have 1:"conjugate ` S  set (cols A)"
    using S apply auto
    by (metis (no_types, lifting) cols_conjugate conjugate_id image_eqI in_mono list.set_map s_hyp(1))
  have 2: "lin_indpt (conjugate ` S)"
    apply (rule ccontr)
    apply (auto simp add: lin_dep_def)
  proof -
    fix T c v
    assume T: "T  conjugate ` S" "finite T" and
      lc:"lincomb c T = 0v n" and "v  T"  "c v  0"
    let ?T = "conjugate ` T"
    let ?c = "conjugate  c  conjugate"
    have 1: "finite ?T"  using T by auto
    have 2: "?T  S"  using T by auto
    have 3: "?c  ?T  UNIV" by auto
    have "lincomb ?c ?T = (VxT. conjugate (c x) v conjugate x)"
      unfolding lincomb_def
      apply (subst finsum_reindex)
        apply auto
       apply (metis "2" carrier_vec_conjugate assms carrier_matD(1) cols_dim image_eqI s_hyp(1) subsetD)
      by (meson conjugate_cancel_iff inj_onI)
    also have "... = (VxT. conjugate (c x v x)) "
      by (simp add: conjugate_smult_vec)
    also have "... = conjugate (VxT. (c x v x))"
      apply(subst conjugate_finsum[of "λx.(c x v x)" T])
       apply (auto simp add:o_def)
      by (smt (verit, ccfv_SIG) Matrix.carrier_vec_conjugate Pi_I' T(1) assms carrier_matD(1) cols_dim dim_row_conjugate imageE s_hyp(1) smult_carrier_vec subset_eq) 
    also have "... = conjugate (lincomb c T)"
      using lincomb_def by presburger
    ultimately have "lincomb ?c ?T = conjugate (lincomb c T)" by auto
    then have 4:"lincomb ?c ?T = 0v n" using lc by auto
    from not_lindepD[OF s_hyp(2) 1 2 3 4]
    have "conjugate  c  conjugate  conjugate ` T  {0}" .
    then have "c v = 0"
      by (simp add: Pi_iff v  T)
    thus False using c v  0 by auto
  qed
  from rank_ge_card_indpt[OF A 1 2]
  have 3:"card (conjugate ` S)  rank A" .
  have 4: "card (conjugate ` S) = card S"
    apply (auto intro!: card_image)
    by (meson conjugate_cancel_iff inj_onI)
  show ?thesis using rankeq 3 4 by auto
qed

lemma rank_conjugate:
  assumes "A  carrier_mat n d"
  shows "rank (conjugate A) = rank A"
  using  rank_conjugate_le
  by (metis carrier_vec_conjugate assms conjugate_id dual_order.antisym)

end (* exit the context *)

lemma conjugate_transpose_rank:
  fixes A::"'a::{conjugatable_ordered_field} mat"
  shows "vec_space.rank (dim_row A) A = vec_space.rank (dim_col A) (AH)"
  using  conjugatable_vec_space.conjugate_transpose_rank_le
  by (metis (no_types, lifting) Matrix.transpose_transpose carrier_matI conjugate_id dim_col_conjugate dual_order.antisym index_transpose_mat(2) transpose_conjugate)

lemma transpose_rank:
  fixes A::"'a::{conjugatable_ordered_field} mat"
  shows "vec_space.rank (dim_row A) A = vec_space.rank (dim_col A) (AT)"
  by (metis carrier_mat_triv conjugatable_vec_space.rank_conjugate conjugate_transpose_rank index_transpose_mat(2))

lemma rank_mat_mul_left:
  fixes A::"'a::{conjugatable_ordered_field} mat"
  assumes "A  carrier_mat n d"
  assumes "B  carrier_mat d nc"
  shows "vec_space.rank n (A * B)  vec_space.rank d B"
  by (metis (no_types, lifting) Matrix.transpose_transpose assms(1) assms(2) carrier_matD(1) carrier_matD(2) carrier_mat_triv conjugatable_vec_space.rank_conjugate conjugate_transpose_rank index_mult_mat(3) index_transpose_mat(3) transpose_mult vec_space.rank_mat_mul_right)

section "Results on Invertibility"

(* Extract specific columns of a matrix  *)
definition take_cols :: "'a mat  nat list  'a mat"
  where "take_cols A inds = mat_of_cols (dim_row A) (map ((!) (cols A)) (filter ((>) (dim_col A)) inds))"

definition take_cols_var :: "'a mat  nat list  'a mat"
  where "take_cols_var A inds = mat_of_cols (dim_row A) (map ((!) (cols A)) (inds))"

definition take_rows :: "'a mat  nat list  'a mat"
  where "take_rows A inds = mat_of_rows (dim_col A) (map ((!) (rows A)) (filter ((>) (dim_row A)) inds))"

lemma cong1:
  "x = y   mat_of_cols n x = mat_of_cols n y"
  by auto

lemma nth_filter:
  assumes "j < length (filter P ls)"
  shows "P  ((filter P ls) ! j)"
  by (simp add: assms list_ball_nth)

lemma take_cols_mat_mul:
  assumes "A  carrier_mat nr n"
  assumes "B  carrier_mat n nc"
  shows "A * take_cols B inds = take_cols (A * B) inds"
proof -
  have "j. j < length (map ((!) (cols B)) (filter ((>) nc) inds)) 
      (map ((!) (cols B)) (filter ((>) nc) inds)) ! j  carrier_vec n"
    using assms apply auto
    apply (subst cols_nth)
    using nth_filter by auto
  from mul_mat_of_cols[OF assms(1) this]
  have "A *  take_cols B inds = mat_of_cols nr (map (λx. A *v cols B ! x) (filter ((>) (dim_col B)) inds))"
    unfolding take_cols_def using assms by (auto simp add: o_def)
  also have "... = take_cols (A * B) inds"
    unfolding take_cols_def using assms by (auto intro!: cong1)
  ultimately show ?thesis by auto
qed

lemma take_cols_carrier_mat:
  assumes "A  carrier_mat nr nc"
  obtains n where "take_cols A inds  carrier_mat nr n"
  unfolding take_cols_def
  using assms
  by fastforce

lemma take_cols_carrier_mat_strict:
  assumes "A  carrier_mat nr nc"
  assumes "i. i  set inds  i < nc"
  shows "take_cols A inds  carrier_mat nr (length inds)"
  unfolding take_cols_def
  using assms by auto

lemma gauss_jordan_take_cols:  
  assumes "gauss_jordan A (take_cols A inds) = (C,D)"
  shows "D = take_cols C inds"
proof -
  obtain nr nc where A: "A   carrier_mat nr nc" by auto
  from take_cols_carrier_mat[OF this]
  obtain n where B: "take_cols A inds  carrier_mat nr n" by auto
  from gauss_jordan_transform[OF A B assms, of undefined]
  obtain P where PP:"PUnits (ring_mat TYPE('a) nr undefined)" and
    CD: "C = P * A" "D = P * take_cols A inds" by blast
  have P: "P  carrier_mat nr nr"
    by (metis (no_types, lifting) Units_def PP mem_Collect_eq partial_object.select_convs(1) ring_mat_def)
  from take_cols_mat_mul[OF P A]
  have "P * take_cols A inds = take_cols (P * A) inds" .
  thus ?thesis using CD by blast  
qed

lemma dim_col_take_cols:
  assumes "j. j  set inds  j < dim_col A"
  shows "dim_col (take_cols A inds) = length inds"
  unfolding take_cols_def
  using assms by auto

lemma dim_col_take_rows[simp]:
  shows "dim_col (take_rows A inds) = dim_col A"
  unfolding take_rows_def by auto

lemma cols_take_cols_subset:
  shows "set (cols (take_cols A inds))  set (cols A)"
  unfolding take_cols_def
  apply (subst cols_mat_of_cols)
   apply auto
  using in_set_conv_nth by fastforce

lemma dim_row_take_cols[simp]:
  shows "dim_row (take_cols A ls) = dim_row A"
  by (simp add: take_cols_def)

lemma dim_row_append_rows[simp]:
  shows "dim_row (A @r B) = dim_row A + dim_row B"
  by (simp add: append_rows_def)

lemma rows_inj:
  assumes "dim_col A = dim_col B"
  assumes "rows A = rows B"
  shows "A = B"
  unfolding mat_eq_iff
  apply auto
    apply (metis assms(2) length_rows)
  using assms(1) apply blast
  by (metis assms(1) assms(2) mat_of_rows_rows)

lemma append_rows_index:
  assumes "dim_col A = dim_col B"
  assumes "i < dim_row A + dim_row B"
  assumes "j < dim_col A"
  shows "(A @r B) $$ (i,j) = (if i < dim_row A then A $$ (i,j) else B $$ (i-dim_row A,j))"
  unfolding append_rows_def
  apply (subst index_mat_four_block)
  using assms by auto

lemma row_append_rows:
  assumes "dim_col A = dim_col B"
  assumes "i < dim_row A + dim_row B"
  shows "row (A @r B) i = (if i < dim_row A then row A i else row B (i-dim_row A))"
  unfolding vec_eq_iff
  using assms by (auto simp add: append_rows_def)

lemma append_rows_mat_mul:
  assumes "dim_col A = dim_col B"
  shows "(A @r B) * C = A * C @r B * C"
  unfolding mat_eq_iff
  apply auto
   apply (simp add: append_rows_def)
  apply (subst index_mult_mat)
    apply auto
   apply (simp add: append_rows_def)
  apply (subst  append_rows_index)
     apply auto
    apply (simp add: append_rows_def)
   apply (metis add.right_neutral append_rows_def assms index_mat_four_block(3) index_mult_mat(1) index_mult_mat(3) index_zero_mat(3) row_append_rows trans_less_add1)
  by (metis add_cancel_right_right add_diff_inverse_nat append_rows_def assms index_mat_four_block(3) index_mult_mat(1) index_mult_mat(3) index_zero_mat(3) nat_add_left_cancel_less row_append_rows)

lemma cardlt:
  shows "card  {i. i < (n::nat)}  n"
  by simp

lemma row_echelon_form_zero_rows:
  assumes row_ech: "row_echelon_form A"
  assumes dim_asm: "dim_col A  dim_row A"
  shows "take_rows A [0..<length (pivot_positions A)] @r  0m (dim_row A - length (pivot_positions A))  (dim_col A) = A"
proof -
  have ex_pivot_fun: " f. pivot_fun A f (dim_col A)" using row_ech unfolding row_echelon_form_def by auto
  have len_help: "length (pivot_positions A) = card {i. i < dim_row A  row A i  0v (dim_col A)}"
    using ex_pivot_fun pivot_positions[where A = "A",where nr = "dim_row A", where nc = "dim_col A"]
    by auto
  then have len_help2: "length (pivot_positions A)  dim_row A"
    by (metis (no_types, lifting) card_mono cardlt finite_Collect_less_nat le_trans mem_Collect_eq subsetI)
  have fileq: "filter (λy. y < dim_row A) [0..< length (pivot_positions A)] = [0..<length (pivot_positions A)]"
    apply (rule filter_True)
    using len_help2 by auto
  have "n. card {i. i < n   row A i  0v (dim_col A)}  n"
  proof clarsimp 
    fix n
    have h: "x. x  {i. i < n  row A i  0v (dim_col A)}  x{..<n}"
      by simp
    then have h1: "{i. i < n   row A i  0v (dim_col A)}  {..<n}"
      by blast
    then have h2: "(card {i. i < n   row A i  0v (dim_col A)}::nat)  (card {..<n}::nat)"
      using card_mono by blast 
    then show "(card {i. i < n  row A i  0v (dim_col A)}::nat)  (n::nat)" using h2 card_lessThan[of n]
      by auto
  qed
  then have pivot_len: "length (pivot_positions A)  dim_row A "  using len_help
    by simp
  have alt_char: "mat_of_rows (dim_col A)
         (map ((!) (rows A)) (filter (λy. y < dim_col A) [0..<length (pivot_positions A)])) = 
      mat_of_rows (dim_col A) (map ((!) (rows A))  [0..<length (pivot_positions A)])"
    using pivot_len dim_asm
    by auto
  have h1: "i j. i < dim_row A 
           j < dim_col A 
           i < dim_row (take_rows A [0..<length (pivot_positions A)]) 
           take_rows A [0..<length (pivot_positions A)] $$ (i, j) = A $$ (i, j)"
  proof - 
    fix i 
    fix j
    assume "i < dim_row A"
    assume j_lt: "j < dim_col A"
    assume i_lt: "i < dim_row (take_rows A [0..<length (pivot_positions A)])" 
    have lt: "length (pivot_positions A)  dim_row A" using pivot_len by auto
    have h1: "take_rows A [0..<length (pivot_positions A)] $$ (i, j) = (row (take_rows A [0..<length (pivot_positions A)]) i)$j"
      by (simp add: i_lt j_lt)
    then have h2: "(row (take_rows A [0..<length (pivot_positions A)]) i)$j = (row A i)$j"
      using lt alt_char i_lt unfolding take_rows_def by auto
    show "take_rows A [0..<length (pivot_positions A)] $$ (i, j) = A $$ (i, j)"
      using h1 h2
      by (simp add: i < dim_row A j_lt) 
  qed
  let ?nc = "dim_col A"
  let ?nr = "dim_row A"
  have h2: "i j. i < dim_row A 
           j < dim_col A 
           ¬ i < dim_row (take_rows A [0..<length (pivot_positions A)]) 
           0m (dim_row A - length (pivot_positions A)) (dim_col A) $$
           (i - dim_row (take_rows A [0..<length (pivot_positions A)]), j) =
           A $$ (i, j)"
  proof - 
    fix i
    fix j
    assume lt_i: "i < dim_row A"
    assume lt_j: "j < dim_col A"
    assume not_lt: "¬ i < dim_row (take_rows A [0..<length (pivot_positions A)])"
    let ?ip = "i+1"
    have h0: "f. pivot_fun A f (dim_col A)  f i = ?nc"
    proof -  
      have half1: "f. pivot_fun A f (dim_col A)" using assms unfolding row_echelon_form_def
        by blast
      have half2: "f. pivot_fun A f (dim_col A)  f i = ?nc " 
      proof clarsimp
        fix f
        assume is_piv: "pivot_fun A f (dim_col A)"
        have len_pp: "length (pivot_positions A) = card {i. i < ?nr  row A i  0v ?nc}" using is_piv pivot_positions[of A ?nr ?nc f]
          by auto
        have  "i. (i < ?nr  row A i  0v ?nc)   (i < ?nr  f i  ?nc)"
          using is_piv pivot_fun_zero_row_iff[of A f ?nc ?nr]
          by blast
        then have len_pp_var: "length (pivot_positions A) = card {i. i < ?nr  f i  ?nc}" 
          using len_pp  by auto 
        have allj_hyp: "j < ?nr. f j = ?nc  ((Suc j) < ?nr  f (Suc j) = ?nc)" 
          using is_piv unfolding pivot_fun_def 
          using lt_i
          by (metis le_antisym le_less) 
        have if_then_bad: "f i  ?nc  (j. j  i  f j  ?nc)"
        proof clarsimp 
          fix j
          assume not_i: "f i  ?nc"
          assume j_leq: "j  i"
          assume bad_asm: "f j = ?nc"
          have "k. k  j   k < ?nr  f k = ?nc"
          proof -
            fix k :: nat
            assume a1: "j  k"
            assume a2: "k < dim_row A"
            have f3: "n. ¬ n < dim_row A  f n  f j  ¬ Suc n < dim_row A  f (Suc n) = f j"
              using allj_hyp bad_asm by presburger
            obtain nn :: "nat  nat  (nat  bool)  nat" where
              f4: "n na p nb nc. (¬ n  na  Suc n  Suc na)  (¬ p nb  ¬ nc  nb  ¬ p (nn nc nb p)  p nc)  (¬ p nb  ¬ nc  nb  p nc  p (Suc (nn nc nb p)))"
              using inc_induct by (metis Suc_le_mono)
            then have f5: "p. ¬ p k  p j  p (Suc (nn j k p))"
              using a1 by presburger
            have f6: "p. ¬ p k  ¬ p (nn j k p)  p j"
              using f4 a1 by meson
            { assume "nn j k (λn. n < dim_row A  f n  dim_col A) < dim_row A  f (nn j k (λn. n < dim_row A  f n  dim_col A))  dim_col A"
              moreover
              { assume "(nn j k (λn. n < dim_row A  f n  dim_col A) < dim_row A  f (nn j k (λn. n < dim_row A  f n  dim_col A))  dim_col A)  (¬ j < dim_row A  f j = dim_col A)"
                then have "¬ k < dim_row A  f k = dim_col A"
                  using f6
                  by (metis (mono_tags, lifting)) }
              ultimately have "(¬ j < dim_row A  f j = dim_col A)  (¬ Suc (nn j k (λn. n < dim_row A  f n  dim_col A)) < dim_row A  f (Suc (nn j k (λn. n < dim_row A  f n  dim_col A))) = dim_col A)  ¬ k < dim_row A  f k = dim_col A"
                using bad_asm
                by blast }
            moreover
            { assume "(¬ j < dim_row A  f j = dim_col A)  (¬ Suc (nn j k (λn. n < dim_row A  f n  dim_col A)) < dim_row A  f (Suc (nn j k (λn. n < dim_row A  f n  dim_col A))) = dim_col A)"
              then have "¬ k < dim_row A  f k = dim_col A"
                using f5
              proof -
                have "¬ (Suc (nn j k (λn. n < dim_row A  f n  dim_col A)) < dim_row A  f (Suc (nn j k (λn. n < dim_row A  f n  dim_col A)))  dim_col A)  ¬ (j < dim_row A  f j  dim_col A)"
                  using (¬ j < dim_row A  f j = dim_col A)  (¬ Suc (nn j k (λn. n < dim_row A  f n  dim_col A)) < dim_row A  f (Suc (nn j k (λn. n < dim_row A  f n  dim_col A))) = dim_col A) by linarith
                then have "¬ (k < dim_row A  f k  dim_col A)"
                  by (metis (mono_tags, lifting) a2 bad_asm f5 le_less)
                then show ?thesis
                  by meson
              qed }
            ultimately show "f k = dim_col A"
              using f3 a2 by (metis (lifting) Suc_lessD bad_asm)
          qed
          then show "False" using lt_i not_i
            using j_leq by blast 
        qed
        have "f i  ?nc  ({0..<?ip}  {y. y < ?nr  f y  dim_col A})"
        proof -
          have h1: "f i  dim_col A  (ji. j < ?nr  f j  dim_col A)"
            using if_then_bad lt_i by auto
          then show ?thesis by auto
        qed
        then have gteq: "f i  ?nc  (card {i. i < ?nr  f i  dim_col A}  (i+1))"
          using card_lessThan[of ?ip] card_mono[where B = "{i. i < ?nr  f i  dim_col A} ", where A = "{0..<?ip}"]
          by auto
        then have clear: "dim_row (take_rows A [0..<length (pivot_positions A)]) = length (pivot_positions A)"
          unfolding take_rows_def using dim_asm fileq by (auto)
        have "i + 1 > length (pivot_positions A)" using not_lt clear by auto
        then show "f i = ?nc" using gteq len_pp_var by auto
      qed
      show ?thesis using half1 half2
        by blast 
    qed
    then have h1a: "row A i =  0v (dim_col A)" 
      using pivot_fun_zero_row_iff[of A _ ?nc ?nr]
      using lt_i by blast
    then have h1: "A $$ (i, j) = 0"
      using index_row(1) lt_i lt_j by fastforce 
    have h2a: "i - dim_row (take_rows A [0..<length (pivot_positions A)]) < dim_row A - length (pivot_positions A)"
      using pivot_len lt_i not_lt
      by (simp add: take_rows_def)
    then have h2: "0m (dim_row A - length (pivot_positions A)) (dim_col A) $$
           (i - dim_row (take_rows A [0..<length (pivot_positions A)]), j) = 0 " 
      unfolding zero_mat_def using pivot_len lt_i lt_j
      using index_mat(1) by blast 
    then show "0m (dim_row A - length (pivot_positions A)) (dim_col A) $$
           (i - dim_row (take_rows A [0..<length (pivot_positions A)]), j) =
           A $$ (i, j)" using h1 h2
      by simp 
  qed
  have h3: "(dim_row (take_rows A [0..<length (pivot_positions A)])::nat) + ((dim_row A::nat) - (length (pivot_positions A)::nat)) =
    dim_row A"
  proof - 
    have h0: "dim_row (take_rows A [0..<length (pivot_positions A)]) = (length (pivot_positions A)::nat)" 
      by (simp add: take_rows_def fileq)
    then show ?thesis using add_diff_inverse_nat  pivot_len
      by linarith
  qed
  have h4: " i j. i < dim_row A 
           j < dim_col A 
           i < dim_row (take_rows A [0..<length (pivot_positions A)]) +
               (dim_row A - length (pivot_positions A))"
    using pivot_len
    by (simp add: h3) 
  then show ?thesis apply (subst mat_eq_iff)
    using h1 h2 h3 h4 by (auto simp add: append_rows_def)
qed

lemma length_pivot_positions_dim_row:
  assumes "row_echelon_form A"
  shows "length (pivot_positions A)  dim_row A"
proof -
  have 1: "A  carrier_mat (dim_row A) (dim_col A)" by auto
  obtain f where 2: "pivot_fun A f (dim_col A)"
    using assms row_echelon_form_def by blast
  from pivot_positions(4)[OF 1 2] have
    "length (pivot_positions A) = card {i. i < dim_row A  row A i  0v (dim_col A)}" .
  also have "...  card {i. i < dim_row A}"
    apply (rule card_mono)
    by auto
  ultimately show ?thesis by auto
qed

lemma rref_pivot_positions:
  assumes "row_echelon_form R"
  assumes R: "R  carrier_mat nr nc"
  shows "i j. (i,j)  set (pivot_positions R)  i < nr  j < nc"
proof -
  obtain f where f: "pivot_fun R f nc"
    using assms(1) assms(2) row_echelon_form_def by blast
  have *: "i. i < nr  f i  nc" using f
    using R pivot_funD(1) by blast
  from pivot_positions[OF R f]
  have "set (pivot_positions R) = {(i, f i) |i. i < nr  f i  nc}" by auto
  then have **: "set (pivot_positions R) = {(i, f i) |i. i < nr  f i < nc}"
    using *
    by fastforce
  fix i j
  assume "(i, j)  set (pivot_positions R)"
  thus "i < nr  j < nc" using **
    by simp
qed

lemma pivot_fun_monoton: 
  assumes pf: "pivot_fun A f (dim_col A)"
  assumes dr: "dim_row A = nr"
  shows " i. i < nr  ( k. ((k < nr  i < k)  f i  f k))"
proof -
  fix i
  assume "i < nr"
  show "( k. ((k < nr  i < k)  f i  f k))"
  proof -
    fix k
    show "((k < nr  i < k)  f i  f k)"
    proof (induct k)
      case 0
      then show ?case
        by blast 
    next
      case (Suc k)
      then show ?case 
        by (smt (verit, ccfv_SIG) dr le_less_trans less_Suc_eq less_imp_le_nat pf pivot_funD(1) pivot_funD(3))
    qed
  qed
qed

lemma pivot_positions_contains:
  assumes row_ech: "row_echelon_form A"
  assumes dim_h: "dim_col A  dim_row A"
  assumes "pivot_fun A f (dim_col A)"
  shows "i < (length (pivot_positions A)). (i, f i)  set (pivot_positions A)"
proof - 
  let ?nr = "dim_row A"
  let ?nc = "dim_col A"
  let ?pp = "pivot_positions A"          
  have i_nr: "i < (length ?pp). i < ?nr" using rref_pivot_positions assms
    using length_pivot_positions_dim_row less_le_trans by blast 
  have i_nc: "i < (length ?pp). f i < ?nc"
  proof clarsimp 
    fix i
    assume i_lt: "i < length ?pp"
    have fis_nc: "f i = ?nc  ( k > i. k < ?nr  f k = ?nc)"
    proof -
      assume is_nc: "f i = ?nc"
      show "( k > i. k < ?nr  f k = ?nc)" 
      proof clarsimp
        fix k
        assume k_gt: "k > i"
        assume k_lt: "k < ?nr"
        have fk_lt: "f k  ?nc" using pivot_funD(1)[of A ?nr f ?nc k] k_lt apply (auto)
          using pivot_fun A f (dim_col A) by blast 
        show "f k = ?nc"
          using fk_lt is_nc k_gt k_lt assms pivot_fun_monoton[of A f ?nr i k]
          using pivot_fun A f (dim_col A) by auto 
      qed
    qed
    have ncimp: "f i = ?nc  ( k i. k  { i. i < ?nr  row A i  0v ?nc})"
    proof -
      assume nchyp: "f i = ?nc"
      show "( k i. k  { i. i < ?nr  row A i  0v ?nc})"
      proof clarsimp 
        fix k
        assume i_lt: "i  k" 
        assume k_lt: "k < dim_row A"
        show "row A k = 0v (dim_col A) "
          using i_lt k_lt fis_nc
          using pivot_fun_zero_row_iff[of A f ?nc ?nr]
          using pivot_fun A f (dim_col A) le_neq_implies_less nchyp by blast 
      qed
    qed
    then have "f i = ?nc  card { i. i < ?nr  row A i  0v ?nc}  i"
    proof - 
      assume nchyp: "f i = ?nc"
      have h: "{ i. i < ?nr  row A i  0v ?nc}  {0..<i}"
        using atLeast0LessThan le_less_linear nchyp ncimp by blast
      then show " card { i. i < ?nr  row A i  0v ?nc}  i"
        using card_lessThan
        using subset_eq_atLeast0_lessThan_card by blast 
    qed
    then show "f i < ?nc" using i_lt pivot_positions(4)[of A ?nr ?nc f]
      apply (auto)
      by (metis pivot_fun A f (dim_col A) i_nr le_neq_implies_less not_less pivot_funD(1)) 
  qed
  then show ?thesis
    using pivot_positions(1)
    by (smt (verit, ccfv_SIG) pivot_fun A f (dim_col A) carrier_matI i_nr less_not_refl mem_Collect_eq)
qed

lemma pivot_positions_form_helper_1:
  shows "(a, b)  set (pivot_positions_main_gen z A nr nc i j)  i  a"
proof  (induct i j rule: pivot_positions_main_gen.induct[of nr nc A z])
  case (1 i j)
  then show ?case using  pivot_positions_main_gen.simps[of z A nr nc i j]
    by (metis Pair_inject Suc_leD emptyE list.set(1) nle_le set_ConsD)
qed

lemma pivot_positions_form_helper_2:
  shows "sorted_wrt (<) (map fst (pivot_positions_main_gen z A nr nc i j))"
proof  (induct i j rule: pivot_positions_main_gen.induct[of nr nc A z])
  case (1 i j)
  then show ?case using  pivot_positions_main_gen.simps[of z A nr nc i j] 
    by (auto simp: pivot_positions_form_helper_1 Suc_le_lessD) 
qed

lemma sorted_pivot_positions:
  shows "sorted_wrt (<) (map fst (pivot_positions A))"
  using pivot_positions_form_helper_2
  by (simp add: pivot_positions_form_helper_2 pivot_positions_gen_def) 

lemma pivot_positions_form:
  assumes row_ech: "row_echelon_form A"
  assumes dim_h: "dim_col A  dim_row A"
  shows " i < (length (pivot_positions A)). fst ((pivot_positions A) ! i) = i"
proof clarsimp 
  let ?nr = "dim_row A"
  let ?nc = "dim_col A"
  let ?pp = "pivot_positions A :: (nat × nat) list"
  fix i
  assume i_lt: "i < length (pivot_positions A)"
  have "f. pivot_fun A f ?nc" using row_ech unfolding row_echelon_form_def
    by blast
  then obtain f where pf:"pivot_fun A f ?nc"
    by blast                  
  have all_f_in: "i < (length ?pp). (i, f i)  set ?pp"
    using pivot_positions_contains pf
      assms 
    by blast   
  have sorted_hyp: " (p::nat) (q::nat). p < (length ?pp)  q < (length ?pp)  p < q  (fst (?pp ! p) < fst (?pp ! q))"
  proof -
    fix p::nat
    fix q::nat
    assume p_lt: "p < q"
    assume p_welldef: "p < (length ?pp)"
    assume q_welldef: "q < (length ?pp)"
    show "fst (?pp ! p) < fst (?pp ! q)"
      using sorted_pivot_positions p_lt p_welldef q_welldef sorted_wrt_nth_less by fastforce
  qed
  have h: "i < (length ?pp)  fst (pivot_positions A ! i) = i"
  proof (induct i)
    case 0
    have "j. fst (pivot_positions A ! j) = 0"
      by (metis all_f_in fst_conv i_lt in_set_conv_nth length_greater_0_conv list.size(3) not_less0)
    then obtain j where jth:" fst (pivot_positions A ! j) = 0"
      by blast      
    have "j  0  (fst (pivot_positions A ! 0) > 0  j  0)"
      by (smt (verit, ccfv_SIG) all_f_in fst_conv i_lt in_set_conv_nth less_nat_zero_code not_gr_zero sorted_hyp)
    then show ?case
      using jth neq0_conv by blast
  next
    case (Suc i)
    have ind_h: "i < length (pivot_positions A)  fst (pivot_positions A ! i) = i"
      using Suc.hyps by blast 
    have thesis_h: "(Suc i) < length (pivot_positions A)  fst (pivot_positions A ! (Suc i)) = (Suc i)"
    proof - 
      assume suc_i_lt: "(Suc i) < length (pivot_positions A)"
      have fst_i_is: "fst (pivot_positions A ! i) = i" using suc_i_lt ind_h
        using Suc_lessD by blast 
      have "(j < (length ?pp). fst (pivot_positions A ! j) = (Suc i))"
        by (metis suc_i_lt all_f_in fst_conv  in_set_conv_nth)
      then obtain j where jth: "j < (length ?pp)  fst (pivot_positions A ! j) = (Suc i)"
        by blast
      have "j > i"
        using sorted_hyp apply (auto)
        by (metis Suc_lessD fst (pivot_positions A ! i) = i jth less_not_refl linorder_neqE_nat n_not_Suc_n suc_i_lt)
      have "j > (Suc i)  False"
      proof -
        assume j_gt: "j > (Suc i)"
        then have h1: "fst (pivot_positions A ! (Suc i)) > i"
          using fst_i_is sorted_pivot_positions
          using sorted_hyp suc_i_lt by force
        have "fst (pivot_positions A ! j) > fst (pivot_positions A ! (Suc i))"
          using jth j_gt sorted_hyp apply (auto)
          by fastforce 
        then have h2: "fst (pivot_positions A ! (Suc i)) < (Suc i)" 
          using jth
          by simp   
        show "False" using h1 h2
          using not_less_eq by blast 
      qed
      show "fst (pivot_positions A ! (Suc i)) = (Suc i)"
        using Suc_lessI Suc i < j  False i < j jth by blast
    qed
    then show ?case
      by blast 
  qed
  then show "fst (pivot_positions A ! i) = i"
    using i_lt by auto
qed

lemma take_cols_pivot_eq:
  assumes row_ech: "row_echelon_form A"
  assumes dim_h: "dim_col A  dim_row A"
  shows "take_cols A (map snd (pivot_positions A)) =
    1m (length (pivot_positions A)) @r
    0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))"
proof - 
  let ?nr = "dim_row A"
  let ?nc = "dim_col A"
  have h1: " dim_col
     (1m (length (pivot_positions A)) @r
      0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))) = (length (pivot_positions A))"
    by (simp add: append_rows_def)
  have len_pivot: "length (pivot_positions A) = card {i. i < ?nr  row A i  0v ?nc}"
    using row_ech pivot_positions(4) row_echelon_form_def by blast
  have pp_leq_nc: "f. pivot_fun A f ?nc  (i < ?nr. f i  ?nc)" unfolding pivot_fun_def
    by meson 
  have pivot_set: "f. pivot_fun A f ?nc  set (pivot_positions A) = {(i, f i) | i. i < ?nr  f i  ?nc}"
    using row_ech row_echelon_form_def pivot_positions(1)
    by (smt (verit) Collect_cong carrier_matI)
  then have pivot_set_alt: "f. pivot_fun A f ?nc  set (pivot_positions A) = {(i, f i) | i. i < ?nr  row A i  0v ?nc}"
    using pivot_positions pivot_fun_zero_row_iff Collect_cong carrier_mat_triv
    by (smt (verit, best))
  have "f. pivot_fun A f ?nc  set (pivot_positions A) = {(i, f i) | i. f i  ?nc  i < ?nr  f i  ?nc}"
    using pivot_set pp_leq_nc by auto
  then have pivot_set_var: "f. pivot_fun A f ?nc  set (pivot_positions A) = {(i, f i) | i. i < ?nr  f i < ?nc}"
    by auto
  have "length (map snd (pivot_positions A)) = card (set (map snd (pivot_positions A)))" 
    using row_ech row_echelon_form_def pivot_positions(3) distinct_card[where xs = "map snd (pivot_positions A)"]
    by (metis carrier_mat_triv)
  then have "length (map snd (pivot_positions A)) = card (set (pivot_positions A))"
    by (metis card_distinct distinct_card distinct_map length_map) 
  then have "length (map snd (pivot_positions A)) = card {i. i < ?nr  row A i  0v ?nc}"
    using pivot_set_alt
    by (simp add: len_pivot) 
  then have length_asm: "length (map snd (pivot_positions A)) = length (pivot_positions A)"
    using len_pivot by linarith
  then have "a. List.member (map snd (pivot_positions A)) a  a < dim_col A"
  proof clarsimp 
    fix a
    assume a_in: "List.member (map snd (pivot_positions A)) a"
    have "v  set (pivot_positions A). a = snd v" 
      using a_in in_set_member[where xs = "(pivot_positions A)"] apply (auto)
      by (metis in_set_impl_in_set_zip2 in_set_member length_map snd_conv zip_map_fst_snd) 
    then show "a < dim_col A"
      using pivot_set_var in_set_member by auto
  qed
  then have h2b: "(filter (λy. y < dim_col A) (map snd (pivot_positions A))) =  (map snd (pivot_positions A))"
    by (meson filter_True in_set_member)
  then have h2a: "length (map ((!) (cols A)) (filter (λy. y < dim_col A) (map snd (pivot_positions A)))) = length (pivot_positions A)"
    using length_asm
    by (simp add: h2b) 
  then have h2: "length (pivot_positions A)  dim_row A 
    dim_col (take_cols A (map snd (pivot_positions A))) = (length (pivot_positions A))" 
    unfolding take_cols_def using mat_of_cols_carrier by auto
  have h_len: "length (pivot_positions A)  dim_row A 
    dim_col (take_cols A (map snd (pivot_positions A))) =
    dim_col
     (1m (length (pivot_positions A)) @r
      0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A)))" 
    using h1 h2
    by (simp add: h1 assms length_pivot_positions_dim_row)
  have h2: "i j. length (pivot_positions A)  dim_row A 
           i < dim_row A 
           j < dim_col
                (1m (length (pivot_positions A)) @r
                 0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))) 
           take_cols A (map snd (pivot_positions A)) $$ (i, j) =
           (1m (length (pivot_positions A)) @r
            0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))) $$
           (i, j)" 
  proof -
    fix i 
    fix j 
    let ?pp = "(pivot_positions A)"
    assume len_lt: "length (pivot_positions A)  dim_row A" 
    assume i_lt: " i < dim_row A" 
    assume j_lt: "j < dim_col
                (1m (length (pivot_positions A)) @r
                 0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A)))"
    let ?w = "((map snd (pivot_positions A)) ! j)"
    have breaking_it_down: "mat_of_cols (dim_row A)
     (map ((!) (cols A)) (map snd (pivot_positions A))) $$ (i, j)  
     =  ((cols A) ! ?w) $ i"
      apply (auto)
      by (metis comp_apply h1 i_lt j_lt length_map mat_of_cols_index nth_map) 
    have h1a: "i < (length ?pp)  (mat_of_cols (dim_row A) (map ((!) (cols A)) (map snd (pivot_positions A))) $$ (i, j) 
        = (1m (length (pivot_positions A))) $$ (i, j))"
    proof - 
      (* need to, using row_ech, rely heavily on pivot_fun_def, that num_cols ≥ num_rows, and row_echelon form*)
      assume "i < (length ?pp)"
      have "f. pivot_fun A f ?nc" using row_ech unfolding row_echelon_form_def
        by blast
      then obtain f where "pivot_fun A f ?nc"
        by blast
      have j_nc: "j < (length ?pp)" using j_lt
        by (simp add: h1) 
      then have j_lt_nr: "j < ?nr" using dim_h
        using len_lt by linarith 
      then have is_this_true: "(pivot_positions A) ! j = (j, f j)" 
        using pivot_positions_form pivot_positions(1)[of A ?nr ?nc f]
      proof -
        have "pivot_positions A ! j  set (pivot_positions A)"
          using j_nc nth_mem by blast
        then have "n. pivot_positions A ! j = (n, f n)  n < dim_row A  f n  dim_col A"
          using A  carrier_mat (dim_row A) (dim_col A); pivot_fun A f (dim_col A)  set (pivot_positions A) = {(i, f i) |i. i < dim_row A  f i  dim_col A} pivot_fun A f (dim_col A) by blast
        then show ?thesis
          by (metis (no_types) A. row_echelon_form A; dim_row A  dim_col A  i<length (pivot_positions A). fst (pivot_positions A ! i) = i dim_h fst_conv j_nc row_ech)
      qed
      then have w_is: "?w = f j"
        by (metis h1 j_lt nth_map snd_conv)
      have h0: "i = j  ((cols A) ! ?w) $ i = 1" using w_is pivot_funD(4)[of A ?nr f ?nc i]
        by (metis a. List.member (map snd (pivot_positions A)) a  a < dim_col A i < length (pivot_positions A) pivot_fun A f (dim_col A) cols_length i_lt in_set_member length_asm mat_of_cols_cols mat_of_cols_index nth_mem)
      have h1:  "i  j  ((cols A) ! ?w) $ i = 0" using w_is pivot_funD(5)
        by (metis a. List.member (map snd (pivot_positions A)) a  a < dim_col A pivot_fun A f (dim_col A) cols_length h1 i_lt in_set_member j_lt len_lt length_asm less_le_trans mat_of_cols_cols mat_of_cols_index nth_mem)
      show "(mat_of_cols (dim_row A) (map ((!) (cols A)) (map snd (pivot_positions A))) $$ (i, j) 
        = (1m (length (pivot_positions A))) $$ (i, j))" using h0 h1 breaking_it_down
        by (metis i < length (pivot_positions A) h2 h_len index_one_mat(1) j_lt len_lt) 
    qed
    have h1b: "i  (length ?pp)  (mat_of_cols (dim_row A) (map ((!) (cols A)) (map snd (pivot_positions A))) $$ (i, j)  = 0)"
    proof - 
      assume i_gt: "i  (length ?pp)"
      have h0a: "((cols A) ! ((map snd (pivot_positions A)) ! j)) $ i = (row A i) $ ?w"
        by (metis a. List.member (map snd (pivot_positions A)) a  a < dim_col A cols_length h1 i_lt in_set_member index_row(1) j_lt length_asm mat_of_cols_cols mat_of_cols_index nth_mem)
      have h0b: 
        "take_rows A [0..<length (pivot_positions A)] @r 0m (dim_row A - length (pivot_positions A)) (dim_col A) = A"
        using assms row_echelon_form_zero_rows[of A]
        by blast 
      then have h0c: "(row A i) = 0v (dim_col A)"  using i_gt
        by (smt (verit, best) add_diff_cancel_left' add_diff_cancel_right' add_less_cancel_left dim_col_take_rows 
            dim_row_append_rows i_lt index_zero_mat(2) index_zero_mat(3) le_Suc_ex len_lt nat_less_le nle_le row_append_rows row_zero)
      then show ?thesis using h0a breaking_it_down apply (auto)
        by (metis a. List.member (map snd (pivot_positions A)) a  a < dim_col A h1 in_set_member index_zero_vec(1) j_lt length_asm nth_mem) 
    qed
    have h1: " mat_of_cols (dim_row A)
     (map ((!) (cols A)) (map snd (pivot_positions A))) $$ (i, j) =
           (1m (length (pivot_positions A)) @r
            0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))) $$
           (i, j) " using h1a h1b
      by (smt (verit) add_diff_inverse_nat append_rows_index diff_less_mono h1 i_lt index_one_mat(2) index_one_mat(3) index_zero_mat(1) index_zero_mat(2) index_zero_mat(3) j_lt leD len_lt not_le_imp_less)
    then show "take_cols A (map snd (pivot_positions A)) $$ (i, j) =
           (1m (length (pivot_positions A)) @r
            0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))) $$
           (i, j)" 
      unfolding take_cols_def
      by (simp add: h2b)
  qed
  show ?thesis
    unfolding mat_eq_iff
    using length_pivot_positions_dim_row[OF assms(1)] h_len h2 by auto
qed

lemma rref_right_mul:
  assumes "row_echelon_form A"
  assumes "dim_col A  dim_row A"
  shows
    "take_cols A (map snd (pivot_positions A)) * take_rows A [0..<length (pivot_positions A)] = A"
proof -
  from take_cols_pivot_eq[OF assms] have
    1: "take_cols A (map snd (pivot_positions A)) =
    1m (length (pivot_positions A)) @r
    0m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))" .
  have 2: "take_cols A (map snd (pivot_positions A)) * take_rows A [0..<length (pivot_positions A)] =
    take_rows A [0..<length (pivot_positions A)]  @r 0m (dim_row A - length (pivot_positions A)) (dim_col A)"
    unfolding 1
    apply (simp add: append_rows_mat_mul)
    by (metis (no_types, lifting) "1" add_right_imp_eq assms dim_col_take_rows dim_row_append_rows dim_row_take_cols index_one_mat(2) index_zero_mat(2) left_mult_one_mat' left_mult_zero_mat' row_echelon_form_zero_rows)
  from row_echelon_form_zero_rows[OF assms] have "... = A" .
  thus ?thesis
    by (simp add: "2")
qed

context conjugatable_vec_space begin

lemma lin_indpt_id:
  shows "lin_indpt (set (cols (1m n)::'a vec list))"
proof -
  have *: "set (cols (1m n)) = set (rows (1m n))"
    by (metis cols_transpose transpose_one)
  have "det (1m n)  0" using det_one by auto
  from det_not_0_imp_lin_indpt_rows[OF _ this]
  have "lin_indpt (set (rows (1m n)))"
    using one_carrier_mat by blast
  thus ?thesis
    by (simp add: *) 
qed

lemma lin_indpt_take_cols_id:
  shows "lin_indpt (set (cols (take_cols (1m n) inds)))"
proof - 
  have subset_h: "set (cols (take_cols (1m n) inds))  set (cols (1m n)::'a vec list)"
    using cols_take_cols_subset by blast
  then show ?thesis using lin_indpt_id subset_li_is_li by auto
qed

lemma cols_id_unit_vecs:
  shows "cols (1m d) = unit_vecs d"
  unfolding unit_vecs_def list_eq_iff_nth_eq
  by auto

lemma distinct_cols_id:
  shows "distinct (cols (1m d)::'a vec list)"
  by (simp add: conjugatable_vec_space.cols_id_unit_vecs vec_space.unit_vecs_distinct)

lemma distinct_map_nth:
  assumes "distinct ls"
  assumes "distinct inds"
  assumes "j. j  set inds  j < length ls"
  shows "distinct (map ((!) ls) inds)"
  by (simp add: assms(1) assms(2) assms(3) distinct_map inj_on_nth)

lemma distinct_take_cols_id:
  assumes "distinct inds"
  shows "distinct (cols (take_cols (1m n) inds) :: 'a vec list)"
  unfolding take_cols_def
  apply (subst cols_mat_of_cols)
   apply (auto intro!:  distinct_map_nth simp add: distinct_cols_id)
  using assms distinct_filter by blast

lemma rank_take_cols:
  assumes "distinct inds"
  shows "rank (take_cols (1m n) inds) = length (filter ((>) n) inds)"
  apply (subst lin_indpt_full_rank[of _ "length (filter ((>) n) inds)"])
     apply (auto simp add: lin_indpt_take_cols_id)
   apply (metis (full_types) index_one_mat(2) index_one_mat(3) length_map mat_of_cols_carrier(1) take_cols_def)
  by (simp add: assms distinct_take_cols_id)

lemma rank_mul_left_invertible_mat:
  fixes A::"'a mat"
  assumes "invertible_mat A"
  assumes "A  carrier_mat n n"
  assumes "B  carrier_mat n nc"
  shows "rank (A * B) = rank B"
proof -
  obtain C where C: "inverts_mat A C" "inverts_mat C A"
    using assms invertible_mat_def by blast 
  from C have ceq: "C * A = 1m n"
    by (metis assms(2) carrier_matD(2) index_mult_mat(3) index_one_mat(3) inverts_mat_def)
  then have *:"B = C*A*B"
    using assms(3) by auto
  from rank_mat_mul_left[OF assms(2-3)]
  have **: "rank (A*B)  rank B" .
  have 1: "C  carrier_mat n n" using C ceq
    by (metis assms(2) carrier_matD(1) carrier_matI index_mult_mat(3) index_one_mat(3) inverts_mat_def) 
  have 2: "A * B  carrier_mat n nc" using assms by auto  
  have "rank B = rank (C* A * B)" using * by auto
  also have "...  rank (A*B)" using rank_mat_mul_left[OF 1 2]
    using "1" assms(2) assms(3) by auto
  ultimately show ?thesis using ** by auto
qed

lemma invertible_take_cols_rank:
  fixes A::"'a mat"
  assumes "invertible_mat A"
  assumes "A  carrier_mat n n"
  assumes "distinct inds"
  shows "rank (take_cols A inds) = length (filter ((>) n) inds)"
proof -
  have " A = A * 1m n" using assms(2) by auto
  then have "take_cols A inds = A * take_cols (1m n) inds"
    by (metis assms(2) one_carrier_mat take_cols_mat_mul)
  then have "rank (take_cols A inds) = rank (take_cols (1m n) inds)"
    by (metis assms(1) assms(2) conjugatable_vec_space.rank_mul_left_invertible_mat one_carrier_mat take_cols_carrier_mat)
  thus ?thesis
    by (simp add: assms(3) conjugatable_vec_space.rank_take_cols)
qed

lemma rank_take_cols_leq:
  assumes R:"R  carrier_mat n nc"
  shows "rank (take_cols R ls)  rank R"
proof -
  from take_cols_mat_mul[OF R]
  have "take_cols R ls =  R * take_cols (1m nc) ls"
    by (metis assms one_carrier_mat right_mult_one_mat)
  thus ?thesis
    by (metis assms one_carrier_mat take_cols_carrier_mat vec_space.rank_mat_mul_right)
qed

lemma rank_take_cols_geq:
  assumes R:"R  carrier_mat n nc"
  assumes t:"take_cols R ls  carrier_mat n r"
  assumes B:"B  carrier_mat r nc"
  assumes "R = (take_cols R ls) * B"
  shows "rank (take_cols R ls)  rank R"
  by (metis B assms(4) t vec_space.rank_mat_mul_right)

lemma rref_drop_pivots:
  assumes row_ech: "row_echelon_form R"
  assumes dims: "R  carrier_mat n nc"
  assumes order: "nc  n"
  shows "rank (take_cols R (map snd (pivot_positions R))) = rank R"
proof -
  let ?B = "take_rows R [0..<length (pivot_positions R)]"
  have equa: "R = take_cols R (map snd (pivot_positions R)) * ?B" using assms rref_right_mul
    by (metis carrier_matD(1) carrier_matD(2))
  have ex_r: "r. take_cols R (map snd (pivot_positions R))  carrier_mat n r  ?B  carrier_mat r nc"
  proof - 
    have h1:
      "take_cols R (map snd (pivot_positions R))  carrier_mat n (length (pivot_positions R))"
      using assms
      by (metis in_set_impl_in_set_zip2 length_map rref_pivot_positions take_cols_carrier_mat_strict zip_map_fst_snd)
    have " f. pivot_fun R f nc" using row_ech unfolding row_echelon_form_def using dims
      by blast
    then have "length (pivot_positions R) = card {i. i < n  row R i  0v nc}"
      using pivot_positions[of R n nc]
      using dims by auto 
    then have "nc  length (pivot_positions R)" using order
      using carrier_matD(1) dims dual_order.trans length_pivot_positions_dim_row row_ech by blast
    then have "dim_col R  length (pivot_positions R)" using dims by auto
    then have h2: "?B  carrier_mat (length (pivot_positions R)) nc" unfolding take_rows_def
      using dims 
      by (smt (verit) atLeastLessThan_iff carrier_matD(2) filter_True le_eq_less_or_eq length_map 
          length_pivot_positions_dim_row less_trans map_nth mat_of_cols_carrier(1) row_ech set_upt transpose_carrier_mat transpose_mat_of_rows) 
    show ?thesis using h1 h2
      by blast
  qed
    (* prove the other two dimensionality assumptions *)
  have "rank R   rank (take_cols R (map snd (pivot_positions R)))"
    using dims ex_r rank_take_cols_geq[where R = "R", where B = "?B", where ls = "(map snd (pivot_positions R))", where nc = "nc"]
    using equa by blast
  thus ?thesis
    using assms(2) conjugatable_vec_space.rank_take_cols_leq le_antisym by blast
qed

lemma gjs_and_take_cols_var:
  fixes A::"'a mat"
  assumes A:"A  carrier_mat n nc"
  assumes order: "nc  n"
  shows "(take_cols A (map snd (pivot_positions (gauss_jordan_single A)))) = 
  (take_cols_var A (map snd (pivot_positions (gauss_jordan_single A))))"
proof -
  let ?gjs = "(gauss_jordan_single A)"
  have "x. List.member (map snd (pivot_positions (gauss_jordan_single A))) x  x  dim_col A"  
    using rref_pivot_positions gauss_jordan_single(3) carrier_matD(2) gauss_jordan_single(2) in_set_impl_in_set_zip2 in_set_member length_map less_irrefl less_trans not_le_imp_less zip_map_fst_snd
    by (metis (no_types, lifting) carrier_mat_triv)
  then have "(filter (λy. y < dim_col A) (map snd (pivot_positions (gauss_jordan_single A)))) = 
    (map snd (pivot_positions (gauss_jordan_single A)))"
    by (metis (no_types, lifting) A carrier_matD(2) filter_True gauss_jordan_single(2) gauss_jordan_single(3) in_set_impl_in_set_zip2 length_map rref_pivot_positions zip_map_fst_snd)
  then show ?thesis unfolding take_cols_def take_cols_var_def
    by simp
qed

lemma gauss_jordan_single_rank:
  fixes A::"'a mat"
  assumes A:"A  carrier_mat n nc"
  assumes order: "nc  n"
  shows "rank (take_cols A (map snd (pivot_positions (gauss_jordan_single A)))) = rank A"
proof -
  let ?R = "gauss_jordan_single A"
  obtain P where P:"PUnits (ring_mat TYPE('a) n undefined)" and
    i: "?R = P * A" using gauss_jordan_transform[OF A]
    by (metis A carrier_matD(1) fst_eqD gauss_jordan_single_def surj_pair zero_carrier_mat)
  have pcarrier: "P  carrier_mat n n" using P unfolding Units_def
    by (auto simp add: ring_mat_def)
  have "invertible_mat P" using P unfolding invertible_mat_def Units_def inverts_mat_def
    apply auto
     apply (simp add: ring_mat_simps(5))
    by (metis index_mult_mat(2) index_one_mat(2) ring_mat_simps(1) ring_mat_simps(3))
  then
  obtain Pi where Pi: "invertible_mat Pi" "Pi * P = 1m n"
  proof -
    assume a1: "Pi. invertible_mat Pi; Pi * P = 1m n  thesis"
    have "dim_row P = n"
      by (metis (no_types) A assms(1) carrier_matD(1) gauss_jordan_single(2) i index_mult_mat(2))
    then show ?thesis
      using a1 by (metis (no_types) invertible_mat P index_mult_mat(3) index_one_mat(3) invertible_mat_def inverts_mat_def square_mat.simps)
  qed
  then have pi_carrier:"Pi  carrier_mat n n"
    by (metis carrier_mat_triv index_mult_mat(2) index_one_mat(2) invertible_mat_def square_mat.simps)
  have R1:"row_echelon_form ?R"
    using assms(2) gauss_jordan_single(3) by blast
  have R2: "?R  carrier_mat n nc"
    using A assms(2) gauss_jordan_single(2) by auto
  have Rcm: "take_cols ?R (map snd (pivot_positions ?R))
     carrier_mat n (length (map snd (pivot_positions ?R)))"
    apply (rule take_cols_carrier_mat_strict[OF R2])
    using rref_pivot_positions[OF R1 R2] by auto
  have "Pi * ?R = A" using i Pi
    by (smt (verit, best) A assoc_mult_mat left_mult_one_mat pcarrier pi_carrier)
  then have "rank (take_cols A (map snd (pivot_positions ?R))) = rank (take_cols (Pi * ?R) (map snd (pivot_positions ?R)))"
    by auto
  also have "... = rank ( Pi * take_cols ?R (map snd (pivot_positions ?R)))"
    by (metis A gauss_jordan_single(2) pi_carrier take_cols_mat_mul)
  also have "... = rank (take_cols ?R (map snd (pivot_positions ?R)))"
    by (intro rank_mul_left_invertible_mat[OF Pi(1) pi_carrier Rcm])
  also have "... = rank ?R"
    using assms(2) conjugatable_vec_space.rref_drop_pivots gauss_jordan_single(3)
    using R1 R2 by blast
  ultimately show ?thesis                                                            
    using A P  carrier_mat n n invertible_mat P conjugatable_vec_space.rank_mul_left_invertible_mat i
    by auto
qed

lemma lin_indpt_subset_cols:
  fixes A:: "'a mat"
  fixes B:: "'a vec set"
  assumes "A  carrier_mat n n"
  assumes inv: "invertible_mat A"
  assumes "B  set (cols A)"
  shows "lin_indpt B"
proof -
  have "det A  0"
    using assms(1) inv invertible_det by blast
  then have "lin_indpt (set (rows AT))"
    using assms(1) idom_vec.lin_dep_cols_imp_det_0 by auto
  thus ?thesis using subset_li_is_li assms(3)
    by auto
qed

lemma rank_invertible_subset_cols:
  fixes A:: "'a mat"
  fixes B:: "'a vec list"
  assumes inv: "invertible_mat A"
  assumes A_square: "A  carrier_mat n n"
  assumes set_sub: "set (B)  set (cols A)"
  assumes dist_B: "distinct B"
  shows "rank (mat_of_cols n B) = length B"
proof - 
  let ?B_mat = "(mat_of_cols n B)"
  have h1: "lin_indpt (set(B))" 
    using assms lin_indpt_subset_cols[of A] by auto
  have "set B  carrier_vec n"
    using set_sub A_square cols_dim[of A] by auto
  then have cols_B: "cols (mat_of_cols n B) = B" using cols_mat_of_cols by auto
  then have "maximal (set B) (λT. T  set (B)  lin_indpt T)" using h1
    by (simp add: maximal_def subset_antisym)
  then have h2: "maximal (set B) (λT. T  set (cols (mat_of_cols n B))  lin_indpt T)"
    using cols_B by auto
  have h3: "rank (mat_of_cols n B) = card (set B)"
    using h1 h2 rank_card_indpt[of ?B_mat]
    using mat_of_cols_carrier(1) by blast 
  then show ?thesis using assms distinct_card by auto
qed

end

end