Theory DL_Network

theory DL_Network
imports Tensor_Unit_Vec DL_Flatten_Matrix DL_Missing_List
(* Author: Alexander Bentkamp, Universit├Ąt des Saarlandes
*)
section ‹Deep Learning Networks›

theory DL_Network
imports Tensor_Product
  Jordan_Normal_Form.Matrix Tensor_Unit_Vec DL_Flatten_Matrix 
  Jordan_Normal_Form.DL_Missing_List
begin

text ‹This symbol is used for the Tensor product:›
no_notation Group.monoid.mult (infixl "⊗ı" 70)

notation Matrix.unit_vec ("unitv")
hide_const (open) Matrix.unit_vec 


datatype 'a convnet = Input nat | Conv 'a "'a convnet" | Pool "'a convnet" "'a convnet"

fun input_sizes :: "'a convnet ⇒ nat list" where
"input_sizes (Input M) = [M]" |
"input_sizes (Conv A m) = input_sizes m" |
"input_sizes (Pool m1 m2) = input_sizes m1 @ input_sizes m2"

fun count_weights :: "bool ⇒ (nat × nat) convnet ⇒ nat" where
"count_weights shared (Input M) = 0" |
"count_weights shared (Conv (r0, r1) m) = r0 * r1 + count_weights shared m" |
"count_weights shared (Pool m1 m2) = 
  (if shared 
    then max (count_weights shared m1) (count_weights shared m2) 
    else count_weights shared m1 + count_weights shared m2)"

fun output_size :: "(nat × nat) convnet ⇒ nat" where
"output_size (Input M) = M" |
"output_size (Conv (r0,r1) m) = r0" |
"output_size (Pool m1 m2) = output_size m1"

inductive valid_net :: "(nat×nat) convnet ⇒ bool" where
"valid_net (Input M)" |
"output_size m = r1 ⟹ valid_net m ⟹ valid_net (Conv (r0,r1) m)" |
"output_size m1 = output_size m2 ⟹ valid_net m1 ⟹ valid_net m2 ⟹ valid_net (Pool m1 m2)"


fun insert_weights :: "bool ⇒ (nat × nat) convnet ⇒ (nat ⇒ real) ⇒ real mat convnet" where
"insert_weights shared (Input M) w = Input M" |
"insert_weights shared (Conv (r0,r1) m) w = Conv
  (extract_matrix w r0 r1)
  (insert_weights shared m (λi. w (i+r0*r1)))" |
"insert_weights shared (Pool m1 m2) w = Pool
  (insert_weights shared m1 w)
  (insert_weights shared m2 (if shared then w else (λi. w (i+(count_weights shared m1)))))"

fun remove_weights :: "real mat convnet ⇒ (nat × nat) convnet" where
"remove_weights (Input M) = Input M" |
"remove_weights (Conv A m) = Conv (dim_row A, dim_col A) (remove_weights m)" |
"remove_weights (Pool m1 m2) = Pool (remove_weights m1) (remove_weights m2)"

abbreviation "output_size' == (λm. output_size (remove_weights m))"
abbreviation "valid_net' == (λm. valid_net (remove_weights m))"

fun evaluate_net :: "real mat convnet ⇒ real vec list ⇒ real vec" where
"evaluate_net (Input M) inputs = hd inputs" |
"evaluate_net (Conv A m) inputs = A *v evaluate_net m inputs" |
"evaluate_net (Pool m1 m2) inputs = component_mult
  (evaluate_net m1 (take (length (input_sizes m1)) inputs))
  (evaluate_net m2 (drop (length (input_sizes m1)) inputs))"

definition mat_tensorlist_mult :: "real mat ⇒ real tensor vec ⇒ nat list ⇒ real tensor vec"
where "mat_tensorlist_mult A Ts ds
 = Matrix.vec (dim_row A) (λj. tensor_from_lookup ds (λis. (A *v (map_vec (λT. Tensor.lookup T is) Ts)) $j))"

lemma insert_weights_cong:
assumes "(⋀i. i<count_weights s m ⟹ w1 i = w2 i)"
shows "insert_weights s m w1 = insert_weights s m w2"
using assms proof (induction m arbitrary: w1 w2)
  case Input
  then show ?case by simp
next
  case (Conv r01 m)
  then obtain r0 r1 where "r01 = (r0,r1)" by (meson surj_pair)
  have 2:"insert_weights s m (λi. w1 (i + r0 * r1)) = insert_weights s m (λi. w2 (i + r0 * r1))" using Conv
    using ‹r01 = (r0, r1)› add.commute add_less_cancel_right count_weights.simps(2) by fastforce
  then show ?case unfolding ‹r01 = (r0,r1)› insert_weights.simps
    by (metis Conv.prems ‹r01 = (r0, r1)› count_weights.simps(2) extract_matrix_cong trans_less_add1)
next
  case (Pool m1 m2)
  have 1:"insert_weights s m1 w1 = insert_weights s m1 w2"
    using Pool(1)[of w1 w2] Pool(3)[unfolded count_weights.simps] 
    by (cases s; auto)
  have shared:"s=True ⟹ insert_weights s m2 w1 = insert_weights s m2 w2"
    using Pool(2)[of w1 w2] Pool(3)[unfolded count_weights.simps] by auto
  have unshared:"s=False ⟹ insert_weights s m2 (λi. w1 (i + count_weights s m1)) = insert_weights s m2 (λi. w2 (i + count_weights s m1))"
    using Pool(2) Pool(3) count_weights.simps by fastforce
  show ?case unfolding insert_weights.simps 1 using unshared shared by simp
qed

lemma dims_mat_tensorlist_mult:
assumes "T∈setv (mat_tensorlist_mult A Ts ds)"
shows "Tensor.dims T = ds"
proof -
  obtain j where "T = tensor_from_lookup ds (λis. (A *v (map_vec (λT. Tensor.lookup T is) Ts)) $j)"
    using vec_setE[OF assms, unfolded mat_tensorlist_mult_def] by (metis dim_vec index_vec)
  then show ?thesis by (simp add: length_tensor_vec_from_lookup tensor_from_lookup_def)
qed

fun tensors_from_net :: "real mat convnet ⇒ real tensor vec" where
"tensors_from_net (Input M) = Matrix.vec M (λi. unit_vec M i)" |
"tensors_from_net (Conv A m) = mat_tensorlist_mult A (tensors_from_net m) (input_sizes m)" |
"tensors_from_net (Pool m1 m2) = component_mult (tensors_from_net m1) (tensors_from_net m2)"

lemma output_size_correct_tensors:
assumes "valid_net' m"
shows "output_size' m = dim_vec (tensors_from_net m)"
using assms proof (induction m)
  case Input
  then show ?case by simp
next
  case (Conv A m)
  then show ?case
    unfolding remove_weights.simps output_size.simps tensors_from_net.simps
    using mat_tensorlist_mult_def by auto
next
  case (Pool m1 m2)
  then show ?case by (metis convnet.distinct(3) convnet.distinct(5) convnet.inject(3) dim_component_mult
    min.idem output_size.simps(3) remove_weights.simps(3) tensors_from_net.simps(3) valid_net.simps)
qed


lemma output_size_correct:
assumes "valid_net' m"
and "map dim_vec inputs = input_sizes m"
shows "output_size' m = dim_vec (evaluate_net m inputs)"
using assms proof (induction m arbitrary:inputs)
  case Input
  then show ?case using length_Cons list.map_sel(1) list.sel(1) list.simps(8) list.size(3) nat.simps(3) by auto
next
  case (Conv A m)
  then show ?case unfolding evaluate_net.simps remove_weights.simps output_size.simps dim_mult_mat_vec
    by auto
next
  case (Pool m1 m2)
  then have "valid_net' m1" "valid_net' m2"
    using convnet.distinct(3) convnet.distinct(5) convnet.inject(3) remove_weights.simps(3) valid_net.cases by fastforce+
  moreover have "map dim_vec (take (length (input_sizes m1)) inputs) = input_sizes m1"
       "map dim_vec (drop (length (input_sizes m1)) inputs) = input_sizes m2"
    using Pool.prems(2) by (metis append_eq_conv_conj drop_map input_sizes.simps(3) take_map)+
  ultimately have
    "output_size' m1 = dim_vec (evaluate_net m1 (take (length (input_sizes m1)) inputs))"
    "output_size' m2 = dim_vec (evaluate_net m2 (drop (length (input_sizes m1)) inputs))"
    using Pool.IH by blast+
  then show ?case unfolding evaluate_net.simps remove_weights.simps output_size.simps
    by (metis Pool.prems(1) ‹valid_net' m1› ‹valid_net' m2› dim_component_mult
     output_size.simps(3) output_size_correct_tensors remove_weights.simps(3) tensors_from_net.simps(3))
qed

lemma input_sizes_remove_weights: "input_sizes m = input_sizes (remove_weights m)"
  by (induction m; simp)

lemma dims_tensors_from_net:
assumes "T ∈ setv (tensors_from_net m)"
shows "Tensor.dims T = input_sizes m"
using assms proof (induction m arbitrary:T)
  case (Input M)
  then obtain j where "T = unit_vec M j"
    using vec_setE tensors_from_net.simps(1) by (metis dim_vec index_vec)
  then show ?case by (simp add: dims_unit_vec)
next
  case (Conv A m)
  then show ?case unfolding remove_weights.simps input_sizes.simps
    using dims_mat_tensorlist_mult by (simp add: input_sizes_remove_weights)
next
  case (Pool m1 m2 T)
  then obtain i where
    "component_mult (tensors_from_net m1) (tensors_from_net m2) $ i = T"
    "i < dim_vec (tensors_from_net m1)" "i < dim_vec (tensors_from_net m2)"
    using tensors_from_net.simps vec_setE dim_component_mult by (metis min.strict_boundedE)
  then obtain T1 T2 where "T = T1 ⊗ T2" "T1 ∈ setv (tensors_from_net m1)" "T2 ∈ setv (tensors_from_net m2)"
    using vec_setI by (metis index_component_mult)
  then show ?case unfolding remove_weights.simps input_sizes.simps by (simp add: Pool.IH(1) Pool.IH(2))
qed

definition base_input :: "real mat convnet ⇒ nat list ⇒ real vec list" where
"base_input m is = (map (λ(n, i). unitv n i) (zip (input_sizes m) is))"

lemma base_input_length:
assumes "is ⊲ input_sizes m"
shows "input_sizes m = map dim_vec (base_input m is)"
proof (rule nth_equalityI)
  have "length (input_sizes m) = length is" using assms valid_index_length by auto
  then show "length (input_sizes m) = length (map dim_vec (base_input m is))" unfolding base_input_def by auto
  {
    fix i
    assume "i<length (input_sizes m)"
    then have "map (λ(n, i). unitv n i) (zip (input_sizes m) is) ! i = unitv (input_sizes m ! i) (is ! i)"
      using ‹length (input_sizes m) = length is› by auto
    then have "input_sizes m ! i = map dim_vec (base_input m is) ! i" unfolding base_input_def using index_unit_vec(3)
      using ‹i < length (input_sizes m)› ‹length (input_sizes m) = length (map dim_vec (base_input m is))›
       base_input_def assms length_map nth_map valid_index_lt by (simp add: input_sizes_remove_weights)
  }
  then show "⋀i. i < length (input_sizes m) ⟹ input_sizes m ! i = map dim_vec (base_input m is) ! i" by auto
qed

lemma nth_mat_tensorlist_mult:
assumes "⋀A. A∈setv Ts ⟹ dims A = ds"
assumes "i < dim_row A"
assumes "dim_vec Ts = dim_col A"
shows "mat_tensorlist_mult A Ts ds $ i = listsum ds (map (λj. (A $$ (i,j)) ⋅ Ts $ j) [0..<dim_vec Ts])"
  (is "_ = listsum ds ?Ts'")
proof (rule tensor_lookup_eqI)
  have dims_Ts':"⋀T. T∈set ?Ts' ⟹ dims T = ds"
  proof -
    fix T assume "T∈set ?Ts'"
    then obtain k where "T = ?Ts' ! k" and "k < length ?Ts'" "k < dim_vec Ts" using in_set_conv_nth by force
    show "dims T = ds" unfolding ‹T = ?Ts' ! k›   nth_map[OF ‹k < length ?Ts'›[unfolded length_map]]
      using assms(1) ‹k < dim_vec Ts›
      by (simp add: ‹k < length (map (λj. A $$ (i, j) ⋅ Ts $ j) [0..<dim_vec Ts])› vec_setI)
  qed
  then show dims_eq:"dims (mat_tensorlist_mult A Ts ds $ i) = dims (Tensor_Plus.listsum ds (map (λj. A $$ (i, j) ⋅ Ts $ j) [0..<dim_vec Ts]))"
    using dims_mat_tensorlist_mult assms  mat_tensorlist_mult_def listsum_dims
    by (metis (no_types, lifting) dim_vec vec_setI)

  fix "is" assume is_valid:"is ⊲ dims (mat_tensorlist_mult A Ts ds $ i)"
  then have "is ⊲ ds" using dims_eq dims_Ts' listsum_dims by (metis (no_types, lifting))

  have summand_eq: "⋀j. j ∈ {0 ..< dim_vec Ts} ⟹ row A i $ j * (map_vec (λT. Tensor.lookup T is) Ts) $ j = lookup (A $$ (i, j) ⋅ Ts $ j) is"
    using index_vec ‹i<dim_row A› row_def ‹dim_vec Ts = dim_col A›
     ‹is ⊲ ds› assms(1) lookup_smult atLeastLessThan_iff index_map_vec(1) vec_setI by metis

  have "lookup (mat_tensorlist_mult A Ts ds $ i) is = (A *v (map_vec (λT. Tensor.lookup T is) Ts)) $ i"
    unfolding mat_tensorlist_mult_def using lookup_tensor_from_lookup[OF ‹is ⊲ ds›] using ‹i<dim_row A› by auto
  also have "... = row A i ∙ map_vec (λT. Tensor.lookup T is) Ts"
    using ‹i<dim_row A› by simp
  also have "... = (∑ j ∈ {0 ..< dim_vec Ts}. row A i $ j * (map_vec (λT. Tensor.lookup T is) Ts) $ j)"
    unfolding scalar_prod_def nth_rows[OF ‹i<dim_row A›] by simp
  also have "... = (∑j∈{0..<dim_vec Ts}. lookup (A $$ (i, j) ⋅ Ts $ j) is)" using summand_eq by force
  also have "... = (∑A←?Ts'. lookup A is)" unfolding map_map
    Groups_List.sum_set_upt_conv_sum_list_nat[symmetric]  atLeastLessThan_upt[symmetric] by auto
  also have "... = lookup (listsum ds ?Ts') is" using lookup_listsum[OF ‹is ⊲ ds›] dims_Ts' by fastforce
  finally show "lookup (mat_tensorlist_mult A Ts ds $ i) is = lookup (listsum ds ?Ts') is" by metis
qed

lemma lookup_tensors_from_net:
assumes "valid_net' m"
and "is ⊲ input_sizes m"
and "j < output_size' m"
shows "Tensor.lookup (tensors_from_net m $ j) is = evaluate_net m (base_input m is) $ j"
using assms proof (induction m arbitrary:j "is")
  case (Input M)
  then have "j < M" using output_size.simps(1) using Input by auto
  then have 1:"tensors_from_net (Input M) $ j = unit_vec M j" by simp
  obtain i where "is = [i]" "i<M" using Input Suc_length_conv input_sizes.simps(1) length_0_conv list.size(3) valid_index_length by auto
  then have 2:"Tensor.lookup (tensors_from_net (Input M) $ j) is = (if i=j then 1 else 0)" using lookup_unit_vec 1 by metis
  have "evaluate_net (Input M) (map (λ(n, i). unitv n i) (zip (input_sizes (Input M)) is)) = unitv M i" using ‹is = [i]› by auto
  then show ?case using 2 ‹j < M› base_input_def by (simp add: ‹i < M›)
next
  case (Conv A m j "is")
  have is_valid:"is ⊲ input_sizes m" using Conv.prems by simp
  have valid_net:"valid_net' m" using Conv.prems(1) unfolding remove_weights.simps
    using valid_net.simps convnet.distinct(1) convnet.distinct(5) convnet.inject(2) by blast
  then have length_em: "dim_vec (evaluate_net m (base_input m is)) = output_size' m"
    using output_size_correct base_input_length is_valid by metis

  have IH':"map_vec (λT. Tensor.lookup T is) (tensors_from_net m) =
                evaluate_net m (base_input m is)"
  proof (rule eq_vecI)
    show equal_lengths: "dim_vec (map_vec (λT. lookup T is) (tensors_from_net m))
      = dim_vec (evaluate_net m (base_input m is))" using length_em
      by (simp add: output_size_correct_tensors valid_net)
    show "⋀i. i < dim_vec (evaluate_net m (base_input m is)) ⟹
         map_vec (λT. lookup T is) (tensors_from_net m) $ i = evaluate_net m (base_input m is) $ i"
    proof -
      fix i
      assume "i < dim_vec (evaluate_net m (base_input m is))"
      then have "i < output_size' m" using equal_lengths length_em by auto
      then show "map_vec (λT. lookup T is) (tensors_from_net m) $ i
        = evaluate_net m (base_input m is) $ i"
        using Conv.IH is_valid equal_lengths valid_net base_input_def length_em nth_map_upt
        length_map nth_map by auto
    qed
  qed

  have "Tensor.lookup ((tensors_from_net (Conv A m)) $ j) is =
    (A *v (map_vec (λT. Tensor.lookup T is) (tensors_from_net m))) $ j"
  proof -
    have "dim_vec (tensors_from_net (Conv A m)) = output_size' (Conv A m)"
      using Conv by (simp add: mat_tensorlist_mult_def)
    then have "j<dim_vec (tensors_from_net (Conv A m))" using Conv.prems by auto
    then have "(tensors_from_net (Conv A m)) $ j =  tensor_from_lookup (input_sizes m)
                (λis. (A *v (map_vec (λT. Tensor.lookup T is) (tensors_from_net m))) $ j)"
      unfolding tensors_from_net.simps mat_tensorlist_mult_def by fastforce
    then show ?thesis
        using lookup_tensor_from_lookup[OF is_valid] by auto
  qed
  also have "(A *v (map_vec (λT. Tensor.lookup T is) (tensors_from_net m))) $ j
    = (A *v (evaluate_net m (base_input m is))) $ j" using IH' by auto
  also have "... = evaluate_net (Conv A m) (base_input (Conv A m) is) $ j"
    unfolding base_input_def using evaluate_net.simps by auto
  finally show ?case by auto
next
  case (Pool m1 m2 j "is")

  text ‹We split "is" into two parts for each subnet:›
  obtain is1 is2 where is12_def:"is = is1 @ is2" "is1 ⊲ input_sizes m1" "is2 ⊲ input_sizes m2"
    by (metis Pool.prems(2) input_sizes.simps(3) valid_index_split)

  text ‹Apply the induction hypothesis to the subnets:›
  have IH:"Tensor.lookup (tensors_from_net m1 $ j) is1
      = evaluate_net m1 (map (λ(x, y). unitv x y) (zip (input_sizes m1) is1)) $ j"
      "Tensor.lookup (tensors_from_net m2 $ j) is2
      = evaluate_net m2 (map (λ(x, y). unitv x y) (zip (input_sizes m2) is2)) $ j"
    using Pool convnet.distinct(3) convnet.distinct(5) convnet.inject(3) remove_weights.simps(3)
    valid_net.simps  ‹is1 ⊲ input_sizes m1› ‹is2 ⊲ input_sizes m2› output_size.simps(3)
    by (metis base_input_def)+

  text ‹In the Pool layer tensor entries get multiplied:›
  have lookup_prod: "Tensor.lookup (tensors_from_net (Pool m1 m2) $ j) is
    = Tensor.lookup (tensors_from_net m1 $ j) is1 * Tensor.lookup (tensors_from_net m2 $ j) is2"
  proof -
    have j_small: "j < dim_vec (tensors_from_net m1)"  "j < dim_vec (tensors_from_net m2)"
      by (metis Pool.prems(1) Pool.prems(3) convnet.distinct(3) convnet.inject(3) convnet.simps(9)
      output_size.simps(3) output_size_correct_tensors remove_weights.simps(3) valid_net.cases)+
    then have 0:"tensors_from_net (Pool m1 m2) $ j = tensors_from_net m1 $ j ⊗ tensors_from_net m2 $ j"
      unfolding tensors_from_net.simps using j_small index_component_mult by blast
    have "Tensor.dims (tensors_from_net m1 $ j) = input_sizes m1"
         "Tensor.dims (tensors_from_net m2 $ j) = input_sizes m2"
      using dims_tensors_from_net j_small nth_mem by (simp_all add: vec_setI)
    then have is12_valid:
              "is1 ⊲ Tensor.dims (tensors_from_net m1 $ j)"
              "is2 ⊲ Tensor.dims (tensors_from_net m2 $ j)"
      using is12_def by presburger+
    then show ?thesis
      unfolding 0 using lookup_tensor_prod[OF is12_valid] is12_def by auto
  qed

  text ‹Output values get multiplied in the Pool layer as well:›
  have "evaluate_net (Pool m1 m2) (base_input (Pool m1 m2) is) $ j
    = evaluate_net m1 (base_input m1 is1) $ j * evaluate_net m2 (base_input m2 is2) $ j"
  proof -
    have "valid_net' m1" "valid_net' m2"
      using remove_weights.simps  valid_net.simps Pool.prems
      by (metis convnet.distinct(3) convnet.distinct(5) convnet.inject(3))+
    have "input_sizes m1 = map dim_vec (base_input m1 is1)"
         "input_sizes m2 = map dim_vec (base_input m2 is2)"
      using base_input_def base_input_length base_input_def is12_def by auto
    have "j < dim_vec (evaluate_net m1 (base_input m1 is1))" "j < dim_vec (evaluate_net m2 (base_input m2 is2))"
      using Pool.prems ‹input_sizes m1 = map dim_vec (base_input m1 is1)› ‹valid_net' m1›
      output_size_correct by (auto,metis Pool.prems(1) Pool.prems(3) ‹input_sizes m2 = map dim_vec (base_input m2 is2)›
      convnet.distinct(3) convnet.distinct(5) convnet.inject(3) output_size.simps(3) output_size_correct
      remove_weights.simps(3) valid_net.cases)
    then show ?thesis unfolding evaluate_net.simps unfolding base_input_def
      using is12_def(1) is12_def(2) valid_index_length by (simp add: append_eq_conv_conj drop_map
      drop_zip index_component_mult input_sizes_remove_weights take_map take_zip)
  qed

  then show ?case using lookup_prod IH base_input_def by auto
qed

primrec extract_weights::"bool ⇒ real mat convnet ⇒ nat ⇒ real" where
  extract_weights_Input: "extract_weights shared (Input M) = (λx. 0)"
| extract_weights_Conv: "extract_weights shared (Conv A m) = 
    (λx. if x < dim_row A * dim_col A then flatten_matrix A x 
         else extract_weights shared m (x - dim_row A * dim_col A))"
| extract_weights_Pool: "extract_weights shared (Pool m1 m2) = 
    (λx. if x < count_weights shared (remove_weights m1) 
         then extract_weights shared m1 x 
         else extract_weights shared m2 (x - count_weights shared (remove_weights m1)))"

inductive balanced_net::"(nat × nat) convnet ⇒ bool" where
  balanced_net_Input: "balanced_net (Input M)"
| balanced_net_Conv: "balanced_net m ⟹ balanced_net (Conv A m)"
| balanced_net_Pool: "balanced_net m1 ⟹ balanced_net m2 ⟹ 
    count_weights True m1 = count_weights True m2 ⟹ balanced_net (Pool m1 m2)"

inductive shared_weight_net::"real mat convnet ⇒ bool" where
  shared_weight_net_Input: "shared_weight_net (Input M)"
| shared_weight_net_Conv: "shared_weight_net m ⟹ shared_weight_net (Conv A m)"
| shared_weight_net_Pool: "shared_weight_net m1 ⟹ shared_weight_net m2 ⟹ 
  count_weights True (remove_weights m1) = count_weights True (remove_weights m2) ⟹ 
  (⋀x. x < count_weights True (remove_weights m1) ⟹ extract_weights True m1 x = extract_weights True m2 x)
   ⟹ shared_weight_net (Pool m1 m2)"

lemma insert_extract_weights_cong_shared:
assumes "shared_weight_net m"
assumes "⋀x. x < count_weights True (remove_weights m) ⟹ f x = extract_weights True m x"
shows "m = insert_weights True (remove_weights m) f"
using assms proof (induction m arbitrary:f)
case (shared_weight_net_Input M)
  then show ?case 
    by simp
next
  case (shared_weight_net_Conv m A)
  have "extract_matrix f (dim_row A) (dim_col A) = A"
    by (simp add: extract_matrix_cong extract_matrix_flatten_matrix shared_weight_net_Conv.prems)
  then show ?case
    using shared_weight_net_Conv.IH[of "(λi. f (i + dim_row A * dim_col A))"]
    using shared_weight_net_Conv.prems by auto
next
  case (shared_weight_net_Pool m1 m2)
  have "m1 = insert_weights True (remove_weights m1) f"
    using shared_weight_net_Pool.IH(1) shared_weight_net_Pool.prems by auto
  have "m2 = insert_weights True (remove_weights m2) f"
    using local.shared_weight_net_Pool(3) shared_weight_net_Pool.IH(2) 
    shared_weight_net_Pool.hyps(4) shared_weight_net_Pool.prems by fastforce
  then show ?case
    using ‹m1 = insert_weights True (remove_weights m1) f› by auto
qed

lemma insert_extract_weights_cong_unshared:
assumes "⋀x. x < count_weights False (remove_weights m) ⟹ f x = extract_weights False m x"
shows "m = insert_weights False (remove_weights m) f"
using assms proof (induction m arbitrary:f)
case (Input M)
  then show ?case 
    by simp
next
  case (Conv A m)
  then have "extract_matrix f (dim_row A) (dim_col A) = A"
    by (metis count_weights.simps(2) extract_matrix_flatten_matrix_cong extract_weights_Conv remove_weights.simps(2) trans_less_add1)
  then show ?case 
    using Conv.IH Conv.prems by auto
next
  case (Pool m1 m2)
  then show ?case
    using Pool.IH(1) Pool.IH(2) Pool.prems by auto
qed

lemma remove_insert_weights:
shows "remove_weights (insert_weights s m w) = m"
proof (induction m arbitrary:w)
  case Input
  then show ?case by simp
next
  case (Conv r12 m)
  then obtain r1 r2 where "r12 = (r1, r2)" by fastforce
  then have "remove_weights (insert_weights s m w) = m" using Conv.IH by blast
  then have "remove_weights (insert_weights s (Conv (r1,r2) m) w) = Conv (r1,r2) m"
    unfolding insert_weights.simps remove_weights.simps
    using extract_matrix_def Conv.IH dim_extract_matrix(1) by (metis dim_col_mat(1) )
  then show ?case using ‹r12 = (r1, r2)› by blast
next
  case (Pool m1 m2 w)
  then show ?case unfolding insert_weights.simps remove_weights.simps using Pool.IH by blast
qed

lemma extract_insert_weights_shared: 
assumes "x<count_weights True m"
and "balanced_net m"
shows "extract_weights True (insert_weights True m w) x = w x"
using assms
proof (induction m arbitrary:w x)
  case (Input x)
  then show ?case 
    by simp
next
  case (Conv r01 m)
  obtain r0 r1 where "r01 = (r0,r1)" by force
  then show ?case unfolding ‹r01 = (r0,r1)› insert_weights.simps extract_weights.simps 
    apply (cases "x < dim_row (extract_matrix w r0 r1) * dim_col (extract_matrix w r0 r1)")  
     apply (auto simp add: dim_extract_matrix(1) dim_extract_matrix(2) flatten_matrix_extract_matrix)
    using Conv.IH[of _ "λi. w (i + r0 * r1)"] Conv.prems(1) Conv.prems(2) ‹r01 = (r0, r1)› balanced_net.cases by force
next
  case (Pool m1 m2)
  then show ?case unfolding insert_weights.simps extract_weights.simps  remove_insert_weights
    apply (cases "x < count_weights True m1")
     apply (metis balanced_net.simps convnet.distinct(5) convnet.inject(3) count_weights.simps(1) not_less_zero)
    by (metis (no_types, lifting) balanced_net.simps convnet.distinct(5) convnet.inject(3) count_weights.simps(1) count_weights.simps(3) less_max_iff_disj not_less_zero)
qed

lemma shared_weight_net_insert_weights: "balanced_net m ⟹ shared_weight_net (insert_weights True m w)"
proof (induction m arbitrary:w)
  case (Input x)
  then show ?case using insert_weights.simps balanced_net.simps shared_weight_net.simps by metis
next
  case (Conv r01 m)
  then obtain r0 r1 where "r01 = (r0,r1)" by force
  then show ?case  unfolding ‹r01 = (r0,r1)› insert_weights.simps  
    by (metis Conv.IH Conv.prems balanced_net.simps convnet.distinct(1) convnet.distinct(5) convnet.inject(2) shared_weight_net_Conv)
next
  case (Pool m1 m2)
  have "balanced_net m1" "balanced_net m2"
    using Pool.prems balanced_net.simps by blast+              
  have "⋀x. x < count_weights True m1 ⟹
         extract_weights True (insert_weights True m1 w) x = extract_weights True (insert_weights True m2 w) x"
    using extract_insert_weights_shared 
    by (metis Pool.prems balanced_net.simps convnet.distinct(3) convnet.distinct(5) convnet.inject(3))
  then show ?case unfolding insert_weights.simps using Pool(1)[of w] Pool(2)[of w] 
    by (metis Pool.prems balanced_net.simps convnet.distinct(3) convnet.distinct(5) convnet.inject(3) remove_insert_weights shared_weight_net_Pool)
qed

lemma finite_valid_index: "finite {is. is ⊲ ds}"
proof (induction ds)
  case Nil
  then show ?case by (metis List.finite_set finite_subset length_0_conv list.set_intros(1) mem_Collect_eq subsetI valid_index_length)
next
  case (Cons d ds)
  have "{is. is ⊲ d # ds} ⊆ (⋃i<d. {i # is |is. is ⊲ ds})"
  proof (rule subsetI)
    fix "is" assume "is ∈ {is. is ⊲ d # ds}"
    then have "is ⊲ d # ds" by auto
    then obtain i is' where "is = i # is'" by blast
    then have "i<d" using ‹is ⊲ d # ds› by blast
    have "is' ⊲ ds" using ‹is = i # is'› ‹is ⊲ d # ds› by blast
    have "is ∈ {i # is |is. is ⊲ ds}" by (simp add: ‹is = i # is'› ‹is' ⊲ ds›)
    then show "is ∈ (⋃i<d. {i # is |is. is ⊲ ds})" using ‹i < d› by blast
  qed
  moreover have "⋀i. finite {i # is |is. is ⊲ ds}" by (simp add: Cons.IH)
  ultimately show "finite {is. is ⊲ d # ds}" by (simp add: finite_subset)
qed

lemma setsum_valid_index_split:
"(∑is | is ⊲ ds1 @ ds2. f is) = (∑is1 | is1 ⊲ ds1. (∑is2 | is2 ⊲ ds2. f (is1 @ is2)))"
proof -
  have 1:"((λ(is1, is2). is1 @ is2) ` ({is1. is1 ⊲ ds1} × {is2. is2 ⊲ ds2})) = {is. is ⊲ ds1 @ ds2}" (is "?A = ?B")
  proof (rule subset_antisym; rule subsetI)
    fix x assume "x ∈ ?A"
    then show "x ∈ ?B" using valid_index_append by auto
  next
    fix x assume "x ∈ ?B"
    then have "x ⊲ ds1 @ ds2" by auto
    then obtain x1 x2 where "x = x1 @ x2" "x1 ⊲ ds1" "x2 ⊲ ds2" by (metis valid_index_split)
    then have "(x1, x2) ∈ ({is1. is1 ⊲ ds1} × {is2. is2 ⊲ ds2})" by auto
    then show "x ∈ ?A" using imageI ‹x = x1 @ x2› by blast
  qed
  have 2:"inj_on (λ(is1, is2). is1 @ is2) ({is1. is1 ⊲ ds1} × {is2. is2 ⊲ ds2})"
    by (simp add: inj_on_def valid_index_length)
  show ?thesis
    unfolding Groups_Big.comm_monoid_add_class.sum.cartesian_product[of "λis1 is2. f (is1 @ is2)"]
    using Groups_Big.comm_monoid_add_class.sum.reindex[OF 2, of f] 1
     "2" SigmaE prod.simps(2) sum.reindex_cong by (simp add: split_def)
qed

lemma prod_lessThan_split:
fixes g :: "nat ⇒ real" shows "prod g {..<n+m} = prod g {..<n} * prod (λx. g (x+n)) {..<m}"
using Groups_Big.comm_monoid_mult_class.prod.union_inter_neutral[of "{..<n}" "{n..<n+m}" g, unfolded ivl_disj_un_one(2)[OF le_add1], OF finite_lessThan finite_atLeastLessThan]
by (metis (no_types) add.commute add.left_neutral atLeast0LessThan empty_iff ivl_disj_int_one(2) prod.shift_bounds_nat_ivl)

(* This is a nice lemma, but never used to prove the Fundamental Theorem of Network Capacity: *)
lemma evaluate_net_from_tensors:
assumes "valid_net' m"
and "map dim_vec inputs = input_sizes m"
and "j < output_size' m"
shows "evaluate_net m inputs $ j
  = (∑is∈{is. is ⊲ input_sizes m}. (∏k<length inputs. inputs ! k $ (is!k)) * Tensor.lookup (tensors_from_net m $ j) is)"
using assms proof (induction m arbitrary:j "is" inputs)
  case (Input M)
  then have "length inputs = 1" "input_sizes (Input M) = [M]" by auto
  {
    fix "is" assume "is ⊲ input_sizes (Input M)"
    then have "length is = 1" by (simp add: valid_index_length)
    then have "is = [hd is]" by (metis One_nat_def length_0_conv length_Suc_conv list.sel(1))
    then have "Tensor.lookup (tensors_from_net (Input M) $ j) is = (if hd is=j then 1 else 0)"
      by (metis Input.prems(3) ‹input_sizes (Input M) = [M]› ‹is ⊲ input_sizes (Input M)› list.distinct(1)
      lookup_unit_vec nth_Cons_0 output_size.simps(1) remove_weights.simps(1) tensors_from_net.simps(1) valid_indexE index_vec)
    then have "(∏k<length inputs. inputs ! k $ (is ! k)) * lookup (tensors_from_net (Input M) $ j) is =
               (if is=[j] then (∏k<length inputs. inputs ! k $ (is ! k)) else 0)" using ‹is = [hd is]› by auto
  }
  then have "(∑is | is ⊲ input_sizes (Input M). (∏k<length inputs. inputs ! k $ (is ! k)) * lookup (tensors_from_net (Input M) $ j) is)
   = (∑is | is ⊲ input_sizes (Input M). (if is=[j] then (∏k<length inputs. inputs ! k $ (is ! k)) else 0))"  by auto
  also have "(∑is | is ⊲ input_sizes (Input M). (if is=[j] then (∏k<length inputs. inputs ! k $ (is ! k)) else 0))
   = (∏k<length inputs. inputs ! k $ ([j] ! k))" unfolding sum.delta[OF finite_valid_index]
    using Input.prems(3) valid_index.Cons valid_index.Nil by auto
  also have "... = inputs ! 0 $ j" using ‹length inputs = 1› by (simp add: prod.lessThan_Suc)
  also have "... = evaluate_net (Input M) inputs $ j" unfolding evaluate_net.simps
    by (metis ‹length inputs = 1› hd_conv_nth list.size(3) zero_neq_one)
  finally show ?case by auto
next
  case (Conv A m j)
  have "j < dim_row A" using Conv.prems(3) by auto
  have 0:"⋀is. is ⊲ input_sizes (Conv A m) ⟹
  (∏k<length inputs. inputs ! k $ (is ! k)) * lookup (tensors_from_net (Conv A m) $ j) is =
  (∑i = 0..<dim_vec (tensors_from_net m). row A j $ i * ((∏k<length inputs. inputs ! k $ (is ! k)) * lookup (tensors_from_net m $ i) is))"
  proof -
    fix "is" assume "is ⊲ input_sizes (Conv A m)"
    then have "is ⊲ input_sizes m" by simp
    have 0:"lookup (tensors_from_net (Conv A m) $ j) is =
          (∑i = 0..<dim_vec (tensors_from_net m). row A j $ i * lookup (tensors_from_net m $ i) is)"
      unfolding tensors_from_net.simps mat_tensorlist_mult_def index_vec[OF ‹j < dim_row A›]
      lookup_tensor_from_lookup[OF ‹is ⊲ input_sizes m›] index_mult_mat_vec[OF ‹j < dim_row A›] scalar_prod_def
      using index_map_vec by auto
    show "(∏k<length inputs. inputs ! k $ (is ! k)) * lookup (tensors_from_net (Conv A m) $ j) is
      = (∑i = 0..<dim_vec (tensors_from_net m). row A j $ i * ((∏k<length inputs. inputs ! k $ (is ! k)) * lookup (tensors_from_net m $ i) is))"
      unfolding 0 sum_distrib_left by (simp add: semiring_normalization_rules(19))
  qed
  have "valid_net' m" by (metis Conv.prems(1) convnet.distinct(1) convnet.distinct(5) convnet.inject(2) remove_weights.simps(2) valid_net.simps)
  have "map dim_vec inputs = input_sizes m" by (simp add: Conv.prems(2))
  have "output_size' m = dim_vec (tensors_from_net m)" by (simp add: ‹valid_net' m› output_size_correct_tensors)
  have 1:"⋀i. i<dim_vec (tensors_from_net m) ⟹ (∑is | is ⊲ input_sizes (Conv A m). ((∏k<length inputs. inputs ! k $ (is ! k)) * lookup (tensors_from_net m $ i) is)) = evaluate_net m inputs $ i" unfolding input_sizes.simps
    using Conv.IH ‹valid_net' m› ‹map dim_vec inputs = input_sizes m› ‹output_size' m = dim_vec (tensors_from_net m)› by simp

  have "(∑is | is ⊲ input_sizes (Conv A m). (∏k<length inputs. inputs ! k $ (is ! k)) * lookup (tensors_from_net (Conv A m) $ j) is)
    = (∑i = 0..<dim_vec (tensors_from_net m). (∑is | is ⊲ input_sizes (Conv A m).  row A j $ i *  ((∏k<length inputs. inputs ! k $ (is ! k)) * lookup (tensors_from_net m $ i) is)))"
    using Groups_Big.comm_monoid_add_class.sum.swap 0 by auto
  also have "... = (∑i = 0..<dim_vec (tensors_from_net m). row A j $ i * (∑is | is ⊲ input_sizes (Conv A m). ((∏k<length inputs. inputs ! k $ (is ! k)) * lookup (tensors_from_net m $ i) is)))"
    by (simp add: sum_distrib_left)
  also have "... = (∑i = 0..<dim_vec (tensors_from_net m). row A j $ i * evaluate_net m inputs $ i)" using 1 by auto
  also have "... = row A j ∙ evaluate_net m inputs"
    by (metis (full_types) ‹map dim_vec inputs = input_sizes m› ‹output_size' m = dim_vec (tensors_from_net m)›
    ‹valid_net' m› output_size_correct scalar_prod_def)
  also have "... = (A *v evaluate_net m inputs) $ j" by (simp add: ‹j < dim_row A›)
  also have "... = evaluate_net (Conv A m) inputs $ j" by simp
  finally show ?case by auto
next
  case (Pool m1 m2 j)
  have "valid_net' m1" "valid_net' m2"
    by (metis Pool.prems(1) convnet.distinct(3) convnet.inject(3) convnet.simps(9) remove_weights.simps(3) valid_net.simps)+
  have "j < output_size' m2" "j < output_size' m1"
    apply (metis Pool.prems(1) Pool.prems(3) convnet.distinct(3) convnet.inject(3) convnet.simps(9)
    output_size.simps(3) remove_weights.simps(3) valid_net.simps) using Pool.prems by auto
  then have "j < dim_vec (tensors_from_net m1)" "j < dim_vec (tensors_from_net m2)"
    by (simp_all add: ‹valid_net' m1› ‹valid_net' m2› output_size_correct_tensors)

  define inputs1 where "inputs1 = take (length (input_sizes m1)) inputs"
  define inputs2 where "inputs2 = drop (length (input_sizes m1)) inputs"
  have "map dim_vec inputs1 = input_sizes m1" "map dim_vec inputs2 = input_sizes m2"
    apply (metis Pool.prems(2) append_eq_conv_conj input_sizes.simps(3) inputs1_def take_map)
    by (metis Pool.prems(2) append_eq_conv_conj drop_map input_sizes.simps(3) inputs2_def)
  have "inputs = inputs1 @ inputs2" by (simp add: inputs1_def inputs2_def)
  {
    fix is1 is2 assume "is1 ⊲ input_sizes m1" "is2 ⊲ input_sizes m2"
    have "length is1 = length inputs1"
      using ‹is1 ⊲ input_sizes m1› ‹map dim_vec inputs1 = input_sizes m1› valid_index_length by fastforce
    have "length is2 = length inputs2"
      using ‹is2 ⊲ input_sizes m2› ‹map dim_vec inputs2 = input_sizes m2› valid_index_length by fastforce
    have 1:"(∏k<length inputs1. (inputs1 @ inputs2) ! k $ ((is1 @ is2) ! k))  = (∏k<length inputs1. inputs1 ! k $ (is1 ! k))"
      using ‹length is1 = length inputs1› ‹length is2 = length inputs2›
      nth_append by (metis (no_types, lifting) lessThan_iff prod.cong)
    have 2:"(∏x<length inputs2. (inputs1 @ inputs2) ! (x + length inputs1) $ ((is1 @ is2) ! (x + length inputs1))) =
      (∏k<length inputs2. inputs2 ! k $ (is2 ! k))"
      using ‹length is1 = length inputs1› ‹length is2 = length inputs2›
      by (metis (no_types, lifting) add.commute nth_append_length_plus)
    have "(∏k<length inputs. inputs ! k $ ((is1 @ is2) ! k)) = (∏k<length inputs1. inputs1 ! k $ (is1 ! k)) * (∏k<length inputs2. inputs2 ! k $ (is2 ! k))"
      unfolding ‹inputs = inputs1 @ inputs2› length_append prod_lessThan_split using 1 2 by metis
  }
  note 1 = this
  {
    fix is1 is2 assume "is1 ⊲ input_sizes m1" "is2 ⊲ input_sizes m2"
    then have "is1 ⊲ dims (tensors_from_net m1 $ j)" "is2 ⊲ dims (tensors_from_net m2 $ j)"
      using ‹j < dim_vec (tensors_from_net m1)›  ‹j < dim_vec (tensors_from_net m2)› dims_tensors_from_net vec_setI by force+
    have "lookup (tensors_from_net (Pool m1 m2) $ j) (is1 @ is2) = lookup (tensors_from_net m1 $ j) is1 * lookup (tensors_from_net m2 $ j) is2"
      unfolding "tensors_from_net.simps" index_component_mult[OF ‹j < dim_vec (tensors_from_net m1)› ‹j < dim_vec (tensors_from_net m2)›]
      lookup_tensor_prod[OF ‹is1 ⊲ dims (tensors_from_net m1 $ j)› ‹is2 ⊲ dims (tensors_from_net m2 $ j)›] by metis
  }
  note 2 = this

  have j_le_eval:"j < dim_vec (evaluate_net m1 (take (length (input_sizes m1)) inputs))"
                 "j < dim_vec (evaluate_net m2 (drop (length (input_sizes m1)) inputs))"
    using ‹j < output_size' m1› ‹map dim_vec inputs1 = input_sizes m1› ‹valid_net' m1› inputs1_def output_size_correct
    using ‹j < output_size' m2› ‹map dim_vec inputs2 = input_sizes m2› ‹valid_net' m2› inputs2_def by auto
  have "(∑is | is ⊲ input_sizes (Pool m1 m2). (∏k<length inputs. inputs ! k $ (is ! k)) * lookup (tensors_from_net (Pool m1 m2) $ j) is)
        = (∑is1 | is1 ⊲ input_sizes m1. ∑is2 | is2 ⊲ input_sizes m2.
          (∏k<length inputs1. inputs1 ! k $ (is1 ! k)) * (∏k<length inputs2. inputs2 ! k $ (is2 ! k)) *
          lookup (tensors_from_net m1 $ j) is1 * lookup (tensors_from_net m2 $ j) is2)"
    unfolding input_sizes.simps setsum_valid_index_split using 1 2
    using mem_Collect_eq sum.cong by (simp add: mult.assoc)
  also have "... = (∑is1 | is1 ⊲ input_sizes m1. (∏k<length inputs1. inputs1 ! k $ (is1 ! k)) * lookup (tensors_from_net m1 $ j) is1) *
                   (∑is2 | is2 ⊲ input_sizes m2. (∏k<length inputs2. inputs2 ! k $ (is2 ! k)) * lookup (tensors_from_net m2 $ j) is2)"
    unfolding sum_product by (rule sum.cong, metis, rule sum.cong, metis, simp)
  also have "... = evaluate_net (Pool m1 m2) inputs $ j" unfolding "evaluate_net.simps" index_component_mult[OF j_le_eval]
    using Pool.IH(1)[OF ‹valid_net' m1› _ ‹j < output_size' m1›] Pool.IH(2)[OF ‹valid_net' m2› _ ‹j < output_size' m2›]
    using ‹map dim_vec inputs1 = input_sizes m1› ‹map dim_vec inputs2 = input_sizes m2› inputs1_def inputs2_def by auto
  finally show ?case by metis
qed

lemma tensors_from_net_eqI:
assumes "valid_net' m1" "valid_net' m2" "input_sizes m1 = input_sizes m2"
assumes "⋀inputs. input_sizes m1 = map dim_vec inputs ⟹ evaluate_net m1 inputs = evaluate_net m2 inputs"
shows "tensors_from_net m1 = tensors_from_net m2"
proof -
  have "map dim_vec (map 0v (input_sizes m2)) = input_sizes m2"
       "map dim_vec (map 0v (input_sizes m1)) = input_sizes m1" by (auto intro: nth_equalityI)
  then have "output_size' m1 = output_size' m2" using
    output_size_correct[OF ‹valid_net' m1› ‹map dim_vec (map 0v (input_sizes m1)) = input_sizes m1›]
    output_size_correct[OF ‹valid_net' m2› ‹map dim_vec (map 0v (input_sizes m2)) = input_sizes m2›]
    assms(3) assms(4)
    by (metis (no_types))
  have "⋀is. base_input m1 is = base_input m2 is"
    unfolding base_input_def ‹input_sizes m1 = input_sizes m2› by metis
  show ?thesis by (rule eq_vecI, rule tensor_lookup_eqI; metis
    lookup_tensors_from_net[OF ‹valid_net' m1›, unfolded ‹⋀is. base_input m1 is = base_input m2 is› ‹output_size' m1 = output_size' m2›]
    lookup_tensors_from_net[OF ‹valid_net' m2›] assms(3) base_input_length
    assms(1) assms(2) dims_tensors_from_net output_size_correct_tensors vec_setI
    ‹output_size' m1 = output_size' m2› assms(4))
qed

end