Theory Constructive_Chernoff_Bound

subsection ‹Constructive Chernoff Bound\label{sec:constructive_chernoff_bound}›

text ‹This section formalizes Theorem~5 by Impagliazzo and Kabanets~\cite{impagliazzo2010}. It is
a general result with which Chernoff-type tail bounds for various kinds of weakly dependent random
variables can be obtained. The results here are general and will be applied in
Section~\ref{sec:random_walks} to random walks in expander graphs.›

theory Constructive_Chernoff_Bound
  imports
    "HOL-Probability.Probability_Measure"
    Universal_Hash_Families.Universal_Hash_Families_More_Product_PMF
    Weighted_Arithmetic_Geometric_Mean.Weighted_Arithmetic_Geometric_Mean
begin

lemma powr_mono_rev:
  fixes x :: real
  assumes "a  b" and  "x > 0" "x  1"
  shows "x powr b  x powr a"
proof -
  have "x powr b = (1/x) powr (-b)"
    using assms by (simp add: powr_divide powr_minus_divide)
  also have "...  (1/x) powr (-a)"
    using assms by (intro powr_mono) auto
  also have "... = x powr a"
    using assms by (simp add: powr_divide powr_minus_divide)
  finally show ?thesis by simp
qed

lemma exp_powr: "(exp x) powr y = exp (x*y)" for x :: real
  unfolding powr_def by simp

lemma integrable_pmf_iff_bounded:
  fixes f :: "'a  real"
  assumes "x. x  set_pmf p  abs (f x)  C"
  shows "integrable (measure_pmf p) f"
proof -
  obtain x where "x  set_pmf p"
    using set_pmf_not_empty by fast
  hence "C  0" using assms(1) by fastforce
  hence " (+ x. ennreal (abs (f x)) measure_pmf p)  (+ x. C measure_pmf p)"
    using assms ennreal_le_iff
    by (intro nn_integral_mono_AE AE_pmfI) auto
  also have "... = C"
    by simp
  also have "... < Orderings.top"
    by simp
  finally have "(+ x. ennreal (abs (f x)) measure_pmf p) < Orderings.top" by simp
  thus ?thesis
    by (intro iffD2[OF integrable_iff_bounded]) auto
qed

lemma split_pair_pmf:
  "measure_pmf.prob (pair_pmf A B) S = integralL A (λa. measure_pmf.prob B {b. (a,b)  S})"
  (is "?L = ?R")
proof -
  have a:"integrable (measure_pmf A) (λx. measure_pmf.prob B {b. (x, b)  S})"
    by (intro integrable_pmf_iff_bounded[where C="1"]) simp

  have "?L = (+x. indicator S x (measure_pmf (pair_pmf A B)))"
    by (simp add: measure_pmf.emeasure_eq_measure)
  also have "... = (+x. (+y. indicator S (x,y) B) A)"
    by (simp add: nn_integral_pair_pmf')
  also have "... = (+x. (+y. indicator {b. (x,b)  S} y B) A)"
    by (simp add:indicator_def)
  also have "... = (+x. (measure_pmf.prob B {b. (x,b)  S}) A)"
    by (simp add: measure_pmf.emeasure_eq_measure)
  also have "... = ?R"
    using a
    by (subst nn_integral_eq_integral) auto
  finally show ?thesis by simp
qed

lemma split_pair_pmf_2:
  "measure(pair_pmf A B) S = integralL B (λa. measure_pmf.prob A {b. (b,a)  S})"
  (is "?L = ?R")
proof -
  have "?L = measure (pair_pmf B A) {ω. (snd ω, fst ω)  S}"
    by (subst pair_commute_pmf) (simp add:vimage_def case_prod_beta)
  also have "... = ?R"
    unfolding split_pair_pmf by simp
  finally show ?thesis by simp
qed

definition KL_div :: "real  real  real"
  where "KL_div p q = p * ln (p/q) + (1-p) * ln ((1-p)/(1-q))"

theorem impagliazzo_kabanets_pmf:
  fixes Y :: "nat  'a  bool"
  fixes p :: "'a pmf"
  assumes "n > 0"
  assumes "i. i  {..<n}  δ i  {0..1}"
  assumes "S. S  {..<n}  measure p {ω. (i  S. Y i ω)}  (i  S. δ i)"
  defines "δ_avg  (i {..<n}. δ i)/n"
  assumes "γ  {δ_avg..1}"
  assumes "δ_avg > 0"
  shows "measure p {ω. real (card {i  {..<n}. Y i ω})  γ * n}  exp (-real n * KL_div γ δ_avg)"
    (is "?L  ?R")
proof -
  let ?n = "real n"
  define q :: real where "q = (if γ = 1 then 1 else (γ-δ_avg)/(γ*(1-δ_avg)))"

  define g where "g ω = card {i. i < n  ¬Y i ω}" for ω
  let ?E = "(λω. real (card {i. i < n  Y i ω})  γ * n)"
  let  = "prod_pmf {..<n} (λ_. bernoulli_pmf q)"

  have q_range:"q {0..1}"
  proof (cases "γ < 1")
    case True
    then show ?thesis
      using assms(5,6)
      unfolding q_def by (auto intro!:divide_nonneg_pos simp add:algebra_simps)
  next
    case False
    hence "γ = 1" using assms(5) by simp
    then show ?thesis unfolding q_def by simp
  qed

  have abs_pos_le_1I: "abs x  1" if "x  0" "x  1" for x :: real
    using that by auto

  have γ_n_nonneg: "γ*?n  0"
    using assms(1,5,6) by simp
  define r where "r = n - nat γ*n"

  have 2:"(1-q) ^ r  (1-q)^ g ω" if "?E ω" for ω
  proof -
    have "g ω = card ({i. i < n} - {i. i < n  Y i ω})"
      unfolding g_def by (intro arg_cong[where f="λx. card x"]) auto
    also have "... = card {i. i < n} - card {i. i < n  Y i ω}"
      by (subst card_Diff_subset, auto)
    also have "...  card {i. i < n} - nat γ*n"
      using that γ_n_nonneg by (intro diff_le_mono2) simp
    also have "... = r"
      unfolding r_def by simp
    finally have "g ω  r" by simp
    thus "(1-q) ^ r  (1-q) ^ (g ω)"
      using q_range by (intro power_decreasing) auto
  qed

  have γ_gt_0: "γ > 0"
    using assms(5,6) by simp

  have q_lt_1: "q < 1" if "γ < 1"
  proof -
    have "δ_avg < 1" using assms(5) that by simp
    hence "(γ - δ_avg) / (γ * (1 - δ_avg)) < 1"
      using γ_gt_0 assms(6) that
      by (subst pos_divide_less_eq) (auto simp add:algebra_simps)
    thus "q < 1"
      unfolding q_def using that by simp
  qed

  have 5: "(δ_avg * q + (1-q)) / (1-q) powr (1-γ) = exp (- KL_div γ δ_avg)" (is "?L1 = ?R1")
    if "γ < 1"
  proof -
    have δ_avg_range: "δ_avg  {0<..<1}"
      using that assms(5,6)  by simp

    have "?L1 = (1 - (1-δ_avg) * q) / (1-q) powr (1-γ)"
      by (simp add:algebra_simps)
    also have "... = (1 - (γ-δ_avg) / γ ) / (1-q) powr (1-γ)"
      unfolding q_def using that γ_gt_0 δ_avg_range by simp
    also have "... = (δ_avg / γ) / (1-q) powr (1-γ)"
      using γ_gt_0 by (simp add:divide_simps)
    also have "... = (δ_avg / γ) * (1/(1-q)) powr (1-γ)"
      using q_lt_1[OF that] by (subst powr_divide, simp_all)
    also have "... = (δ_avg / γ) * (1/((γ*(1-δ_avg)-(γ-δ_avg))/(γ*(1-δ_avg)))) powr (1-γ)"
      using γ_gt_0 δ_avg_range unfolding q_def by (simp add:divide_simps)
    also have "... = (δ_avg / γ) * ((γ / δ_avg) *((1-δ_avg)/(1-γ))) powr (1-γ)"
      by (simp add:algebra_simps)
    also have "... = (δ_avg / γ) * (γ / δ_avg) powr (1-γ) *((1-δ_avg)/(1-γ)) powr (1-γ)"
      using γ_gt_0  δ_avg_range that by (subst powr_mult, auto)
    also have "... = (δ_avg / γ) powr 1 * (δ_avg / γ) powr -(1-γ) *((1-δ_avg)/(1-γ)) powr (1-γ)"
      using γ_gt_0 δ_avg_range that unfolding powr_minus_divide by (simp add:powr_divide)
    also have "... = (δ_avg / γ) powr γ *((1-δ_avg)/(1-γ)) powr (1-γ)"
      by (subst powr_add[symmetric]) simp
    also have "... = exp ( ln ((δ_avg / γ) powr γ *((1-δ_avg)/(1-γ)) powr (1-γ)))"
      using γ_gt_0 δ_avg_range that by (intro exp_ln[symmetric] mult_pos_pos) auto
    also have "... =  exp ((ln ((δ_avg / γ) powr γ) + ln (((1 - δ_avg) / (1 - γ)) powr (1-γ))))"
      using γ_gt_0 δ_avg_range that by (subst ln_mult) auto
    also have "... =  exp ((γ * ln (δ_avg / γ) + (1 - γ) * ln ((1 - δ_avg) / (1 - γ))))"
      using γ_gt_0 δ_avg_range that by (simp add:ln_powr algebra_simps)
    also have "... =  exp (- (γ * ln (γ / δ_avg) + (1 - γ) * ln ((1 - γ) / (1 - δ_avg))))"
      using γ_gt_0 δ_avg_range that by (simp add: ln_div algebra_simps)
    also have "... = ?R1"
      unfolding KL_div_def by simp

    finally show ?thesis by simp
  qed

  have 3: "(δ_avg * q + (1-q)) ^ n / (1-q) ^ r  exp (- ?n* KL_div γ δ_avg)" (is "?L1  ?R1")
  proof (cases "γ < 1")
    case True
    have "γ * real n  1 * real n"
      using True by (intro mult_right_mono) auto
    hence "r = real n - real (nat γ * real n)"
      unfolding r_def by (subst of_nat_diff) auto
    also have "... = real n - γ * real n"
      using γ_n_nonneg by (subst of_nat_nat, auto)
    also have "...  ?n - γ * ?n"
      by (intro diff_mono) auto
    also have "... = (1-γ) *?n" by (simp add:algebra_simps)
    finally have r_bound: "r  (1-γ)*n" by simp

    have "?L1 = (δ_avg * q + (1-q)) ^ n / (1-q) powr r"
      using q_lt_1[OF True] assms(1) by (simp add: powr_realpow)
    also have "... = (δ_avg * q + (1-q)) powr n / (1-q) powr r"
      using q_lt_1[OF True] assms(6) q_range
      by (subst powr_realpow[symmetric], auto intro!:add_nonneg_pos)
    also have "...  (δ_avg * q + (1-q)) powr n / (1-q) powr ((1-γ)*n)"
      using q_range q_lt_1[OF True] by (intro divide_left_mono powr_mono_rev r_bound) auto
    also have "... = (δ_avg * q + (1-q)) powr n / ((1-q) powr (1-γ)) powr n"
      unfolding powr_powr by simp
    also have "... = ((δ_avg * q + (1-q)) / (1-q) powr (1-γ)) powr n"
      using assms(6) q_range by (subst powr_divide) auto
    also have "... = exp (- KL_div γ δ_avg) powr real n"
      unfolding 5[OF True] by simp
    also have "... = ?R1"
      unfolding exp_powr by simp
    finally show ?thesis by simp
  next
    case False
    hence γ_eq_1: "γ=1" using assms(5) by simp
    have "?L1 = δ_avg ^ n"
      using γ_eq_1 r_def q_def by simp
    also have "... = exp( - KL_div 1 δ_avg) ^ n"
      unfolding KL_div_def using assms(6) by (simp add:ln_div)
    also have "... = ?R1"
      using γ_eq_1 by (simp add: powr_realpow[symmetric] exp_powr)
    finally show ?thesis by simp
  qed

  have 4: "(1 - q) ^ r > 0"
  proof (cases "γ < 1")
    case True
    then show ?thesis using q_lt_1[OF True] by simp
  next
    case False
    hence "γ=1" using assms(5) by simp
    hence "r=0" unfolding r_def by simp
    then show ?thesis by simp
  qed

  have "(1-q) ^ r * ?L = (ω. indicator {ω. ?E ω} ω * (1-q) ^ r p)"
    by simp
  also have "...  (ω. indicator {ω. ?E ω} ω * (1-q) ^ g ω p)"
    using q_range 2 by (intro integral_mono_AE integrable_pmf_iff_bounded[where C="1"]
        abs_pos_le_1I mult_le_one power_le_one AE_pmfI) (simp_all split:split_indicator)
  also have "... = (ω. indicator {ω. ?E ω} ω * (i  {i. i < n  ¬Y i ω}. (1-q)) p)"
    unfolding g_def using q_range
    by (intro integral_cong_AE AE_pmfI, simp_all add:powr_realpow)
  also have "... = (ω. indicator {ω. ?E ω} ω * measure  ({j. j < n  ¬Y j ω}  {False}) p)"
    using q_range by (subst prob_prod_pmf') (auto simp add:measure_pmf_single)
  also have "... = (ω. measure  {ξ. ?E ω  (i{j. j < n  ¬Y j ω}. ¬ξ i)} p)"
    by (intro integral_cong_AE AE_pmfI, simp_all add:Pi_def split:split_indicator)
  also have "... = (ω. measure  {ξ. ?E ω  (i{..<n}. ξ i  Y i ω)} p)"
    by (intro integral_cong_AE AE_pmfI measure_eq_AE) auto
  also have "... = measure (pair_pmf p ) {φ.?E (fst φ)(i  {..<n}. snd φ i  Y i (fst φ))}"
    unfolding split_pair_pmf by simp
  also have "...  measure (pair_pmf p ) {φ. (i  {j. j < n  snd φ j}. Y i (fst φ))}"
    by (intro pmf_mono, auto)
  also have "... = (ξ. measure p {ω. i{j. j< n  ξ j}. Y i ω}  )"
    unfolding split_pair_pmf_2 by simp
  also have "...  (a. (i  {j. j < n  a j}. δ i)  )"
    using assms(2) by (intro integral_mono_AE AE_pmfI assms(3) subsetI  prod_le_1 prod_nonneg
        integrable_pmf_iff_bounded[where C="1"] abs_pos_le_1I) auto
  also have "... = (a. (i  {..<n}. δ i^ of_bool(a i))  )"
    unfolding of_bool_def by (intro integral_cong_AE AE_pmfI)
      (auto simp add:if_distrib prod.If_cases Int_def)
  also have "... = (i<n. (a. (δ i ^ of_bool a) (bernoulli_pmf q)))"
    using assms(2) by (intro expectation_prod_Pi_pmf integrable_pmf_iff_bounded[where C="1"]) auto
  also have "... = (i<n. δ i * q + (1-q))"
    using q_range by simp
  also have "... = (root (card {..<n}) (i<n. δ i * q + (1-q))) ^ (card {..<n})"
    using assms(1,2) q_range by (intro real_root_pow_pos2[symmetric] prod_nonneg) auto
  also have "...  ((i<n. δ i * q + (1-q))/card{..<n})^(card {..<n})"
    using assms(1,2) q_range by (intro power_mono arithmetic_geometric_mean)
      (auto intro: prod_nonneg)
  also have "... = ((i<n. δ i * q)/n + (1-q))^n"
    using assms(1) by (simp add:sum.distrib divide_simps mult.commute)
  also have "... = (δ_avg * q + (1-q))^n"
    unfolding δ_avg_def by (simp add: sum_distrib_right[symmetric])
  finally have "(1-q) ^ r * ?L  (δ_avg * q + (1-q)) ^ n" by simp
  hence "?L  (δ_avg * q + (1-q)) ^ n / (1-q) ^ r"
    using 4 by (subst pos_le_divide_eq) (auto simp add:algebra_simps)
  also have "...  ?R"
    by (intro 3)
  finally show ?thesis by simp
qed

text ‹The distribution of a random variable with a countable range is a discrete probability space,
i.e., induces a PMF. Using this it is possible to generalize the previous result to arbitrary
probability spaces.›

lemma (in prob_space) establish_pmf:
  fixes f :: "'a  'b"
  assumes rv: "random_variable discrete f"
  assumes "countable (f ` space M)"
  shows "distr M discrete f  {M. prob_space M  sets M = UNIV  (AE x in M. measure M {x}  0)}"
proof -
  define N where "N = {x  space M.¬ prob (f -` {f x}  space M)  0}"
  define I where "I = {z  (f ` space M). prob (f -` {z}   space M) = 0}"

  have countable_I: " countable I"
    unfolding I_def by (intro countable_subset[OF _ assms(2)]) auto

  have disj: "disjoint_family_on (λy. f -` {y}  space M) I"
    unfolding disjoint_family_on_def by auto

  have N_alt_def: "N = (y  I. f -` {y}  space M)"
    unfolding N_def I_def by (auto simp add:set_eq_iff)
  have "emeasure M N = + y. emeasure M (f -` {y}  space M) count_space I"
    using rv countable_I unfolding N_alt_def
    by (subst emeasure_UN_countable) (auto simp add:disjoint_family_on_def)
  also have "... =  + y. 0 count_space I"
    unfolding I_def using emeasure_eq_measure ennreal_0
    by (intro nn_integral_cong) auto
  also have "... = 0" by simp
  finally have 0:"emeasure M N = 0" by simp

  have 1:"N  events"
    unfolding N_alt_def using rv
    by (intro  sets.countable_UN'' countable_I) simp

  have " AE x in M. prob (f -` {f x}  space M)  0"
    using 0 1 by (subst AE_iff_measurable[OF _ N_def[symmetric]])
  hence " AE x in M. measure (distr M discrete f) {f x}  0"
    by (subst measure_distr[OF rv], auto)
  hence "AE x in distr M discrete f. measure (distr M discrete f) {x}  0"
    by (subst AE_distr_iff[OF rv], auto)
  thus ?thesis
    using prob_space_distr rv by auto
qed

lemma singletons_image_eq:
  "(λx. {x}) ` T  Pow T"
  by auto

theorem (in prob_space) impagliazzo_kabanets:
  fixes Y :: "nat  'a  bool"
  assumes "n > 0"
  assumes "i. i  {..<n}  random_variable discrete (Y i)"
  assumes "i. i  {..<n}  δ i  {0..1}"
  assumes "S. S  {..<n}  𝒫(ω in M. (i  S. Y i ω))  (i  S. δ i)"
  defines "δ_avg  (i {..<n}. δ i)/n"
  assumes "γ  {δ_avg..1}" "δ_avg > 0"
  shows "𝒫(ω in M. real (card {i  {..<n}. Y i ω})  γ * n)  exp (-real n * KL_div γ δ_avg)"
    (is "?L  ?R")
proof -
  define f where "f = (λω i. if i < n then Y i ω else False)"
  define g where "g = (λω i. if i < n then ω i else False)"
  define T where "T = {ω. (i. ω i  i < n)}"

  have g_idem: "g  f = f" unfolding f_def g_def by (simp add:comp_def)

  have f_range: " f  space M  T"
    unfolding T_def f_def by simp

  have "T = PiE_dflt {..<n} False (λ_. UNIV)"
    unfolding T_def PiE_dflt_def by auto
  hence "finite T"
    using finite_PiE_dflt by auto
  hence countable_T: "countable T"
    by (intro countable_finite)
  moreover have "f ` space M  T"
    using f_range by auto
  ultimately have countable_f: "countable (f ` space M)"
    using countable_subset by auto

  have "f -` y  space M  events" if t:"y  (λx. {x}) ` T" for y
  proof -
    obtain t where "y = {t}" and t_range: "t  T" using t by auto
    hence "f -` y  space M = {ω  space M. f ω = t}"
      by (auto simp add:vimage_def)
    also have "... = {ω  space M. (i < n. Y i ω = t i)}"
      using t_range unfolding f_def T_def by auto
    also have "... = (i  {..<n}. {ω  space M. Y i ω = t i})"
      using assms(1) by auto
    also have "...  events"
      using assms(1,2)
      by (intro sets.countable_INT) auto
    finally show ?thesis by simp
  qed

  hence "random_variable (count_space T) f"
    using sigma_sets_singletons[OF countable_T] singletons_image_eq f_range
    by (intro measurable_sigma_sets[where Ω="T" and A=" (λx. {x}) ` T"]) simp_all
  moreover have "g  measurable discrete (count_space T)"
    unfolding g_def T_def by simp
  ultimately have "random_variable discrete (g  f)"
    by simp
  hence rv:"random_variable discrete f"
    unfolding g_idem by simp

  define M' :: "(nat  bool) measure"
    where "M' = distr M discrete f"

  define Ω where "Ω = Abs_pmf M'"
  have a:"measure_pmf (Abs_pmf M') = M'"
    unfolding M'_def
    by (intro Abs_pmf_inverse[OF establish_pmf] rv countable_f)

  have b:"{i. (i < n  Y i x)  i < n} = {i. i < n  Y i x}" for x
    by auto

  have c: "measure Ω {ω. iS. ω i}  prod δ S" (is "?L1  ?R1") if "S  {..<n}" for S
  proof -
    have d: "i  S  i < n" for i
      using that by auto
    have "?L1 = measure M' {ω. iS. ω i}"
      unfolding Ω_def a by simp
    also have "... = 𝒫(ω in M. (i  S. Y i ω))"
      unfolding M'_def using that d
      by (subst measure_distr[OF rv]) (auto simp add:f_def Int_commute Int_def)
    also have "...  ?R1"
      using that assms(4) by simp
    finally show ?thesis by simp
  qed

  have "?L = measure M' {ω. real (card {i. i < n  ω i})  γ * n}"
    unfolding M'_def by (subst measure_distr[OF rv])
      (auto simp add:f_def algebra_simps Int_commute Int_def b)
  also have "... = measure_pmf.prob Ω {ω. real (card {i  {..<n}. ω i})  γ * n}"
    unfolding Ω_def a by simp
  also have "...  ?R"
    using assms(1,3,6,7) c unfolding δ_avg_def
    by (intro impagliazzo_kabanets_pmf) auto
  finally show ?thesis by simp
qed

text ‹Bounds and properties of @{term "KL_div"}

lemma KL_div_mono_right_aux_1:
  assumes "0  p" "p  q" "q  q'" "q' < 1"
  shows "KL_div p q-2*(p-q)^2  KL_div p q'-2*(p-q')^2"
proof (cases "p = 0")
  case True
  define f' :: "real  real" where "f' = (λx. 1/(1-x) - 4 * x)"

  have deriv: "((λq. ln (1/(1-q)) - 2*q^2) has_real_derivative (f' x)) (at x)"
    if "x  {q..q'}" for x
  proof -
    have "x  {0..<1}" using assms that by auto
    thus ?thesis unfolding f'_def by (auto intro!: derivative_eq_intros)
  qed

  have deriv_nonneg: "f' x  0" if "x  {q..q'}" for x
  proof -
    have 0:"x  {0..<1}" using assms that by auto
    have "4 * x*(1-x) = 1 - 4*(x-1/2)^2" by (simp add:power2_eq_square field_simps)
    also have "...  1" by simp
    finally have "4*x*(1-x)  1" by simp
    hence "1/(1-x)  4*x" using 0 by (simp add: pos_le_divide_eq)
    thus ?thesis unfolding f'_def by auto
  qed

  have "ln (1 / (1 - q)) - 2 * q^2  ln (1 / (1 - q')) - 2 * q'^2"
    using deriv deriv_nonneg by (intro DERIV_nonneg_imp_nondecreasing[OF assms(3)]) auto
  thus ?thesis using True unfolding KL_div_def by simp
next
  case False
  hence p_gt_0: "p > 0" using assms by auto

  define f' :: "real  real" where "f' = (λx. (1-p)/(1-x) - p/x + 4 * (p-x))"

  have deriv: "((λq. KL_div p q - 2*(p-q)^2) has_real_derivative (f' x)) (at x)" if "x  {q..q'}"
    for x
  proof -
    have "0 < p /x" " 0 < (1 - p) / (1 - x)" using that assms p_gt_0 by auto
    thus ?thesis unfolding KL_div_def f'_def by (auto intro!: derivative_eq_intros)
  qed

  have f'_part_nonneg: "(1/(x*(1-x)) - 4)  0" if "x  {0<..<1}" for x :: real
  proof -
    have "4 * x * (1-x) = 1 - 4 * (x-1/2)^2" by (simp add:power2_eq_square algebra_simps)
    also have "...  1" by simp
    finally have "4 * x * (1-x)  1" by simp
    hence "1/(x*(1-x))  4" using that by (subst pos_le_divide_eq) auto
    thus ?thesis by simp
  qed

  have f'_alt: "f' x = (x-p)*(1/(x*(1-x)) - 4)" if "x  {0<..<1}" for x
  proof -
    have "f' x = (x-p)/(x*(1-x)) + 4 * (p-x)" using that unfolding f'_def by (simp add:field_simps)
    also have "... = (x-p)*(1/(x*(1-x)) - 4)" by (simp add:algebra_simps)
    finally show ?thesis by simp
  qed

  have deriv_nonneg: "f' x  0" if "x  {q..q'}" for x
  proof -
    have "x  {0<..<1}" using assms that p_gt_0 by auto
    have "f' x =(x-p)*(1/(x*(1-x)) - 4)" using that assms p_gt_0 by (subst f'_alt) auto
    also have "...  0" using that f'_part_nonneg assms p_gt_0 by (intro mult_nonneg_nonneg) auto
    finally show ?thesis by simp
  qed

  show ?thesis using deriv deriv_nonneg
    by (intro DERIV_nonneg_imp_nondecreasing[OF assms(3)]) auto
qed

lemma KL_div_swap: "KL_div (1-p) (1-q) = KL_div p q"
  unfolding KL_div_def by auto

lemma KL_div_mono_right_aux_2:
  assumes "0 < q'" "q'  q" "q  p" "p  1"
  shows "KL_div p q-2*(p-q)^2  KL_div p q'-2*(p-q')^2"
proof -
  have "KL_div (1-p) (1-q)-2*((1-p)-(1-q))^2  KL_div (1-p) (1-q')-2*((1-p)-(1-q'))^2"
    using assms by (intro KL_div_mono_right_aux_1) auto
  thus ?thesis unfolding KL_div_swap by (auto simp:algebra_simps power2_commute)
qed

lemma KL_div_mono_right_aux:
  assumes "(0  p  p  q  q  q'  q' < 1)  (0 < q'  q'  q  q  p  p  1)"
  shows "KL_div p q-2*(p-q)^2  KL_div p q'-2*(p-q')^2"
  using KL_div_mono_right_aux_1 KL_div_mono_right_aux_2 assms by auto

lemma KL_div_mono_right:
  assumes "(0  p  p  q  q  q'  q' < 1)  (0 < q'  q'  q  q  p  p  1)"
  shows "KL_div p q  KL_div p q'" (is "?L  ?R")
proof -
  consider (a) "0  p" "p  q" "q  q'" "q' < 1" | (b) "0 < q'" "q'  q" "q  p" "p  1"
    using assms by auto
  hence 0: "(p - q)2  (p - q')2"
  proof (cases)
    case a
    hence "(q-p)^2  (q' - p)^2" by auto
    thus ?thesis by (simp add: power2_commute)
  next
    case b thus ?thesis by simp
  qed
  have "?L = (KL_div p q - 2*(p-q)^2) + 2 * (p-q)^2" by simp
  also have "...  (KL_div p q' - 2*(p-q')^2) + 2 * (p-q')^2"
    by (intro add_mono KL_div_mono_right_aux assms mult_left_mono 0) auto
  also have "... = ?R" by simp
  finally show ?thesis by simp
qed

lemma KL_div_lower_bound:
  assumes "p  {0..1}" "q  {0<..<1}"
  shows "2*(p-q)^2  KL_div p q"
proof -
  have "0  KL_div p p - 2 * (p-p)^2" unfolding KL_div_def by simp
  also have "...  KL_div p q - 2 * (p-q)^2" using assms by (intro KL_div_mono_right_aux) auto
  finally show ?thesis by simp
qed

end