Theory Gauss_Jordan_Elimination

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section Gauss-Jordan Algorithm

text We define the elementary row operations and use them to implement the
  Gauss-Jordan algorithm to transform matrices into row-echelon-form. 
  This algorithm is used to implement the inverse of a matrix and to derive
  certain results on determinants, as well as determine a basis of the kernel
  of a matrix. 

theory Gauss_Jordan_Elimination
imports Matrix
begin

subsection Row Operations

definition mat_multrow_gen :: "('a  'a  'a)  nat  'a  'a mat  'a mat" where
  "mat_multrow_gen mul k a A = mat (dim_row A) (dim_col A) 
     (λ (i,j). if k = i then mul a (A $$ (i,j)) else A $$ (i,j))"

abbreviation mat_multrow :: "nat  'a :: semiring_1  'a mat  'a mat" ("multrow") where
  "multrow  mat_multrow_gen ((*))"

lemmas mat_multrow_def = mat_multrow_gen_def

definition multrow_mat :: "nat  nat  'a :: semiring_1  'a mat" where
  "multrow_mat n k a = mat n n 
     (λ (i,j). if k = i  k = j then a else if i = j then 1 else 0)"

definition mat_swaprows :: "nat  nat  'a mat  'a mat" ("swaprows")where
  "swaprows k l A = mat (dim_row A) (dim_col A) 
    (λ (i,j). if k = i then A $$ (l,j) else if l = i then A $$ (k,j) else A $$ (i,j))"

definition swaprows_mat :: "nat  nat  nat  'a :: semiring_1 mat" where
  "swaprows_mat n k l = mat n n
    (λ (i,j). if k = i  l = j  k = j  l = i  i = j  i  k  i  l then 1 else 0)"

definition mat_addrow_gen :: "('a  'a  'a)  ('a  'a  'a)  'a  nat  nat  'a mat  'a mat" where
  "mat_addrow_gen ad mul a k l A = mat (dim_row A) (dim_col A) 
    (λ (i,j). if k = i then ad (mul a (A $$ (l,j))) (A $$ (i,j)) else A $$ (i,j))"

abbreviation mat_addrow :: "'a :: semiring_1  nat  nat  'a mat  'a mat" ("addrow") where
  "addrow  mat_addrow_gen (+) ((*))"

lemmas mat_addrow_def = mat_addrow_gen_def

definition addrow_mat :: "nat  'a :: semiring_1  nat  nat  'a mat" where
  "addrow_mat n a k l = mat n n (λ (i,j). 
    (if k = i  l = j then (+) a else id) (if i = j then 1 else 0))"

lemma index_mat_multrow[simp]: 
  "i < dim_row A  j < dim_col A  mat_multrow_gen mul k a A $$ (i,j) = (if k = i then mul a (A $$ (i,j)) else A $$ (i,j))"
  "i < dim_row A  j < dim_col A  mat_multrow_gen mul i a A $$ (i,j) = mul a (A $$ (i,j))"
  "i < dim_row A  j < dim_col A  k  i  mat_multrow_gen mul k a A $$ (i,j) = A $$ (i,j)"
  "dim_row (mat_multrow_gen mul k a A) = dim_row A" "dim_col (mat_multrow_gen mul k a A) = dim_col A"
  unfolding mat_multrow_def by auto

lemma index_mat_multrow_mat[simp]:
  "i < n  j < n  multrow_mat n k a $$ (i,j) = (if k = i  k = j then a else if i = j 
     then 1 else 0)"
  "dim_row (multrow_mat n k a) = n" "dim_col (multrow_mat n k a) = n"
  unfolding multrow_mat_def by auto

lemma index_mat_swaprows[simp]: 
  "i < dim_row A  j < dim_col A  swaprows k l A $$ (i,j) = (if k = i then A $$ (l,j) else 
    if l = i then A $$ (k,j) else A $$ (i,j))"
  "dim_row (swaprows k l A) = dim_row A" "dim_col (swaprows k l A) = dim_col A"
  unfolding mat_swaprows_def by auto

lemma index_mat_swaprows_mat[simp]:
  "i < n  j < n  swaprows_mat n k l $$ (i,j) = 
    (if k = i  l = j  k = j  l = i  i = j  i  k  i  l then 1 else 0)"
  "dim_row (swaprows_mat n k l) = n" "dim_col (swaprows_mat n k l) = n"
  unfolding swaprows_mat_def by auto

lemma index_mat_addrow[simp]: 
  "i < dim_row A  j < dim_col A  mat_addrow_gen ad mul a k l A $$ (i,j) = (if k = i then 
    ad (mul a (A $$ (l,j))) (A $$ (i,j)) else A $$ (i,j))"
  "i < dim_row A  j < dim_col A  mat_addrow_gen ad mul a i l A $$ (i,j) = ad (mul a (A $$ (l,j))) (A $$ (i,j))"
  "i < dim_row A  j < dim_col A  k  i  mat_addrow_gen ad mul a k l A $$ (i,j) = A $$(i,j)"
  "dim_row (mat_addrow_gen ad mul a k l A) = dim_row A" "dim_col (mat_addrow_gen ad mul a k l A) = dim_col A"
  unfolding mat_addrow_def by auto

lemma index_mat_addrow_mat[simp]:
  "i < n  j < n  addrow_mat n a k l $$ (i,j) = 
    (if k = i  l = j then (+) a else id) (if i = j then 1 else 0)"
  "dim_row (addrow_mat n a k l) = n" "dim_col (addrow_mat n a k l) = n"
  unfolding addrow_mat_def by auto

lemma multrow_carrier[simp]: "(mat_multrow_gen mul k a A  carrier_mat n nc) = (A  carrier_mat n nc)"
  unfolding carrier_mat_def by fastforce

lemma multrow_mat_carrier[simp]: "multrow_mat n k a  carrier_mat n n"
  unfolding carrier_mat_def by auto

lemma addrow_mat_carrier[simp]: "addrow_mat n a k l  carrier_mat n n"
  unfolding carrier_mat_def by auto

lemma swaprows_mat_carrier[simp]: "swaprows_mat n k l  carrier_mat n n"
  unfolding carrier_mat_def by auto

lemma swaprows_carrier[simp]: "(swaprows k l A  carrier_mat n nc) = (A  carrier_mat n nc)"
  unfolding carrier_mat_def by fastforce

lemma addrow_carrier[simp]: "(mat_addrow_gen ad mul a k l A  carrier_mat n nc) = (A  carrier_mat n nc)"
  unfolding carrier_mat_def by fastforce

lemma row_multrow:  "k  i  i < n  row (multrow_mat n k a) i = unit_vec n i"
  "k < n  row (multrow_mat n k a) k = a v unit_vec n k"
  by (rule eq_vecI, auto)

lemma multrow_mat: assumes A: "A  carrier_mat n nc"
  shows "multrow k a A = multrow_mat n k a * A"
  by (rule eq_matI, insert A, auto simp: row_multrow smult_scalar_prod_distrib[of _ n])

lemma row_addrow: 
  "k  i  i < n  row (addrow_mat n a k l) i = unit_vec n i"
  "k < n  l < n  row (addrow_mat n a k l) k = a v unit_vec n l + unit_vec n k"
  by (rule eq_vecI, auto)

lemma addrow_mat: assumes A: "A  carrier_mat n nc" 
  and l: "l < n"
  shows "addrow a k l A = addrow_mat n a k l * A"
  by (rule eq_matI, insert l A, auto simp: row_addrow 
  add_scalar_prod_distrib[of _ n] smult_scalar_prod_distrib[of _ n])

lemma row_swaprows: 
  "l < n  row (swaprows_mat n l l) l = unit_vec n l"
  "i  k  i  l  i < n  row (swaprows_mat n k l) i = unit_vec n i"
  "k < n  l < n  row (swaprows_mat n k l) l = unit_vec n k"
  "k < n  l < n  row (swaprows_mat n k l) k = unit_vec n l"
  by (rule eq_vecI, auto)

lemma swaprows_mat: assumes A: "A  carrier_mat n nc" and k: "k < n" "l < n"
  shows "swaprows k l A = swaprows_mat n k l * A"
  by (rule eq_matI, insert A k, auto simp: row_swaprows)

lemma swaprows_mat_inv: assumes k: "k < n" and l: "l < n"
  shows "swaprows_mat n k l * swaprows_mat n k l = 1m n"
proof -
  have "swaprows_mat n k l * swaprows_mat n k l = 
    swaprows_mat n k l * (swaprows_mat n k l * 1m n)"
    by (simp add: right_mult_one_mat[of _ n])
  also have "swaprows_mat n k l * 1m n = swaprows k l (1m n)"
    by (rule swaprows_mat[symmetric, OF _ k l, of _ n], simp)
  also have "swaprows_mat n k l *  = swaprows k l "
    by (rule swaprows_mat[symmetric, of _ _ n], insert k l, auto)
  also have " = 1m n"
    by (rule eq_matI, insert k l, auto)
  finally show ?thesis .
qed

lemma swaprows_mat_Unit: assumes k: "k < n" and l: "l < n"
  shows "swaprows_mat n k l  Units (ring_mat TYPE('a :: semiring_1) n b)"
proof -
  interpret m: semiring "ring_mat TYPE('a) n b" by (rule semiring_mat)
  show ?thesis unfolding Units_def
    by (rule, rule conjI[OF _ bexI[of _ "swaprows_mat n k l"]],
    auto simp: ring_mat_def swaprows_mat_inv[OF k l] swaprows_mat_inv[OF l k])
qed

lemma addrow_mat_inv: assumes k: "k < n" and l: "l < n" and neq: "k  l"
  shows "addrow_mat n a k l * addrow_mat n (- (a :: 'a :: comm_ring_1)) k l = 1m n"
proof -
  have "addrow_mat n a k l * addrow_mat n (- a) k l = 
    addrow_mat n a k l * (addrow_mat n (- a) k l * 1m n)"
    by (simp add: right_mult_one_mat[of _ n])
  also have "addrow_mat n (- a) k l * 1m n = addrow (- a) k l (1m n)"
    by (rule addrow_mat[symmetric, of _ _ n], insert k l, auto)
  also have "addrow_mat n a k l *  = addrow a k l "
    by (rule addrow_mat[symmetric, of _ _ n], insert k l, auto)
  also have " = 1m n"
    by (rule eq_matI, insert k l neq, auto, algebra)
  finally show ?thesis .
qed

lemma addrow_mat_Unit: assumes k: "k < n" and l: "l < n" and neq: "k  l"
  shows "addrow_mat n a k l  Units (ring_mat TYPE('a :: comm_ring_1) n b)"
proof -
  interpret m: semiring "ring_mat TYPE('a) n b" by (rule semiring_mat)
  show ?thesis unfolding Units_def
    by (rule, rule conjI[OF _ bexI[of _ "addrow_mat n (- a) k l"]], insert neq,
    auto simp: ring_mat_def addrow_mat_inv[OF k l neq], 
    rule trans[OF _ addrow_mat_inv[OF k l neq, of "- a"]], auto)
qed

lemma multrow_mat_inv: assumes k: "k < n" and a: "(a :: 'a :: division_ring)  0"
  shows "multrow_mat n k a * multrow_mat n k (inverse a) = 1m n"
proof -
  have "multrow_mat n k a * multrow_mat n k (inverse a) = 
    multrow_mat n k a * (multrow_mat n k (inverse a) * 1m n)"
    using k by (simp add: right_mult_one_mat[of _ n])
  also have "multrow_mat n k (inverse a) * 1m n = multrow k (inverse a) (1m n)"
    by (rule multrow_mat[symmetric, of _ _ n], insert k, auto)
  also have "multrow_mat n k a *  = multrow k a "
    by (rule multrow_mat[symmetric, of _ _ n], insert k, auto)
  also have " = 1m n"
    by (rule eq_matI, insert a k a, auto)
  finally show ?thesis .
qed

lemma multrow_mat_Unit: assumes k: "k < n" and a: "(a :: 'a :: division_ring)  0"
  shows "multrow_mat n k a  Units (ring_mat TYPE('a) n b)"
proof -
  from a have ia: "inverse a  0" by auto
  interpret m: semiring "ring_mat TYPE('a) n b" by (rule semiring_mat)
  show ?thesis unfolding Units_def
    by (rule, rule conjI[OF _ bexI[of _ "multrow_mat n k (inverse a)"]], insert a,
    auto simp: ring_mat_def multrow_mat_inv[OF k],
    rule trans[OF _ multrow_mat_inv[OF k ia]], insert a, auto)
qed

subsection Gauss-Jordan Elimination

fun eliminate_entries_rec where
  "eliminate_entries_rec B i [] = B"
| "eliminate_entries_rec B i ((ai'j,i') # is) = ( 
  eliminate_entries_rec (mat_addrow_gen ((+) :: 'b :: ring_1  'b  'b) (*) ai'j i' i B) i is)"

context
  fixes minus :: "'a  'a  'a"
  and times :: "'a  'a  'a"
begin

definition eliminate_entries_gen :: "(nat  'a)  'a mat  nat  nat  'a mat" where
  "eliminate_entries_gen v A I J = mat (dim_row A) (dim_col A) (λ (i, j).
     if i  I then minus (A $$ (i,j)) (times (v i) (A $$ (I,j))) else A $$ (i,j))" 

lemma dim_eliminate_entries_gen[simp]: "dim_row (eliminate_entries_gen v B i as) = dim_row B"
  "dim_col (eliminate_entries_gen v B i as) = dim_col B"
  unfolding eliminate_entries_gen_def by auto
  
lemma dimc_eliminate_entries_rec[simp]: "dim_col (eliminate_entries_rec B i as) = dim_col B"
  by (induct as arbitrary: B, auto simp: Let_def)

lemma dimr_eliminate_entries_rec[simp]: "dim_row (eliminate_entries_rec B i as) = dim_row B"
  by (induct as arbitrary: B, auto simp: Let_def)

lemma carrier_eliminate_entries: "A  carrier_mat nr nc  eliminate_entries_gen v A i bs  carrier_mat nr nc"
  "B  carrier_mat nr nc  eliminate_entries_rec B i as  carrier_mat nr nc"
  unfolding carrier_mat_def by auto
end

abbreviation "eliminate_entries  eliminate_entries_gen (-) ((*) :: 'a :: ring_1  'a  'a)"

lemma eliminate_entries_convert: 
  assumes jA: "J < dim_col A" and *: "I < dim_row A" "dim_row B = dim_row A" 
  shows "eliminate_entries (λ i. A $$ (i,J)) B I J = 
    eliminate_entries_rec B I (map (λ i. (- A $$ (i, J), i)) (filter (λ i. i  I) [0 ..< dim_row A]))"
proof -
  let ?ais = "λ is. map (λ i. (- A $$ (i, J), i)) (filter (λ i. i  I) is)" 
  define one_go where "one_go = (λ B is. mat (dim_row B) (dim_col B) (λ (i, j).
    if i  I  i  set is then B $$ (i,j) - (A $$ (i,J))  * B $$ (I,j) else B $$ (i,j)))"
  {
    fix "is" :: "nat list" 
    assume "distinct is"     
    from * this have "eliminate_entries_rec B I (?ais is) = one_go B is" 
    proof (induct "is" arbitrary: B)
      case Nil
      show ?case unfolding one_go_def
        by (rule eq_matI, auto)
    next
      case (Cons i "is")
      note I = Cons(2) note dim = Cons(3)      
      note II = Cons(2)[folded dim]
      let ?B = "addrow (- A $$ (i, J)) i I B" 
      from Cons(4) I dim have "I < dim_row A" "dim_row ?B = dim_row A" and dist: "distinct is" by auto
      note IH = Cons(1)[OF this]
      from Cons(4) have i: "i  set is" by auto
      show ?case 
      proof (cases "i = I")
        case False
        hence id: "?ais (i # is) = (- A $$ (i, J), i) # ?ais is" by simp
        show ?thesis unfolding id eliminate_entries_rec.simps IH
          unfolding one_go_def index_mat_addrow
        proof (rule eq_matI, goal_cases)
          case (1 ii jj)
          hence ii: "ii < dim_row B" and jj: "jj < dim_col B" and iiA: "ii < dim_row A" using dim by auto
          show ?case unfolding index_mat[OF ii jj] split
            index_mat_addrow(1)[OF ii jj] index_mat_addrow(1)[OF II jj]
            using i False by auto 
        qed auto
      next
        case True
        hence id: "?ais (i # is) = ?ais is" by simp        
        show ?thesis unfolding id Cons(1)[OF I dim dist]
          unfolding one_go_def True by auto
      qed
    qed
  } note main = this
  show ?thesis
    by (subst main, force, unfold one_go_def eliminate_entries_gen_def, rule eq_matI, 
    insert *, auto)
qed

lemma Unit_prod_eliminate_entries: "i < nr  ( a i'. (a, i')  set is  i' < nr  i'  i)
    P  Units (ring_mat TYPE('a :: comm_ring_1) nr b) .  B nc. B  carrier_mat nr nc  eliminate_entries_rec B i is = P * B" 
proof (induct "is")
  case Nil
  thus ?case by (intro bexI[of _ "1m nr"], auto simp: Units_def ring_mat_def)
next
  case (Cons ai' "is")
  obtain a i' where ai': "ai' = (a,i')" by force
  let ?U = "Units (ring_mat TYPE('a) nr b)"
  interpret m: ring "ring_mat TYPE('a) nr b" by (rule ring_mat)
  from Cons(1)[OF Cons(2-3)] 
  obtain P where P: "P  ?U" and id: " B nc . B  carrier_mat nr nc  
    eliminate_entries_rec B i is = P * B" by force
  let ?Add = "addrow_mat nr a i' i"
  have Add: "?Add  ?U"
    by (rule addrow_mat_Unit, insert Cons ai', auto)
  from m.Units_m_closed[OF P Add] have PI: "P * ?Add  ?U" unfolding ring_mat_def by simp
  from m.Units_closed[OF P] have P: "P  carrier_mat nr nr" unfolding ring_mat_def by simp
  show ?case
  proof (rule bexI[OF _ PI], intro allI impI)
    fix B :: "'a mat" and nc
    assume BB: "B  carrier_mat nr nc"
    let ?B = "addrow a i' i B"
    from BB have B: "?B  carrier_mat nr nc" by simp
    from id[OF B] have id: "eliminate_entries_rec ?B i is = P * ?B" .
    have id2: "eliminate_entries_rec B i (ai' # is) = eliminate_entries_rec ?B i is" unfolding ai' by simp
    show "eliminate_entries_rec B i (ai' # is) = P * ?Add * B"
      unfolding id2 id unfolding addrow_mat[OF BB Cons(2)]
      by (rule assoc_mult_mat[symmetric, OF P _ BB], auto)
  qed
qed

function gauss_jordan_main :: "'a :: field mat  'a mat  nat  nat  'a mat × 'a mat" where
  "gauss_jordan_main A B i j = (let nr = dim_row A; nc = dim_col A in
    if i < nr  j < nc then let aij = A $$ (i,j) in if aij = 0 then
      (case [ i' . i' <- [Suc i ..< nr],  A $$ (i',j)  0] 
        of []  gauss_jordan_main A B i (Suc j)
         | (i' # _)  gauss_jordan_main (swaprows i i' A) (swaprows i i' B) i j)
      else if aij = 1 then let 
        v = (λ i. A $$ (i,j)) in
        gauss_jordan_main 
        (eliminate_entries v A i j) (eliminate_entries v B i j) (Suc i) (Suc j)
      else let iaij = inverse aij in gauss_jordan_main (multrow i iaij A) (multrow i iaij B) i j
    else (A,B))"
  by pat_completeness auto

termination
proof -
  let ?R = "measures [λ (A :: 'a :: field mat,B,i,j). dim_col A - j, 
    λ (A,B,i,j). if A $$ (i,j) = 0 then 2 else if A $$ (i,j) = 1 then 0 else 1]"
  show ?thesis
  proof
    show "wf ?R" by auto
  next
    fix A B :: "'a mat" and i j nr nc a i' "is"
    assume *: "nr = dim_row A" "nc = dim_col A" "i < nr  j < nc" "a = A $$ (i, j)" "a = 0"
      and ne: "[ i' . i' <- [Suc i ..< nr],  A $$ (i',j)  0] = i' # is"
    from ne have "i'  set ([ i' . i' <- [Suc i ..< nr],  A $$ (i',j)  0])" by auto
    with *
    show "((swaprows i i' A, swaprows i i' B, i, j), A, B, i, j)  ?R" by auto
  qed auto
qed

declare gauss_jordan_main.simps[simp del]

definition "gauss_jordan A B  gauss_jordan_main A B 0 0"

lemma gauss_jordan_transform: assumes A: "A  carrier_mat nr nc" and B: "B  carrier_mat nr nc'"
  and res: "gauss_jordan (A :: 'a :: field mat) B = (A',B')"
  shows " P  Units (ring_mat TYPE('a) nr b). A' = P * A  B' = P * B"
proof -
  let ?U = "Units (ring_mat TYPE('a) nr b)"
  interpret m: ring "ring_mat TYPE('a) nr b" by (rule ring_mat)
  {
    fix i j :: nat
    assume "gauss_jordan_main A B i j = (A',B')"
    with A B
    have " P  ?U. A' = P * A  B' = P * B"
    proof (induction A B i j rule: gauss_jordan_main.induct)
      case (1 A B i j)
      note A = 1(5)
      hence dim: "dim_row A = nr" "dim_col A = nc" by auto
      note B = 1(6)
      hence dimB: "dim_row B = nr" "dim_col B = nc'" by auto
      note IH = 1(1-4)[OF dim[symmetric]]
      note res = 1(7)
      note simp = gauss_jordan_main.simps[of A B i j] Let_def
      let ?g = "gauss_jordan_main A B i j"
      show ?case 
      proof (cases "i < nr  j < nc")
        case False
        with res have res: "A' = A" "B' = B" unfolding simp dim by auto
        show ?thesis unfolding res
          by (rule bexI[of _ "1m nr"], insert A B, auto simp: Units_def ring_mat_def)
      next
        case True note valid = this
        note IH = IH[OF valid refl]
        show ?thesis 
        proof (cases "A $$ (i,j) = 0")
          case False note nZ = this
          note IH = IH(3-4)[OF nZ]
          show ?thesis
          proof (cases "A $$ (i,j) = 1")
            case False note nO = this
            let ?inv = "inverse (A $$ (i,j))"
            from nO nZ valid res 
            have "gauss_jordan_main (multrow i ?inv A) (multrow i ?inv B) i j = (A',B')"
              unfolding simp dim by simp
            note IH = IH(2)[OF nO refl, unfolded multrow_carrier, OF A B this]
            from IH obtain P where P: "P  ?U" and
              id: "A' = P * multrow i ?inv A" "B' = P * multrow i ?inv B" by blast
            let ?Inv = "multrow_mat nr i ?inv"
            from nZ valid have "i < nr" "?inv  0" by auto
            from multrow_mat_Unit[OF this]
            have Inv: "?Inv  ?U" .
            from m.Units_m_closed[OF P Inv] have PI: "P * ?Inv  ?U" unfolding ring_mat_def by simp
            from m.Units_closed[OF P] have P: "P  carrier_mat nr nr" unfolding ring_mat_def by simp
            show ?thesis unfolding id unfolding multrow_mat[OF A] multrow_mat[OF B]
              by (rule bexI[OF _ PI], intro conjI, 
                rule assoc_mult_mat[symmetric, OF P _ A], simp, 
                rule assoc_mult_mat[symmetric, OF P _ B], simp)
          next
            case True note O = this
            let ?is = "filter (λ i'. i'  i) [0 ..< nr]" 
            let ?ais = "map (λ i'. (-A $$ (i',j), i')) ?is" 
            let ?E = "λ B. eliminate_entries (λ i. A $$ (i,j)) B i j"
            let ?EE = "λ B. eliminate_entries_rec B i ?ais"
            let ?A = "?E A"
            let ?B = "?E B"
            let ?AA = "?EE A"
            let ?BB = "?EE B"
            from O nZ valid res have "gauss_jordan_main ?A ?B (Suc i) (Suc j) = (A',B')"
              unfolding simp dim by simp
            note IH = IH(1)[OF O refl carrier_eliminate_entries(1)[OF A] carrier_eliminate_entries(1)[OF B] this]
            from IH obtain P where P: "P  ?U" and id: "A' = P * ?A" "B' = P * ?B" by blast
            have *: "j < dim_col A" "i < dim_row A" by (auto simp add: dim valid)
            have "P?U.  B nc. B  carrier_mat nr nc  ?EE B = P * B"
              by (rule Unit_prod_eliminate_entries, insert valid, auto)
            then obtain Q where Q: "Q  ?U" and QQ: " B nc. B  carrier_mat nr nc  ?EE B = Q * B" by auto
            {
              fix B :: "'a mat" and nc
              assume B: "B  carrier_mat nr nc" 
              with dim have "dim_row B = dim_row A" by auto
              from eliminate_entries_convert[OF * this]
              have "?E B = ?EE B" using dim by simp
              also have " = Q * B" using QQ[OF B] by simp
              finally have "?E B = Q * B" .
            } note QQ = this              
            have id3: "?A = Q * A" by (rule QQ[OF A])
            have id4: "?B = Q * B" by (rule QQ[OF B])
            from m.Units_closed[OF P] have Pc: "P  carrier_mat nr nr" unfolding ring_mat_def by simp
            from m.Units_closed[OF Q] have Qc: "Q  carrier_mat nr nr" unfolding ring_mat_def by simp
            from m.Units_m_closed[OF P Q] have PQ: "P * Q  ?U" unfolding ring_mat_def by simp
            show ?thesis unfolding id unfolding id3 id4 
              by (rule bexI[OF _ PQ], rule conjI, 
              rule assoc_mult_mat[symmetric, OF Pc Qc A],
              rule assoc_mult_mat[symmetric, OF Pc Qc B])
          qed
        next
          case True note Z = this
          note IH = IH(1-2)[OF Z]
          let ?is = "[ i' . i' <- [Suc i ..< nr],  A $$ (i',j)  0]"
          show ?thesis
          proof (cases ?is)
            case Nil 
            from Z valid res have id: "gauss_jordan_main A B i (Suc j) = (A',B')" unfolding simp dim Nil by simp
            from IH(1)[OF Nil A B this] show ?thesis unfolding id .
          next
            case (Cons i' iis)
            from Z valid res have "gauss_jordan_main (swaprows i i' A) (swaprows i i' B) i j = (A',B')" 
              unfolding simp dim Cons by simp
            from IH(2)[OF Cons, unfolded swaprows_carrier, OF A B this]
            obtain P where P: "P  ?U" and
              id: "A' = P * swaprows i i' A" "B' = P * swaprows i i' B" by blast
            let ?Swap = "swaprows_mat nr i i'"
            from Cons have "i'  set ?is" by auto
            with valid have i': "i < nr" "i' < nr" by auto
            from swaprows_mat_Unit[OF this] have Swap: "?Swap  ?U" .
            from m.Units_m_closed[OF P Swap] have PI: "P * ?Swap  ?U" unfolding ring_mat_def by simp
            from m.Units_closed[OF P] have P: "P  carrier_mat nr nr" unfolding ring_mat_def by simp
            show ?thesis unfolding id swaprows_mat[OF A i'] swaprows_mat[OF B i']
              by (rule bexI[OF _ PI], rule conjI, 
              rule assoc_mult_mat[symmetric, OF P _ A], simp,
              rule assoc_mult_mat[symmetric, OF P _ B], simp)
          qed
        qed
      qed
    qed
  }
  from this[of 0 0, folded gauss_jordan_def, OF res] show ?thesis .
qed

lemma gauss_jordan_carrier: assumes A: "(A :: 'a :: field mat)  carrier_mat nr nc"
  and B: "B  carrier_mat nr nc'" 
  and res: "gauss_jordan A B = (A',B')"
  shows "A'  carrier_mat nr nc" "B'  carrier_mat nr nc'"
proof -
  from gauss_jordan_transform[OF A B res, of undefined]
  obtain P where P: "P  Units (ring_mat TYPE('a) nr undefined)"
    and id: "A' = P * A" "B' = P * B" by auto
  from P have P: "P  carrier_mat nr nr" unfolding Units_def ring_mat_def by auto
  show "A'  carrier_mat nr nc" "B'  carrier_mat nr nc'" unfolding id
    using P A B by auto
qed


definition pivot_fun :: "'a :: {zero,one} mat  (nat  nat)  nat  bool" where
  "pivot_fun A f nc  let nr = dim_row A in 
    ( i < nr. f i  nc  
      (f i < nc  A $$ (i, f i) = 1  ( i' < nr. i'  i  A $$ (i',f i) = 0))  
      ( j < f i. A $$ (i, j) = 0) 
      (Suc i < nr  f (Suc i) > f i  f (Suc i) = nc))"

lemma pivot_funI: assumes d: "dim_row A = nr"
  and *: " i. i < nr  f i  nc"
      " i j. i < nr  j < f i  A $$ (i,j) = 0"
      " i. i < nr  Suc i < nr  f (Suc i) > f i  f (Suc i) = nc"
      " i. i < nr  f i < nc  A $$ (i, f i) = 1"
      " i i'. i < nr  f i < nc  i' < nr  i'  i  A $$ (i',f i) = 0"
  shows "pivot_fun A f nc"
  unfolding pivot_fun_def Let_def d using * by blast

lemma pivot_funD: assumes d: "dim_row A = nr"
  and p: "pivot_fun A f nc"
  shows " i. i < nr  f i  nc"
      " i j. i < nr  j < f i  A $$ (i,j) = 0"
      " i. i < nr  Suc i < nr  f (Suc i) > f i  f (Suc i) = nc"
      " i. i < nr  f i < nc  A $$ (i, f i) = 1"
      " i i'. i < nr  f i < nc  i' < nr  i'  i  A $$ (i',f i) = 0"
  using p unfolding pivot_fun_def Let_def d by blast+

lemma pivot_fun_multrow: assumes p: "pivot_fun A f jj"
  and d: "dim_row A = nr" "dim_col A = nc"
  and fi: "f i0 = jj"
  and jj: "jj  nc"
  shows "pivot_fun (multrow i0 a A) f jj"
proof -
  note p = pivot_funD[OF d(1) p]
  let ?A = "multrow i0 a A"
  have "dim_row ?A = nr" using d by simp
  thus ?thesis
  proof (rule pivot_funI)
    fix i
    assume i: "i < nr"
    note p = p[OF i]
    show "f i  jj" by fact
    show "Suc i < nr  f i < f (Suc i)  f (Suc i) = jj" by fact
    {
      fix i'
      assume *: "f i < jj" "i' < nr" "i'  i" 
      from p(5)[OF this]
      show "?A $$ (i', f i) = 0"
        by (subst index_mat_multrow(1), insert * d jj, auto)
    }
    {
      assume *: "f i < jj"
      from p(4)[OF this] have A: "A $$ (i, f i) = 1" by auto
      show "?A $$ (i, f i) = 1"
        by (subst index_mat_multrow(1), insert * d i A jj fi, auto)
    }
    {
      fix j
      assume j: "j < f i"
      from p(2)[OF j]
      show "?A $$ (i, j) = 0"
        by (subst index_mat_multrow(1), insert j d i p jj fi, auto)
    }
  qed
qed

lemma pivot_fun_swaprows: assumes p: "pivot_fun A f jj"
  and d: "dim_row A = nr" "dim_col A = nc"
  and flk: "f l = jj" "f k = jj"
  and nr: "l < nr" "k < nr"
  and jj: "jj  nc"
  shows "pivot_fun (swaprows l k A) f jj"
proof -
  note pivot = pivot_funD[OF d(1) p]
  let ?A = "swaprows l k A"
  have "dim_row ?A = nr" using d by simp
  thus ?thesis
  proof (rule pivot_funI)
    fix i
    assume i: "i < nr"
    note p = pivot[OF i]
    show "f i  jj" by fact
    show "Suc i < nr  f i < f (Suc i)  f (Suc i) = jj" by fact
    {
      fix i'
      assume *: "f i < jj" "i' < nr" "i'  i" 
      from *(1) flk have diff: "l  i" "k  i" by auto
      from p(5)[OF *] p(5)[OF *(1) nr(1) diff(1)] p(5)[OF *(1) nr(2) diff(2)]
      show "?A $$ (i', f i) = 0"  
        by (subst index_mat_swaprows(1), insert * d jj, auto)
    }
    {
      assume *: "f i < jj"
      from p(4)[OF this] have A: "A $$ (i, f i) = 1" by auto
      show "?A $$ (i, f i) = 1"
        by (subst index_mat_swaprows(1), insert * d i A jj flk, auto)
    }
    {
      fix j
      assume j: "j < f i"
      with p(1) flk have le: "j < f l" "j < f k" by auto
      from p(2)[OF j] pivot(2)[OF nr(1) le(1)] pivot(2)[OF nr(2) le(2)]
      show "?A $$ (i, j) = 0" 
        by (subst index_mat_swaprows(1), insert j d i p jj, auto) 
    }
  qed
qed

lemma pivot_fun_eliminate_entries: assumes p: "pivot_fun A f jj"
  and d: "dim_row A = nr" "dim_col A = nc"
  and fl: "f l = jj"
  and nr: "l < nr"
  and jj: "jj  nc"
shows "pivot_fun (eliminate_entries vs A l j) f jj" 
proof -
  note pD = pivot_funD[OF d(1) p]
  {
    fix i j
    assume *: "i < nr" "j < f i"
    from pD(1)[OF this(1)] this(2) jj have j: "j < nc" by auto
    from pD nr fl * j have "A $$ (l, j) = 0" by (meson less_le_trans)
    note j this
  } note hint = this
  show ?thesis by (rule pivot_funI, insert fl nr jj pD, auto simp: eliminate_entries_gen_def d hint)
qed
    
definition row_echelon_form :: "'a :: {zero,one} mat  bool" where
  "row_echelon_form A   f. pivot_fun A f (dim_col A)"

lemma pivot_fun_init: "pivot_fun A (λ _. 0) 0"
  by (rule pivot_funI, auto)

lemma gauss_jordan_main_row_echelon: 
  assumes 
    "A  carrier_mat nr nc"
    "gauss_jordan_main A B i j = (A',B')"
    "pivot_fun A f j" 
    " i'. i' < i  f i' < j" " i'. i'  i  f i' = j"
    "i  nr" "j  nc"
  shows "row_echelon_form A'"
proof -
  fix b
  interpret m: ring "ring_mat TYPE('a) nr b" by (rule ring_mat)
  show ?thesis
    using assms
  proof (induct A B i j arbitrary: f rule: gauss_jordan_main.induct)
    case (1 A B i j f)
    note A = 1(5)
    hence dim: "dim_row A = nr" "dim_col A = nc" by auto
    note res = 1(6)
    note pivot = 1(7)
    note f = 1(8-9)
    note ij = 1(10-11)
    note IH = 1(1-4)[OF dim[symmetric]]
    note simp = gauss_jordan_main.simps[of A B i j] Let_def
    let ?g = "gauss_jordan_main A B i j"
    show ?case 
    proof (cases "i < nr  j < nc")
      case False note nij = this
      with res have id: "A' = A" unfolding simp dim by auto
      have "pivot_fun A f nc"
      proof (cases "j  nc")
        case True
        with ij have j: "j = nc" by auto
        with pivot show "pivot_fun A f nc" by simp
      next
        case False
        hence j: "j < nc" by simp
        from False nij ij have i: "i = nr" by auto
        note f = f[unfolded i]
        note p = pivot_funD[OF dim(1) pivot]
        show ?thesis
        proof (rule pivot_funI[OF dim(1)])
          fix i
          assume i: "i < nr"
          note p = p[OF i]
          from p(1) j show "f i  nc" by simp
          from f(1)[OF i] have fij: "f i < j" .
          from p(4)[OF fij] show "A $$ (i, f i) = 1" .
          from p(5)[OF fij] show " i'. i' < nr  i'  i  A $$ (i', f i) = 0" .
          show " j. j < f i  A $$ (i, j) = 0" by (rule p(2))
          assume "Suc i < nr"
          with p(3)[OF this] f
          show "f i < f (Suc i)  f (Suc i) = nc" by auto
        qed          
      qed
      thus ?thesis using pivot unfolding id row_echelon_form_def dim by blast
    next
      case True note valid = this
      hence sij: "Suc i  nr" "Suc j  nc" by auto
      note IH = IH[OF valid refl]
      show ?thesis 
      proof (cases "A $$ (i,j) = 0")
        case False note nZ = this
        note IH = IH(3-4)[OF nZ]
        show ?thesis
        proof (cases "A $$ (i,j) = 1")
          case False note nO = this
          let ?inv = "inverse (A $$ (i,j))"
          let ?A = "multrow i ?inv A"
          from nO nZ valid res have id: "gauss_jordan_main (multrow i ?inv A) (multrow i ?inv B) i j = (A', B')"
            unfolding simp dim by simp
          have "pivot_fun ?A f j"
            by (rule pivot_fun_multrow[OF pivot dim f(2) ij(2)], auto)
          note IH = IH(2)[OF nO refl, unfolded multrow_carrier, OF A id this f ij]
          show ?thesis unfolding id using IH .
        next
          case True note O = this
          let ?E = "λ B. eliminate_entries (λ i. A $$ (i,j)) B i j" 
          let ?A = "?E A"
          let ?B = "?E B"
          define E where "E = ?A"
          let ?f = "λ i'. if i' = i then j else if i' > i then Suc j else f i'"
          have pivot: "pivot_fun E f j" unfolding E_def          
            by (rule pivot_fun_eliminate_entries[OF pivot dim f(2)], insert valid, auto)
          {
            fix i'
            assume i': "i' < nr"
            have "E $$ (i', j) = (if i' = i then 1 else 0)"
              unfolding E_def eliminate_entries_gen_def using dim O i' valid by auto
          } note Ej = this
          have E: "E  carrier_mat nr nc" unfolding E_def by (rule carrier_eliminate_entries[OF A])
          hence dimE: "dim_row E = nr" "dim_col E = nc" by auto
          note pivot = pivot_funD[OF dimE(1) pivot]
          have "pivot_fun E ?f (Suc j)"
          proof (rule pivot_funI[OF dimE(1)])
            fix ii
            assume ii: "ii < nr"
            note p = pivot[OF ii]
            show "?f ii  Suc j" using p(1) by simp
            {
              fix jj
              assume jj: "jj < ?f ii"
              show "E $$ (ii,jj) = 0"
              proof (cases "ii < i")
                case True
                with jj have "jj < f ii" by auto
                from p(2)[OF this] show ?thesis .
              next
                case False note ge = this
                with f have fiij: "f ii = j" by simp 
                show ?thesis
                proof (cases "i < ii")
                  case True
                  with jj have jj: "jj  j" by auto
                  show ?thesis
                  proof (cases "jj < j")
                    case True
                    with p(2)[of jj] fiij show ?thesis by auto
                  next
                    case False
                    with jj have jj: "jj = j" by auto
                    with Ej[OF ii] True show ?thesis by auto
                  qed
                next
                  case False
                  with ge have ii: "ii = i" by simp
                  with jj have jj: "jj < j" by simp
                  from p(2)[of jj] ii jj fiij show ?thesis by auto
                qed
              qed
            }
            {
              assume "Suc ii < nr"
              from p(3)[OF this] f
              show "?f (Suc ii) > ?f ii  ?f (Suc ii) = Suc j" by auto
            }
            {
              assume fii: "?f ii < Suc j"
              show "E $$ (ii, ?f ii) = 1"
              proof (cases "ii = i")
                case True
                with Ej[of i] valid show ?thesis by auto
              next
                case False
                with fii have ii: "ii < i" by (auto split: if_splits)
                from f(1)[OF this] have "f ii < j" by auto
                from p(4)[OF this] ii show ?thesis by simp
              qed
            }
            {
               fix i'
               assume *: "?f ii < Suc j" "i' < nr" "i'  ii"
               show "E $$ (i', ?f ii) = 0"
               proof (cases "ii = i")
                 case False
                 with *(1) have iii: "ii < i" by (auto split: if_splits)
                 from f(1)[OF this] have "f ii < j" by auto
                 from p(5)[OF this *(2-3)] show ?thesis using iii by simp
               next
                 case True
                 with *(2-3) Ej[of i'] show ?thesis by auto
               qed
            }
          qed 
          note IH = IH(1)[OF O refl, folded E_def, OF E _ this _ _ sij]     
          from O nZ valid res have "gauss_jordan_main E ?B (Suc i) (Suc j) = (A', B')"
            unfolding E_def simp dim by simp
          note IH = IH[OF this]
          show ?thesis  
          proof (rule IH)
            fix i'
            assume "i' < Suc i"
            thus "?f i' < Suc j" using f[of i'] by (cases "i' < i", auto)
          qed auto
        qed
      next
        case True note Z = this
        note IH = IH(1-2)[OF Z]
        let ?is = "[ i' . i' <- [Suc i ..< nr],  A $$ (i',j)  0]"
        show ?thesis
        proof (cases ?is)
          case Nil
          {
            fix i'
            assume "i  i'" and "i' < nr"
            hence "i' = i  i'  {Suc i ..< nr}" by auto
            from this arg_cong[OF Nil, of set] Z have "A $$ (i',j) = 0" by auto
          } note zero = this
          let ?f = "λ i'. if i' < i then f i' else Suc j"
          note p = pivot_funD[OF dim(1) pivot]
          have "pivot_fun A ?f (Suc j)"
          proof (rule pivot_funI[OF dim(1)])
            fix ii
            assume ii: "ii < nr"
            note p = p[OF this]
            show "?f ii  Suc j" using p(1) by simp
            {
              fix jj
              assume jj: "jj < ?f ii"              
              show "A $$ (ii,jj) = 0"
              proof (cases "ii < i")
                case True
                with jj have "jj < f ii" by auto
                from p(2)[OF this] show ?thesis .
              next
                case False
                with jj have ii': "ii  i" and jjj: "jj  j" by auto
                from zero[OF ii' ii] have Az: "A $$ (ii,j) = 0" .
                show ?thesis
                proof (cases "jj < j")
                  case False
                  with jjj have "jj = j" by auto
                  with Az show ?thesis by simp
                next
                  case True
                  show ?thesis
                    by (rule p(2), insert True False f, auto)
                qed
              qed
            }
            {
              assume sii: "Suc ii < nr"
              show "?f ii < ?f (Suc ii)  ?f (Suc ii) = Suc j"
                using p(3)[OF sii] f by auto
            }
            {
              assume fii: "?f ii < Suc j"
              thus "A $$ (ii, ?f ii) = 1"
                using p(4) f by (cases "ii < i", auto)
              fix i'
              assume "i' < nr" "i'  ii"
              from p(5)[OF _ this] f fii
              show "A $$ (i', ?f ii) = 0" 
                by (cases "ii < i", auto)
            }
          qed
          note IH = IH(1)[OF Nil A _ this _ _ ij(1) sij(2)]
          from Z valid res have "gauss_jordan_main A B i (Suc j) = (A',B')" unfolding simp dim Nil by simp
          note IH = IH[OF this]
          show ?thesis  
            by (rule IH, insert f, force+)
        next
          case (Cons i' iis)
          from arg_cong[OF this, of set] have i': "i'  Suc i" "i' < nr" by auto
          from f[of i] f[of "i'"] i' have fij: "f i = j" "f i' = j" by auto 
          let ?A = "swaprows i i' A"
          let ?B = "swaprows i i' B"
          have "pivot_fun ?A f j"
            by (rule pivot_fun_swaprows[OF pivot dim fij], insert i' ij, auto)
          note IH = IH(2)[OF Cons, unfolded swaprows_carrier, OF A _ this f ij]
          from Z valid res have id: "gauss_jordan_main ?A ?B i j = (A', B')" unfolding simp dim Cons by simp
          note IH = IH[OF this]
          show ?thesis using IH .
        qed
      qed
    qed
  qed
qed

lemma gauss_jordan_row_echelon: 
  assumes A: "A  carrier_mat nr nc" 
  and res: "gauss_jordan A B = (A', B')"
  shows "row_echelon_form A'"
  by (rule gauss_jordan_main_row_echelon[OF A res[unfolded gauss_jordan_def] pivot_fun_init], auto)

lemma pivot_bound: assumes dim: "dim_row A = nr"
  and pivot: "pivot_fun A f n"
  shows "i + j < nr  f (i + j) = n  f (i + j)  j + f i"
proof (induct j)
  case (Suc j)
  hence IH: "f (i + j) = n  j + f i  f (i + j)" 
    and lt: "i + j < nr" "Suc (i + j) < nr" by auto
  note p = pivot_funD[OF dim pivot]
  from p(3)[OF lt] IH p(1)[OF lt(2)] show ?case by auto
qed simp

context
  fixes zero :: 'a
  and A :: "'a mat"
  and nr nc :: nat
begin
function pivot_positions_main_gen :: "nat  nat  (nat × nat) list" where
  "pivot_positions_main_gen i j = (
     if i < nr then
       if j < nc then 
         if A $$ (i,j) = zero then 
           pivot_positions_main_gen i (Suc j)
         else (i,j) # pivot_positions_main_gen (Suc i) (Suc j)
       else []
     else [])" by pat_completeness auto

termination by (relation "measures [(λ (i,j). Suc nr - i), (λ (i,j). Suc nc - j)]", auto)

declare pivot_positions_main_gen.simps[simp del]
end

context
  fixes A :: "'a :: semiring_1 mat"
  and nr nc :: nat
begin

abbreviation "pivot_positions_main  pivot_positions_main_gen (0 :: 'a) A nr nc"

lemma pivot_positions_main: assumes A: "A  carrier_mat nr nc"
  and pivot: "pivot_fun A f nc"
  shows "j  f i  i  nr  
    set (pivot_positions_main i j) = {(i', f i') | i'. i  i'  i' < nr} - UNIV × {nc}
     distinct (map snd (pivot_positions_main i j))
     distinct (map fst (pivot_positions_main i j))"
proof (induct i j rule: pivot_positions_main_gen.induct[of nr nc A 0])
  case (1 i j)
  let ?a = "A $$ (i,j)"
  let ?pivot = "λ i j. pivot_positions_main i j"
  let ?set = "λ i. {(i',f i') | i'. i  i'  i' < nr}"
  let ?s = "?set i"
  let ?set = "λ i. {(i',f i') | i'. i  i'  i' < nr}"
  let ?s = "?set i"
  let ?p = "?pivot i j"
  from A have dA: "dim_row A = nr" by simp
  note [simp] = pivot_positions_main_gen.simps[of 0 A nr nc i j]
  show ?case
  proof (cases "i < nr")
    case True note i = this
    note IH = 1(1-2)[OF True]
    have jfi: "j  f i" using 1(3) i by auto
    note pivotB = pivot_bound[OF dA pivot]
    note pivot' = pivot_funD[OF dA pivot]
    note pivot = pivot'[OF True]
    have id1: "[i ..< nr] = i # [Suc i ..< nr]" using i by (rule upt_conv_Cons)
    show ?thesis
    proof (cases "j < nc")
      case True note j = this
      note IH = IH(1-2)[OF True]
      show ?thesis
      proof (cases "?a = 0")
        case True note a = this
        from i j a have p: "?p = ?pivot i (Suc j)" by simp
        {
          assume "f i = j"
          with pivot(4) j have "?a = 1" by simp
          with a have False by simp
        }
        with jfi have "Suc j  f i  i  nr" by fastforce
        note IH = IH(1)[OF True this]
        thus ?thesis unfolding p .
      next
        case False note a = this
        from i j a have p: "?p = (i,j) # ?pivot (Suc i) (Suc j)" by simp
        from pivot(2)[of j] jfi a have jfi: "j = f i" by force
        from pivotB[of i "Suc 0"] jfi have "Suc j  f (Suc i)  nr  Suc i" 
          using Suc_le_eq j leI by auto
        note IH = IH(2)[OF False this]
        {
          fix i'
          assume *: "f i = f i'" "Suc i  i'" "i' < nr" 
          hence "i + (i' - i) = i'" by auto
          from pivotB[of i "i' - i", unfolded this] * jfi j have False by auto
        } note distinct = this
        have id2: "?s = insert (i,j) (?set (Suc i))" using i jfi not_less_eq_eq
          by fastforce
        show ?thesis using IH j jfi i unfolding p id1 id2 by (auto intro: distinct)
      qed
    next
      case False note j = this
      from pivot(1) j jfi have *: "f i = nc" "nc = j" by auto
      from i j have p: "?p = []" by simp
      from pivotB[of i "Suc 0"] * have "j  f (Suc i)  nr  Suc i" by auto
      {
        fix i'
        assume **: "i  i'" "i' < nr" 
        hence "i + (i' - i) = i'" by auto
        from pivotB[of i "i' - i", unfolded this] ** * have "nc  f i'" by auto
        with pivot'(1)[OF i' < nr] have "f i' = nc" by auto
      }
      thus ?thesis using IH unfolding p id1 by auto
    qed
  qed auto
qed
end

lemma pivot_fun_zero_row_iff: assumes pivot: "pivot_fun (A :: 'a :: semiring_1 mat) f nc"
  and A: "A  carrier_mat nr nc"
  and i: "i < nr"
  shows "f i = nc  row A i = 0v nc"
proof -
  from A have dim: "dim_row A = nr" by auto
  note pivot = pivot_funD[OF dim pivot i]
  {
    assume "f i = nc"
    from pivot(2)[unfolded this]
    have "row A i = 0v nc"
      by (intro eq_vecI, insert A, auto simp: row_def)
  }
  moreover
  {
    assume row: "row A i = 0v nc"
    assume "f i  nc"
    with pivot(1) have "f i < nc" by auto
    with pivot(4)[OF this] i A arg_cong[OF row, of "λ v. v $ f i"] have False by auto
  }
  ultimately show ?thesis by auto
qed

definition pivot_positions_gen :: "'a  'a mat  (nat × nat) list" where
  "pivot_positions_gen zer A  pivot_positions_main_gen zer A (dim_row A) (dim_col A) 0 0"

abbreviation pivot_positions :: "'a :: semiring_1 mat  (nat × nat) list" where
  "pivot_positions  pivot_positions_gen 0"

lemmas pivot_positions_def = pivot_positions_gen_def

lemma pivot_positions: assumes A: "A  carrier_mat nr nc"
  and pivot: "pivot_fun A f nc"
  shows 
    "set (pivot_positions A) = {(i, f i) | i. i < nr  f i  nc}"
    "distinct (map fst (pivot_positions A))"
    "distinct (map snd (pivot_positions A))"
    "length (pivot_positions A) = card { i. i < nr  row A i  0v nc}"
proof -
  from A have dim: "dim_row A = nr" by auto
  let ?pp = "pivot_positions A"
  show id: "set ?pp = {(i, f i) | i. i < nr  f i  nc}"
    and dist: "distinct (map fst ?pp)"
    and "distinct (map snd ?pp)"  
  using pivot_positions_main[OF A pivot, of 0 0] A
  unfolding pivot_positions_def by auto
  have "length ?pp = length (map fst ?pp)" by simp
  also have " = card (fst ` set ?pp)" using distinct_card[OF dist] by simp
  also have "fst ` set ?pp = { i. i < nr  f i  nc}" unfolding id by force
  also have " = { i. i < nr  row A i  0v nc}"
    using pivot_fun_zero_row_iff[OF pivot A] by auto
  finally show "length ?pp = card {i. i < nr  row A i  0v nc}" .
qed

context 
  fixes uminus :: "'a  'a"
  and zero :: 'a
  and one :: 'a
begin
definition non_pivot_base_gen :: "'a mat  (nat × nat)list  nat  'a vec" where
  "non_pivot_base_gen A pivots  let nr = dim_row A; nc = dim_col A; 
     invers = map_of (map prod.swap pivots)
     in (λ qj. vec nc (λ i. 
     if i = qj then one else (case invers i of Some j => uminus (A $$ (j,qj)) | None  zero)))"

definition find_base_vectors_gen :: "'a mat  'a vec list" where
  "find_base_vectors_gen A  
    let 
      pp = pivot_positions_gen zero A;     
      cands = filter (λ j. j  set (map snd pp)) [0 ..< dim_col A]
    in map (non_pivot_base_gen A pp) cands"
end

abbreviation "non_pivot_base  non_pivot_base_gen uminus 0 (1 :: 'a :: comm_ring_1)"
abbreviation "find_base_vectors  find_base_vectors_gen uminus 0 (1 :: 'a :: comm_ring_1)"

lemmas non_pivot_base_def = non_pivot_base_gen_def
lemmas find_base_vectors_def = find_base_vectors_gen_def

text The soundness of @{const find_base_vectors} is proven in theory Matrix-Kern,
  where it is shown that @{const find_base_vectors} is a basis of the kern of $A$.

definition find_base_vector :: "'a :: comm_ring_1 mat  'a vec" where
  "find_base_vector A  
    let 
      pp = pivot_positions A;     
      cands = filter (λ j. j  set (map snd pp)) [0 ..< dim_col A]
    in non_pivot_base A pp (hd cands)"

context
  fixes A :: "'a :: field mat" and nr nc :: nat and p :: "nat  nat"
  assumes ref: "row_echelon_form A"
  and A: "A  carrier_mat nr nc"
begin

lemma non_pivot_base:
  defines pp: "pp  pivot_positions A"
  assumes qj: "qj < nc" "qj  snd ` set pp" 
  shows "non_pivot_base A pp qj  carrier_vec nc"
    "non_pivot_base A pp qj $ qj = 1"
    "A *v non_pivot_base A pp qj = 0v nr"
    " qj'. qj' < nc  qj'  snd ` set pp  qj  qj'  non_pivot_base A pp qj $ qj' = 0"
proof -
  from A have dim: "dim_row A = nr" "dim_col A = nc" by auto
  from ref[unfolded row_echelon_form_def] obtain p 
  where pivot: "pivot_fun A p nc" using dim by auto
  note pivot' = pivot_funD[OF dim(1) pivot]
  note pp = pivot_positions[OF A pivot, folded pp]
  let ?p = "λ i. i < nr  p i = nc  i = nr"
  let ?spp = "map prod.swap pp"
  let ?map = "map_of ?spp"
  define I where "I = (λ i. case map_of (map prod.swap pp) i of Some j  - A $$ (j,qj) | None  0)"
  have d: "non_pivot_base A pp qj = vec nc (λ i. if i = qj then 1 else I i)"
    unfolding non_pivot_base_def Let_def dim I_def ..
  from pp have dist: "distinct (map fst ?spp)" 
    unfolding map_map o_def prod.swap_def by auto
  let ?r = "set (map snd pp)"
  have r: "?r = p ` {0 ..< nr} - {nc}" unfolding set_map pp by force
  let ?l = "set (map fst pp)"
  from qj have qj': "qj  p ` {0 ..< nr}" using r by auto
  let ?v = "non_pivot_base A pp qj"
  let ?P = "p ` {0 ..< nr}"
  have dimv: "dim_vec ?v = nc" unfolding d by simp
  thus "?v  carrier_vec nc" unfolding carrier_vec_def by auto
  show vqj: "?v $ qj = 1" unfolding d using qj by auto
  { 
    fix qj'
    assume *: "qj' < nc" "qj  qj'" "qj'  snd ` set pp"
    hence "?map qj' = None" unfolding map_of_eq_None_iff by force
    hence "I qj' = 0" unfolding I_def by simp
    with * show "non_pivot_base A pp qj $ qj' = 0" 
      unfolding d by simp
  }    
  {
    fix i
    assume i: "i < nr"
    let ?I = "{j. ?map j = Some i}"
    have "row A i  ?v = 0" 
    proof -
      have id: "({0..<nc}  ?P)  ({0..<nc} - ?P) = {0..<nc}" by auto
      let ?e = "λ j. row A i $ j * ?v $ j"
      let ?e' = "λ j. (if ?map j = Some i then - A $$ (i, qj) else 0)"
      {
        fix j
        assume j: "j < nc" "j  ?P"
        then obtain ii where ii: "ii < nr" and jpi: "j = p ii" and pii: "p ii < nc" by auto
        hence mem: "(ii,j)  set pp" and "(j,ii)  set ?spp" by (auto simp: pp)        
        from map_of_is_SomeI[OF dist this(2)] 
        have map: "?map j = Some ii" by auto
        from mem j qj have jqj: "j  qj" by force
        note p = pivot'(4-5)[OF ii pii]
        define start where "start = ?e j"
        have "start = A $$ (i,j) * ?v $ j" using j i A by (auto simp: start_def)
        also have "A $$ (i,j) = A $$ (i, p ii)" unfolding jpi ..
        also have " = (if i = ii then 1 else 0)" using p(1) p(2)[OF i] by auto
        also have " * ?v $ j = (if i = ii then ?v $ j else 0)" by simp
        also have "?v $ j = I j" unfolding d 
          using j jqj A by auto
        also have "I j = - A $$ (ii, qj)" unfolding I_def map by simp
        finally have "?e j = ?e' j" 
          unfolding start_def map by auto
      } note piv = this
      have "row A i  ?v = ( j = 0..<nc. ?e j)" unfolding row_def scalar_prod_def dimv ..
      also have " = sum ?e ({0..<nc}  ?P) + sum ?e ({0..<nc} - ?P)"
        by (subst sum.union_disjoint[symmetric], auto simp: id)
      also have "sum ?e ({0..<nc} - ?P) = ?e qj + sum ?e ({0 ..<nc} - ?P - {qj})"
        by (rule sum.remove, insert qj qj', auto)
      also have "?e qj = row A i $ qj" unfolding vqj by simp
      also have "row A i $ qj = A $$ (i, qj)" using i A qj by auto
      also have "sum ?e ({0 ..<nc} - ?P - {qj}) = 0"
      proof (rule sum.neutral, intro ballI)
        fix j
        assume "j  {0 ..<nc} - ?P - {qj}"
        hence j: "j < nc" "j  ?P" "j  qj" "j  ?r" unfolding r by auto
        hence id: "map_of ?spp j = None" unfolding map_of_eq_None_iff by force
        have "?v $ j = I j" unfolding d using j by simp
        also have " = 0" unfolding I_def id by simp 
        finally show "row A i $ j * ?v $ j = 0" by simp
      qed
      also have "A $$ (i, qj) + 0 = A $$ (i, qj)" by simp
      also have "sum ?e ({0..<nc}  ?P) = sum ?e' ({0..<nc}  ?P)"
        by (rule sum.cong, insert piv, auto)
      also have "{0..<nc}  ?P = {0..<nc}  ?I  ?P  ({0..<nc} - ?I)  ?P" by auto
      also have "sum ?e' ({0..<nc}  ?I  ?P  ({0..<nc} - ?I)  ?P)
        = sum ?e' ({0..<nc}  ?I  ?P) + sum ?e' (({0..<nc} - ?I)  ?P)"
        by (rule sum.union_disjoint, auto)
      also have "sum ?e' (({0..<nc} - ?I)  ?P) = 0"
        by (rule sum.neutral, auto)
      also have "sum ?e' ({0..<nc}  ?I  ?P) = 
        sum (λ _. - A $$ (i, qj)) ({0..<nc}  ?I  ?P)"
        by (rule sum.cong, auto)
      also have " + 0 = " by simp
      also have "sum (λ _. - A $$ (i, qj)) ({0..<nc}  ?I  ?P) + A $$ (i, qj) = 0" 
      proof (cases "i  ?l")
        case False
        with pp(1) i have "p i = nc" by force
        from pivot'(2)[OF i, unfolded this, OF qj(1)] have z: "A $$ (i, qj) = 0" .
        show ?thesis 
          by (subst sum.neutral, auto simp: z)
      next
        case True
        then obtain j where mem: "(i,j)  set pp" and id: "(j,i)  set ?spp" by auto
        from map_of_is_SomeI[