Theory Fishers_Inequality

(* Title: Fishers_Inequality.thy
   Author: Chelsea Edmonds
*)

section ‹Fisher's Inequality›

text ‹This theory presents the proof of Fisher's Inequality cite"fisherExaminationDifferentPossible1940a"
 on BIBD's (i.e. uniform Fisher's) and the generalised nonuniform Fisher's Inequality ›
theory Fishers_Inequality imports Rank_Argument_General Linear_Bound_Argument
begin

subsection ‹ Uniform Fisher's Inequality ›
context ordered_bibd
begin

text ‹Row/Column transformation steps ›

text‹Following design theory lecture notes from MATH3301 at the University of Queensland
 cite"HerkeLectureNotes2016", a simple transformation to produce an upper triangular matrix using elementary
row operations is to (1) Subtract the first row from each other row, and (2) add all columns to the first column›

lemma transform_N_step1_vals: 
  defines mdef: "M  (N * NT)"
  assumes "i < dim_row M" 
  assumes "j < dim_col M"
  shows "i = 0  j = 0  (add_row_to_multiple (-1) [1..<dim_row M] 0 M) $$ (i, j) = (int 𝗋)" ― ‹ top left elem ›
  and "i  0  j = 0  (add_row_to_multiple (-1) [1..<dim_row M] 0 M) $$ (i, j) = (int Λ) - (int 𝗋)" ― ‹ first column ex. 1 ›
  and "i = 0  j  0  (add_row_to_multiple (-1) [1..<dim_row M] 0 M) $$ (i, j) = (int Λ)" ― ‹ first row ex. 1 ›
  and "i  0  j  0  i = j  (add_row_to_multiple (-1) [1..<dim_row M] 0 M) $$ (i, j) = (int 𝗋) - (int Λ)" ― ‹ diagonal ex. 1 ›
  and "i  0  j  0  i  j  (add_row_to_multiple (-1) [1..<dim_row M] 0 M) $$ (i, j) = 0" ― ‹ everything else ›
  using transpose_N_mult_diag v_non_zero assms
proof (simp) 
  show "i = 0  j  0  (add_row_to_multiple (-1) [1..<dim_row M] 0 M) $$ (i, j) = (int Λ)" 
    using transpose_N_mult_off_diag v_non_zero assms transpose_N_mult_dim(2) by force
next
  assume a: "j = 0" "i0"
  then have ail: "((-1) * M $$(0, j)) = -(int 𝗋)"
    using transpose_N_mult_diag v_non_zero mdef by auto 
  then have ijne: "j  i" using a by simp
  then have aij: "M $$ (i, j) = (int Λ)" using assms(2) mdef transpose_N_mult_off_diag a v_non_zero
    by (metis transpose_N_mult_dim(1)) 
  then have "add_row_to_multiple (-1) [1..<dim_row M] 0 M $$ (i, j) = (-1)*(int 𝗋) + (int Λ)" 
    using ail add_first_row_to_multiple_index(2) assms(2) assms(3) a by (metis mult_minus1) 
  then show "(add_row_to_multiple (-1) [1..<dim_row M] 0 M) $$ (i, j) = (int Λ) - (int 𝗋)"
    by linarith 
next 
  assume a: "i  0" "j  0" 
  have ail: "((-1) * M $$(0, j)) = -(int Λ)" 
    using assms transpose_N_mult_off_diag a v_non_zero transpose_N_mult_dim(1) by auto  
  then have  "i = j  M $$ (i, j) = (int 𝗋)" 
    using assms transpose_N_mult_diag a v_non_zero by (metis transpose_N_mult_dim(1)) 
  then show "i = j  (add_row_to_multiple (-1) [1..<dim_row M] 0 M) $$ (i, j) = (int 𝗋) - (int Λ)"
    using ail add_first_row_to_multiple_index(2) assms a by (metis uminus_add_conv_diff) 
  then have "i  j  M $$ (i, j) = (int Λ)" using assms transpose_N_mult_off_diag a v_non_zero
    by (metis transpose_N_mult_dim(1) transpose_N_mult_dim(2)) 
  then show "i  j  (add_row_to_multiple (-1) [1..<dim_row M] 0 M) $$ (i, j) = 0" 
    using ail add_first_row_to_multiple_index(2) assms a by (metis add.commute add.right_inverse)
qed

lemma transform_N_step2_vals: 
  defines mdef: "M  (add_row_to_multiple (-1) [1..<dim_row (N * NT)] 0 (N * NT))"
  assumes "i < dim_row (M)"
  assumes "j < dim_col (M)"
  shows "i = 0  j = 0  add_multiple_cols 1 0 [1..<dim_col M] M $$ (i, j) = 
    (int 𝗋) + (int Λ) * (𝗏 - 1)" ― ‹ top left element ›
  and "i = 0  j  0  add_multiple_cols 1 0 [1..<dim_col M] M $$ (i, j)  = (int Λ)" ― ‹ top row ›
  and "i  0  i = j  add_multiple_cols 1 0 [1..<dim_col M] M $$ (i, j) = (int 𝗋) - (int Λ)" ― ‹ Diagonal ›
  and "i  0  i  j   add_multiple_cols 1 0 [1..<dim_col M] M $$ (i, j) = 0" ― ‹Everything else›
proof -
  show "i = 0  j  0  add_multiple_cols 1 0 [1..<dim_col M] M $$ (i, j)  = (int Λ)"
    using add_all_cols_to_first assms transform_N_step1_vals(3) by simp
  show "i  0  i = j  add_multiple_cols 1 0 [1..<dim_col M] M $$ (i, j) = (int 𝗋) - (int Λ)"
    using add_all_cols_to_first assms transform_N_step1_vals(4) by simp
next
  assume a: "i = 0" "j =0"
  then have size: "card {1..<dim_col M} = 𝗏 - 1" using assms by simp 
  have val: " l . l  {1..<dim_col M}  M $$ (i, l) = (int Λ)" 
    using mdef transform_N_step1_vals(3) by (simp add: a(1))
  have "add_multiple_cols 1 0 [1..<dim_col M] M $$ (i, j) = (l{1..<dim_col M}.  M $$(i,l)) + M$$(i,0)" 
    using a assms add_all_cols_to_first by blast 
  also have "... = (l{1..<dim_col M}.  (int Λ)) + M$$(i,0)" using val by simp
  also have "... = (𝗏 - 1) * (int Λ) + M$$(i,0)" using size by (metis sum_constant) 
  finally show "add_multiple_cols 1 0 [1..<dim_col M] M $$ (i, j) = (int 𝗋) + (int Λ) * (𝗏 - 1)" 
    using transform_N_step1_vals(1) a(1) a(2) assms(1) assms(2) by simp 
next
  assume a: "i  0"  "i  j"
  then show "add_multiple_cols 1 0 [1..<dim_col M] M $$ (i, j) = 0" 
  proof (cases "j  0")
    case True
    then show ?thesis using add_all_cols_to_first assms a transform_N_step1_vals(5) by simp
  next 
    case False
    then have iin: "i  {1..<dim_col M}" using a(1) assms by simp
    have cond: " l . l  {1..<dim_col M}  l <dim_col (N * NT)  l  0" using assms by simp 
    then have val: " l . l  {1..<dim_col M } - {i}  M $$ (i, l) = 0" 
      using assms(3) transform_N_step1_vals(5) a False assms(1)
      by (metis DiffE iin index_mult_mat(2) index_mult_mat(3) index_transpose_mat(3) insertI1) 
    have val2: "M $$ (i, i) = (int 𝗋) - (int Λ)" using mdef transform_N_step1_vals(4) a False
       assms(1) transpose_N_mult_dim(1) transpose_N_mult_dim(2)
      by (metis cond iin) 
    have val3: "M$$ (i, 0) = (int Λ) - (int 𝗋)" 
      using assms(3) transform_N_step1_vals(2) a False assms(1) assms(2)
      by (metis add_row_to_multiple_dim(1) transpose_N_mult_dim(2) v_non_zero)
    have "add_multiple_cols 1 0 [1..<dim_col M] M $$ (i, j) = (l{1..<dim_col M}.  M $$(i,l)) + M$$(i,0)" 
      using assms add_all_cols_to_first False by blast
    also have "... = M $$ (i, i)  + (l{1..<dim_col M} - {i}.  M $$(i,l)) + M$$(i,0)"
      by (metis iin finite_atLeastLessThan sum.remove)
    finally show ?thesis using val val2 val3 by simp
  qed
qed

text ‹Transformed matrix is upper triangular ›
lemma transform_upper_triangular: 
  defines mdef: "M  (add_row_to_multiple (-1) [1..<dim_row (N * NT)] 0 (N * NT))"
  shows "upper_triangular (add_multiple_cols 1 0 [1..<dim_col M] M)"
  using transform_N_step2_vals(4) by (intro upper_triangularI) (simp add: assms)

text ‹Find the determinant of the $NN^T$ matrix using transformed matrix values›
lemma determinant_inc_mat_square: "det (N * NT) = (𝗋 + Λ * (𝗏 - 1))* (𝗋 - Λ)^(𝗏 - 1)"
proof -
― ‹ Show the matrix is now lower triangular, therefore the det is the product of the sum of diagonal ›
  have cm: "(N * NT)  carrier_mat 𝗏 𝗏"
    using transpose_N_mult_dim(1) transpose_N_mult_dim(2) by blast  
  define C where "C (add_row_to_multiple (-1) [1..<dim_row (N * NT)] 0 (N * NT))"
  have "0  set [1..<dim_row (N * NT)]" by simp
  then have detbc: "det (N * NT) = det C" 
    using C_def add_row_to_multiple_det v_non_zero by (metis cm) 
  define D where "D  add_multiple_cols 1 0 [1..<dim_col C] C"
  have d00: "D $$ (0, 0) = ((int 𝗋) + (int Λ) * (𝗏 - 1))" using transform_N_step2_vals(1)
    by (simp add: C_def D_def v_non_zero) 
  have ine0: " i. i  {1..<dim_row D}  i  0" by simp
  have " i. i  {1..<dim_row D}  i < dim_row (N * NT)" using D_def C_def by simp       
  then have diagnon0: " i. i  {1..<𝗏}  D $$ (i, i) = (int 𝗋) - (int Λ)"   
    using transform_N_step2_vals(3) ine0 D_def C_def by simp (* Slow *)
  have alll: " l. l  set [1..<dim_col C]  l < 𝗏" using C_def by simp
  have cmc: "C  carrier_mat 𝗏 𝗏" using cm C_def
    by (simp add: add_row_to_multiple_carrier)
  have dimgt2: "dim_row D  2"
    using t_lt_order D_def C_def by (simp)  
  then have fstterm: "0  { 0 ..< dim_row D}" by (simp add: points_list_length)
  have "0  set [1..<dim_col C]" by simp
  then have "det (N * NT) = det D" using add_multiple_cols_det alll cmc D_def C_def
    by (metis detbc) 
  also have "... = prod_list (diag_mat D)" using det_upper_triangular
    using transform_upper_triangular D_def C_def by fastforce 
  also have "... = ( i = 0 ..< dim_row D. D $$ (i,i))" using prod_list_diag_prod by blast  
  also have "... = ( i = 0 ..< 𝗏. D $$ (i,i))"  by (simp add: D_def C_def)  
  finally have "det (N * NT) = D $$ (0, 0) * ( i =  1 ..< 𝗏. D $$ (i,i))" 
    using dimgt2 by (simp add: prod.atLeast_Suc_lessThan v_non_zero) 
  then have "det (N * NT) = (𝗋 + Λ * (𝗏 - 1)) * ((int 𝗋) - (int Λ))^(𝗏 - 1)"
    using d00 diagnon0 by simp
  then have "det (N * NT) = (𝗋 + Λ * (𝗏 - 1)) * ( 𝗋 - Λ)^(𝗏 - 1)"
    using index_lt_replication
    by (metis (no_types, lifting) less_imp_le_nat of_nat_diff of_nat_mult of_nat_power)
  then show ?thesis by blast 
qed

text ‹Fisher's Inequality using the rank argument. 
Note that to use the rank argument we must first map N to a real matrix. It is useful to explicitly
include the parameters which should be used in the application of the @{thm [source] "rank_argument_det"} lemma ›
theorem Fishers_Inequality_BIBD: "𝗏  𝖻"
proof (intro rank_argument_det[of "map_mat real_of_int N" "𝗏" "𝖻"], simp_all)
  show "N  carrier_mat 𝗏 (length ℬs)" using blocks_list_length N_carrier_mat by simp
  let ?B = "map_mat (real_of_int) (N * NT)"
  have b_split: "?B = map_mat (real_of_int) N * (map_mat (real_of_int) N)T"
    using semiring_hom.mat_hom_mult  of_int_hom.semiring_hom_axioms transpose_carrier_mat map_mat_transpose
    by (metis (no_types, lifting) N_carrier_mat) 
  have db: "det ?B = (𝗋 + Λ * (𝗏 - 1))* (𝗋 - Λ)^(𝗏 - 1)"
    using determinant_inc_mat_square by simp
  have lhn0: "(𝗋 + Λ * (𝗏 - 1)) > 0"
    using r_gzero by blast 
  have "(𝗋 - Λ) > 0"
    using index_lt_replication zero_less_diff by blast  
  then have det_not_0:  "det ?B  0" using lhn0 db
    by (metis gr_implies_not0 mult_is_0 of_nat_eq_0_iff power_not_zero) 
  thus "det (of_int_hom.mat_hom N * (of_int_hom.mat_hom N)T)  (0:: real)" using b_split by simp
qed

end

subsection ‹Generalised Fisher's Inequality ›

context simp_ordered_const_intersect_design
begin

text ‹Lemma to reason on sum coefficients ›
lemma sum_split_coeffs_0: 
  fixes c :: "nat  real"
  assumes "𝖻  2"
  assumes "𝗆 > 0"
  assumes "j' < 𝖻"
  assumes "0 = ( j  {0..<𝖻} . (c j)^2 * ((card (ℬs ! j))- (int 𝗆))) +
           𝗆 * (( j  {0..<𝖻} . c j)^2)"
  shows "c j' = 0"
proof (rule ccontr)
  assume cine0: "c j'  0"
  have innerge: " j . j < 𝖻   (c j)^2 * (card (ℬs ! j) - (int 𝗆))   0" 
    using inter_num_le_block_size assms(1) by simp
  then have lhsge: "( j  {0..<𝖻} . (c j)^2 * ((card (ℬs ! j))- (int 𝗆)))  0"
    using sum_bounded_below[of "{0..<𝖻}" "0" "λ i. (c i)^2 * ((card (ℬs ! i))- (int 𝗆))"] by simp
  have "𝗆 * (( j  {0..<𝖻} . c j)^2)  0" by simp
  then have rhs0: "𝗆 * (( j  {0..<𝖻} . c j)^2) = 0" 
    using assms(2) assms(4) lhsge by linarith
  then have "( j  {0..<𝖻} . (c j)^2 * ((card (ℬs ! j))- (int 𝗆))) = 0" 
    using assms by linarith
  then have lhs0inner: " j . j < 𝖻   (c j)^2 * (card (ℬs ! j) - (int 𝗆)) = 0" 
    using innerge sum_nonneg_eq_0_iff[of "{0..<𝖻}" "λ j . (c j)^2 * (card (ℬs ! j) - (int 𝗆))"] 
    by simp
  thus False proof (cases "card (ℬs ! j') = 𝗆")
    case True
    then have cj0: " j. j  {0..<𝖻} - {j'}  (c j) = 0"
      using lhs0inner const_intersect_block_size_diff assms True by auto
    then have "( i  {0..<𝖻} . c i)  0" 
      using sum.remove[of "{0..<𝖻}" "j'" "λ i. c i"] assms(3) cine0 cj0 by simp
    then show ?thesis using rhs0 assms by simp
  next
    case False
    then have ne: "(card (ℬs ! j') - (int 𝗆))  0"
      by linarith  
    then have "(c j')^2 * (card (ℬs ! j') - (int 𝗆))  0" using cine0
      by auto 
    then show ?thesis using lhs0inner assms(3) by auto
  qed
qed

text ‹The general non-uniform version of fisher's inequality is also known as the "Block town problem".
In this case we are working in a constant intersect design, hence the inequality is the opposite
way around compared to the BIBD version. The theorem below is the more traditional set theoretic 
approach. This proof is based off one by Jukna cite"juknaExtremalCombinatorics2011"
theorem general_fishers_inequality:  "𝖻  𝗏"
proof (cases "𝗆 = 0  𝖻 = 1")
  case True
  then show ?thesis using empty_inter_implies_b_lt_v v_non_zero by linarith
next
  case False
  then have mge: "𝗆 > 0" by simp 
  then have bge: "𝖻  2" using b_positive False blocks_list_length by linarith
  define NR :: "real mat" where "NR  lift_01_mat N"
  show ?thesis 
  proof (intro lin_bound_argument2[of NR])
    show "distinct (cols NR)" using lift_01_distinct_cols_N NR_def by simp
    show nrcm: "NR  carrier_mat 𝗏 𝖻" using NR_def N_carrier_mat_01_lift by simp 
    have scalar_prod_real1: " i. i <𝖻   ((col NR i)  (col NR i)) = card (ℬs ! i)"
      using scalar_prod_block_size_lift_01 NR_def by auto 
    have scalar_prod_real2: " i j. i <𝖻  j <𝖻  i  j  ((col NR i)  (col NR j)) = 𝗆"
      using scalar_prod_inter_num_lift_01 NR_def indexed_const_intersect by auto
    show "f. vec 𝗏 (λi. j = 0..<𝖻. f (col NR j) * (col NR j) $ i) = 0v 𝗏  vset (cols NR). f v = 0"
    proof (intro ballI)
      fix f v
      assume eq0: "vec 𝗏 (λi. j = 0..<𝖻. f (col NR j) * col NR j $ i) = 0v 𝗏"
      assume vin: "v  set (cols NR)"
      define c where "c  (λ j. f (col NR j))"
      obtain j' where v_def: "col NR j' = v" and jvlt: "j' < dim_col NR"
        using vin by (metis cols_length cols_nth index_less_size_conv nth_index)
      have dim_col: "j. j  {0..< 𝖻}  dim_vec (col NR j) = 𝗏" using nrcm by auto
      ― ‹ Summation reasoning to obtain conclusion on coefficients ›
      have "0 = (vec 𝗏 (λi. j = 0..<𝖻. c j * (col NR j) $ i))  (vec 𝗏 (λi. j = 0..<𝖻. c j * (col NR j) $ i))" 
        using vec_prod_zero eq0 c_def by simp
      also have "... = ( j1  {0..<𝖻} . c j1 * c j1 * ((col NR j1)  (col NR j1))) + ( j1  {0..<𝖻} . 
        ( j2  ({0..< 𝖻} - {j1}) . c j1 * c j2 * ((col NR j1)  (col NR j2))))" 
        using scalar_prod_double_sum_fn_vec[of 𝖻 "col NR" 𝗏 c] dim_col by simp
      also have "... = ( j1  {0..<𝖻} . (c j1) * (c j1) * (card (ℬs ! j1))) + ( j1  {0..<𝖻} . 
        ( j2  ({0..< 𝖻} - {j1}) . c j1 * c j2 * ((col NR j1)  (col NR j2))))"
        using scalar_prod_real1 by simp
      also have "... = ( j1  {0..<𝖻} . (c j1)^2 * (card (ℬs ! j1))) + ( j1  {0..<𝖻} . 
        ( j2  ({0..< 𝖻} - {j1}) . c j1 * c j2 * ((col NR j1)  (col NR j2))))"
        by (metis power2_eq_square) 
      also have "... = ( j1  {0..<𝖻} . (c j1)^2 * (card (ℬs ! j1))) + ( j1  {0..<𝖻} . 
        ( j2  ({0..< 𝖻} - {j1}) . c j1 * c j2 * 𝗆))" using scalar_prod_real2  by auto
      also have "... = ( j1  {0..<𝖻} . (c j1)^2 * (card (ℬs ! j1))) + 
         𝗆 * ( j1  {0..<𝖻} . ( j2  ({0..< 𝖻} - {j1}) . c j1 * c j2))" 
        using double_sum_mult_hom[of "𝗆" "λ i j . c i * c j" "λ i.{0..<𝖻} - {i}" "{0..<𝖻}"]
        by (metis (no_types, lifting) mult_of_nat_commute sum.cong) 
      also have "... = ( j  {0..<𝖻} . (c j)^2 * (card (ℬs ! j))) + 
         𝗆 * (( j  {0..<𝖻} . c j)^2 - ( j  {0..<𝖻} . c j * c j))" 
        using double_sum_split_square_diff by auto 
      also have "... = ( j  {0..<𝖻} . (c j)^2 * (card (ℬs ! j))) + (-𝗆) * ( j  {0..<𝖻} . (c j)^2) + 
         𝗆 * (( j  {0..<𝖻} . c j)^2)" by (simp add: algebra_simps power2_eq_square)
      also have "... = ( j  {0..<𝖻} . (c j)^2 * (card (ℬs ! j))) + ( j  {0..<𝖻} . (-𝗆)* (c j)^2) + 
         𝗆 * (( j  {0..<𝖻} . c j)^2)" by (simp add: sum_distrib_left) 
      also have "... = ( j  {0..<𝖻} . (c j)^2 * (card (ℬs ! j))+ (-𝗆)* (c j)^2) + 
         𝗆 * (( j  {0..<𝖻} . c j)^2)" by (metis (no_types) sum.distrib)
      finally have sum_rep: "0 = ( j  {0..<𝖻} . (c j)^2 * ((card (ℬs ! j))- (int 𝗆))) + 
         𝗆 * (( j  {0..<𝖻} . c j)^2)" by (simp add: algebra_simps)
      thus "f v = 0" using sum_split_coeffs_0[of "j'" "c"] mge bge jvlt nrcm c_def v_def by simp
    qed
  qed
qed

end

text ‹Using the dual design concept, it is easy to translate the set theoretic general definition
of Fisher's inequality to a more traditional design theoretic version on pairwise balanced designs. 
Two versions of this are given using different trivial (but crucial) conditions on design properties›
context ordered_pairwise_balance
begin

corollary general_nonuniform_fishers: ― ‹only valid on incomplete designs ›
  assumes "Λ > 0" 
  assumes " bl. bl ∈#   incomplete_block bl" 
    ― ‹ i.e. not a super trivial design with only complete blocks ›
  shows "𝗏  𝖻"
proof -
  have "mset (ℬs*) = dual_blocks 𝒱 ℬs" using dual_blocks_ordered_eq by simp
  then interpret des: simple_const_intersect_design "set [0..<(length ℬs)]" "mset (ℬs*)" Λ 
    using assms dual_is_simp_const_inter_des by simp
  interpret odes: simp_ordered_const_intersect_design "[0..<length ℬs]" "ℬs*" Λ 
    using distinct_upt des.wellformed by (unfold_locales) (blast)
  have "length (ℬs*)  length [0..<length ℬs]" using odes.general_fishers_inequality
    using odes.blocks_list_length odes.points_list_length by presburger
  then have "𝗏  length ℬs"
    by (simp add: dual_blocks_len points_list_length) 
  then show ?thesis by auto
qed

corollary general_nonuniform_fishers_comp: 
  assumes "Λ > 0" 
  assumes "count  𝒱 < Λ" ― ‹ i.e. not a super trivial design with only complete blocks and single blocks ›
  shows "𝗏  𝖻"
proof -
  define B where "B = (removeAll_mset 𝒱 )"
  have b_smaller: "size B  𝖻" using B_def removeAll_size_lt by simp
  then have b_incomp: " bl. bl ∈# B  card bl < 𝗏"
    using wellformed B_def by (simp add: psubsetI psubset_card_mono) 
  have index_gt: "(Λ - (count  𝒱)) > 0" using assms by simp 
  interpret pbd: pairwise_balance 𝒱 B "(Λ - (count  𝒱))"
    using remove_all_complete_blocks_pbd B_def assms(2) by blast 
  obtain Bs where m: "mset Bs = B"
    using ex_mset by blast 
  interpret opbd: ordered_pairwise_balance 𝒱s Bs "(Λ - (count  𝒱))" 
    by (intro pbd.ordered_pbdI) (simp_all add: m distinct)
  have "𝗏  (size B)" using b_incomp opbd.general_nonuniform_fishers
    using index_gt m by blast 
  then show ?thesis using b_smaller m by auto
qed

end
end