# Theory Constructive_Chernoff_Bound

subsection ‹Constructive Chernoff Bound\label{sec:constructive_chernoff_bound}›

text ‹This section formalizes Theorem~5 by Impagliazzo and Kabanets~\cite{impagliazzo2010}. It is
a general result with which Chernoff-type tail bounds for various kinds of weakly dependent random
variables can be obtained. The results here are general and will be applied in
Section~\ref{sec:random_walks} to random walks in expander graphs.›

theory Constructive_Chernoff_Bound
imports
"HOL-Probability.Probability_Measure"
Universal_Hash_Families.Universal_Hash_Families_More_Product_PMF
Weighted_Arithmetic_Geometric_Mean.Weighted_Arithmetic_Geometric_Mean
begin

lemma powr_mono_rev:
fixes x :: real
assumes "a ≤ b" and  "x > 0" "x ≤ 1"
shows "x powr b ≤ x powr a"
proof -
have "x powr b = (1/x) powr (-b)"
using assms by (simp add: powr_divide powr_minus_divide)
also have "... ≤ (1/x) powr (-a)"
using assms by (intro powr_mono) auto
also have "... = x powr a"
using assms by (simp add: powr_divide powr_minus_divide)
finally show ?thesis by simp
qed

lemma exp_powr: "(exp x) powr y = exp (x*y)" for x :: real
unfolding powr_def by simp

lemma integrable_pmf_iff_bounded:
fixes f :: "'a ⇒ real"
assumes "⋀x. x ∈ set_pmf p ⟹ abs (f x) ≤ C"
shows "integrable (measure_pmf p) f"
proof -
obtain x where "x ∈ set_pmf p"
using set_pmf_not_empty by fast
hence "C ≥ 0" using assms(1) by fastforce
hence " (∫⇧+ x. ennreal (abs (f x)) ∂measure_pmf p) ≤ (∫⇧+ x. C ∂measure_pmf p)"
using assms ennreal_le_iff
by (intro nn_integral_mono_AE AE_pmfI) auto
also have "... = C"
by simp
also have "... < Orderings.top"
by simp
finally have "(∫⇧+ x. ennreal (abs (f x)) ∂measure_pmf p) < Orderings.top" by simp
thus ?thesis
by (intro iffD2[OF integrable_iff_bounded]) auto
qed

lemma split_pair_pmf:
"measure_pmf.prob (pair_pmf A B) S = integral⇧L A (λa. measure_pmf.prob B {b. (a,b) ∈ S})"
(is "?L = ?R")
proof -
have a:"integrable (measure_pmf A) (λx. measure_pmf.prob B {b. (x, b) ∈ S})"
by (intro integrable_pmf_iff_bounded[where C="1"]) simp

have "?L = (∫⇧+x. indicator S x ∂(measure_pmf (pair_pmf A B)))"
also have "... = (∫⇧+x. (∫⇧+y. indicator S (x,y) ∂B) ∂A)"
also have "... = (∫⇧+x. (∫⇧+y. indicator {b. (x,b) ∈ S} y ∂B) ∂A)"
also have "... = (∫⇧+x. (measure_pmf.prob B {b. (x,b) ∈ S}) ∂A)"
also have "... = ?R"
using a
by (subst nn_integral_eq_integral) auto
finally show ?thesis by simp
qed

lemma split_pair_pmf_2:
"measure(pair_pmf A B) S = integral⇧L B (λa. measure_pmf.prob A {b. (b,a) ∈ S})"
(is "?L = ?R")
proof -
have "?L = measure (pair_pmf B A) {ω. (snd ω, fst ω) ∈ S}"
by (subst pair_commute_pmf) (simp add:vimage_def case_prod_beta)
also have "... = ?R"
unfolding split_pair_pmf by simp
finally show ?thesis by simp
qed

definition KL_div :: "real ⇒ real ⇒ real"
where "KL_div p q = p * ln (p/q) + (1-p) * ln ((1-p)/(1-q))"

theorem impagliazzo_kabanets_pmf:
fixes Y :: "nat ⇒ 'a ⇒ bool"
fixes p :: "'a pmf"
assumes "n > 0"
assumes "⋀i. i ∈ {..<n} ⟹ δ i ∈ {0..1}"
assumes "⋀S. S ⊆ {..<n} ⟹ measure p {ω. (∀i ∈ S. Y i ω)} ≤ (∏i ∈ S. δ i)"
defines "δ_avg ≡ (∑i∈ {..<n}. δ i)/n"
assumes "γ ∈ {δ_avg..1}"
assumes "δ_avg > 0"
shows "measure p {ω. real (card {i ∈ {..<n}. Y i ω}) ≥ γ * n} ≤ exp (-real n * KL_div γ δ_avg)"
(is "?L ≤ ?R")
proof -
let ?n = "real n"
define q :: real where "q = (if γ = 1 then 1 else (γ-δ_avg)/(γ*(1-δ_avg)))"

define g where "g ω = card {i. i < n ∧ ¬Y i ω}" for ω
let ?E = "(λω. real (card {i. i < n ∧ Y i ω}) ≥ γ * n)"
let ?Ξ = "prod_pmf {..<n} (λ_. bernoulli_pmf q)"

have q_range:"q ∈{0..1}"
proof (cases "γ < 1")
case True
then show ?thesis
using assms(5,6)
unfolding q_def by (auto intro!:divide_nonneg_pos simp add:algebra_simps)
next
case False
hence "γ = 1" using assms(5) by simp
then show ?thesis unfolding q_def by simp
qed

have abs_pos_le_1I: "abs x ≤ 1" if "x ≥ 0" "x ≤ 1" for x :: real
using that by auto

have γ_n_nonneg: "γ*?n ≥ 0"
using assms(1,5,6) by simp
define r where "r = n - nat ⌈γ*n⌉"

have 2:"(1-q) ^ r ≤ (1-q)^ g ω" if "?E ω" for ω
proof -
have "g ω = card ({i. i < n} - {i. i < n ∧ Y i ω})"
unfolding g_def by (intro arg_cong[where f="λx. card x"]) auto
also have "... = card {i. i < n} - card {i. i < n ∧ Y i ω}"
by (subst card_Diff_subset, auto)
also have "... ≤ card {i. i < n} - nat ⌈γ*n⌉"
using that γ_n_nonneg by (intro diff_le_mono2) simp
also have "... = r"
unfolding r_def by simp
finally have "g ω ≤ r" by simp
thus "(1-q) ^ r ≤ (1-q) ^ (g ω)"
using q_range by (intro power_decreasing) auto
qed

have γ_gt_0: "γ > 0"
using assms(5,6) by simp

have q_lt_1: "q < 1" if "γ < 1"
proof -
have "δ_avg < 1" using assms(5) that by simp
hence "(γ - δ_avg) / (γ * (1 - δ_avg)) < 1"
using γ_gt_0 assms(6) that
by (subst pos_divide_less_eq) (auto simp add:algebra_simps)
thus "q < 1"
unfolding q_def using that by simp
qed

have 5: "(δ_avg * q + (1-q)) / (1-q) powr (1-γ) = exp (- KL_div γ δ_avg)" (is "?L1 = ?R1")
if "γ < 1"
proof -
have δ_avg_range: "δ_avg ∈ {0<..<1}"
using that assms(5,6)  by simp

have "?L1 = (1 - (1-δ_avg) * q) / (1-q) powr (1-γ)"
also have "... = (1 - (γ-δ_avg) / γ ) / (1-q) powr (1-γ)"
unfolding q_def using that γ_gt_0 δ_avg_range by simp
also have "... = (δ_avg / γ) / (1-q) powr (1-γ)"
also have "... = (δ_avg / γ) * (1/(1-q)) powr (1-γ)"
using q_lt_1[OF that] by (subst powr_divide, simp_all)
also have "... = (δ_avg / γ) * (1/((γ*(1-δ_avg)-(γ-δ_avg))/(γ*(1-δ_avg)))) powr (1-γ)"
using γ_gt_0 δ_avg_range unfolding q_def by (simp add:divide_simps)
also have "... = (δ_avg / γ) * ((γ / δ_avg) *((1-δ_avg)/(1-γ))) powr (1-γ)"
also have "... = (δ_avg / γ) * (γ / δ_avg) powr (1-γ) *((1-δ_avg)/(1-γ)) powr (1-γ)"
using γ_gt_0  δ_avg_range that by (subst powr_mult, auto)
also have "... = (δ_avg / γ) powr 1 * (δ_avg / γ) powr -(1-γ) *((1-δ_avg)/(1-γ)) powr (1-γ)"
using γ_gt_0 δ_avg_range that unfolding powr_minus_divide by (simp add:powr_divide)
also have "... = (δ_avg / γ) powr γ *((1-δ_avg)/(1-γ)) powr (1-γ)"
also have "... = exp ( ln ((δ_avg / γ) powr γ *((1-δ_avg)/(1-γ)) powr (1-γ)))"
using γ_gt_0 δ_avg_range that by (intro exp_ln[symmetric] mult_pos_pos) auto
also have "... =  exp ((ln ((δ_avg / γ) powr γ) + ln (((1 - δ_avg) / (1 - γ)) powr (1-γ))))"
using γ_gt_0 δ_avg_range that by (subst ln_mult) auto
also have "... =  exp ((γ * ln (δ_avg / γ) + (1 - γ) * ln ((1 - δ_avg) / (1 - γ))))"
using γ_gt_0 δ_avg_range that by (simp add:ln_powr algebra_simps)
also have "... =  exp (- (γ * ln (γ / δ_avg) + (1 - γ) * ln ((1 - γ) / (1 - δ_avg))))"
using γ_gt_0 δ_avg_range that by (simp add: ln_div algebra_simps)
also have "... = ?R1"
unfolding KL_div_def by simp

finally show ?thesis by simp
qed

have 3: "(δ_avg * q + (1-q)) ^ n / (1-q) ^ r ≤ exp (- ?n* KL_div γ δ_avg)" (is "?L1 ≤ ?R1")
proof (cases "γ < 1")
case True
have "γ * real n ≤ 1 * real n"
using True by (intro mult_right_mono) auto
hence "r = real n - real (nat ⌈γ * real n⌉)"
unfolding r_def by (subst of_nat_diff) auto
also have "... = real n - ⌈γ * real n⌉"
using γ_n_nonneg by (subst of_nat_nat, auto)
also have "... ≤ ?n - γ * ?n"
by (intro diff_mono) auto
also have "... = (1-γ) *?n" by (simp add:algebra_simps)
finally have r_bound: "r ≤ (1-γ)*n" by simp

have "?L1 = (δ_avg * q + (1-q)) ^ n / (1-q) powr r"
using q_lt_1[OF True] assms(1) by (simp add: powr_realpow)
also have "... = (δ_avg * q + (1-q)) powr n / (1-q) powr r"
using q_lt_1[OF True] assms(6) q_range
also have "... ≤ (δ_avg * q + (1-q)) powr n / (1-q) powr ((1-γ)*n)"
using q_range q_lt_1[OF True] by (intro divide_left_mono powr_mono_rev r_bound) auto
also have "... = (δ_avg * q + (1-q)) powr n / ((1-q) powr (1-γ)) powr n"
unfolding powr_powr by simp
also have "... = ((δ_avg * q + (1-q)) / (1-q) powr (1-γ)) powr n"
using assms(6) q_range by (subst powr_divide) auto
also have "... = exp (- KL_div γ δ_avg) powr real n"
unfolding 5[OF True] by simp
also have "... = ?R1"
unfolding exp_powr by simp
finally show ?thesis by simp
next
case False
hence γ_eq_1: "γ=1" using assms(5) by simp
have "?L1 = δ_avg ^ n"
using γ_eq_1 r_def q_def by simp
also have "... = exp( - KL_div 1 δ_avg) ^ n"
unfolding KL_div_def using assms(6) by (simp add:ln_div)
also have "... = ?R1"
using γ_eq_1 by (simp add: powr_realpow[symmetric] exp_powr)
finally show ?thesis by simp
qed

have 4: "(1 - q) ^ r > 0"
proof (cases "γ < 1")
case True
then show ?thesis using q_lt_1[OF True] by simp
next
case False
hence "γ=1" using assms(5) by simp
hence "r=0" unfolding r_def by simp
then show ?thesis by simp
qed

have "(1-q) ^ r * ?L = (∫ω. indicator {ω. ?E ω} ω * (1-q) ^ r ∂p)"
by simp
also have "... ≤ (∫ω. indicator {ω. ?E ω} ω * (1-q) ^ g ω ∂p)"
using q_range 2 by (intro integral_mono_AE integrable_pmf_iff_bounded[where C="1"]
abs_pos_le_1I mult_le_one power_le_one AE_pmfI) (simp_all split:split_indicator)
also have "... = (∫ω. indicator {ω. ?E ω} ω * (∏i ∈ {i. i < n ∧ ¬Y i ω}. (1-q)) ∂p)"
unfolding g_def using q_range
by (intro integral_cong_AE AE_pmfI, simp_all add:powr_realpow)
also have "... = (∫ω. indicator {ω. ?E ω} ω * measure ?Ξ ({j. j < n ∧ ¬Y j ω} → {False}) ∂p)"
using q_range by (subst prob_prod_pmf') (auto simp add:measure_pmf_single)
also have "... = (∫ω. measure ?Ξ {ξ. ?E ω ∧ (∀i∈{j. j < n ∧ ¬Y j ω}. ¬ξ i)} ∂p)"
by (intro integral_cong_AE AE_pmfI, simp_all add:Pi_def split:split_indicator)
also have "... = (∫ω. measure ?Ξ {ξ. ?E ω ∧ (∀i∈{..<n}. ξ i ⟶ Y i ω)} ∂p)"
by (intro integral_cong_AE AE_pmfI measure_eq_AE) auto
also have "... = measure (pair_pmf p ?Ξ) {φ.?E (fst φ)∧(∀i ∈ {..<n}. snd φ i ⟶ Y i (fst φ))}"
unfolding split_pair_pmf by simp
also have "... ≤ measure (pair_pmf p ?Ξ) {φ. (∀i ∈ {j. j < n ∧ snd φ j}. Y i (fst φ))}"
by (intro pmf_mono, auto)
also have "... = (∫ξ. measure p {ω. ∀i∈{j. j< n ∧ ξ j}. Y i ω} ∂ ?Ξ)"
unfolding split_pair_pmf_2 by simp
also have "... ≤ (∫a. (∏i ∈ {j. j < n ∧ a j}. δ i) ∂ ?Ξ)"
using assms(2) by (intro integral_mono_AE AE_pmfI assms(3) subsetI  prod_le_1 prod_nonneg
integrable_pmf_iff_bounded[where C="1"] abs_pos_le_1I) auto
also have "... = (∫a. (∏i ∈ {..<n}. δ i^ of_bool(a i)) ∂ ?Ξ)"
unfolding of_bool_def by (intro integral_cong_AE AE_pmfI)
also have "... = (∏i<n. (∫a. (δ i ^ of_bool a) ∂(bernoulli_pmf q)))"
using assms(2) by (intro expectation_prod_Pi_pmf integrable_pmf_iff_bounded[where C="1"]) auto
also have "... = (∏i<n. δ i * q + (1-q))"
using q_range by simp
also have "... = (root (card {..<n}) (∏i<n. δ i * q + (1-q))) ^ (card {..<n})"
using assms(1,2) q_range by (intro real_root_pow_pos2[symmetric] prod_nonneg) auto
also have "... ≤ ((∑i<n. δ i * q + (1-q))/card{..<n})^(card {..<n})"
using assms(1,2) q_range by (intro power_mono arithmetic_geometric_mean)
(auto intro: prod_nonneg)
also have "... = ((∑i<n. δ i * q)/n + (1-q))^n"
using assms(1) by (simp add:sum.distrib divide_simps mult.commute)
also have "... = (δ_avg * q + (1-q))^n"
unfolding δ_avg_def by (simp add: sum_distrib_right[symmetric])
finally have "(1-q) ^ r * ?L ≤ (δ_avg * q + (1-q)) ^ n" by simp
hence "?L ≤ (δ_avg * q + (1-q)) ^ n / (1-q) ^ r"
using 4 by (subst pos_le_divide_eq) (auto simp add:algebra_simps)
also have "... ≤ ?R"
by (intro 3)
finally show ?thesis by simp
qed

text ‹The distribution of a random variable with a countable range is a discrete probability space,
i.e., induces a PMF. Using this it is possible to generalize the previous result to arbitrary
probability spaces.›

lemma (in prob_space) establish_pmf:
fixes f :: "'a ⇒ 'b"
assumes rv: "random_variable discrete f"
assumes "countable (f  space M)"
shows "distr M discrete f ∈ {M. prob_space M ∧ sets M = UNIV ∧ (AE x in M. measure M {x} ≠ 0)}"
proof -
define N where "N = {x ∈ space M.¬ prob (f - {f x} ∩ space M) ≠ 0}"
define I where "I = {z ∈ (f  space M). prob (f - {z}  ∩ space M) = 0}"

have countable_I: " countable I"
unfolding I_def by (intro countable_subset[OF _ assms(2)]) auto

have disj: "disjoint_family_on (λy. f - {y} ∩ space M) I"
unfolding disjoint_family_on_def by auto

have N_alt_def: "N = (⋃y ∈ I. f - {y} ∩ space M)"
unfolding N_def I_def by (auto simp add:set_eq_iff)
have "emeasure M N = ∫⇧+ y. emeasure M (f - {y} ∩ space M) ∂count_space I"
using rv countable_I unfolding N_alt_def
by (subst emeasure_UN_countable) (auto simp add:disjoint_family_on_def)
also have "... =  ∫⇧+ y. 0 ∂count_space I"
unfolding I_def using emeasure_eq_measure ennreal_0
by (intro nn_integral_cong) auto
also have "... = 0" by simp
finally have 0:"emeasure M N = 0" by simp

have 1:"N ∈ events"
unfolding N_alt_def using rv
by (intro  sets.countable_UN'' countable_I) simp

have " AE x in M. prob (f - {f x} ∩ space M) ≠ 0"
using 0 1 by (subst AE_iff_measurable[OF _ N_def[symmetric]])
hence " AE x in M. measure (distr M discrete f) {f x} ≠ 0"
by (subst measure_distr[OF rv], auto)
hence "AE x in distr M discrete f. measure (distr M discrete f) {x} ≠ 0"
by (subst AE_distr_iff[OF rv], auto)
thus ?thesis
using prob_space_distr rv by auto
qed

lemma singletons_image_eq:
"(λx. {x})  T ⊆ Pow T"
by auto

theorem (in prob_space) impagliazzo_kabanets:
fixes Y :: "nat ⇒ 'a ⇒ bool"
assumes "n > 0"
assumes "⋀i. i ∈ {..<n} ⟹ random_variable discrete (Y i)"
assumes "⋀i. i ∈ {..<n} ⟹ δ i ∈ {0..1}"
assumes "⋀S. S ⊆ {..<n} ⟹ 𝒫(ω in M. (∀i ∈ S. Y i ω)) ≤ (∏i ∈ S. δ i)"
defines "δ_avg ≡ (∑i∈ {..<n}. δ i)/n"
assumes "γ ∈ {δ_avg..1}" "δ_avg > 0"
shows "𝒫(ω in M. real (card {i ∈ {..<n}. Y i ω}) ≥ γ * n) ≤ exp (-real n * KL_div γ δ_avg)"
(is "?L ≤ ?R")
proof -
define f where "f = (λω i. if i < n then Y i ω else False)"
define g where "g = (λω i. if i < n then ω i else False)"
define T where "T = {ω. (∀i. ω i ⟶ i < n)}"

have g_idem: "g ∘ f = f" unfolding f_def g_def by (simp add:comp_def)

have f_range: " f ∈ space M → T"
unfolding T_def f_def by simp

have "T = PiE_dflt {..<n} False (λ_. UNIV)"
unfolding T_def PiE_dflt_def by auto
hence "finite T"
using finite_PiE_dflt by auto
hence countable_T: "countable T"
by (intro countable_finite)
moreover have "f  space M ⊆ T"
using f_range by auto
ultimately have countable_f: "countable (f  space M)"
using countable_subset by auto

have "f - y ∩ space M ∈ events" if t:"y ∈ (λx. {x})  T" for y
proof -
obtain t where "y = {t}" and t_range: "t ∈ T" using t by auto
hence "f - y ∩ space M = {ω ∈ space M. f ω = t}"
also have "... = {ω ∈ space M. (∀i < n. Y i ω = t i)}"
using t_range unfolding f_def T_def by auto
also have "... = (⋂i ∈ {..<n}. {ω ∈ space M. Y i ω = t i})"
using assms(1) by auto
also have "... ∈ events"
using assms(1,2)
by (intro sets.countable_INT) auto
finally show ?thesis by simp
qed

hence "random_variable (count_space T) f"
using sigma_sets_singletons[OF countable_T] singletons_image_eq f_range
by (intro measurable_sigma_sets[where Ω="T" and A=" (λx. {x})  T"]) simp_all
moreover have "g ∈ measurable discrete (count_space T)"
unfolding g_def T_def by simp
ultimately have "random_variable discrete (g ∘ f)"
by simp
hence rv:"random_variable discrete f"
unfolding g_idem by simp

define M' :: "(nat ⇒ bool) measure"
where "M' = distr M discrete f"

define Ω where "Ω = Abs_pmf M'"
have a:"measure_pmf (Abs_pmf M') = M'"
unfolding M'_def
by (intro Abs_pmf_inverse[OF establish_pmf] rv countable_f)

have b:"{i. (i < n ⟶ Y i x) ∧ i < n} = {i. i < n ∧ Y i x}" for x
by auto

have c: "measure Ω {ω. ∀i∈S. ω i} ≤ prod δ S" (is "?L1 ≤ ?R1") if "S ⊆ {..<n}" for S
proof -
have d: "i ∈ S ⟹ i < n" for i
using that by auto
have "?L1 = measure M' {ω. ∀i∈S. ω i}"
unfolding Ω_def a by simp
also have "... = 𝒫(ω in M. (∀i ∈ S. Y i ω))"
unfolding M'_def using that d
by (subst measure_distr[OF rv]) (auto simp add:f_def Int_commute Int_def)
also have "... ≤ ?R1"
using that assms(4) by simp
finally show ?thesis by simp
qed

have "?L = measure M' {ω. real (card {i. i < n ∧ ω i}) ≥ γ * n}"
unfolding M'_def by (subst measure_distr[OF rv])
(auto simp add:f_def algebra_simps Int_commute Int_def b)
also have "... = measure_pmf.prob Ω {ω. real (card {i ∈ {..<n}. ω i}) ≥ γ * n}"
unfolding Ω_def a by simp
also have "... ≤ ?R"
using assms(1,3,6,7) c unfolding δ_avg_def
by (intro impagliazzo_kabanets_pmf) auto
finally show ?thesis by simp
qed

text ‹Bounds and properties of @{term "KL_div"}›

lemma KL_div_mono_right_aux_1:
assumes "0 ≤ p" "p ≤ q" "q ≤ q'" "q' < 1"
shows "KL_div p q-2*(p-q)^2 ≤ KL_div p q'-2*(p-q')^2"
proof (cases "p = 0")
case True
define f' :: "real ⇒ real" where "f' = (λx. 1/(1-x) - 4 * x)"

have deriv: "((λq. ln (1/(1-q)) - 2*q^2) has_real_derivative (f' x)) (at x)"
if "x ∈ {q..q'}" for x
proof -
have "x ∈ {0..<1}" using assms that by auto
thus ?thesis unfolding f'_def by (auto intro!: derivative_eq_intros)
qed

have deriv_nonneg: "f' x ≥ 0" if "x ∈ {q..q'}" for x
proof -
have 0:"x ∈ {0..<1}" using assms that by auto
have "4 * x*(1-x) = 1 - 4*(x-1/2)^2" by (simp add:power2_eq_square field_simps)
also have "... ≤ 1" by simp
finally have "4*x*(1-x) ≤ 1" by simp
hence "1/(1-x) ≥ 4*x" using 0 by (simp add: pos_le_divide_eq)
thus ?thesis unfolding f'_def by auto
qed

have "ln (1 / (1 - q)) - 2 * q^2 ≤ ln (1 / (1 - q')) - 2 * q'^2"
using deriv deriv_nonneg by (intro DERIV_nonneg_imp_nondecreasing[OF assms(3)]) auto
thus ?thesis using True unfolding KL_div_def by simp
next
case False
hence p_gt_0: "p > 0" using assms by auto

define f' :: "real ⇒ real" where "f' = (λx. (1-p)/(1-x) - p/x + 4 * (p-x))"

have deriv: "((λq. KL_div p q - 2*(p-q)^2) has_real_derivative (f' x)) (at x)" if "x ∈ {q..q'}"
for x
proof -
have "0 < p /x" " 0 < (1 - p) / (1 - x)" using that assms p_gt_0 by auto
thus ?thesis unfolding KL_div_def f'_def by (auto intro!: derivative_eq_intros)
qed

have f'_part_nonneg: "(1/(x*(1-x)) - 4) ≥ 0" if "x ∈ {0<..<1}" for x :: real
proof -
have "4 * x * (1-x) = 1 - 4 * (x-1/2)^2" by (simp add:power2_eq_square algebra_simps)
also have "... ≤ 1" by simp
finally have "4 * x * (1-x) ≤ 1" by simp
hence "1/(x*(1-x)) ≥ 4" using that by (subst pos_le_divide_eq) auto
thus ?thesis by simp
qed

have f'_alt: "f' x = (x-p)*(1/(x*(1-x)) - 4)" if "x ∈ {0<..<1}" for x
proof -
have "f' x = (x-p)/(x*(1-x)) + 4 * (p-x)" using that unfolding f'_def by (simp add:field_simps)
also have "... = (x-p)*(1/(x*(1-x)) - 4)" by (simp add:algebra_simps)
finally show ?thesis by simp
qed

have deriv_nonneg: "f' x ≥ 0" if "x ∈ {q..q'}" for x
proof -
have "x ∈ {0<..<1}" using assms that p_gt_0 by auto
have "f' x =(x-p)*(1/(x*(1-x)) - 4)" using that assms p_gt_0 by (subst f'_alt) auto
also have "... ≥ 0" using that f'_part_nonneg assms p_gt_0 by (intro mult_nonneg_nonneg) auto
finally show ?thesis by simp
qed

show ?thesis using deriv deriv_nonneg
by (intro DERIV_nonneg_imp_nondecreasing[OF assms(3)]) auto
qed

lemma KL_div_swap: "KL_div (1-p) (1-q) = KL_div p q"
unfolding KL_div_def by auto

lemma KL_div_mono_right_aux_2:
assumes "0 < q'" "q' ≤ q" "q ≤ p" "p ≤ 1"
shows "KL_div p q-2*(p-q)^2 ≤ KL_div p q'-2*(p-q')^2"
proof -
have "KL_div (1-p) (1-q)-2*((1-p)-(1-q))^2 ≤ KL_div (1-p) (1-q')-2*((1-p)-(1-q'))^2"
using assms by (intro KL_div_mono_right_aux_1) auto
thus ?thesis unfolding KL_div_swap by (auto simp:algebra_simps power2_commute)
qed

lemma KL_div_mono_right_aux:
assumes "(0 ≤ p ∧ p ≤ q ∧ q ≤ q' ∧ q' < 1) ∨ (0 < q' ∧ q' ≤ q ∧ q ≤ p ∧ p ≤ 1)"
shows "KL_div p q-2*(p-q)^2 ≤ KL_div p q'-2*(p-q')^2"
using KL_div_mono_right_aux_1 KL_div_mono_right_aux_2 assms by auto

lemma KL_div_mono_right:
assumes "(0 ≤ p ∧ p ≤ q ∧ q ≤ q' ∧ q' < 1) ∨ (0 < q' ∧ q' ≤ q ∧ q ≤ p ∧ p ≤ 1)"
shows "KL_div p q ≤ KL_div p q'" (is "?L ≤ ?R")
proof -
consider (a) "0 ≤ p" "p ≤ q" "q ≤ q'" "q' < 1" | (b) "0 < q'" "q' ≤ q" "q ≤ p" "p ≤ 1"
using assms by auto
hence 0: "(p - q)⇧2 ≤ (p - q')⇧2"
proof (cases)
case a
hence "(q-p)^2 ≤ (q' - p)^2" by auto
thus ?thesis by (simp add: power2_commute)
next
case b thus ?thesis by simp
qed
have "?L = (KL_div p q - 2*(p-q)^2) + 2 * (p-q)^2" by simp
also have "... ≤ (KL_div p q' - 2*(p-q')^2) + 2 * (p-q')^2"
by (intro add_mono KL_div_mono_right_aux assms mult_left_mono 0) auto
also have "... = ?R" by simp
finally show ?thesis by simp
qed

lemma KL_div_lower_bound:
assumes "p ∈ {0..1}" "q ∈ {0<..<1}"
shows "2*(p-q)^2 ≤ KL_div p q"
proof -
have "0 ≤ KL_div p p - 2 * (p-p)^2" unfolding KL_div_def by simp
also have "... ≤ KL_div p q - 2 * (p-q)^2" using assms by (intro KL_div_mono_right_aux) auto
finally show ?thesis by simp
qed

end`