Theory Gauss-Jordan-Elim-Fun.Gauss_Jordan_Elim_Fun

(*  Gauss-Jordan elimination for matrices represented as functions
    Author: Tobias Nipkow
*)
section ‹Gauss-Jordan elimination algorithm›
theory Gauss_Jordan_Elim_Fun
  imports
    "HOL-Combinatorics.Transposition"
begin

text‹Matrices are functions:›

type_synonym 'a matrix = "nat  nat  'a"

text‹In order to restrict to finite matrices, a matrix is usually combined
with one or two natural numbers indicating the maximal row and column of the
matrix.

Gauss-Jordan elimination is parameterized with a natural number n›. It indicates that the matrix A› has n› rows and columns.
In fact, A› is the augmented matrix with n+1› columns. Column
n› is the ``right-hand side'', i.e.\ the constant vector b›. The result is the unit matrix augmented with the solution in column
n›; see the correctness theorem below.›

fun gauss_jordan :: "('a::field)matrix  nat  ('a)matrix option" where
"gauss_jordan A 0 = Some(A)" |
"gauss_jordan A (Suc m) =
 (case dropWhile (λi. A i m = 0) [0..<Suc m] of
   []  None |
   p # _ 
    (let Ap' = (λj. A p j / A p m);
         A' = (λi. if i=p then Ap' else (λj. A i j - A i m * Ap' j))
     in gauss_jordan (Fun.swap p m A') m))"

text‹Some auxiliary functions:›

definition solution :: "('a::field)matrix  nat  (nat  'a)  bool" where
"solution A n x = (i<n. ( j=0..<n. A i j * x j) = A i n)"

definition unit :: "('a::field)matrix  nat  nat  bool" where
"unit A m n =
 (i j::nat. mj  j<n  A i j = (if i=j then 1 else 0))"

lemma solution_swap:
assumes "p1 < n" "p2 < n"
shows "solution (Fun.swap p1 p2 A) n x = solution A n x" (is "?L = ?R")
proof(cases "p1=p2")
  case True thus ?thesis by simp
next
  case False
  show ?thesis
  proof
    assume ?R thus ?L using assms False by(simp add: solution_def Fun.swap_def)
  next
   assume ?L
   show ?R
   proof(auto simp: solution_def)
     fix i assume "i<n"
     show "(j = 0..<n. A i j * x j) = A i n"
     proof cases
       assume "i=p1"
       with ?L assms False show ?thesis
         by(fastforce simp add: solution_def Fun.swap_def)
     next
       assume "ip1"
       show ?thesis
       proof cases
         assume "i=p2"
         with ?L assms False show ?thesis
           by(fastforce simp add: solution_def Fun.swap_def)
       next
         assume "ip2"
         with ip1 ?L i<n assms False show ?thesis
           by(fastforce simp add: solution_def Fun.swap_def)
       qed
     qed
   qed
 qed
qed

(* Converting these apply scripts makes them blow up - see above *)

lemma solution_upd1:
  "c  0  solution (A(p:=(λj. A p j / c))) n x = solution A n x"
apply(cases "p<n")
 prefer 2
 apply(simp add: solution_def)
apply(clarsimp simp add: solution_def)
apply rule
 apply clarsimp
 apply(case_tac "i=p")
  apply (simp add: sum_divide_distrib[symmetric] eq_divide_eq field_simps)
 apply simp
apply (simp add: sum_divide_distrib[symmetric] eq_divide_eq field_simps)
done

lemma solution_upd_but1: " ap = A p; i j. ip  a i j = A i j; p<n  
 solution (λi. if i=p then ap else (λj. a i j - c i * ap j)) n x =
 solution A n x"
apply(clarsimp simp add: solution_def)
apply rule
 prefer 2
 apply (simp add: field_simps sum_subtractf sum_distrib_left[symmetric])
apply(clarsimp)
apply(case_tac "i=p")
 apply simp
apply (auto simp add: field_simps sum_subtractf sum_distrib_left[symmetric] all_conj_distrib)
done

subsection‹Correctness›

text‹The correctness proof:›

lemma gauss_jordan_lemma: "mn  unit A m n  gauss_jordan A m = Some B 
  unit B 0 n  solution A n (λj. B j n)"
proof(induct m arbitrary: A B)
  case 0
  { fix a and b c d :: "'a"
    have "(if a then b else c) * d = (if a then b*d else c*d)" by simp
  } with 0 show ?case by(simp add: unit_def solution_def sum.If_cases)
next
  case (Suc m)
  let "?Ap' p" = "(λj. A p j / A p m)"
  let "?A' p" = "(λi. if i=p then ?Ap' p else (λj. A i j - A i m * ?Ap' p j))"
  from gauss_jordan A (Suc m) = Some B
  obtain p ks where "dropWhile (λi. A i m = 0) [0..<Suc m] = p#ks" and
    rec: "gauss_jordan (Fun.swap p m (?A' p)) m = Some B"
    by (auto split: list.splits)
  from this have p: "pm" "A p m  0"
    apply(simp_all add: dropWhile_eq_Cons_conv del:upt_Suc)
    by (metis set_upt atLeast0AtMost atLeastLessThanSuc_atLeastAtMost atMost_iff in_set_conv_decomp)
  have "mn" "m<n" using Suc m  n by arith+
  have "unit (Fun.swap p m (?A' p)) m n" using Suc.prems(2) p
    unfolding unit_def Fun.swap_def Suc_le_eq by (auto simp: le_less)
  from Suc.hyps[OF mn this rec] m<n p
  show ?case
    by (simp only: solution_swap) (simp_all add: solution_swap solution_upd_but1 [where A = "A(p := ?Ap' p)"] solution_upd1)
qed

theorem gauss_jordan_correct:
  "gauss_jordan A n = Some B  solution A n (λj. B j n)"
by(simp add:gauss_jordan_lemma[of n n] unit_def  field_simps)

definition solution2 :: "('a::field)matrix  nat  nat  (nat  'a)  bool"
where "solution2 A m n x = (i<m. ( j=0..<m. A i j * x j) = A i n)"

definition "usolution A m n x 
  solution2 A m n x  (y. solution2 A m n y  (j<m. y j = x j))"

lemma non_null_if_pivot:
  assumes "usolution A m n x" and "q < m" shows "p<m. A p q  0"
proof(rule ccontr)
  assume "¬(p<m. A p q  0)"
  hence 1: "p. p<m  A p q = 0" by simp
  { fix y assume 2: "j. jq  y j = x j"
    { fix i assume "i<m"
      with assms(1) have "A i n = (j = 0..<m. A i j * x j)"
        by (auto simp: solution2_def usolution_def)
      with 1[OF i<m] 2
      have "(j = 0..<m. A i j * y j) = A i n"
        by (auto intro!: sum.cong)
    }
    hence "solution2 A m n y" by(simp add: solution2_def)
  }
  hence "solution2 A m n (x(q:=0))" and "solution2 A m n (x(q:=1))" by auto
  with assms(1) zero_neq_one q < m
  show False
    by (simp add: usolution_def)
       (metis fun_upd_same zero_neq_one)
qed

lemma lem1:
  fixes f :: "'a  'b::field"
  shows "(xA. f x * (a * g x)) = a * (xA. f x * g x)"
  by (simp add: sum_distrib_left field_simps)

lemma lem2:
  fixes f :: "'a  'b::field"
  shows "(xA. f x * (g x * a)) = a * (xA. f x * g x)"
  by (simp add: sum_distrib_left field_simps)

subsection‹Complete›

lemma gauss_jordan_complete:
  "m  n  usolution A m n x  B. gauss_jordan A m = Some B"
proof(induction m arbitrary: A)
  case 0 show ?case by simp
next
  case (Suc m A)
  from Suc m  n have "mn" and "m<Suc m" by arith+
  from non_null_if_pivot[OF Suc.prems(2) m<Suc m]
  obtain p' where "p'<Suc m" and "A p' m  0" by blast
  hence "dropWhile (λi. A i m = 0) [0..<Suc m]  []"
    by (simp add: atLeast0LessThan) (metis lessThan_iff linorder_neqE_nat not_less_eq)
  then obtain p xs where 1: "dropWhile (λi. A i m = 0) [0..<Suc m] = p#xs"
    by (metis list.exhaust)
  from this have "pm" "A p m  0"
    by (simp_all add: dropWhile_eq_Cons_conv del: upt_Suc)
       (metis set_upt atLeast0AtMost atLeastLessThanSuc_atLeastAtMost atMost_iff in_set_conv_decomp)
  then have p: "p < Suc m" "A p m  0"
    by auto
  let ?Ap' = "(λj. A p j / A p m)"
  let ?A' = "(λi. if i=p then ?Ap' else (λj. A i j - A i m * ?Ap' j))"
  let ?A = "Fun.swap p m ?A'"
  have A: "solution2 A (Suc m) n x" using Suc.prems(2) by(simp add: usolution_def)
  { fix i assume le_m: "p < Suc m" "i < Suc m" "A p m  0"
    have "(j = 0..<m. (A i j - A i m * A p j / A p m) * x j) =
      ((j = 0..<Suc m. A i j * x j) - A i m * x m) -
      ((j = 0..<Suc m. A p j * x j) - A p m * x m) * A i m / A p m"
      by (simp add: field_simps sum_subtractf sum_divide_distrib
                    sum_distrib_left)
    also have " = A i n - A p n * A i m / A p m"
      using A le_m
      by (simp add: solution2_def field_simps del: sum.op_ivl_Suc)
    finally have "(j = 0..<m. (A i j - A i m * A p j / A p m) * x j) =
      A i n - A p n * A i m / A p m" . }
  then have "solution2 ?A m n x" using p
    by (auto simp add: solution2_def Fun.swap_def field_simps)
  moreover
  { fix y assume a: "solution2 ?A m n y"
    let ?y = "y(m := A p n / A p m - (j = 0..<m. A p j * y j) / A p m)"
    have "solution2 A (Suc m) n ?y" unfolding solution2_def
    proof safe
      fix i assume "i < Suc m"
      show "(j=0..<Suc m. A i j * ?y j) = A i n"
      proof (cases "i = p")
        assume "i = p" with p show ?thesis by (simp add: field_simps)
      next
        assume "i  p"
        show ?thesis
        proof (cases "i = m")
          assume "i = m"
          with p i  p have "p < m" by simp
          with a[unfolded solution2_def, THEN spec, of p] p(2)
          have "A p m * (A m m * A p n + A p m * (j = 0..<m. y j * A m j)) = A p m * (A m n * A p m + A m m * (j = 0..<m. y j * A p j))"
            by (simp add: Fun.swap_def field_simps sum_subtractf lem1 lem2 sum_divide_distrib[symmetric]
                     split: if_splits)
          with A p m  0 show ?thesis unfolding i = m
            by simp (simp add: field_simps)
        next
          assume "i  m"
          then have "i < m" using i < Suc m by simp
          with a[unfolded solution2_def, THEN spec, of i] p(2)
          have "A p m * (A i m * A p n + A p m * (j = 0..<m. y j * A i j)) = A p m * (A i n * A p m + A i m * (j = 0..<m. y j * A p j))"
            by (simp add: Fun.swap_def split: if_splits)
              (simp add: field_simps sum_subtractf lem1 lem2 sum_divide_distrib [symmetric])
          with A p m  0 show ?thesis
            by simp (simp add: field_simps)
        qed
      qed
    qed
    with usolution A (Suc m) n x
    have "j<Suc m. ?y j = x j" by (simp add: usolution_def)
    hence "j<m. y j = x j"
      by simp (metis less_SucI nat_neq_iff)
  } ultimately have "usolution ?A m n x" 
    by (simp add: usolution_def)
  note * = Suc.IH [OF m  n this]
  from 1 show ?case
    by auto (use * in blast)
qed

text‹Future work: extend the proof to matrix inversion.›

hide_const (open) unit

end