# Theory Matrix_IArray_Impl

```(*
Author:      René Thiemann
*)
section ‹Code Generation for Basic Matrix Operations›

text ‹In this theory we implement matrices as arrays of arrays.
entries should be constant time. Hence operations like
matrix addition, multiplication, etc.~should all have their
standard complexity.

There might be room for optimizations.

To implement the infinite carrier set, we use A.\ Lochbihler's container framework
\<^cite>‹"Containers-AFP"›.›

theory Matrix_IArray_Impl
imports
Matrix
"HOL-Library.IArray"
Containers.Set_Impl
begin

typedef 'a vec_impl = "{(n,v :: 'a iarray). IArray.length v = n}" by auto
typedef 'a mat_impl = "{(nr,nc,m :: 'a iarray iarray).
IArray.length m = nr ∧ IArray.all (λ r. IArray.length r = nc) m}"
by (rule exI[of _ "(0,0,IArray [])"], auto)

setup_lifting type_definition_vec_impl
setup_lifting type_definition_mat_impl

lift_definition vec_impl :: "'a vec_impl ⇒ 'a vec" is
"λ (n,v). (n,mk_vec n (IArray.sub v))" by auto

lift_definition vec_add_impl :: "'a::plus vec_impl ⇒ 'a vec_impl ⇒ 'a vec_impl" is
"λ (n,v) (m,w).
(n, IArray.of_fun (λi. IArray.sub v i + IArray.sub w i) n)"
by auto

lift_definition mat_impl :: "'a mat_impl ⇒ 'a mat" is
"λ (nr,nc,m). (nr,nc,mk_mat nr nc (λ (i,j). IArray.sub (IArray.sub m i) j))" by auto

lift_definition vec_of_list_impl :: "'a list ⇒ 'a vec_impl" is
"λ v. (length v, IArray v)" by auto

lift_definition list_of_vec_impl :: "'a vec_impl ⇒ 'a list" is
"λ (n,v). IArray.list_of v" .

lift_definition vec_of_fun :: "nat ⇒ (nat ⇒ 'a) ⇒ 'a vec_impl" is
"λ n f. (n, IArray.of_fun f n)" by auto

lift_definition mat_of_fun :: "nat ⇒ nat ⇒ (nat × nat ⇒ 'a) ⇒ 'a mat_impl" is
"λ nr nc f. (nr, nc, IArray.of_fun (λ i. IArray.of_fun (λ j. f (i,j)) nc) nr)" by auto

lift_definition vec_index_impl :: "'a vec_impl ⇒ nat ⇒ 'a"
is "λ (n,v). IArray.sub v" .

lift_definition index_mat_impl :: "'a mat_impl ⇒ nat × nat ⇒ 'a"
is "λ (nr,nc,m) (i,j). if i < nr then IArray.sub (IArray.sub m i) j
else IArray.sub (IArray ([] ! (i - nr))) j" .

lift_definition vec_equal_impl :: "'a vec_impl ⇒ 'a vec_impl ⇒ bool"
is "λ (n1,v1) (n2,v2). n1 = n2 ∧ v1 = v2" .

lift_definition mat_equal_impl :: "'a mat_impl ⇒ 'a mat_impl ⇒ bool"
is "λ (nr1,nc1,m1) (nr2,nc2,m2). nr1 = nr2 ∧ nc1 = nc2 ∧ m1 = m2" .

lift_definition dim_vec_impl :: "'a vec_impl ⇒ nat" is fst .

lift_definition dim_row_impl :: "'a mat_impl ⇒ nat" is fst .
lift_definition dim_col_impl :: "'a mat_impl ⇒ nat" is "fst o snd" .

code_datatype vec_impl
code_datatype mat_impl

lemma vec_code[code]: "vec n f = vec_impl (vec_of_fun n f)"
by (transfer, auto simp: mk_vec_def)

lemma mat_code[code]: "mat nr nc f = mat_impl (mat_of_fun nr nc f)"
by (transfer, auto simp: mk_mat_def, intro ext, clarsimp,
auto intro: undef_cong_mat)

lemma vec_of_list[code]: "vec_of_list v = vec_impl (vec_of_list_impl v)"
by (transfer, auto simp: mk_vec_def)

lemma list_of_vec_code[code]: "list_of_vec (vec_impl v) = list_of_vec_impl v"
by (transfer, auto simp: mk_vec_def, case_tac b, auto intro: nth_equalityI)

lemma empty_nth: "¬ i < length x ⟹ x ! i = [] ! (i - length x)"
by (metis append_Nil2 nth_append)

lemma undef_vec: "¬ i < length x ⟹ undef_vec (i - length x) = x ! i"
unfolding undef_vec_def by (rule empty_nth[symmetric])

lemma vec_index_code[code]: "(vec_impl v) \$ i = vec_index_impl v i"
by (transfer, auto simp: mk_vec_def, case_tac b, auto simp: undef_vec)

lemma index_mat_code[code]: "(mat_impl m) \$\$ ij = (index_mat_impl m ij :: 'a)"
proof (transfer, unfold o_def, clarify)
fix m :: "'a iarray iarray" and i j nc
assume all: "IArray.all (λr. IArray.length r = nc) m"
obtain mm where m: "m = IArray mm" by (cases m)
with all have all: "⋀ v. v ∈ set mm ⟹ IArray.length v = nc" by auto
show "snd (snd (IArray.length m, nc, mk_mat (IArray.length m) nc (λ(i, y). m !! i !! y))) (i, j) =
(if i < IArray.length m then m !! i !! j
else IArray ([] ! (i - IArray.length m)) !! j)" (is "?l = ?r")
proof (cases "i < length mm")
case False
hence "⋀ f. ¬ i < length (map f [0..<length mm])" by simp
note [simp] = empty_nth[OF this]
have "?l = [] ! (i - length mm) ! j" using False unfolding m mk_mat_def undef_mat_def by simp
also have "… = ?r" unfolding m by (simp add: False empty_nth[OF False])
finally show ?thesis .
next
case True
obtain v where mm: "mm ! i = IArray v" by (cases "mm ! i")
with True all[of "mm ! i"] have len: "length v = nc" unfolding set_conv_nth by force
from mm True have "?l = map ((!) v) [0..<nc] ! j" (is "_ = ?m") unfolding m mk_mat_def undef_mat_def by simp
also have "?m = m !! i !! j"
proof (cases "j < length v")
case True
thus ?thesis unfolding m using mm len by auto
next
case False
hence j: "¬ j < length (map ((!) v) [0..<length v])" by simp
show ?thesis unfolding m using mm len by (auto simp: empty_nth[OF j] empty_nth[OF False])
qed
also have "… = ?r" using True m by simp
finally show ?thesis .
qed
qed

lift_definition (code_dt) mat_of_rows_list_impl :: "nat ⇒ 'a list list ⇒ 'a mat_impl option" is
"λ n rows. if list_all (λ r. length r = n) rows then Some (length rows, n, IArray (map IArray rows))
else None"
by (auto split: if_splits simp: list_all_iff)

lemma mat_of_rows_list_impl: "mat_of_rows_list_impl n rs = Some A ⟹ mat_impl A = mat_of_rows_list n rs"
unfolding mat_of_rows_list_def
by (transfer, auto split: if_splits simp: list_all_iff intro!: cong_mk_mat)

lemma mat_of_rows_list_code[code]: "mat_of_rows_list nc vs =
(case mat_of_rows_list_impl nc vs of Some A ⇒ mat_impl A
| None ⇒ mat_of_rows nc (map (λ v. vec nc (nth v)) vs))"
proof (cases "mat_of_rows_list_impl nc vs")
case (Some A)
from mat_of_rows_list_impl[OF this] show ?thesis unfolding Some by simp
next
case None
show ?thesis unfolding None unfolding mat_of_rows_list_def mat_of_rows_def
by (intro eq_matI, auto)
qed

lemma dim_vec_code[code]: "dim_vec (vec_impl v) = dim_vec_impl v"
by (transfer, auto)

lemma dim_row_code[code]: "dim_row (mat_impl m) = dim_row_impl m"
by (transfer, auto)

lemma dim_col_code[code]: "dim_col (mat_impl m) = dim_col_impl m"
by (transfer, auto)

instantiation vec :: (type)equal
begin
definition "(equal_vec :: ('a vec ⇒ 'a vec ⇒ bool)) = (=)"
instance
by (intro_classes, auto simp: equal_vec_def)
end

instantiation mat :: (type)equal
begin
definition "(equal_mat :: ('a mat ⇒ 'a mat ⇒ bool)) = (=)"
instance
by (intro_classes, auto simp: equal_mat_def)
end

lemma veq_equal_code[code]: "HOL.equal (vec_impl (v1 :: 'a vec_impl)) (vec_impl v2) = vec_equal_impl v1 v2"
proof -
{
fix x1 x2 :: "'a list"
assume len: "length x1 = length x2"
and index: "(λi. if i < length x2 then IArray x1 !! i else undef_vec (i - length (IArray.list_of (IArray x1)))) =
(λi. if i < length x2 then IArray x2 !! i else undef_vec (i - length (IArray.list_of (IArray x2))))"
have "x1 = x2"
proof (intro nth_equalityI[OF len])
fix i
assume "i < length x1"
with fun_cong[OF index, of i] len show "x1 ! i = x2 ! i" by simp
qed
} note * = this
show ?thesis unfolding equal_vec_def
by (transfer, insert *, auto simp: mk_vec_def, case_tac b, case_tac ba, auto)
qed

lemma mat_equal_code[code]: "HOL.equal (mat_impl (m1 :: 'a mat_impl)) (mat_impl m2) = mat_equal_impl m1 m2"
proof -
show ?thesis unfolding equal_mat_def
proof (transfer, auto, case_tac b, case_tac ba, auto)
fix x1 x2 :: "'a iarray list" and nc
assume len: "∀r∈set x1. length (IArray.list_of r) = nc"
"∀r∈set x2. length (IArray.list_of r) = nc"
"length x1 = length x2"
and index: "mk_mat (length x2) nc (λ(i, j). x1 ! i !! j) = mk_mat (length x2) nc (λ(i, j). x2 ! i !! j)"
show "x1 = x2"
proof (rule nth_equalityI[OF len(3)])
fix i
assume i: "i < length x1"
obtain ia1 where 1: "x1 ! i = IArray ia1" by (cases "x1 ! i")
obtain ia2 where 2: "x2 ! i = IArray ia2" by (cases "x2 ! i")
from i 1 len(1) have l1: "length ia1 = nc" using nth_mem by fastforce
from i 2 len(2-3) have l2: "length ia2 = nc" using nth_mem by fastforce
from l1 l2 have l: "length ia1 = length ia2" by simp
show "x1 ! i = x2 ! i" unfolding 1 2
proof (simp, rule nth_equalityI[OF l])
fix j
assume j: "j < length ia1"
with fun_cong[OF index, of "(i,j)"] i len(3)
have "x1 ! i !! j = x2 ! i !! j"
thus "ia1 ! j = ia2 ! j" unfolding 1 2 by simp
qed
qed
qed
qed

declare prod.set_conv_list[code del, code_unfold]

derive (eq) ceq mat vec
derive (no) ccompare mat vec
derive (dlist) set_impl mat vec
derive (no) cenum mat vec

lemma carrier_mat_code[code]: "carrier_mat nr nc = Collect_set (λ A. dim_row A = nr ∧ dim_col A = nc)" by auto
lemma carrier_vec_code[code]: "carrier_vec n = Collect_set (λ v. dim_vec v = n)"
unfolding carrier_vec_def by auto

end
```