# Theory DL_Deep_Model_Poly

section ‹Polynomials representing the Deep Network Model›

theory DL_Deep_Model_Poly
imports DL_Deep_Model Polynomials.More_MPoly_Type Jordan_Normal_Form.Determinant
begin

lemma polyfun_det:
assumes "⋀x. (A x) ∈ carrier_mat n n"
assumes "⋀x i j. i<n ⟹ j<n ⟹ polyfun N (λx. (A x) \$\$ (i,j))"
shows "polyfun N (λx. det (A x))"
proof -
{
fix p assume "p∈ {p. p permutes {0..<n}}"
then have "p permutes {0..<n}" by auto
then have "⋀x. x < n ⟹ p x < n" using permutes_in_image by auto
then have "polyfun N (λx. ∏i = 0..<n. A x \$\$ (i, p i))"
using polyfun_Prod[of "{0..<n}" N "λi x. A x \$\$ (i, p i)"] assms by simp
then have "polyfun N (λx. signof p * (∏i = 0..<n. A x \$\$ (i, p i)))" using polyfun_const polyfun_mult by blast
}
moreover have "finite {i. i permutes {0..<n}}" by (simp add: finite_permutations)
ultimately show ?thesis  unfolding det_def'[OF assms(1)]
using polyfun_Sum[OF ‹finite {i. i permutes {0..<n}}›, of N "λp x. signof p * (∏i = 0..<n. A x \$\$ (i, p i))"]
by blast
qed

lemma polyfun_extract_matrix:
assumes "i<m" "j<n"
shows "polyfun {..<a + (m * n + c)} (λf. extract_matrix (λi. f (i + a)) m n \$\$ (i,j))"
unfolding index_extract_matrix[OF assms] apply (rule polyfun_single) using two_digit_le[OF assms] by simp

lemma polyfun_mult_mat_vec:
assumes "⋀x. v x ∈ carrier_vec n"
assumes "⋀j. j<n ⟹ polyfun N (λx. v x \$ j)"
assumes "⋀x. A x ∈ carrier_mat m n"
assumes "⋀i j. i<m ⟹ j<n ⟹ polyfun N (λx. A x \$\$ (i,j))"
assumes "j < m"
shows "polyfun N (λx. ((A x) *⇩v (v x)) \$ j)"
proof -
have "⋀x. j < dim_row (A x)" using ‹j < m› assms(3) carrier_matD(1) by force
have "⋀x. n = dim_vec (v x)" using assms(1) carrier_vecD by fastforce
{
fix i assume "i ∈ {0..<n}"
then have "i < n" by auto
{
fix x
have "i < dim_vec (v x)" using assms(1) carrier_vecD ‹i<n› by fastforce
have "j < dim_row (A x)" using ‹j < m› assms(3) carrier_matD(1) by force
have "dim_col (A x) = dim_vec (v x)" by (metis assms(1) assms(3) carrier_matD(2) carrier_vecD)
then have "row (A x) j \$ i = A x \$\$ (j,i)" "i<n" using ‹j < dim_row (A x)› ‹i<n› by (simp_all add: ‹i < dim_vec (v x)›)
}
then have "polyfun N (λx. row (A x) j \$ i * v x \$ i)"
using polyfun_mult assms(4)[OF ‹j < m›] assms(2) by fastforce
}
then show ?thesis unfolding index_mult_mat_vec[OF ‹⋀x. j < dim_row (A x)›] scalar_prod_def
using polyfun_Sum[of "{0..<n}" N "λi x. row (A x) j \$ i * v x \$ i"] finite_atLeastLessThan[of 0 n] ‹⋀x. n = dim_vec (v x)›
by simp
qed

(* The variable a has been inserted here to make the induction work:*)
lemma polyfun_evaluate_net_plus_a:
assumes "map dim_vec inputs = input_sizes m"
assumes "valid_net m"
assumes "j < output_size m"
shows "polyfun {..<a + count_weights s m} (λf. evaluate_net (insert_weights s m (λi. f (i + a))) inputs \$ j)"
using assms proof (induction m arbitrary:inputs j a)
case (Input)
then show ?case unfolding insert_weights.simps evaluate_net.simps using polyfun_const by metis
next
case (Conv x m)
then obtain x1 x2 where "x=(x1,x2)" by fastforce
show ?case unfolding ‹x=(x1,x2)› insert_weights.simps evaluate_net.simps drop_map unfolding list_of_vec_index
proof (rule polyfun_mult_mat_vec)
{
fix f
have 1:"valid_net' (insert_weights s m (λi. f (i + x1 * x2)))"
using ‹valid_net (Conv x m)› valid_net.simps by (metis
convnet.distinct(1) convnet.distinct(5) convnet.inject(2) remove_insert_weights)
have 2:"map dim_vec inputs = input_sizes (insert_weights s m (λi. f (i + x1 * x2)))"
using input_sizes_remove_weights remove_insert_weights
have "dim_vec (evaluate_net (insert_weights s m (λi. f (i + x1 * x2))) inputs) = output_size m"
using output_size_correct[OF 1 2] using remove_insert_weights by auto
then show "evaluate_net (insert_weights s m (λi. f (i + x1 * x2))) inputs ∈ carrier_vec (output_size m)"
using carrier_vec_def by (metis (full_types) mem_Collect_eq)
}

have "map dim_vec inputs = input_sizes m" by (simp add: Conv.prems(1))
have "valid_net m" using Conv.prems(2) valid_net.cases by fastforce
show "⋀j. j < output_size m ⟹  polyfun {..<a + count_weights s (Conv (x1, x2) m)}
(λf. evaluate_net (insert_weights s m (λi. f (i + x1 * x2 + a))) inputs \$ j)"
unfolding vec_of_list_index count_weights.simps
using Conv(1)[OF ‹map dim_vec inputs = input_sizes m› ‹valid_net m›, of _ "x1 * x2 + a"]
by blast

have "output_size m = x2" using Conv.prems(2) ‹x = (x1, x2)› valid_net.cases by fastforce
show "⋀f. extract_matrix (λi. f (i + a)) x1 x2 ∈ carrier_mat x1 (output_size m)" unfolding ‹output_size m = x2› using dim_extract_matrix
using carrier_matI by (metis (no_types, lifting))

show "⋀i j. i < x1 ⟹ j < output_size m ⟹ polyfun {..<a + count_weights s (Conv (x1, x2) m)} (λf. extract_matrix (λi. f (i + a)) x1 x2 \$\$ (i, j))"
unfolding ‹output_size m = x2› count_weights.simps using polyfun_extract_matrix[of _ x1 _ x2 a "count_weights s m"] by blast

show "j < x1" using Conv.prems(3) ‹x = (x1, x2)› by auto
qed
next
case (Pool m1 m2 inputs j a)
have A2:"⋀f. map dim_vec (take (length (input_sizes (insert_weights s m1 (λi. f (i + a))))) inputs) = input_sizes m1"
by (metis Pool.prems(1)  append_eq_conv_conj input_sizes.simps(3) input_sizes_remove_weights remove_insert_weights take_map)
have B2:"⋀f. map dim_vec (drop (length (input_sizes (insert_weights s m1 (λi. f (i + a))))) inputs) = input_sizes m2"
using Pool.prems(1) append_eq_conv_conj input_sizes.simps(3) input_sizes_remove_weights remove_insert_weights by (metis drop_map)
have A3:"valid_net m1" and B3:"valid_net m2" using ‹valid_net (Pool m1 m2)› valid_net.simps by blast+
have "output_size (Pool m1 m2) = output_size m2" unfolding output_size.simps
using ‹valid_net (Pool m1 m2)› "valid_net.cases" by fastforce
then have A4:"j < output_size m1" and B4:"j < output_size m2" using ‹j < output_size (Pool m1 m2)› by simp_all

let ?net1 = "λf. evaluate_net (insert_weights s m1 (λi. f (i + a)))
(take (length (input_sizes (insert_weights s m1 (λi. f (i + a))))) inputs)"
let ?net2 = "λf. evaluate_net (insert_weights s m2 (if s then λi. f (i + a) else (λi. f (i + count_weights s m1 + a))))
(drop (length (input_sizes (insert_weights s m1 (λi. f (i + a))))) inputs)"
have length1: "⋀f. output_size m1 = dim_vec (?net1 f)"
by (metis A2 A3 input_sizes_remove_weights output_size_correct remove_insert_weights)
then have jlength1:"⋀f. j < dim_vec (?net1 f)" using A4 by metis
have length2: "⋀f. output_size m2 = dim_vec (?net2 f)"
by (metis B2 B3 input_sizes_remove_weights output_size_correct remove_insert_weights)
then have jlength2:"⋀f. j < dim_vec (?net2 f)" using B4 by metis
have cong1:"⋀xf. (λf. evaluate_net (insert_weights s m1 (λi. f (i + a)))
(take (length (input_sizes (insert_weights s m1 (λi. xf (i + a))))) inputs) \$ j)
= (λf. ?net1 f \$ j)"
using input_sizes_remove_weights remove_insert_weights by auto
have cong2:"⋀xf. (λf. evaluate_net (insert_weights s m2 (λi. f (i + (a + (if s then 0 else count_weights s m1)))))
(drop (length (input_sizes (insert_weights s m1 (λi. xf (i + a))))) inputs) \$ j)
= (λf. ?net2 f \$ j)"
using input_sizes_remove_weights remove_insert_weights by auto

show ?case unfolding insert_weights.simps evaluate_net.simps  count_weights.simps
unfolding  index_component_mult[OF jlength1 jlength2]
apply (rule polyfun_mult)
using Pool.IH(1)[OF A2 A3 A4, of a, unfolded cong1]
apply (simp add:polyfun_subset[of "{..<a + count_weights s m1}" "{..<a + (if s then max (count_weights s m1) (count_weights s m2) else count_weights s m1 + count_weights s m2)}"])
using Pool.IH(2)[OF B2 B3 B4, of "a + (if s then 0 else count_weights s m1)", unfolded cong2 semigroup_add_class.add.assoc[of a]]
by (simp add:polyfun_subset[of "{..<a + ((if s then 0 else count_weights s m1) + count_weights s m2)}" "{..<a + (if s then max (count_weights s m1) (count_weights s m2) else count_weights s m1 + count_weights s m2)}"])
qed

lemma polyfun_evaluate_net:
assumes "map dim_vec inputs = input_sizes m"
assumes "valid_net m"
assumes "j < output_size m"
shows "polyfun {..<count_weights s m} (λf. evaluate_net (insert_weights s m f) inputs \$ j)"
using polyfun_evaluate_net_plus_a[where a=0, OF assms] by simp

lemma polyfun_tensors_from_net:
assumes "valid_net m"
assumes "is ⊲ input_sizes m"
assumes "j < output_size m"
shows "polyfun {..<count_weights s m} (λf. Tensor.lookup (tensors_from_net (insert_weights s m f) \$ j) is)"
proof -
have 1:"⋀f. valid_net' (insert_weights s m f)" by (simp add: assms(1) remove_insert_weights)
have input_sizes:"⋀f. input_sizes (insert_weights s m f) = input_sizes m"
unfolding input_sizes_remove_weights by (simp add: remove_insert_weights)
have 2:"⋀f. is ⊲ input_sizes (insert_weights s m f)"
unfolding input_sizes using assms(2) by blast
have 3:"⋀f. j < output_size' (insert_weights s m f)"
have "⋀f1 f2. base_input (insert_weights s m f1) is = base_input (insert_weights s m f2) is"
unfolding base_input_def by (simp add: input_sizes)
then have "⋀xf. (λf. evaluate_net (insert_weights s m f) (base_input (insert_weights s m xf) is) \$ j)
= (λf. evaluate_net (insert_weights s m f) (base_input (insert_weights s m f) is) \$ j)"
by metis
then show ?thesis unfolding lookup_tensors_from_net[OF 1 2 3]
using polyfun_evaluate_net[OF base_input_length[OF 2, unfolded input_sizes, symmetric] assms(1) assms(3), of s]
by simp
qed

lemma polyfun_matricize:
assumes "⋀x. dims (T x) = ds"
assumes "⋀is. is ⊲ ds ⟹ polyfun N (λx. Tensor.lookup (T x) is)"
assumes "⋀x. dim_row (matricize I (T x)) = nr"
assumes "⋀x. dim_col (matricize I (T x)) = nc"
assumes "i < nr"
assumes "j < nc"
shows "polyfun N (λx. matricize I (T x) \$\$ (i,j))"
proof -
let ?weave = "λ x. (weave I
(digit_encode (nths ds I ) i)
(digit_encode (nths ds (-I )) j))"
have 1:"⋀x. matricize I (T x) \$\$ (i,j) = Tensor.lookup (T x) (?weave x)" unfolding matricize_def
by (metis (no_types, lifting) assms(1) assms(3) assms(4) assms(5) assms(6) case_prod_conv
dim_col_mat(1) dim_row_mat(1) index_mat(1) matricize_def)
have "⋀x. ?weave x ⊲ ds"
using valid_index_weave(1) assms(2) digit_encode_valid_index dim_row_mat(1) matricize_def
using assms digit_encode_valid_index matricize_def by (metis dim_col_mat(1))
then have "polyfun N (λx. Tensor.lookup (T x) (?weave x))" using assms(2) by simp
then show ?thesis unfolding 1 using assms(1) by blast
qed

lemma "(¬ (a::nat) < b) = (a ≥ b)"
by (metis not_le)

lemma polyfun_submatrix:
assumes "⋀x. (A x) ∈ carrier_mat m n"
assumes "⋀x i j. i<m ⟹ j<n ⟹ polyfun N (λx. (A x) \$\$ (i,j))"
assumes "i < card {i. i < m ∧ i ∈ I}"
assumes "j < card {j. j < n ∧ j ∈ J}"
assumes "infinite I" "infinite J"
shows "polyfun N (λx. (submatrix (A x) I J) \$\$ (i,j))"
proof -
have 1:"⋀x. (submatrix (A x) I J) \$\$ (i,j) = (A x) \$\$ (pick I i, pick J j)"
using submatrix_index by (metis (no_types, lifting) Collect_cong assms(1) assms(3) assms(4) carrier_matD(1) carrier_matD(2))
have "pick I i < m"  "pick J j < n" using card_le_pick_inf[OF ‹infinite I›] card_le_pick_inf[OF ‹infinite J›]
‹i < card {i. i < m ∧ i ∈ I}›[unfolded set_le_in] ‹j < card {j. j < n ∧ j ∈ J}›[unfolded set_le_in] not_less by metis+
then show ?thesis unfolding 1 by (simp add: assms(2))
qed

context deep_model_correct_params_y
begin

definition witness_submatrix where
"witness_submatrix f = submatrix (A' f) rows_with_1 rows_with_1"

lemma polyfun_tensor_deep_model:
assumes "is ⊲ input_sizes (deep_model_l rs)"
shows "polyfun {..<weight_space_dim}
(λf. Tensor.lookup (tensors_from_net (insert_weights shared_weights (deep_model_l rs) f) \$ y) is)"
proof -
have 1:"⋀f. remove_weights (insert_weights shared_weights (deep_model_l rs) f) = deep_model_l rs"
using remove_insert_weights by metis
then have "y < output_size ( deep_model_l rs)" using valid_deep_model y_valid length_output_deep_model by force
have 0:"{..<weight_space_dim} = set [0..<weight_space_dim]" by auto
then show ?thesis unfolding weight_space_dim_def using polyfun_tensors_from_net assms(1) valid_deep_model
‹y < output_size ( deep_model_l rs )› by metis
qed

lemma input_sizes_deep_model: "input_sizes (deep_model_l rs) = replicate (2 * N_half) (last rs)"
unfolding N_half_def using input_sizes_deep_model deep
by (metis (no_types, lifting) Nitpick.size_list_simp(2) One_nat_def Suc_1 Suc_le_lessD diff_Suc_Suc length_tl less_imp_le_nat list.size(3) not_less_eq numeral_3_eq_3 power_eq_if)

lemma polyfun_matrix_deep_model:
assumes "i<(last rs) ^ N_half"
assumes "j<(last rs) ^ N_half"
shows "polyfun {..<weight_space_dim} (λf. A' f \$\$ (i,j))"
proof -
have 0:"y < output_size ( deep_model_l rs )" using valid_deep_model y_valid length_output_deep_model by force
have 1:"⋀f. remove_weights (insert_weights shared_weights (deep_model_l rs) f) = deep_model_l rs"
using remove_insert_weights by metis
have 2:"(⋀f is. is ⊲ replicate (2 * N_half) (last rs) ⟹
polyfun {..<weight_space_dim} (λx. Tensor.lookup (A x) is))"
unfolding A_def using polyfun_tensor_deep_model[unfolded input_sizes_deep_model] 0 by blast
show ?thesis
unfolding A'_def A_def apply (rule polyfun_matricize)
using dims_tensor_deep_model[OF 1] 2[unfolded A_def]
using dims_A'_pow[unfolded A'_def A_def] ‹i<(last rs) ^ N_half› ‹j<(last rs) ^ N_half›
by auto
qed

lemma polyfun_submatrix_deep_model:
assumes "i < r ^ N_half"
assumes "j < r ^ N_half"
shows "polyfun {..<weight_space_dim} (λf. witness_submatrix f \$\$ (i,j))"
unfolding witness_submatrix_def
proof (rule polyfun_submatrix)
have 1:"⋀f. remove_weights (insert_weights shared_weights (deep_model_l rs) f) = deep_model_l rs"
using remove_insert_weights by metis
show "⋀f. A' f ∈ carrier_mat ((last rs) ^ N_half) ((last rs) ^ N_half)"
using "1" dims_A'_pow using weight_space_dim_def by auto
show "⋀f i j. i < last rs ^ N_half ⟹ j < last rs ^ N_half ⟹
polyfun {..<weight_space_dim} (λf. A' f \$\$ (i, j))"
using polyfun_matrix_deep_model weight_space_dim_def by force
show "i < card {i. i < last rs ^ N_half ∧ i ∈ rows_with_1}"
using assms(1) card_rows_with_1 dims_Aw'_pow set_le_in by metis
show "j < card {i. i < last rs ^ N_half ∧ i ∈ rows_with_1}"
using assms(2) card_rows_with_1 dims_Aw'_pow set_le_in by metis
show "infinite rows_with_1" "infinite rows_with_1" by (simp_all add: infinite_rows_with_1)
qed

lemma polyfun_det_deep_model:
shows "polyfun {..<weight_space_dim} (λf. det (witness_submatrix f))"
proof (rule polyfun_det)
fix f
have "remove_weights (insert_weights shared_weights (deep_model_l rs) f) = deep_model_l rs"
using remove_insert_weights by metis

show "witness_submatrix f ∈ carrier_mat (r ^ N_half) (r ^ N_half)"
unfolding witness_submatrix_def apply (rule carrier_matI) unfolding dim_submatrix[unfolded set_le_in]
unfolding dims_A'_pow[unfolded weight_space_dim_def] using card_rows_with_1 dims_Aw'_pow by simp_all
show "⋀i j. i < r ^ N_half ⟹ j < r ^ N_half ⟹ polyfun {..<weight_space_dim} (λf. witness_submatrix f \$\$ (i, j))"
using polyfun_submatrix_deep_model by blast
qed

end

end
