Theory Deep_Learning.Tensor_Matricization

(* Author: Alexander Bentkamp, Universität des Saarlandes
*)
section ‹Tensor Matricization›

theory Tensor_Matricization
imports Tensor_Plus
Jordan_Normal_Form.Matrix Jordan_Normal_Form.DL_Missing_Sublist
begin

fun digit_decode :: "nat list  nat list  nat" where
"digit_decode [] [] = 0" |
"digit_decode (d # ds) (i # is) = i + d * digit_decode ds is"

fun digit_encode :: "nat list  nat  nat list" where
"digit_encode [] a = []" |
"digit_encode (d # ds) a = a mod d # digit_encode ds (a div d)"

lemma digit_encode_decode[simp]:
assumes "is  ds"
shows "digit_encode ds (digit_decode ds is) = is"
  using assms apply (induction rule:valid_index.induct)
  unfolding digit_decode.simps digit_encode.simps
  by simp_all

lemma digit_decode_encode[simp]:
shows "digit_decode ds (digit_encode ds a) = a mod (prod_list ds)"
  by (induction ds arbitrary: a) (simp_all add: mod_mult2_eq)

lemma digit_decode_encode_lt[simp]:
assumes "a < prod_list ds"
shows "digit_decode ds (digit_encode ds a) = a"
by (simp add: assms)

lemma digit_decode_lt:
assumes "is  ds"
shows "digit_decode ds is < prod_list ds"
using assms proof (induction rule:valid_index.induct)
  case Nil
  then show ?case by simp
next
  case (Cons "is" ds i d)
  have "(i + d * digit_decode ds is) div (d * prod_list ds) = 0"
    using Cons.IH Cons.hyps(2) div_mult2_eq by force
  then show ?case unfolding digit_decode.simps prod_list.Cons
    by (metis (no_types) Cons.IH Cons.hyps(2) div_eq_0_iff mult_eq_0_iff not_less0)
qed

lemma digit_encode_valid_index:
assumes "a < prod_list ds"
shows "digit_encode ds a  ds"
using assms proof (induction ds arbitrary:a)
  case Nil
  show ?case by (simp add: valid_index.Nil)
next
  case (Cons d ds a)
  then have "a < d * prod_list ds"
    by simp
  then have "a div d < prod_list ds"
    by (metis div_eq_0_iff div_mult2_eq mult_0_right not_less0)
  then have "digit_encode ds (a div d)  ds"
    by (rule Cons)
  moreover have "d > 0"
    using a < d * prod_list ds by (cases "d = 0") simp_all
  then have "a mod d < d"
    by simp
  ultimately show ?case
    by (simp add: valid_index.Cons)
qed

lemma length_digit_encode:
shows "length (digit_encode ds a) = length ds"
  by (induction ds arbitrary:a; simp_all)

lemma digit_encode_0:
"prod_list ds dvd a  digit_encode ds a = replicate (length ds) 0"
proof (induction ds arbitrary:a)
  case Nil
  then show ?case by simp
next
  case (Cons d ds a)
  then have "prod_list ds dvd (a div d)" unfolding prod_list.Cons
    by (metis dvd_0_right dvd_div_iff_mult dvd_mult_left mult.commute split_div)
  then show ?case unfolding digit_encode.simps length_Cons replicate_Suc prod_list.Cons using Cons
    using dvd_imp_mod_0 dvd_mult_left prod_list.Cons by force
qed
  
lemma valid_index_weave:
assumes "is1  (nths ds A)"
and     "is2  (nths ds (-A))"
shows "weave A is1 is2  ds"
and "nths (weave A is1 is2) A = is1"
and "nths (weave A is1 is2) (-A) = is2"
proof -
  have length_ds: "length is1 + length is2 = length ds"
    using valid_index_length[OF assms(1)] valid_index_length[OF assms(2)]
    length_weave  weave_complementary_nthss by metis
  have 1:"length is1 = card {i  A. i < length is1 + length is2}" unfolding length_ds
    using length_nths' assms(1) valid_index_length by auto
  have 2:"length is2 = card {i  -A. i < length is1 + length is2}" unfolding length_ds
    using length_nths'[of ds "-A"] assms(2) valid_index_length by auto
  show "nths (weave A is1 is2) A = is1" "nths (weave A is1 is2) (-A) = is2" using nths_weave[OF 1 2] by blast+
  then have "nths (weave A is1 is2) A  (nths ds A)"
       "nths (weave A is1 is2) (-A)  (nths ds (-A))" using assms by auto
  then show "weave A is1 is2  ds" using list_all2_nths valid_index_list_all2_iff by blast
qed

definition matricize :: "nat set  'a tensor  'a mat" where
"matricize rmodes T = mat
  (prod_list (nths (Tensor.dims T) rmodes))
  (prod_list (nths (Tensor.dims T) (-rmodes)))
  (λ(r, c). Tensor.lookup T (weave rmodes
    (digit_encode (nths (Tensor.dims T) rmodes) r)
    (digit_encode (nths (Tensor.dims T) (-rmodes)) c)
  ))
"

definition dematricize::"nat set  'a mat  nat list  'a tensor" where
"dematricize rmodes A ds  = tensor_from_lookup ds
  (λis. A $$ (digit_decode (nths ds rmodes) (nths is rmodes),
              digit_decode (nths ds (-rmodes)) (nths is (-rmodes)))
 )
"

lemma dims_matricize:
"dim_row (matricize rmodes T) = prod_list (nths (Tensor.dims T) rmodes)"
"dim_col (matricize rmodes T) = prod_list (nths (Tensor.dims T) (-rmodes))"
  unfolding matricize_def using dim_row_mat by simp_all

lemma dims_dematricize: "Tensor.dims (dematricize rmodes A ds) = ds"
  by (simp add: dematricize_def dims_tensor_from_lookup)

lemma valid_index_nths:
assumes "is  ds"
shows "nths is A  nths ds A"
using assms proof (induction arbitrary:A rule:valid_index.induct)
  case Nil
  then show ?case using nths_nil valid_index.simps by blast
next
  case (Cons "is" ds i d)
  then have " nths is {j. Suc j  A}  nths ds {j. Suc j  A}"
    by simp
  then show ?case unfolding nths_Cons
    by (cases "0A"; simp_all add: Cons.hyps(2) valid_index.Cons)
qed

lemma dematricize_matricize:
shows "dematricize rmodes (matricize rmodes T) (Tensor.dims T) = T"
proof (rule tensor_lookup_eqI)
  show 1:"Tensor.dims (dematricize rmodes (matricize rmodes T) (Tensor.dims T)) = Tensor.dims T"
    by (simp add: dematricize_def dims_tensor_from_lookup)
  fix "is" assume "is  Tensor.dims (dematricize rmodes (matricize rmodes T) (Tensor.dims T))"
  then have "is  Tensor.dims T" using 1 by auto
  let ?rds = "(nths (Tensor.dims T) rmodes)"
  let ?cds = "(nths (Tensor.dims T) (-rmodes))"
  have decode_r: "digit_decode ?rds (nths is rmodes) < prod_list ?rds"
    by (simp add: is  Tensor.dims T valid_index_nths digit_decode_lt)
  have decode_c: "digit_decode ?cds (nths is (-rmodes)) < prod_list ?cds"
    by (simp add: is  Tensor.dims T valid_index_nths digit_decode_lt)
  have "(matricize rmodes T) $$
     (digit_decode ?rds (nths is rmodes),
      digit_decode ?cds (nths is (- rmodes))) =
    Tensor.lookup T is"
    unfolding matricize_def
    by (simp add: decode_r decode_c is  Tensor.dims T valid_index_nths)
  then show "Tensor.lookup (dematricize rmodes (matricize rmodes T) (Tensor.dims T)) is = Tensor.lookup T is"
    by (simp add: dematricize_def dims_tensor_from_lookup lookup_tensor_from_lookup[OF is  Tensor.dims T])
qed

lemma matricize_dematricize:
assumes " dim_row A = prod_list (nths ds rmodes)"
and " dim_col A = prod_list (nths ds (-rmodes))"
shows "matricize rmodes (dematricize rmodes A ds) = A"
proof (rule eq_matI)
  show "dim_row (matricize rmodes (dematricize rmodes A ds)) = dim_row A"
    unfolding assms(1) dematricize_def dims_tensor_from_lookup matricize_def dim_row_mat by metis
  show "dim_col (matricize rmodes (dematricize rmodes A ds)) = dim_col A"
    unfolding assms(2) dematricize_def dims_tensor_from_lookup matricize_def dim_col_mat by metis
  fix r c assume "r < dim_row A" "c < dim_col A"
  have valid1:"digit_encode (nths ds rmodes) r  nths ds rmodes" and
       valid2:"digit_encode (nths ds (- rmodes)) c  nths ds (- rmodes)"
    using r < dim_row A assms(1) c < dim_col A assms(2) digit_encode_valid_index by auto
  have 0:"Tensor.lookup (dematricize rmodes A ds)
     (weave rmodes
       (digit_encode (nths (Tensor.dims (dematricize rmodes A ds)) rmodes) r)
       (digit_encode (nths (Tensor.dims (dematricize rmodes A ds)) (- rmodes)) c)
     ) =  A $$ (r, c)"
      unfolding dematricize_def unfolding dims_tensor_from_lookup
      unfolding lookup_tensor_from_lookup[OF valid_index_weave(1)[OF valid1 valid2]]
      using digit_decode_encode_lt[OF c < dim_col A[unfolded assms(2)]]
      digit_decode_encode_lt[OF r < dim_row A[unfolded assms(1)]]
      valid_index_weave(2)[OF valid1 valid2] valid_index_weave(3)[OF valid1 valid2]
      by presburger
  from r < dim_row A have r_le: "r < prod_list (nths (Tensor.dims (dematricize rmodes A ds)) rmodes)"
    by (metis dim_row (matricize rmodes (dematricize rmodes A ds)) = dim_row A matricize_def dim_row_mat(1))
  from c < dim_col Ahave c_le: "c < prod_list (nths (Tensor.dims (dematricize rmodes A ds)) (- rmodes))"
    by (metis dim_col (matricize rmodes (dematricize rmodes A ds)) = dim_col A matricize_def dim_col_mat(1))
  then show "(matricize rmodes (dematricize rmodes A ds)) $$ (r, c) = A $$ (r, c)"
    unfolding matricize_def using r_le c_le 0 by simp
qed

lemma matricize_add:
assumes "dims A = dims B"
shows "matricize I A + matricize I B = matricize I (A+B)"
proof (rule eq_matI)
  show "dim_row (matricize I A + matricize I B) = dim_row (matricize I (A + B))" by (simp add: assms dims_matricize(1))
  show "dim_col (matricize I A + matricize I B) = dim_col (matricize I (A + B))" by (simp add: assms dims_matricize(2))
  fix i j assume ij_le1:"i < dim_row (matricize I (A + B))" "j < dim_col (matricize I (A + B))"
  then have
    ij_le2:"i < prod_list (nths (Tensor.dims A) I)"  "j < prod_list (nths (Tensor.dims A) (-I))" and
    ij_le3:"i < prod_list (nths (Tensor.dims B) I)"  "j < prod_list (nths (Tensor.dims B) (-I))" and
    ij_le4:"i < prod_list (nths (Tensor.dims (A + B)) I)"  "j < prod_list (nths (Tensor.dims (A + B)) (-I))"
    by (simp_all add: assms dims_matricize)
  then have ij_le5:"i < dim_row (matricize I B)" "j < dim_col (matricize I B)"
    by (simp_all add: assms dims_matricize)
  show "(matricize I A + matricize I B) $$ (i, j) = matricize I (A + B) $$ (i, j)"
    unfolding index_add_mat(1)[OF ij_le5] unfolding matricize_def unfolding index_mat[OF ij_le2] index_mat[OF ij_le3] index_mat[OF ij_le4]
    using assms digit_encode_valid_index ij_le2(1) ij_le2(2) valid_index_weave(1) by auto
qed

lemma matricize_0:
shows "matricize I (tensor0 ds) = 0m (dim_row (matricize I (tensor0 ds))) (dim_col (matricize I (tensor0 ds)))"
proof (rule eq_matI)
  show "dim_row (matricize I (tensor0 ds)) = dim_row (0m (dim_row (matricize I (tensor0 ds))) (dim_col (matricize I (tensor0 ds))))"
    unfolding zero_mat_def dim_row_mat by (simp add: dims_matricize(1))
  show "dim_col (matricize I (tensor0 ds)) = dim_col (0m (dim_row (matricize I (tensor0 ds))) (dim_col (matricize I (tensor0 ds))))"
    unfolding zero_mat_def dim_row_mat by (simp add: dims_matricize(2))
  fix i j assume ij_le1: "i < dim_row (0m (dim_row (matricize I (tensor0 ds))) (dim_col (matricize I (tensor0 ds))))"
                 "j < dim_col (0m (dim_row (matricize I (tensor0 ds))) (dim_col (matricize I (tensor0 ds))))"
  then have ij_le2:"i < dim_row (matricize I (tensor0 ds))" "j < dim_col (matricize I (tensor0 ds))"
    unfolding zero_mat_def dim_row_mat by (simp_all add: dims_matricize)
  show "matricize I (tensor0 ds) $$ (i, j) = 0m (dim_row (matricize I (tensor0 ds))) (dim_col (matricize I (tensor0 ds))) $$ (i, j)"
    unfolding zero_mat_def  index_mat[OF ij_le2] unfolding matricize_def index_mat[OF ij_le2[unfolded dims_matricize]]
    by (simp, metis lookup_tensor0 digit_encode_valid_index dims_matricize(1) dims_matricize(2) dims_tensor0
    ij_le2(1) ij_le2(2) valid_index_weave(1))
qed

end