Theory DL_Deep_Model
section ‹Deep Network Model›
theory DL_Deep_Model
imports DL_Network Tensor_Matricization Jordan_Normal_Form.DL_Submatrix DL_Concrete_Matrices
DL_Missing_Finite_Set Jordan_Normal_Form.DL_Missing_Sublist Jordan_Normal_Form.Determinant
begin
hide_const(open) Polynomial.order
hide_const (open) Matrix.unit_vec
fun deep_model and deep_model' where
"deep_model' Y [] = Input Y" |
"deep_model' Y (r # rs) = Pool (deep_model Y r rs) (deep_model Y r rs)" |
"deep_model Y r rs = Conv (Y,r) (deep_model' r rs)"
abbreviation "deep_model'_l rs == deep_model' (rs!0) (tl rs)"
abbreviation "deep_model_l rs == deep_model (rs!0) (rs!1) (tl (tl rs))"
lemma valid_deep_model: "valid_net (deep_model Y r rs)"
apply (induction rs arbitrary: Y r)
apply (simp add: valid_net.intros(1) valid_net.intros(2))
using valid_net.intros(2) valid_net.intros(3) by auto
lemma valid_deep_model': "valid_net (deep_model' r rs)"
apply (induction rs arbitrary: r)
apply (simp add: valid_net.intros(1))
by (metis deep_model'.elims deep_model'.simps(2) deep_model.elims output_size.simps valid_net.simps)
lemma input_sizes_deep_model':
assumes "length rs ≥ 1"
shows "input_sizes (deep_model'_l rs) = replicate (2^(length rs - 1)) (last rs)"
using assms proof (induction "butlast rs" arbitrary:rs)
case Nil
then have "rs = [rs!0]"
by (metis One_nat_def diff_diff_cancel diff_zero length_0_conv length_Suc_conv length_butlast nth_Cons_0)
then have "input_sizes (deep_model'_l rs) = [last rs]"
by (metis deep_model'.simps(1) input_sizes.simps(1) last.simps list.sel(3))
then show "input_sizes (deep_model'_l rs) = replicate (2 ^ (length rs - 1)) (last rs)"
by (metis One_nat_def ‹[] = butlast rs› empty_replicate length_butlast list.size(3) power_0 replicate.simps(2))
next
case (Cons r rs' rs)
then have IH: "input_sizes (deep_model'_l (tl rs)) = replicate (2 ^ (length (tl rs) - 1)) (last rs)"
by (metis (no_types, lifting) One_nat_def butlast_tl diff_is_0_eq' last_tl length_Cons
length_butlast length_tl list.sel(3) list.size(3) nat_le_linear not_one_le_zero)
have "rs = r # (tl rs)" by (metis Cons.hyps(2) Cons.prems One_nat_def append_Cons append_butlast_last_id length_greater_0_conv less_le_trans list.sel(3) zero_less_Suc)
then have "deep_model'_l rs = Pool (deep_model_l rs) (deep_model_l rs)"
by (metis Cons.hyps(2) One_nat_def butlast.simps(2) deep_model'.elims list.sel(3) list.simps(3) nth_Cons_0 nth_Cons_Suc)
then have "input_sizes (deep_model'_l rs) = input_sizes (deep_model_l rs) @ input_sizes (deep_model_l rs)"
using input_sizes.simps(3) by metis
also have "... = input_sizes (deep_model'_l (tl rs)) @ input_sizes (deep_model'_l (tl rs))"
by (metis (no_types, lifting) Cons.hyps(2) One_nat_def deep_model.elims input_sizes.simps(2)
length_Cons length_butlast length_greater_0_conv length_tl list.sel(2) list.sel(3) list.size(3)
nth_tl one_neq_zero)
also have "... = replicate (2 ^ (length (tl rs) - 1)) (last rs) @ replicate (2 ^ (length (tl rs) - 1)) (last rs)"
using IH by auto
also have "... = replicate (2 ^ (length rs - 1)) (last rs)"
using replicate_add[of "2 ^ (length (tl rs) - 1)" "2 ^ (length (tl rs) - 1)" "last rs"]
by (metis Cons.hyps(2) One_nat_def butlast_tl length_butlast list.sel(3) list.size(4) mult_2_right
power_add power_one_right)
finally show ?case by auto
qed
lemma input_sizes_deep_model:
assumes "length rs ≥ 2"
shows "input_sizes (deep_model_l rs) = replicate (2^(length rs - 2)) (last rs)"
proof -
have "input_sizes (deep_model_l rs) = input_sizes (deep_model'_l (tl rs))"
by (metis One_nat_def Suc_1 assms hd_Cons_tl deep_model.elims input_sizes.simps(2) length_Cons
length_greater_0_conv lessI linorder_not_le list.size(3) not_numeral_le_zero nth_tl)
also have "... = replicate (2^(length rs - 2)) (last rs)" using input_sizes_deep_model'
by (metis (no_types, lifting) One_nat_def Suc_1 Suc_eq_plus1 assms diff_diff_left hd_Cons_tl
last_tl length_Cons length_tl linorder_not_le list.size(3) not_less_eq not_numeral_le_zero
numeral_le_one_iff semiring_norm(69))
finally show ?thesis by auto
qed
lemma evaluate_net_Conv_id:
assumes "valid_net' m"
and "input_sizes m = map dim_vec input"
and "j<nr"
shows "evaluate_net (Conv (id_matrix nr (output_size' m)) m) input $ j
= (if j<output_size' m then evaluate_net m input $ j else 0)"
unfolding evaluate_net.simps output_size_correct[OF assms(1) assms(2)[symmetric]]
using mult_id_matrix[OF ‹j<nr›, of "evaluate_net m input", unfolded dim_vec_of_list]
by metis
lemma tensors_from_net_Conv_id:
assumes "valid_net' m"
and "i<nr"
shows "tensors_from_net (Conv (id_matrix nr (output_size' m)) m) $ i
= (if i<output_size' m then tensors_from_net m $ i else tensor0 (input_sizes m))"
(is "?a $ i = ?b")
proof (rule tensor_lookup_eqI)
have "Tensor.dims (?a $ i) = input_sizes m" by (metis assms(1) assms(2) dims_tensors_from_net
id_matrix_dim(1) id_matrix_dim(2) input_sizes.simps(2) output_size.simps(2)
output_size_correct_tensors remove_weights.simps(2) valid_net.intros(2) vec_setI)
moreover have "Tensor.dims (?b) = input_sizes m" using dims_tensors_from_net
output_size_correct_tensors[OF assms(1)] dims_tensor0 by (simp add: vec_setI)
ultimately show "Tensor.dims (?a $ i) = Tensor.dims (?b)" by auto
define Convm where "Convm = Conv (id_matrix nr (output_size' m)) m"
fix "is"
assume "is ⊲ Tensor.dims (?a$i)"
then have "is ⊲ input_sizes m" using ‹Tensor.dims (?a$i) = input_sizes m› by auto
have "valid_net' Convm" by (simp add: assms id_matrix_dim valid_net.intros(2) Convm_def)
have "base_input m is = base_input Convm is" by (simp add: Convm_def base_input_def)
have "i < output_size' Convm" unfolding Convm_def remove_weights.simps output_size.simps
id_matrix_dim using assms by metis
have "is ⊲ input_sizes (Conv (id_matrix nr (output_size' m)) m)"
by (metis ‹is ⊲ input_sizes m› input_sizes.simps(2))
then have f1: "lookup (tensors_from_net (Conv (id_matrix nr (output_size' m)) m) $ i) is = evaluate_net (Conv (id_matrix nr (output_size' m)) m) (base_input (Conv (id_matrix nr (output_size' m)) m) is) $ i"
using Convm_def ‹i < output_size' Convm› ‹valid_net' Convm› lookup_tensors_from_net by blast
have "lookup (tensor0 (input_sizes m)) is = (0::real)"
by (meson ‹is ⊲ input_sizes m› lookup_tensor0)
then show "Tensor.lookup (?a $ i) is = Tensor.lookup ?b is"
using Convm_def ‹base_input m is = base_input Convm is› ‹is ⊲ input_sizes m› assms(1) assms(2)
base_input_length evaluate_net_Conv_id f1 lookup_tensors_from_net by auto
qed
lemma evaluate_net_Conv_copy_first:
assumes "valid_net' m"
and "input_sizes m = map dim_vec input"
and "j<nr"
and "output_size' m > 0"
shows "evaluate_net (Conv (copy_first_matrix nr (output_size' m)) m) input $ j
= evaluate_net m input $ 0"
unfolding evaluate_net.simps output_size_correct[OF assms(1) assms(2)[symmetric]]
using mult_copy_first_matrix[OF ‹j<nr›, of "evaluate_net m input", unfolded dim_vec_of_list]
assms(3) copy_first_matrix_dim(1) by (metis ‹output_size' m = dim_vec (evaluate_net m input)› assms(4))
lemma tensors_from_net_Conv_copy_first:
assumes "valid_net' m"
and "i<nr"
and "output_size' m > 0"
shows "tensors_from_net (Conv (copy_first_matrix nr (output_size' m)) m) $ i = tensors_from_net m $ 0"
(is "?a $ i = ?b")
proof (rule tensor_lookup_eqI)
have "Tensor.dims (?a$i) = input_sizes m"
by (metis assms(1) assms(2) copy_first_matrix_dim(1) copy_first_matrix_dim(2) dims_tensors_from_net
input_sizes.simps(2) output_size.simps(2) output_size_correct_tensors remove_weights.simps(2)
valid_net.intros(2) vec_setI)
moreover have "Tensor.dims (?b) = input_sizes m" using dims_tensors_from_net
output_size_correct_tensors[OF assms(1)] using assms(3) by (simp add: vec_setI)
ultimately show "Tensor.dims (?a$i) = Tensor.dims (?b)" by auto
define Convm where "Convm = Conv (copy_first_matrix nr (output_size' m)) m"
fix "is"
assume "is ⊲ Tensor.dims (?a$i)"
then have "is ⊲ input_sizes m" using ‹Tensor.dims (?a$i) = input_sizes m› by auto
have "valid_net' Convm" by (simp add: assms copy_first_matrix_dim valid_net.intros(2) Convm_def)
have "base_input m is = base_input Convm is" by (simp add: Convm_def base_input_def)
have "i < output_size' Convm" unfolding Convm_def remove_weights.simps output_size.simps
copy_first_matrix_dim using assms by metis
show "Tensor.lookup (?a $ i) is = Tensor.lookup ?b is"
by (metis Convm_def ‹base_input m is = base_input Convm is› ‹i < output_size' Convm›
‹is ⊲ input_sizes m› ‹valid_net' Convm› assms(1) assms(2) assms(3) base_input_length
evaluate_net_Conv_copy_first input_sizes.simps(2) lookup_tensors_from_net)
qed
lemma evaluate_net_Conv_all1:
assumes "valid_net' m"
and "input_sizes m = map dim_vec input"
and "i<nr"
shows "evaluate_net (Conv (all1_matrix nr (output_size' m)) m) input $ i
= Groups_List.sum_list (list_of_vec (evaluate_net m input))"
unfolding evaluate_net.simps output_size_correct[OF assms(1) assms(2)[symmetric]]
using mult_all1_matrix[OF ‹i<nr›, of "evaluate_net m input", unfolded dim_vec_of_list]
assms(3) all1_matrix_dim(1) by metis
lemma tensors_from_net_Conv_all1:
assumes "valid_net' m"
and "i<nr"
shows "tensors_from_net (Conv (all1_matrix nr (output_size' m)) m) $ i
= listsum (input_sizes m) (list_of_vec (tensors_from_net m))"
(is "?a $ i = ?b")
proof (rule tensor_lookup_eqI)
have "i < dim_vec ?a" by (metis assms all1_matrix_dim output_size.simps(2)
output_size_correct_tensors remove_weights.simps(2) valid_net.intros(2))
then show "Tensor.dims (?a $ i) = Tensor.dims (?b)"
using dims_tensors_from_net input_sizes.simps(2) listsum_dims
by (metis index_vec_of_list in_set_conv_nth length_list_of_vec vec_list vec_setI)
define Convm where "Convm = Conv (all1_matrix nr (output_size' m)) m"
fix "is" assume "is ⊲ Tensor.dims (?a $ i)"
then have "is ⊲ input_sizes m"
using ‹i < dim_vec ?a› dims_tensors_from_net input_sizes.simps(2) by (metis vec_setI)
then have "is ⊲ input_sizes Convm" by (simp add: Convm_def)
have "valid_net' Convm" by (simp add: Convm_def assms all1_matrix_dim valid_net.intros(2))
have "i< output_size' Convm" using Convm_def ‹i < dim_vec ?a› ‹valid_net' Convm›
output_size_correct_tensors by presburger
have "base_input Convm is = base_input m is" unfolding base_input_def Convm_def input_sizes.simps by metis
have "Tensor.lookup (?a $ i) is = evaluate_net Convm (base_input Convm is) $ i"
using lookup_tensors_from_net[OF ‹valid_net' Convm› ‹is ⊲ input_sizes Convm› ‹i< output_size' Convm›]
by (metis Convm_def )
also have "... = monoid_add_class.sum_list (list_of_vec (evaluate_net m (base_input Convm is)))"
using evaluate_net_Conv_all1 Convm_def ‹is ⊲ input_sizes Convm› assms base_input_length ‹i < nr›
by simp
also have "... = monoid_add_class.sum_list (list_of_vec (map_vec (λA. lookup A is)(tensors_from_net m)))"
unfolding ‹base_input Convm is = base_input m is›
using lookup_tensors_from_net[OF ‹valid_net' m› ‹is ⊲ input_sizes m›]
base_input_length[OF ‹is ⊲ input_sizes m›] output_size_correct[OF assms(1)] output_size_correct_tensors[OF assms(1)]
eq_vecI[of "evaluate_net m (base_input m is)" "map_vec (λA. lookup A is) (tensors_from_net m)"] index_map_vec(1) index_map_vec(2)
by force
also have "... = monoid_add_class.sum_list (map (λA. lookup A is) (list_of_vec (tensors_from_net m)))"
using eq_vecI[of "vec_of_list (list_of_vec (map_vec (λA. lookup A is)(tensors_from_net m)))"
"vec_of_list (map (λA. lookup A is) (list_of_vec (tensors_from_net m)))"] dim_vec_of_list
nth_list_of_vec length_map list_vec nth_map index_map_vec(1) index_map_vec(2) vec_list
by (metis (no_types, lifting))
also have "... = Tensor.lookup ?b is" using dims_tensors_from_net set_list_of_vec
using lookup_listsum[OF ‹is ⊲ input_sizes m›, of "list_of_vec (tensors_from_net m)"]
by metis
finally show "Tensor.lookup (?a $ i) is = Tensor.lookup ?b is" by blast
qed
fun witness and witness' where
"witness' Y [] = Input Y" |
"witness' Y (r # rs) = Pool (witness Y r rs) (witness Y r rs)" |
"witness Y r rs = Conv ((if length rs = 0 then id_matrix else (if length rs = 1 then all1_matrix else copy_first_matrix)) Y r) (witness' r rs)"
abbreviation "witness_l rs == witness (rs!0) (rs!1) (tl (tl rs))"
abbreviation "witness'_l rs == witness' (rs!0) (tl rs)"
lemma witness_is_deep_model: "remove_weights (witness Y r rs) = deep_model Y r rs"
proof (induction rs arbitrary: Y r)
case Nil
then show ?case unfolding witness.simps witness'.simps deep_model.simps deep_model'.simps
by (simp add: id_matrix_dim)
next
case (Cons r' rs Y r)
have "dim_row ((if length (r' # rs) = 0 then id_matrix else (if length (r' # rs) = 1 then all1_matrix else copy_first_matrix)) Y r) = Y"
"dim_col ((if length (r' # rs) = 0 then id_matrix else (if length (r' # rs) = 1 then all1_matrix else copy_first_matrix)) Y r) = r"
by (simp_all add: all1_matrix_dim copy_first_matrix_dim)
then show ?case unfolding witness.simps unfolding witness'.simps unfolding remove_weights.simps
using Cons by simp
qed
lemma witness'_is_deep_model: "remove_weights (witness' Y rs) = deep_model' Y rs"
proof (induction rs arbitrary: Y)
case Nil
then show ?case unfolding witness.simps witness'.simps deep_model.simps deep_model'.simps
by (simp add: id_matrix_dim)
next
case (Cons r rs Y)
have "dim_row ((if length rs = 0 then id_matrix else (if length rs = 1 then all1_matrix else copy_first_matrix)) Y r) = Y"
"dim_col ((if length rs = 0 then id_matrix else (if length rs = 1 then all1_matrix else copy_first_matrix)) Y r) = r"
by (simp_all add: all1_matrix_dim copy_first_matrix_dim id_matrix_dim)
then show ?case unfolding witness'.simps unfolding witness.simps unfolding remove_weights.simps
using Cons by simp
qed
lemma witness_valid: "valid_net' (witness Y r rs)"
using valid_deep_model witness_is_deep_model by auto
lemma witness'_valid: "valid_net' (witness' Y rs)"
using valid_deep_model' witness'_is_deep_model by auto
lemma shared_weight_net_witness: "shared_weight_net (witness Y r rs)"
proof (induction rs arbitrary:Y r)
case Nil
then show ?case unfolding witness.simps witness'.simps by (simp add: shared_weight_net_Conv shared_weight_net_Input)
next
case (Cons a rs)
then show ?case unfolding witness.simps witness'.simps
by (simp add: shared_weight_net_Conv shared_weight_net_Input shared_weight_net_Pool)
qed
lemma witness_l0': "witness' Y [M] =
(Pool
(Conv (id_matrix Y M) (Input M))
(Conv (id_matrix Y M) (Input M))
)"
unfolding witness'.simps witness.simps by simp
lemma witness_l1: "witness Y r0 [M] =
Conv (all1_matrix Y r0) (witness' r0 [M])"
unfolding witness'.simps by simp
lemma tensors_ht_l0:
assumes "j<r0"
shows "tensors_from_net (Conv (id_matrix r0 M) (Input M)) $ j
= (if j<M then unit_vec M j else tensor0 [M])"
by (metis assms input_sizes.simps(1) output_size.simps(1) remove_weights.simps(1) tensors_from_net.simps(1)
tensors_from_net_Conv_id valid_net.intros(1) index_vec)
lemma tensor_prod_unit_vec:
"unit_vec M j ⊗ unit_vec M j = tensor_from_lookup [M,M] (λis. if is=[j,j] then 1 else 0)" (is "?A=?B")
proof (rule tensor_lookup_eqI)
show "Tensor.dims ?A = Tensor.dims ?B"
by (metis append_Cons self_append_conv2 dims_unit_vec dims_tensor_prod dims_tensor_from_lookup)
fix "is" assume is_valid:"is ⊲ Tensor.dims (unit_vec M j ⊗ unit_vec M j)"
then have "is ⊲ [M,M]" by (metis append_Cons self_append_conv2 dims_unit_vec dims_tensor_prod)
then obtain i1 i2 where is_split: "is = [i1, i2]" "i1 < M" "i2 < M" using list.distinct(1) by blast
then have "[i1] ⊲ Tensor.dims (unit_vec M j)" "[i2] ⊲ Tensor.dims (unit_vec M j)"
by (simp_all add: valid_index.Cons valid_index.Nil dims_unit_vec)
have "is = [i1] @ [i2]" by (simp add: is_split(1))
show "Tensor.lookup ?A is = Tensor.lookup ?B is"
unfolding ‹is = [i1] @ [i2]›
lookup_tensor_prod[OF ‹[i1] ⊲ Tensor.dims (unit_vec M j)› ‹[i2] ⊲ Tensor.dims (unit_vec M j)›]
lookup_tensor_from_lookup[OF ‹is ⊲ [M, M]›, unfolded ‹is = [i1] @ [i2]›]
lookup_unit_vec[OF ‹i1 < M›] lookup_unit_vec[OF ‹i2 < M›] by fastforce
qed
lemma tensors_ht_l0':
assumes "j<r0"
shows "tensors_from_net (witness' r0 [M]) $ j
= (if j<M then unit_vec M j ⊗ unit_vec M j else tensor0 [M,M])" (is "_ = ?b")
proof -
have "valid_net' (Conv (id_matrix r0 M) (Input M))"
by (metis convnet.inject(3) list.discI witness'.elims witness_l0' witness_valid)
have j_le:"j < dim_vec (tensors_from_net (Conv (id_matrix r0 M) (Input M)))"
using output_size_correct_tensors[OF ‹valid_net' (Conv (id_matrix r0 M) (Input M))›,
unfolded remove_weights.simps output_size.simps id_matrix_dim]
assms by simp
show ?thesis
unfolding tensors_from_net.simps(3) witness_l0' index_component_mult[OF j_le j_le] tensors_ht_l0[OF assms]
by auto
qed
lemma lookup_tensors_ht_l0':
assumes "j<r0"
and "is ⊲ [M,M]"
shows "(Tensor.lookup (tensors_from_net (witness' r0 [M]) $ j)) is = (if is=[j,j] then 1 else 0)"
proof (cases "j<M")
assume "j<M"
show ?thesis unfolding tensors_ht_l0'[OF assms(1)] tensor_prod_unit_vec
apply (cases "is = [j,j]") using ‹j<M› assms(2)
by (simp_all add:lookup_tensor_from_lookup)
next
assume "¬j<M"
then have "is ≠ [j, j]" using assms(2) using list.distinct(1) nth_Cons_0 valid_index.simps by blast
show ?thesis unfolding tensors_ht_l0'[OF assms(1)] tensor_prod_unit_vec
using ‹¬j<M› by (simp add: lookup_tensor0[OF assms(2)] ‹is ≠ [j, j]›)
qed
lemma lookup_tensors_ht_l1:
assumes "j < r1"
and "is ⊲ [M,M]"
shows "Tensor.lookup (tensors_from_net (witness r1 r0 [M]) $ j) is
= (if is!0 = is!1 ∧ is!0<r0 then 1 else 0)"
proof -
have witness_l0'_valid: "valid_net' (witness' r0 [M])" unfolding witness_l0'
by (simp add: id_matrix_dim valid_net.intros)
have "input_sizes (witness' r0 [M]) = [M,M]" unfolding witness_l0' by simp
have "output_size' (witness' r0 [M]) = r0" unfolding witness_l0' using witness_l0'_valid
by (simp add: id_matrix_dim)
have "dim_vec (tensors_from_net (witness' r0 [M])) = r0"
using ‹output_size' (witness' r0 [M]) = r0› witness_l0'_valid output_size_correct_tensors by fastforce
have all0_but1:"⋀i. i≠is!0 ⟹ i<r0 ⟹ Tensor.lookup (tensors_from_net (witness' r0 [M]) $ i) is = 0"
using lookup_tensors_ht_l0' ‹is ⊲ [M, M]› by auto
have "tensors_from_net (witness r1 r0 [M]) $ j =
Tensor_Plus.listsum [M,M] (list_of_vec (tensors_from_net (witness' r0 [M])))"
unfolding witness_l1 using tensors_from_net_Conv_all1[OF witness_l0'_valid assms(1)]
witness_l0' ‹output_size' (witness' r0 [M]) = r0› by simp
then have "Tensor.lookup (tensors_from_net (witness r1 r0 [M]) $ j) is
= monoid_add_class.sum_list (map (λA. Tensor.lookup A is) (list_of_vec (tensors_from_net (witness' r0 [M]))))"
using lookup_listsum[OF ‹is ⊲ [M, M]›] ‹input_sizes (witness' r0 [M]) = [M, M]›
dims_tensors_from_net by (metis set_list_of_vec)
also have "... = monoid_add_class.sum_list (map (λi. lookup (tensors_from_net (witness' r0 [M]) $ i) is) [0..<r0])"
using map_map[of "(λA. Tensor.lookup A is)" "λi. (tensors_from_net (witness' r0 [M]) $ i)" "[0..<r0]"]
using list_of_vec_map ‹dim_vec (tensors_from_net (witness' r0 [M])) = r0› by (metis (mono_tags, lifting) comp_apply map_eq_conv)
also have "... = (∑i<r0. Tensor.lookup ((tensors_from_net (witness' r0 [M])) $ i) is)"
using sum_set_upt_conv_sum_list_nat atLeast0LessThan by (metis atLeast_upt)
also have "... = (if is!0 = is!1 ∧ is!0<r0 then 1 else 0)"
proof (cases "is!0<r0")
case True
have "finite {0..<r0}" by auto
have "is!0 ∈ {0..<r0}" using True by auto
have "(∑i<r0. Tensor.lookup ((tensors_from_net (witness' r0 [M])) $ i) is)
= Tensor.lookup (tensors_from_net (witness' r0 [M]) $ (is!0)) is"
using ‹dim_vec (tensors_from_net (witness' r0 [M])) = r0›
using sum.remove[OF ‹finite {0..<r0}› ‹is!0 ∈ {0..<r0}›,
of "λi. (Tensor.lookup (tensors_from_net (witness' r0 [M])$i) is)"]
using all0_but1 atLeast0LessThan by force
then show ?thesis using lookup_tensors_ht_l0' ‹is ! 0 < r0› ‹is ⊲ [M, M]› by fastforce
next
case False
then show ?thesis using all0_but1 atLeast0LessThan sum.neutral by force
qed
finally show ?thesis by auto
qed
lemma length_output_deep_model:
assumes "remove_weights m = deep_model_l rs"
shows "dim_vec (tensors_from_net m) = rs ! 0"
using output_size_correct_tensors valid_deep_model
deep_model.elims output_size.simps(2) by (metis assms)
lemma length_output_deep_model':
assumes "remove_weights m = deep_model'_l rs"
shows "dim_vec (tensors_from_net m) = rs ! 0"
using output_size_correct_tensors valid_deep_model'
deep_model'.elims output_size.simps by (metis assms deep_model.elims)
lemma length_output_witness:
"dim_vec (tensors_from_net (witness_l rs)) = rs ! 0"
using length_output_deep_model witness_is_deep_model by blast
lemma length_output_witness':
"dim_vec (tensors_from_net (witness'_l rs)) = rs ! 0"
using length_output_deep_model' witness'_is_deep_model by blast
lemma dims_output_deep_model:
assumes "length rs ≥ 2"
and "⋀r. r∈set rs ⟹ r > 0"
and "j < rs!0"
and "remove_weights m = deep_model_l rs"
shows "Tensor.dims (tensors_from_net m $ j) = replicate (2^(length rs - 2)) (last rs)"
using dims_tensors_from_net input_sizes_deep_model[OF assms(1)] output_size_correct_tensors valid_deep_model
assms(3) assms(4) input_sizes_remove_weights length_output_witness witness_is_deep_model by (metis vec_setI)
lemma dims_output_witness:
assumes "length rs ≥ 2"
and "⋀r. r∈set rs ⟹ r > 0"
and "j < rs!0"
shows "Tensor.dims (tensors_from_net (witness_l rs) $ j) = replicate (2^(length rs - 2)) (last rs)"
using dims_output_deep_model witness_is_deep_model assms by blast
lemma dims_output_deep_model':
assumes "length rs ≥ 1"
and "⋀r. r∈set rs ⟹ r > 0"
and "j < rs!0"
and "remove_weights m = deep_model'_l rs"
shows "Tensor.dims (tensors_from_net m $ j) = replicate (2^(length rs - 1)) (last rs)"
proof -
have "dim_vec (tensors_from_net m) > j"
using length_output_deep_model' ‹remove_weights m = deep_model'_l rs› ‹j < rs!0› by auto
then have "Tensor.dims (tensors_from_net m $ j) = input_sizes m"
using dims_tensors_from_net[of _ m] output_size_correct_tensors
vec_setI by metis
then show ?thesis
using assms(1) input_sizes_deep_model'
input_sizes_remove_weights[of m, unfolded ‹remove_weights m = deep_model'_l rs›] by auto
qed
lemma dims_output_witness':
assumes "length rs ≥ 1"
and "⋀r. r∈set rs ⟹ r > 0"
and "j < rs!0"
shows "Tensor.dims (tensors_from_net (witness'_l rs) $ j) = replicate (2^(length rs - 1)) (last rs)"
using dims_output_deep_model' assms witness'_is_deep_model by blast
abbreviation "ten2mat == matricize {n. even n}"
abbreviation "mat2ten == dematricize {n. even n}"
locale deep_model_correct_params =
fixes shared_weights::bool
fixes rs::"nat list"
assumes deep:"length rs ≥ 3"
and no_zeros:"⋀r. r∈set rs ⟹ 0 < r"
begin
definition "r = min (last rs) (last (butlast rs))"
definition "N_half = 2^(length rs - 3)"
definition "weight_space_dim = count_weights shared_weights (deep_model_l rs)"
end
locale deep_model_correct_params_y = deep_model_correct_params +
fixes y::nat
assumes y_valid:"y < rs ! 0"
begin
definition A :: "(nat ⇒ real) ⇒ real tensor"
where "A ws = tensors_from_net (insert_weights shared_weights (deep_model_l rs) ws) $ y"
definition A' :: "(nat ⇒ real) ⇒ real mat"
where "A' ws = ten2mat (A ws)"
lemma dims_tensor_deep_model:
assumes "remove_weights m = deep_model_l rs"
shows "dims (tensors_from_net m $ y) = replicate (2 * N_half) (last rs)"
proof -
have "dims (tensors_from_net m $ y) = replicate (2 ^ (length rs - 2)) (last rs)"
using dims_output_deep_model[OF _ no_zeros y_valid assms] using less_imp_le_nat Suc_le_lessD deep numeral_3_eq_3
by auto
then show ?thesis using N_half_def by (metis One_nat_def Suc_1 Suc_eq_plus1 Suc_le_lessD deep
diff_diff_left less_numeral_extra(3) numeral_3_eq_3 power_eq_if zero_less_diff)
qed
lemma order_tensor_deep_model:
assumes "remove_weights m = deep_model_l rs"
shows "order (tensors_from_net m $ y) = 2 * N_half"
using dims_tensor_deep_model by (simp add: assms)
lemma dims_A:
shows "Tensor.dims (A ws) = replicate (2 * N_half) (last rs)"
unfolding A_def
using dims_tensor_deep_model remove_insert_weights by blast
lemma order_A:
shows "order (A ws) = 2 * N_half" using dims_A length_replicate by auto
lemma dims_A':
shows "dim_row (A' ws) = prod_list (nths (Tensor.dims (A ws)) {n. even n})"
and "dim_col (A' ws) = prod_list (nths (Tensor.dims (A ws)) {n. odd n})"
unfolding A'_def matricize_def by (simp_all add: A_def Collect_neg_eq)
lemma dims_A'_pow:
shows "dim_row (A' ws) = (last rs) ^ N_half" "dim_col (A' ws) = (last rs) ^ N_half"
unfolding dims_A' dims_A nths_replicate set_le_in card_even card_odd prod_list_replicate
by simp_all
definition "Aw = tensors_from_net (witness_l rs) $ y"
definition "Aw' = ten2mat Aw"
definition "witness_weights = extract_weights shared_weights (witness_l rs)"
lemma witness_weights:"witness_l rs = insert_weights shared_weights (deep_model_l rs) witness_weights"
by (metis (full_types) insert_extract_weights_cong_shared insert_extract_weights_cong_unshared shared_weight_net_witness witness_is_deep_model witness_weights_def)
lemma Aw_def': "Aw = A witness_weights" unfolding Aw_def A_def using witness_weights by auto
lemma Aw'_def': "Aw' = A' witness_weights" unfolding Aw'_def A'_def Aw_def' by auto
lemma dims_Aw: "Tensor.dims Aw = replicate (2 * N_half) (last rs)"
unfolding Aw_def' using dims_A by auto
lemma order_Aw: "order Aw = 2 * N_half"
unfolding Aw_def' using order_A by auto
lemma dims_Aw':
"dim_row Aw' = prod_list (nths (Tensor.dims Aw) {n. even n})"
"dim_col Aw' = prod_list (nths (Tensor.dims Aw) {n. odd n})"
unfolding Aw'_def' Aw_def' using dims_A' by auto
lemma dims_Aw'_pow: "dim_row Aw' = (last rs) ^ N_half" "dim_col Aw' = (last rs) ^ N_half"
unfolding Aw'_def' Aw_def' using dims_A'_pow by auto
lemma witness_tensor:
assumes "is ⊲ Tensor.dims Aw"
shows "Tensor.lookup Aw is
= (if nths is {n. even n} = nths is {n. odd n} ∧ (∀i∈set is. i < last (butlast rs)) then 1 else 0)"
using assms deep no_zeros y_valid unfolding Aw_def proof (induction "butlast (butlast (butlast rs))" arbitrary:rs "is" y)
case Nil
have "length rs = 3"
by (rule antisym, metis Nil.hyps One_nat_def Suc_1 Suc_eq_plus1 add_2_eq_Suc' diff_diff_left
length_butlast less_numeral_extra(3) list.size(3) not_le numeral_3_eq_3 zero_less_diff, metis ‹3 ≤ length rs›)
then have "rs = [rs!0, rs!1, rs!2]" by (metis (no_types, lifting) Cons_nth_drop_Suc One_nat_def Suc_eq_plus1
append_Nil id_take_nth_drop length_0_conv length_tl lessI list.sel(3) list.size(4) not_le numeral_3_eq_3
numeral_le_one_iff one_add_one semiring_norm(70) take_0 zero_less_Suc)
have "input_sizes (witness_l [rs ! 0, rs ! 1, rs ! 2]) = [rs!2, rs!2]"
using witness.simps witness'.simps input_sizes.simps by auto
then have "Tensor.dims (tensors_from_net (witness_l rs) $ y) = [rs!2, rs!2]"
using dims_tensors_from_net[of "tensors_from_net (witness_l rs) $ y" "witness_l rs"]
Nil.prems(4) length_output_witness ‹rs = [rs ! 0, rs ! 1, rs ! 2]› vec_setI by metis
then have "is ⊲ [rs!2, rs!2]" using Nil.prems by metis
then have "Tensor.lookup ((tensors_from_net (witness_l rs))$y) is
= (if is ! 0 = is ! 1 ∧ is ! 0 < rs ! 1 then 1 else 0)"
using Nil.prems(4) ‹rs = [rs ! 0, rs ! 1, rs ! 2]› by (metis list.sel(3) lookup_tensors_ht_l1)
have "is ! 0 = is ! 1 ∧ is ! 0 < rs ! 1
⟷ nths is {n. even n} = nths is {n. odd n} ∧ (∀i∈set is. i < last (butlast rs))"
proof -
have "length is = 2" by (metis One_nat_def Suc_eq_plus1 ‹is ⊲ [rs ! 2, rs ! 2]› list.size(3) list.size(4) numeral_2_eq_2 valid_index_length)
have "nths is {n. even n} = [is!0]"
apply (rule nths_only_one)
using subset_antisym less_2_cases ‹length is = 2› by fastforce
have "nths is {n. odd n} = [is!1]"
apply (rule nths_only_one)
using subset_antisym less_2_cases ‹length is = 2› by fastforce
have "last (butlast rs) = rs!1" by (metis One_nat_def Suc_eq_plus1 ‹rs = [rs ! 0, rs ! 1, rs ! 2]›
append_butlast_last_id last_conv_nth length_butlast length_tl lessI list.sel(3) list.simps(3)
list.size(3) list.size(4) nat.simps(3) nth_append)
show ?thesis unfolding ‹last (butlast rs) = rs!1›
apply (rule iffI; rule conjI)
apply (simp add: ‹nths is (Collect even) = [is ! 0]› ‹nths is {n. odd n} = [is ! 1]›)
apply (metis ‹length is = 2› One_nat_def in_set_conv_nth less_2_cases)
apply (simp add: ‹nths is (Collect even) = [is ! 0]› ‹nths is {n. odd n} = [is ! 1]›)
apply (simp add: ‹length is = 2›)
done
qed
then show ?case unfolding ‹Tensor.lookup (tensors_from_net (witness_l rs) $ y) is = (if is ! 0 = is ! 1 ∧ is ! 0 < rs ! 1 then 1 else 0)›
using witness_is_deep_model witness_valid ‹rs = [rs ! 0, rs ! 1, rs ! 2]› by auto
next
case (Cons r rs' rs "is" j)
text ‹We prove the Induction Hypothesis for "tl rs" and j=0:›
have "rs = r # tl rs" by (metis Cons.hyps(2) append_butlast_last_id butlast.simps(1) hd_append2 list.collapse list.discI list.sel(1))
have 1:"rs' = butlast (butlast (butlast (tl rs)))" by (metis Cons.hyps(2) butlast_tl list.sel(3))
have 2:"3 ≤ length (tl rs)" by (metis (no_types, lifting) Cons.hyps(2) Cons.prems(2)
Nitpick.size_list_simp(2) One_nat_def Suc_eq_plus1 ‹rs = r # tl rs› ‹rs' = butlast (butlast (butlast (tl rs)))›
diff_diff_left diff_self_eq_0 gr0_conv_Suc le_Suc_eq length_butlast length_tl less_numeral_extra(3) list.simps(3) numeral_3_eq_3)
have 3:"⋀r. r ∈ set (tl rs) ⟹ 0 < r" by (metis Cons.prems(3) list.sel(2) list.set_sel(2))
have 4:"0 < (tl rs) ! 0" using "2" "3" by auto
have IH: "⋀is'. is' ⊲ Tensor.dims (tensors_from_net (witness_l (tl rs)) $ 0)
⟹ Tensor.lookup (tensors_from_net (witness_l (tl rs)) $ 0) is' =
(if nths is' (Collect even) = nths is' {n. odd n} ∧ (∀i∈set is'. i < last (butlast (tl rs))) then 1 else 0)"
using "1" "2" "3" 4 Cons.hyps(1) by blast
text ‹The list "is" can be split in two parts:›
have "is ⊲ replicate (2^(length rs - 2)) (last rs)"
using Cons.prems(3) dims_output_witness 2 by (metis (no_types, lifting) Cons.prems(1) Cons.prems(3)
Cons.prems(4) Nitpick.size_list_simp(2) One_nat_def diff_diff_left diff_is_0_eq length_tl
nat_le_linear not_numeral_le_zero numeral_le_one_iff one_add_one semiring_norm(70))
then have "is ⊲ replicate (2^(length (tl rs) - 2)) (last rs) @ replicate (2^(length (tl rs) - 2)) (last rs)"
using Cons.prems dims_output_witness by (metis "2" Nitpick.size_list_simp(2) One_nat_def
diff_diff_left length_tl mult_2 not_numeral_le_zero numeral_le_one_iff one_add_one
power.simps(2) replicate_add semiring_norm(70))
then obtain is1 is2 where "is = is1 @ is2" and
is1_replicate: "is1 ⊲ replicate (2^(length (tl rs) - 2)) (last rs)" and
is2_replicate: "is2 ⊲ replicate (2^(length (tl rs) - 2)) (last rs)" by (metis valid_index_split)
then have
is1_valid: "is1 ⊲ Tensor.dims (tensors_from_net (witness_l (tl rs)) $ 0)" (is ?is1) and
is2_valid: "is2 ⊲ Tensor.dims (tensors_from_net (witness_l (tl rs)) $ 0)" (is ?is2)
proof -
have "last (tl rs) = last rs" by (metis "2" ‹rs = r # tl rs› last_ConsR list.size(3) not_numeral_le_zero)
then show ?is1 ?is2 using dims_output_witness[of "tl rs"]
using dims_output_witness[of "tl rs"] 2 3 is1_replicate is2_replicate ‹last (tl rs) = last rs› by auto
qed
text ‹A shorthand for the condition to find a "1" in the tensor:›
let ?cond = "λis rs. nths is {n. even n} = nths is {n. odd n} ∧ (∀i∈set is. i < last (butlast rs))"
text ‹We can use the IH on our newly created is1 and is2:›
have IH_is12:
"Tensor.lookup (tensors_from_net (witness_l (tl rs)) $ 0) is1 =
(if (?cond is1 (tl rs)) then 1 else 0)"
"Tensor.lookup (tensors_from_net (witness_l (tl rs)) $ 0) is2 =
(if (?cond is2 (tl rs)) then 1 else 0)"
using IH is1_valid is2_valid by fast+
text ‹In the induction step we have to add two layers: first the Pool layer, then the Conv layer.
The Pool layer connects the two subtrees. Therefore the two conditions on is1 and is2 become
one, and we have to prove that they are equivalent:›
have "?cond is1 (tl rs) ∧ ?cond is2 (tl rs) ⟷ ?cond is rs"
proof -
have "length is1 = 2 ^ (length (tl rs) - 2)"
"length is2 = 2 ^ (length (tl rs) - 2)"
using is1_replicate is2_replicate by (simp_all add: valid_index_length)
then have "even (length is1)" "even (length is2)"
by (metis Cons.hyps(2) One_nat_def add_gr_0 diff_diff_left even_numeral even_power
length_butlast length_tl list.size(4) one_add_one zero_less_Suc)+
then have "{j. j + length is1 ∈ {n. even n}} = {n. even n}"
"{j. j + length is1 ∈ {n. odd n}} = {n. odd n}" by simp_all
have "length (nths is2 (Collect even)) = length (nths is2 (Collect odd))"
using length_nths_even ‹even (length is2)› by blast
have cond1_iff: "(nths is1 (Collect even) = nths is1 {n. odd n} ∧ nths is2 (Collect even) = nths is2 {n. odd n})
= (nths is (Collect even) = nths is {n. odd n})"
unfolding ‹is = is1 @ is2› nths_append
‹{j. j + length is1 ∈ {n. odd n}} = {n. odd n}› ‹{j. j + length is1 ∈ {n. even n}} = {n. even n}›
by (simp add: ‹length (nths is2 (Collect even)) = length (nths is2 (Collect odd))›)
have "last (butlast (tl rs)) = last (butlast rs)" using Nitpick.size_list_simp(2) ‹even (length is1)›
‹length is1 = 2 ^ (length (tl rs) - 2)› butlast_tl last_tl length_butlast length_tl not_less_eq zero_less_diff
by (metis (full_types) Cons.hyps(2) length_Cons less_nat_zero_code)
have cond2_iff: "(∀i∈set is1. i < last (butlast (tl rs))) ∧ (∀i∈set is2. i < last (butlast (tl rs))) ⟷ (∀i∈set is. i < last (butlast rs))"
unfolding ‹last (butlast (tl rs)) = last (butlast rs)› ‹is = is1 @ is2› set_append by blast
then show ?thesis using cond1_iff cond2_iff by blast
qed
text ‹Now we can make the Pool layer step: ›
have lookup_witness': "Tensor.lookup ((tensors_from_net (witness' (rs ! 1) (tl (tl rs)))) $ 0) is =
(if ?cond is rs then 1 else 0)"
proof -
have lookup_prod: "Tensor.lookup ((tensors_from_net (witness_l (tl rs)) $ 0) ⊗ (tensors_from_net (witness_l (tl rs))) $ 0) is =
(if ?cond is rs then 1 else 0)"
using ‹?cond is1 (tl rs) ∧ ?cond is2 (tl rs) ⟷ ?cond is rs›
unfolding ‹is = is1 @ is2› lookup_tensor_prod[OF is1_valid is2_valid] IH_is12
by auto
have witness_l_tl: "witness_l (tl rs) = witness (rs ! 1) (rs ! 2) (tl (tl (tl rs)))"
by (metis One_nat_def Suc_1 ‹rs = r # tl rs› nth_Cons_Suc)
have tl_tl:"(tl (tl rs)) = ((rs ! 2) # tl (tl (tl rs)))"
proof -
have "length (tl (tl rs)) ≠ 0"
by (metis One_nat_def Suc_eq_plus1 diff_diff_left diff_is_0_eq length_tl not_less_eq_eq
Cons.prems(2) numeral_3_eq_3)
then have "tl (tl rs) ≠ []"
by fastforce
then show ?thesis
by (metis list.exhaust_sel nth_Cons_0 nth_Cons_Suc numeral_2_eq_2 tl_Nil)
qed
have length_gt0:"dim_vec (tensors_from_net (witness (rs ! 1) (rs ! 2) (tl (tl (tl rs))))) > 0"
using output_size_correct_tensors[of "witness (rs ! 1) (rs ! 2) (tl (tl (tl rs)))"]
witness_is_deep_model[of "rs ! 1" "rs ! 2" "tl (tl (tl rs))"]
valid_deep_model[of "rs ! 1" "rs ! 2" "tl (tl (tl rs))"] output_size.simps witness.simps
by (metis "2" "3" One_nat_def ‹rs = r # tl rs› deep_model.elims length_greater_0_conv list.size(3)
not_numeral_le_zero nth_Cons_Suc nth_mem)
then have "tensors_from_net (witness' (rs ! 1) ((rs ! 2) # tl (tl (tl rs)))) $ 0
= (tensors_from_net (witness_l (tl rs)) $ 0) ⊗ (tensors_from_net (witness_l (tl rs)) $ 0)"
unfolding witness'.simps tensors_from_net.simps witness_l_tl using index_component_mult by blast
then show ?thesis using lookup_prod tl_tl by simp
qed
text ‹Then we can make the Conv layer step: ›
show ?case
proof -
have "valid_net' (witness' (rs ! 1) (tl (tl rs)))" by (simp add: witness'_valid)
have "output_size' (witness' (rs ! 1) (tl (tl rs))) = rs ! 1"
by (metis "2" Nitpick.size_list_simp(2) diff_diff_left diff_is_0_eq hd_Cons_tl deep_model'.simps(2) deep_model.elims length_tl not_less_eq_eq numeral_2_eq_2 numeral_3_eq_3 one_add_one output_size.simps(2) output_size.simps(3) tl_Nil witness'_is_deep_model)
have if_resolve:"(if length (tl (tl rs)) = 0 then id_matrix else if length (tl (tl rs)) = 1 then all1_matrix else copy_first_matrix) = copy_first_matrix"
by (metis "2" Cons.prems(2) Nitpick.size_list_simp(2) One_nat_def Suc_n_not_le_n not_numeral_le_zero numeral_3_eq_3)
have "tensors_from_net (Conv (copy_first_matrix (rs ! 0) (rs ! 1)) (witness' (rs ! 1) (tl (tl rs)))) $ j =
tensors_from_net (witness' (rs ! 1) (tl (tl rs))) $ 0"
using tensors_from_net_Conv_copy_first[OF ‹valid_net' (witness' (rs ! 1) (tl (tl rs)))› ‹j < rs ! 0›, unfolded ‹output_size' (witness' (rs ! 1) (tl (tl rs))) = rs ! 1›]
using "4" One_nat_def ‹rs = r # tl rs› nth_Cons_Suc by metis
then show ?thesis unfolding witness.simps if_resolve ‹output_size' (witness' (rs ! 1) (tl (tl rs))) = rs ! 1›
using lookup_witness' ‹valid_net' (witness' (rs ! 1) (tl (tl rs)))› hd_conv_nth output_size_correct_tensors
by fastforce
qed
qed
lemma witness_matricization:
assumes "i < dim_row Aw'" and "j < dim_col Aw'"
shows "Aw' $$ (i, j)
= (if i=j ∧ (∀i0∈set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i). i0 < last (butlast rs)) then 1 else 0)"
proof -
define "is" where "is = weave {n. even n}
(digit_encode (nths (Tensor.dims Aw) {n. even n}) i)
(digit_encode (nths (Tensor.dims Aw) {n. odd n}) j)"
have lookup_eq: "Aw' $$ (i, j) = Tensor.lookup Aw is"
using Aw'_def matricize_def dims_Aw'(1)[symmetric, unfolded A_def] dims_Aw'(2)[symmetric, unfolded A_def Collect_neg_eq]
index_mat(1)[OF ‹i < dim_row Aw'› ‹j < dim_col Aw'›] is_def Collect_neg_eq case_prod_conv
by (metis (no_types) Aw'_def Collect_neg_eq case_prod_conv is_def matricize_def)
have "is ⊲ Tensor.dims Aw"
using is_def valid_index_weave A_def Collect_neg_eq assms digit_encode_valid_index
dims_Aw' by metis
have "even (order Aw)"
unfolding Aw_def using assms dims_output_witness even_numeral le_eq_less_or_eq numeral_2_eq_2 numeral_3_eq_3 deep no_zeros y_valid by fastforce
have nths_dimsAw: "nths (Tensor.dims Aw) (Collect even) = nths (Tensor.dims Aw) {n. odd n}"
proof -
have 0:"Tensor.dims (tensors_from_net (witness_l rs) $ y) = replicate (2 ^ (length rs - 2)) (last rs)"
using dims_output_witness[OF _ no_zeros y_valid] using deep by linarith
show ?thesis unfolding A_def
using nths_replicate
by (metis (no_types, lifting) "0" Aw_def ‹even (order Aw)› length_replicate length_nths_even)
qed
have "i = j ⟷ nths is (Collect even) = nths is {n. odd n}"
proof
have eq_lengths: "length (digit_encode (nths (Tensor.dims Aw) (Collect even)) i)
= length (digit_encode (nths (Tensor.dims Aw) {n. odd n}) j)"
unfolding length_digit_encode by (metis ‹even (order Aw)› length_nths_even)
then show "i = j ⟹ nths is (Collect even) = nths is {n. odd n}" unfolding is_def
using nths_weave[of "digit_encode (nths (Tensor.dims Aw) (Collect even)) i"
"Collect even" "digit_encode (nths (Tensor.dims Aw) {n. odd n}) j", unfolded eq_lengths, unfolded Collect_neg_eq[symmetric] card_even mult_2[symmetric] card_odd]
nths_dimsAw by simp
show "nths is (Collect even) = nths is {n. odd n} ⟹ i = j" unfolding is_def
using nths_weave[of "digit_encode (nths (Tensor.dims Aw) (Collect even)) i"
"Collect even" "digit_encode (nths (Tensor.dims Aw) {n. odd n}) j", unfolded eq_lengths, unfolded Collect_neg_eq[symmetric] card_even mult_2[symmetric] card_odd]
using ‹nths (Tensor.dims Aw) (Collect even) = nths (Tensor.dims Aw) {n. odd n}›
deep no_zeros y_valid assms digit_decode_encode dims_Aw'
by auto (metis digit_decode_encode_lt)
qed
have "i=j ⟹ set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i) = set is"
unfolding is_def nths_dimsAw
using set_weave[of "(digit_encode (nths (Tensor.dims Aw) {n. odd n}) j)" "Collect even"
"(digit_encode (nths (Tensor.dims Aw) {n. odd n}) j)",
unfolded mult_2[symmetric] card_even Collect_neg_eq[symmetric] card_odd]
Un_absorb card_even card_odd mult_2 by blast
then show ?thesis unfolding lookup_eq
using witness_tensor[OF ‹is ⊲ Tensor.dims Aw›]
by (simp add: A_def ‹(i = j) = (nths is (Collect even) = nths is {n. odd n})›)
qed
definition "rows_with_1 = {i. (∀i0∈set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i). i0 < last (butlast rs))}"
lemma card_low_digits:
assumes "m>0" "⋀d. d∈set ds ⟹ m ≤ d"
shows "card {i. i<prod_list ds ∧ (∀i0∈set (digit_encode ds i). i0 < m)} = m ^ (length ds)"
using assms proof (induction ds)
case Nil
then show ?case using prod_list.Nil by simp
next
case (Cons d ds)
define low_digits
where "low_digits ds i ⟷ i < prod_list ds ∧ (∀i0∈set (digit_encode ds i). i0 < m)" for ds i
have "card {i. low_digits ds i} = m ^ (length ds)" unfolding low_digits_def
by (simp add: Cons.IH Cons.prems(1) Cons.prems(2))
have "card {i. low_digits (d # ds) i} = card ({..<m} × {i. low_digits ds i})"
proof -
define f where "f p = fst p + d * snd p" for p
have "inj_on f ({..<m} × {i. low_digits ds i})"
proof (rule inj_onI)
fix x y assume "x ∈ {..<m} × {i. low_digits ds i}" "y ∈ {..<m} × {i. low_digits ds i}" "f x = f y"
then have "fst x<m" "fst y<m" by auto
then have "fst x<d" "fst y<d" using Cons(3) by (meson list.set_intros(1) not_le order_trans)+
then have "f x mod d = fst x" "f y mod d = fst y" unfolding f_def by simp_all
have "f x div d = snd x" "f y div d = snd y" using ‹f x = f y› ‹f x mod d = fst x› ‹fst y < d› f_def by auto
show "x = y" using ‹f x = f y› ‹f x div d = snd x› ‹f x mod d = fst x› ‹f y div d = snd y› ‹f y mod d = fst y› prod_eqI by fastforce
qed
have "f ` ({..<m} × {i. low_digits ds i}) = {i. low_digits (d # ds) i}"
proof (rule subset_antisym; rule subsetI)
fix x assume "x ∈ f ` ({..<m} × {i. low_digits ds i})"
then obtain i0 i1 where "x = i0 + d * i1" "i0 < m" "low_digits ds i1" using f_def by force
then have "i0<d" using Cons(3) by (meson list.set_intros(1) not_le order_trans)
show "x ∈ {i. low_digits (d # ds) i}" unfolding low_digits_def
proof (rule; rule conjI)
have "i1 < prod_list ds" "∀i0∈set (digit_encode ds i1). i0 < m"
using ‹low_digits ds i1› low_digits_def by auto
show "x < prod_list (d # ds)" unfolding prod_list.Cons ‹x = i0 + d * i1› using ‹i0<d› ‹i1 < prod_list ds›
proof -
have "d ≠ 0"
by (metis ‹i0 < d› gr_implies_not0)
then have "(i0 + d * i1) div (d * prod_list ds) = 0"
using ‹i0 < d› ‹i1 < prod_list ds› by (simp add: div_mult2_eq)
then show "i0 + d * i1 < d * prod_list ds"
by (metis (no_types) ‹i0 < d› ‹i1 < prod_list ds› div_eq_0_iff gr_implies_not0 no_zero_divisors)
qed
show "∀i0∈set (digit_encode (d # ds) x). i0 < m"
using ‹∀i0∈set (digit_encode ds i1). i0 < m› ‹i0 < d› ‹i0 < m› ‹x = i0 + d * i1› by auto
qed
next
fix x assume "x ∈ {i. low_digits (d # ds) i}"
then have "x < prod_list (d # ds)" "∀i0∈set (digit_encode (d # ds) x). i0 < m" using low_digits_def by auto
have "x mod d < m" using ‹∀i0∈set (digit_encode (d # ds) x). i0 < m›[unfolded digit_encode.simps] by simp
have "x div d < prod_list ds" using ‹x < prod_list (d # ds)›[unfolded prod_list.Cons]
by (metis div_eq_0_iff div_mult2_eq mult_0_right not_less0)
have "∀i0∈set (digit_encode ds (x div d)). i0 < m" by (simp add: ‹∀i0∈set (digit_encode (d # ds) x). i0 < m›)
have "f ((x mod d),(x div d)) = x" by (simp add: f_def)
show "x ∈ f ` ({..<m} × {i. low_digits ds i})" by (metis SigmaI ‹∀i0∈set (digit_encode ds (x div d)). i0 < m› ‹f (x mod d, x div d) = x› ‹x div d < prod_list ds› ‹x mod d < m› image_eqI lessThan_iff low_digits_def mem_Collect_eq)
qed
then have "bij_betw f ({..<m} × {i. low_digits ds i}) {i. low_digits (d # ds) i}"
by (simp add: ‹inj_on f ({..<m} × {i. low_digits ds i})› bij_betw_def)
then show ?thesis by (simp add: bij_betw_same_card)
qed
then show ?case unfolding ‹card {i. low_digits ds i} = m ^ (length ds)› card_cartesian_product using low_digits_def by simp
qed
lemma card_rows_with_1: "card {i∈rows_with_1. i<dim_row Aw'} = r ^ N_half"
proof -
have 1:"{i∈rows_with_1. i<dim_row Aw'} = {i. i < prod_list (nths (Tensor.dims Aw) (Collect even)) ∧
(∀i0∈set (digit_encode (nths (Tensor.dims Aw) (Collect even)) i). i0 < r)}" (is "?A = ?B")
proof (rule subset_antisym; rule subsetI)
fix i assume "i ∈ ?A"
then have "i < dim_row Aw'" "∀i0∈set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i). i0 < last (butlast rs)"
using rows_with_1_def by auto
then have "i < prod_list (nths (dims Aw) (Collect even))" using dims_Aw' by linarith
then have "digit_encode (nths (dims Aw) (Collect even)) i ⊲ nths (dims Aw) (Collect even)"
using digit_encode_valid_index by auto
have "∀i0∈set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i). i0 < r"
proof
fix i0 assume 1:"i0 ∈ set (digit_encode (nths (dims Aw) (Collect even)) i)"
then obtain k where "k < length (digit_encode (nths (dims Aw) (Collect even)) i)"
"digit_encode (nths (dims Aw) (Collect even)) i ! k = i0" by (meson in_set_conv_nth)
have "i0 < last (butlast rs)"
using ‹∀i0∈set (digit_encode (nths (dims Aw) (Collect even)) i). i0 < last (butlast rs)› 1 by blast
have "set (nths (dims Aw) (Collect even)) ⊆ {last rs}" unfolding dims_Aw using subset_eq by fastforce
then have "nths (dims Aw) (Collect even) ! k = last rs"
using ‹digit_encode (nths (dims Aw) (Collect even)) i ⊲ nths (dims Aw) (Collect even)›
‹k < length (digit_encode (nths (dims Aw) (Collect even)) i)›
nth_mem valid_index_length by auto
then have "i0 < last rs"
using valid_index_lt ‹digit_encode (nths (dims Aw) (Collect even)) i ! k = i0›
‹digit_encode (nths (dims Aw) (Collect even)) i ⊲ nths (dims Aw) (Collect even)›
‹k < length (digit_encode (nths (dims Aw) (Collect even)) i)› valid_index_length by fastforce
then show "i0 < r" unfolding r_def by (simp add: ‹i0 < last (butlast rs)›)
qed
then show "i ∈ ?B" using ‹i < prod_list (nths (dims Aw) (Collect even))› by blast
next
fix i assume "i∈?B"
then show "i∈?A" by (simp add: dims_Aw' r_def rows_with_1_def)
qed
have 2:"⋀d. d ∈ set (nths (Tensor.dims Aw) (Collect even)) ⟹ r ≤ d"
proof -
fix d assume "d ∈ set (nths (Tensor.dims Aw) (Collect even))"
then have "d ∈ set (Tensor.dims Aw)" using in_set_nthsD by fast
then have "d = last rs" using dims_Aw by simp
then show "r ≤ d" by (simp add: r_def)
qed
have 3:"0 < r" unfolding r_def by (metis deep diff_diff_cancel diff_zero dual_order.trans in_set_butlastD last_in_set length_butlast list.size(3) min_def nat_le_linear no_zeros not_numeral_le_zero numeral_le_one_iff rel_simps(3))
have 4: "length (nths (Tensor.dims Aw) (Collect even)) = N_half"
unfolding length_nths order_Aw using card_even[of N_half]
by (metis (mono_tags, lifting) Collect_cong)
then show ?thesis using card_low_digits[of "r" "nths (Tensor.dims Aw) (Collect even)"] 1 2 3 4 by metis
qed
lemma infinite_rows_with_1: "infinite rows_with_1"
proof -
define listpr where "listpr = prod_list (nths (Tensor.dims Aw) {n. even n})"
have "⋀i. listpr dvd i ⟹ i ∈ rows_with_1"
proof -
fix i assume dvd_i: "listpr dvd i"
{
fix i0::nat
assume "i0∈set (digit_encode (nths (Tensor.dims Aw) {n. even n}) i)"
then have "i0=0" using digit_encode_0 dvd_i listpr_def by auto
then have "i0 < last (butlast rs)" using deep no_zeros
by (metis Nitpick.size_list_simp(2) One_nat_def Suc_le_lessD in_set_butlastD last_in_set length_butlast length_tl not_numeral_less_zero numeral_2_eq_2 numeral_3_eq_3 numeral_le_one_iff semiring_norm(70))
}
then show "i∈rows_with_1" by (simp add: rows_with_1_def)
qed
have 0:"Tensor.dims Aw = replicate (2 ^ (length rs - 2)) (last rs)" unfolding Aw_def
using dims_output_witness[OF _ no_zeros y_valid] using deep by linarith
then have "listpr > 0" unfolding listpr_def 0
by (metis "0" deep last_in_set length_greater_0_conv less_le_trans no_zeros dims_Aw'_pow(1) dims_Aw'(1)
zero_less_numeral zero_less_power)
then have "inj ((*) listpr)" by (metis injI mult_left_cancel neq0_conv)
then show ?thesis using ‹⋀i. listpr dvd i ⟹ i ∈ rows_with_1›
by (meson dvd_triv_left image_subset_iff infinite_iff_countable_subset)
qed
lemma witness_submatrix: "submatrix Aw' rows_with_1 rows_with_1 = 1⇩m (r^N_half)"
proof
show "dim_row (submatrix Aw' rows_with_1 rows_with_1) = dim_row (1⇩m (r ^ N_half))"
unfolding index_one_mat(2) dim_submatrix(1)
by (metis (full_types) set_le_in card_rows_with_1)
show "dim_col (submatrix Aw' rows_with_1 rows_with_1) = dim_col (1⇩m (r ^ N_half))"
by (metis ‹dim_row (submatrix Aw' rows_with_1 rows_with_1) = dim_row (1⇩m (r ^ N_half))› dim_submatrix(1) dim_submatrix(2) index_one_mat(2) index_one_mat(3) dims_Aw'_pow)
show "⋀i j. i < dim_row (1⇩m (r ^ N_half)) ⟹
j < dim_col (1⇩m (r ^ N_half)) ⟹ submatrix Aw' rows_with_1 rows_with_1 $$ (i, j) = 1⇩m (r ^ N_half) $$ (i, j)"
proof -
fix i j assume "i < dim_row (1⇩m (r ^ N_half))" "j < dim_col (1⇩m (r ^ N_half))"
then have "i < r ^ N_half" "j < r ^ N_half" by auto
then have "i < card {i ∈ rows_with_1. i < dim_row Aw'}" "j < card {i ∈ rows_with_1. i < dim_col Aw'}"
using card_rows_with_1 dims_Aw'_pow by auto
then have "pick rows_with_1 i < dim_row Aw'" "pick rows_with_1 j < dim_col Aw'"
using card_le_pick_inf[OF infinite_rows_with_1, of "dim_row Aw'" i]
using card_le_pick_inf[OF infinite_rows_with_1, of "dim_col Aw'" j] by force+
have "∀i0∈set (digit_encode (nths (dims Aw) (Collect even)) (pick rows_with_1 i)). i0 < last (butlast rs)"
using infinite_rows_with_1 pick_in_set_inf rows_with_1_def by auto
then have "Aw' $$ (pick rows_with_1 i, pick rows_with_1 j) = (if pick rows_with_1 i = pick rows_with_1 j then 1 else 0)"
using witness_matricization[OF ‹pick rows_with_1 i < dim_row Aw'› ‹pick rows_with_1 j < dim_col Aw'›] by simp
then have "submatrix Aw' rows_with_1 rows_with_1 $$ (i, j) = (if pick rows_with_1 i = pick rows_with_1 j then 1 else 0)"
using submatrix_index by (metis (no_types, lifting)
‹dim_col (submatrix Aw' rows_with_1 rows_with_1) = dim_col (1⇩m (r ^ N_half))›
‹dim_row (submatrix Aw' rows_with_1 rows_with_1) = dim_row (1⇩m (r ^ N_half))›
‹i < dim_row (1⇩m (r ^ N_half))› ‹j < r ^ N_half› dim_submatrix(1) dim_submatrix(2) index_one_mat(3))
then have "submatrix Aw' rows_with_1 rows_with_1 $$ (i, j) = (if i = j then 1 else 0)"
using pick_eq_iff_inf[OF infinite_rows_with_1] by auto
then show "submatrix Aw' rows_with_1 rows_with_1 $$ (i, j) = 1⇩m (r ^ N_half) $$ (i, j)"
by (simp add: ‹i < r ^ N_half› ‹j < r ^ N_half›)
qed
qed
lemma witness_det: "det (submatrix Aw' rows_with_1 rows_with_1) ≠ 0" unfolding witness_submatrix by simp
end
interpretation example : deep_model_correct_params False "[10,10,10]"
unfolding deep_model_correct_params_def by simp
interpretation example : deep_model_correct_params_y False "[10,10,10]" 1
unfolding deep_model_correct_params_y_def deep_model_correct_params_y_axioms_def
deep_model_correct_params_def by simp
end