# Theory DL_Flatten_Matrix

theory DL_Flatten_Matrix
imports Matrix
```(* Author: Alexander Bentkamp, Universität des Saarlandes
*)
section ‹Matrix to Vector Conversion›

theory DL_Flatten_Matrix
imports Jordan_Normal_Form.Matrix
begin

definition extract_matrix :: "(nat ⇒ 'a) ⇒ nat ⇒ nat ⇒ 'a mat" where
"extract_matrix a m n = mat m n (λ(i,j). a (i*n + j))"

definition flatten_matrix :: "'a mat ⇒ (nat ⇒ 'a)" where
"flatten_matrix A k = A \$\$ (k div dim_col A, k mod dim_col A)"

lemma two_digit_le:
"i * n + j < m * n" if "i < m" "j < n" for i j :: nat

lemma extract_matrix_cong:
assumes "⋀i. i < m * n ⟹ a i = b i"
shows "extract_matrix a m n = extract_matrix b m n"
proof -
have "⋀i j. i < m ⟹ j < n ⟹ a (i*n + j) = b (i*n + j)" using two_digit_le assms by blast
then show ?thesis unfolding extract_matrix_def by auto
qed

lemma extract_matrix_flatten_matrix:
"extract_matrix (flatten_matrix A) (dim_row A) (dim_col A) = A"
unfolding extract_matrix_def flatten_matrix_def by auto

lemma extract_matrix_flatten_matrix_cong:
assumes "⋀x. x < dim_row A * dim_col A ⟹ f x = flatten_matrix A x"
shows "extract_matrix f (dim_row A) (dim_col A) = A"
unfolding extract_matrix_def
by (metis assms extract_matrix_cong extract_matrix_def extract_matrix_flatten_matrix)

lemma flatten_matrix_extract_matrix:
"flatten_matrix (extract_matrix a m n) k = a k" if "k < m * n"
proof -
from that have "m * n > 0"
by (cases "m * n = 0") simp_all
then have "m > 0" and "n > 0"
by simp_all
with that have "k div n < m"
by (metis div_eq_0_iff div_mult2_eq mult.commute neq0_conv)
moreover have "k mod n < n"
using ‹n > 0› by simp
ultimately show ?thesis
by (auto simp add: extract_matrix_def flatten_matrix_def)
qed

lemma index_extract_matrix:
assumes "i<m" "j<n"
shows "extract_matrix a m n \$\$ (i,j) = a (i*n + j)"
unfolding extract_matrix_def using assms by simp

lemma dim_extract_matrix:
shows "dim_row (extract_matrix as m n) = m"
and "dim_col (extract_matrix as m n) = n"
unfolding extract_matrix_def by simp_all

end
```