Theory Low_Dimensional_Linear_Algebra

(*  Title:      Three_Squares/Low_Dimensional_Linear_Algebra.thy
    Author:     Anton Danilkin
*)

section ‹Vectors and matrices, determinants and
         their properties in dimensions 2 and 3›

theory Low_Dimensional_Linear_Algebra
  imports "HOL-Library.Adhoc_Overloading"
begin

datatype vec2 =
  vec2
  (vec21 : int)
  (vec22 : int)

datatype vec3 =
  vec3
  (vec31 : int)
  (vec32 : int)
  (vec33 : int)

datatype mat2 =
  mat2
  (mat211 : int) (mat212 : int)
  (mat221 : int) (mat222 : int)

datatype mat3 =
  mat3
  (mat311 : int) (mat312 : int) (mat313 : int)
  (mat321 : int) (mat322 : int) (mat323 : int)
  (mat331 : int) (mat332 : int) (mat333 : int)

instantiation vec2 :: ab_group_add
begin

definition zero_vec2 where
"zero_vec2 =
  vec2
  0
  0"

definition uminus_vec2 where
"uminus_vec2 v =
  vec2
  (- vec21 v)
  (- vec22 v)"

definition plus_vec2 where
"plus_vec2 v1 v2 =
  vec2
  (vec21 v1 + vec21 v2)
  (vec22 v1 + vec22 v2)"

definition minus_vec2 where
"minus_vec2 v1 v2 =
  vec2
  (vec21 v1 - vec21 v2)
  (vec22 v1 - vec22 v2)"

instance
  apply intro_classes
  unfolding zero_vec2_def uminus_vec2_def plus_vec2_def minus_vec2_def
  apply simp_all
  done

end

instantiation vec3 :: ab_group_add
begin

definition zero_vec3 where
"zero_vec3 =
  vec3
  0
  0
  0"

definition uminus_vec3 where
"uminus_vec3 v =
  vec3
  (- vec31 v)
  (- vec32 v)
  (- vec33 v)"

definition plus_vec3 where
"plus_vec3 v1 v2 =
  vec3
  (vec31 v1 + vec31 v2)
  (vec32 v1 + vec32 v2)
  (vec33 v1 + vec33 v2)"

definition minus_vec3 where
"minus_vec3 v1 v2 =
  vec3
  (vec31 v1 - vec31 v2)
  (vec32 v1 - vec32 v2)
  (vec33 v1 - vec33 v2)"

instance
  apply intro_classes
  unfolding zero_vec3_def uminus_vec3_def plus_vec3_def minus_vec3_def
  apply simp_all
  done

end

instantiation mat2 :: ring_1
begin

definition zero_mat2 where
"zero_mat2 =
  mat2
  0 0
  0 0"

definition one_mat2 where
"one_mat2 =
  mat2
  1 0
  0 1"

definition uminus_mat2 where
"uminus_mat2 m =
  mat2
  (- mat211 m) (- mat212 m)
  (- mat221 m) (- mat222 m)"

definition plus_mat2 where
"plus_mat2 m1 m2 =
  mat2
  (mat211 m1 + mat211 m2) (mat212 m1 + mat212 m2)
  (mat221 m1 + mat221 m2) (mat222 m1 + mat222 m2)"

definition minus_mat2 where
"minus_mat2 m1 m2 =
  mat2
  (mat211 m1 - mat211 m2) (mat212 m1 - mat212 m2)
  (mat221 m1 - mat221 m2) (mat222 m1 - mat222 m2)"

definition times_mat2 where
"times_mat2 m1 m2 =
  mat2
  (mat211 m1 * mat211 m2 + mat212 m1 * mat221 m2) (mat211 m1 * mat212 m2 + mat212 m1 * mat222 m2)
  (mat221 m1 * mat211 m2 + mat222 m1 * mat221 m2) (mat221 m1 * mat212 m2 + mat222 m1 * mat222 m2)"

instance
  apply intro_classes
  unfolding zero_mat2_def one_mat2_def uminus_mat2_def plus_mat2_def minus_mat2_def times_mat2_def
  apply (simp_all add: algebra_simps)
  done

end

instantiation mat3 :: ring_1
begin

definition zero_mat3 where
"zero_mat3 =
  mat3
  0 0 0
  0 0 0
  0 0 0"

definition one_mat3 where
"one_mat3 =
  mat3
  1 0 0
  0 1 0
  0 0 1"

definition uminus_mat3 where
"uminus_mat3 m =
  mat3
  (- mat311 m) (- mat312 m) (- mat313 m)
  (- mat321 m) (- mat322 m) (- mat323 m)
  (- mat331 m) (- mat332 m) (- mat333 m)"

definition plus_mat3 where
"plus_mat3 m1 m2 =
  mat3
  (mat311 m1 + mat311 m2) (mat312 m1 + mat312 m2) (mat313 m1 + mat313 m2)
  (mat321 m1 + mat321 m2) (mat322 m1 + mat322 m2) (mat323 m1 + mat323 m2)
  (mat331 m1 + mat331 m2) (mat332 m1 + mat332 m2) (mat333 m1 + mat333 m2)"

definition minus_mat3 where
"minus_mat3 m1 m2 =
  mat3
  (mat311 m1 - mat311 m2) (mat312 m1 - mat312 m2) (mat313 m1 - mat313 m2)
  (mat321 m1 - mat321 m2) (mat322 m1 - mat322 m2) (mat323 m1 - mat323 m2)
  (mat331 m1 - mat331 m2) (mat332 m1 - mat332 m2) (mat333 m1 - mat333 m2)"

definition times_mat3 where
"times_mat3 m1 m2 =
  mat3
  (mat311 m1 * mat311 m2 + mat312 m1 * mat321 m2 + mat313 m1 * mat331 m2) (mat311 m1 * mat312 m2 + mat312 m1 * mat322 m2 + mat313 m1 * mat332 m2) (mat311 m1 * mat313 m2 + mat312 m1 * mat323 m2 + mat313 m1 * mat333 m2)
  (mat321 m1 * mat311 m2 + mat322 m1 * mat321 m2 + mat323 m1 * mat331 m2) (mat321 m1 * mat312 m2 + mat322 m1 * mat322 m2 + mat323 m1 * mat332 m2) (mat321 m1 * mat313 m2 + mat322 m1 * mat323 m2 + mat323 m1 * mat333 m2)
  (mat331 m1 * mat311 m2 + mat332 m1 * mat321 m2 + mat333 m1 * mat331 m2) (mat331 m1 * mat312 m2 + mat332 m1 * mat322 m2 + mat333 m1 * mat332 m2) (mat331 m1 * mat313 m2 + mat332 m1 * mat323 m2 + mat333 m1 * mat333 m2)"

instance
  apply intro_classes
  unfolding zero_mat3_def one_mat3_def uminus_mat3_def plus_mat3_def minus_mat3_def times_mat3_def
  apply (simp_all add: algebra_simps)
  done

end

consts vec_dot :: "'a  'a  int" (<_ | _> 65)

definition vec2_dot :: "vec2  vec2  int" where
"vec2_dot v1 v2 = vec21 v1 * vec21 v2 + vec22 v1 * vec22 v2"

adhoc_overloading vec_dot vec2_dot

definition vec3_dot :: "vec3  vec3  int" where
"vec3_dot v1 v2 = vec31 v1 * vec31 v2 + vec32 v1 * vec32 v2 + vec33 v1 * vec33 v2"

adhoc_overloading vec_dot vec3_dot

lemma vec2_dot_zero_left [simp]:
  fixes v :: vec2
  shows "<0 | v> = 0"
  unfolding vec2_dot_def zero_vec2_def by auto

lemma vec2_dot_zero_right [simp]:
  fixes v :: vec2
  shows "<v | 0> = 0"
  unfolding vec2_dot_def zero_vec2_def by auto

lemma vec3_dot_zero_left [simp]:
  fixes v :: vec3
  shows "<0 | v> = 0"
  unfolding vec3_dot_def zero_vec3_def by auto

lemma vec3_dot_zero_right [simp]:
  fixes v :: vec3
  shows "<v | 0> = 0"
  unfolding vec3_dot_def zero_vec3_def by auto

consts mat_app :: "'a  'b  'b" (infixr $ 65)

definition mat2_app :: "mat2  vec2  vec2" where
"mat2_app m v =
  vec2
  (mat211 m * vec21 v + mat212 m * vec22 v)
  (mat221 m * vec21 v + mat222 m * vec22 v)"

adhoc_overloading mat_app mat2_app

definition mat3_app :: "mat3  vec3  vec3" where
"mat3_app m v =
  vec3
  (mat311 m * vec31 v + mat312 m * vec32 v + mat313 m * vec33 v)
  (mat321 m * vec31 v + mat322 m * vec32 v + mat323 m * vec33 v)
  (mat331 m * vec31 v + mat332 m * vec32 v + mat333 m * vec33 v)"

adhoc_overloading mat_app mat3_app

lemma mat2_app_zero [simp]:
  fixes m :: mat2
  shows "m $ 0 = 0"
  unfolding mat2_app_def zero_vec2_def by auto

lemma mat3_app_zero [simp]:
  fixes m :: mat3
  shows "m $ 0 = 0"
  unfolding mat3_app_def zero_vec3_def by auto

lemma mat2_app_one [simp]:
  fixes v :: vec2
  shows "1 $ v = v"
  unfolding mat2_app_def one_mat2_def by auto

lemma mat3_app_one [simp]:
  fixes v :: vec3
  shows "1 $ v = v"
  unfolding mat3_app_def one_mat3_def by auto

lemma mat2_app_mul [simp]:
  fixes m1 m2 :: mat2
  fixes v :: vec2
  shows "m1 * m2 $ v = m1 $ m2 $ v"
  unfolding times_mat2_def mat2_app_def by (simp add: algebra_simps)

lemma mat3_app_mul [simp]:
  fixes m1 m2 :: mat3
  fixes v :: vec3
  shows "m1 * m2 $ v = m1 $ m2 $ v"
  unfolding times_mat3_def mat3_app_def by (simp add: algebra_simps)

consts mat_det :: "'a  int"

definition mat2_det where
"mat2_det m = mat211 m * mat222 m - mat212 m * mat221 m"

adhoc_overloading mat_det mat2_det

definition mat3_det where
"mat3_det m =
    mat311 m * mat322 m * mat333 m
  + mat312 m * mat323 m * mat331 m
  + mat313 m * mat321 m * mat332 m
  - mat311 m * mat323 m * mat332 m
  - mat312 m * mat321 m * mat333 m
  - mat313 m * mat322 m * mat331 m"

adhoc_overloading mat_det mat3_det

lemma mat2_mul_det [simp]:
  fixes m1 m2 :: mat2
  shows "mat_det (m1 * m2) = mat_det m1 * mat_det m2"
  unfolding times_mat2_def mat2_det_def by (simp; algebra)

lemma mat3_mul_det [simp]:
  fixes m1 m2 :: mat3
  shows "mat_det (m1 * m2) = mat_det m1 * mat_det m2"
  unfolding times_mat3_def mat3_det_def by (simp; algebra)

consts mat_sym :: "'a  bool"

definition mat2_sym :: "mat2  bool" where
"mat2_sym m = (mat212 m = mat221 m)"

adhoc_overloading mat_sym mat2_sym

definition mat3_sym :: "mat3  bool" where
"mat3_sym m = (mat312 m = mat321 m  mat313 m = mat331 m  mat323 m = mat332 m)"

adhoc_overloading mat_sym mat3_sym

consts mat_transpose :: "'a  'a" (‹_T [91] 90)

definition mat2_transpose :: "mat2  mat2" where
"mat2_transpose m =
  mat2
  (mat211 m) (mat221 m)
  (mat212 m) (mat222 m)"

adhoc_overloading mat_transpose mat2_transpose

definition mat3_transpose :: "mat3  mat3" where
"mat3_transpose m =
  mat3
  (mat311 m) (mat321 m) (mat331 m)
  (mat312 m) (mat322 m) (mat332 m)
  (mat313 m) (mat323 m) (mat333 m)"

adhoc_overloading mat_transpose mat3_transpose

lemma mat2_transpose_involution [simp]:
  fixes m :: mat2
  shows "(mT)T = m"
  unfolding mat2_transpose_def
  by auto

lemma mat3_transpose_involution [simp]:
  fixes m :: mat3
  shows "(mT)T = m"
  unfolding mat3_transpose_def
  by auto

lemma mat2_sym_criterion:
  fixes m :: mat2
  shows "mat_sym m  mT = m"
  unfolding mat2_sym_def mat2_transpose_def
  by (cases m; auto)

lemma mat3_sym_criterion:
  fixes m :: mat3
  shows "mat_sym m  mT = m"
  unfolding mat3_sym_def mat3_transpose_def
  by (cases m; auto)

lemma mat2_transpose_one [simp]: "(1 :: mat2)T = 1"
  unfolding mat2_transpose_def one_mat2_def by auto

lemma mat3_transpose_one [simp]: "(1 :: mat3)T = 1"
  unfolding mat3_transpose_def one_mat3_def by auto

lemma mat2_transpose_mul [simp]:
  fixes a b :: mat2
  shows "(a * b)T = bT * aT"
  unfolding mat2_transpose_def times_mat2_def by auto

lemma mat3_transpose_mul [simp]:
  fixes a b :: mat3
  shows "(a * b)T = bT * aT"
  unfolding mat3_transpose_def times_mat3_def by auto

lemma vec2_dot_transpose_left:
  fixes m :: mat2
  fixes u v :: vec2
  shows "<mT $ u | v> = <u | m $ v>"
  unfolding vec2_dot_def mat2_app_def mat2_transpose_def
  by (simp add: algebra_simps)

lemma vec2_dot_transpose_right:
  fixes m :: mat2
  fixes u v :: vec2
  shows "<u | mT $ v> = <m $ u | v>"
  unfolding vec2_dot_def mat2_app_def mat2_transpose_def
  by (simp add: algebra_simps)

lemma vec3_dot_transpose_left:
  fixes m :: mat3
  fixes u v :: vec3
  shows "<mT $ u | v> = <u | m $ v>"
  unfolding vec3_dot_def mat3_app_def mat3_transpose_def
  by (simp add: algebra_simps)

lemma vec3_dot_transpose_right:
  fixes m :: mat3
  fixes u v :: vec3
  shows "<u | mT $ v> = <m $ u | v>"
  unfolding vec3_dot_def mat3_app_def mat3_transpose_def
  by (simp add: algebra_simps)

lemma mat2_det_tranpose [simp]:
  fixes m :: mat2
  shows "mat_det (mT) = mat_det m"
  unfolding mat2_det_def mat2_transpose_def by auto

lemma mat3_det_tranpose [simp]:
  fixes m :: mat3
  shows "mat_det (mT) = mat_det m"
  unfolding mat3_det_def mat3_transpose_def by auto

consts mat_inverse :: "'a  'a" (‹_-1 [91] 90)

definition mat2_inverse :: "mat2  mat2" where
"mat2_inverse m =
  mat2
    (mat222 m) (- mat212 m)
    (- mat221 m) (mat211 m)
"

adhoc_overloading mat_inverse mat2_inverse

definition mat3_inverse :: "mat3  mat3" where
"mat3_inverse m =
  mat3
    (mat322 m * mat333 m - mat323 m * mat332 m) (mat313 m * mat332 m - mat312 m * mat333 m) (mat312 m * mat323 m - mat313 m * mat322 m)
    (mat323 m * mat331 m - mat321 m * mat333 m) (mat311 m * mat333 m - mat313 m * mat331 m) (mat313 m * mat321 m - mat311 m * mat323 m)
    (mat321 m * mat332 m - mat322 m * mat331 m) (mat312 m * mat331 m - mat311 m * mat332 m) (mat311 m * mat322 m - mat312 m * mat321 m)
"

adhoc_overloading mat_inverse mat3_inverse

lemma mat2_inverse_cancel:
  fixes m :: mat2
  assumes "mat_det m = 1"
  shows "m * m-1 = 1" "m-1 * m = 1"
  using assms unfolding mat2_det_def mat2_inverse_def times_mat2_def one_mat2_def
  by (auto simp add: algebra_simps)

lemma mat3_inverse_cancel:
  fixes m :: mat3
  assumes "mat_det m = 1"
  shows "m * m-1 = 1" "m-1 * m = 1"
  using assms unfolding mat3_det_def mat3_inverse_def times_mat3_def one_mat3_def
  by (auto simp add: algebra_simps)

lemma mat2_inverse_cancel_left:
  fixes m a :: mat2
  assumes "mat_det m = 1"
  shows "m * (m-1 * a) = a" "m-1 * (m * a) = a"
  unfolding mult.assoc[symmetric]
  using assms mat2_inverse_cancel
  by auto

lemma mat3_inverse_cancel_left:
  fixes m a :: mat3
  assumes "mat_det m = 1"
  shows "m * (m-1 * a) = a" "m-1 * (m * a) = a"
  unfolding mult.assoc[symmetric]
  using assms mat3_inverse_cancel
  by auto

lemma mat2_inverse_cancel_right:
  fixes m a :: mat2
  assumes "mat_det m = 1"
  shows "a * (m * m-1) = a" "a * (m-1 * m) = a"
  using assms mat2_inverse_cancel
  by auto

lemma mat3_inverse_cancel_right:
  fixes m a :: mat3
  assumes "mat_det m = 1"
  shows "a * (m * m-1) = a" "a * (m-1 * m) = a"
  using assms mat3_inverse_cancel
  by auto

lemma mat2_inversable_cancel_left:
  fixes m a1 a2 :: mat2
  assumes "mat_det m = 1"
  assumes "m * a1 = m * a2"
  shows "a1 = a2"
  by (metis assms mat2_inverse_cancel_left(2))

lemma mat3_inversable_cancel_left:
  fixes m a1 a2 :: mat3
  assumes "mat_det m = 1"
  assumes "m * a1 = m * a2"
  shows "a1 = a2"
  by (metis assms mat3_inverse_cancel_left(2))

lemma mat2_inversable_cancel_right:
  fixes m a1 a2 :: mat2
  assumes "mat_det m = 1"
  assumes "a1 * m = a2 * m"
  shows "a1 = a2"
  by (metis assms mat2_inverse_cancel(1) mult.assoc mult.right_neutral)

lemma mat3_inversable_cancel_right:
  fixes m a1 a2 :: mat3
  assumes "mat_det m = 1"
  assumes "a1 * m = a2 * m"
  shows "a1 = a2"
  by (metis assms mat3_inverse_cancel(1) mult.assoc mult.right_neutral)

lemma mat2_inverse_det [simp]:
  fixes m :: mat2
  shows "mat_det (m-1) = mat_det m"
  unfolding mat2_inverse_def mat2_det_def
  by auto

lemma mat3_inverse_det [simp]:
  fixes m :: mat3
  shows "mat_det (m-1) = (mat_det m)2"
  unfolding mat3_inverse_def mat3_det_def power2_eq_square
  by (simp add: algebra_simps)

lemma mat2_inverse_transpose:
  fixes m :: mat2
  shows "(mT)-1 = (m-1)T"
  unfolding mat2_inverse_def mat2_transpose_def
  by auto

lemma mat3_inverse_transpose:
  fixes m :: mat3
  shows "(mT)-1 = (m-1)T"
  unfolding mat3_inverse_def mat3_transpose_def
  by auto

lemma mat2_special_preserves_zero:
  fixes u :: mat2
  fixes v :: vec2
  assumes "mat_det u = 1"
  shows "u $ v = 0  v = 0"
proof
  assume "u $ v = 0"
  hence "u-1 $ u $ v = 0" by auto
  hence "(u-1 * u) $ v = 0" by auto
  thus "v = 0" using assms mat2_inverse_cancel by auto
next
  assume "v = 0"
  thus "u $ v = 0" by auto
qed

lemma mat3_special_preserves_zero:
  fixes u :: mat3
  fixes v :: vec3
  assumes "mat_det u = 1"
  shows "u $ v = 0  v = 0"
proof
  assume "u $ v = 0"
  hence "u-1 $ u $ v = 0" by auto
  hence "(u-1 * u) $ v = 0" by auto
  thus "v = 0" using assms mat3_inverse_cancel by auto
next
  assume "v = 0"
  thus "u $ v = 0" by auto
qed

end