Theory MFMC_Countable.Matrix_For_Marginals

(* Author: Andreas Lochbihler, ETH Zurich *)

section ‹Matrices for given marginals›

text ‹This theory derives from the finite max-flow min-cut theorem the existence of
matrices with given marginals based on a proof by Georg Kellerer cite"Kellerer1961MA".›

theory Matrix_For_Marginals
  imports MFMC_Misc "HOL-Library.Diagonal_Subsequence" MFMC_Finite
begin

lemma bounded_matrix_for_marginals_finite:
  fixes f g :: "nat  real"
    and n :: nat
    and R :: "(nat × nat) set"
  assumes eq_sum: "sum f {..n} = sum g {..n}"
    and le: "X. X  {..n}  sum f X  sum g (R `` X)"
    and f_nonneg: "x. 0  f x"
    and g_nonneg: "y. 0  g y"
    and R: "R  {..n} × {..n}"
  obtains h :: "nat  nat  real"
    where "x y.  x  n; y  n   0  h x y"
    and "x y.  0 < h x y; x  n; y  n   (x, y)  R"
    and "x. x  n  f x = sum (h x) {..n}"
    and "y. y  n  g y = sum (λx. h x y) {..n}"
proof(cases "xn. f x > 0")
  case False
  hence f: "f x = 0" if "x  n" for x using f_nonneg[of x] that by(auto simp add: not_less)
  hence "sum g {..n} = 0" using eq_sum by simp
  hence "g y = 0" if "y  n" for y using g_nonneg that by(simp add: sum_nonneg_eq_0_iff)
  with f show thesis by(auto intro!: that[of "λ_ _. 0"])
next
  case True
  then obtain x0 where x0: "x0  n" "f x0 > 0" by blast

  define R' where "R' = R  {x. f x > 0} × {y. g y > 0}"

  have [simp]: "finite (R `` A)" for A
  proof -
    have "R `` A  {..n}" using R by auto
    thus ?thesis by(rule finite_subset) auto
  qed

  have R': "R'  {..n} × {..n}" using R by(auto simp add: R'_def)
  have R'': "R' `` A  {..n}" for A using R by(auto simp add: R'_def)
  have [simp]: "finite (R' `` A)" for A using R''[of A]
    by(rule finite_subset) auto

  have hop: "y0n. (x0, y0)  R  g y0 > 0" if x0: "x0  n" "f x0 > 0" for x0
  proof(rule ccontr)
    assume "¬ ?thesis"
    then have "g y0 = 0" if "(x0, y0)  R" "y0  n" for y0 using g_nonneg[of y0] that by auto
    moreover from R have "R `` {x0}  {..n}" by auto
    ultimately have "sum g (R `` {x0}) = 0" 
      using g_nonneg by(auto intro!: sum_nonneg_eq_0_iff[THEN iffD2])
    moreover have "{x0}  {..n}" using x0 by auto
    from le[OF this] x0 have "R `` {x0}  {}" "sum g (R `` {x0}) > 0" by auto
    ultimately show False by simp
  qed
  then obtain y0 where y0: "y0  n" "(x0, y0)  R'" "g y0 > 0" using x0 by(auto simp add: R'_def)

  define LARGE where "LARGE = sum f {..n} + 1"
  have "1  LARGE" using f_nonneg by(simp add: LARGE_def sum_nonneg)
  hence [simp]: "LARGE  0" "0  LARGE" "0 < LARGE" "0  LARGE" by simp_all

  define s where "s = 2 * n + 2"
  define t where "t = 2 * n + 3"
  define c where "c = (λ(x, y).
     if x = s  y  n then f y
     else if x  n  n < y  y  2 * n + 1  (x, y - n - 1)  R' then LARGE
     else if n < x  x  2 * n + 1  y = t then g (x - n - 1)
     else 0)"

  have st [simp]: "¬ s  n" "¬ s  Suc (2 * n)" "s  t" "t  s" "¬ t  n" "¬ t  Suc (2 * n)"
    by(simp_all add: s_def t_def)

  have c_simps: "c (x, y) = 
    (if x = s  y  n then f y
     else if x  n  n < y  y  2 * n + 1  (x, y - n - 1)  R' then LARGE
     else if n < x  x  2 * n + 1  y = t then g (x - n - 1)
     else 0)"
    for x y by(simp add: c_def)
  have cs [simp]: "c (s, y) = (if y  n then f y else 0)"
    and ct [simp]: "c (x, t) = (if n < x  x  2 * n + 1 then g (x - n - 1) else 0)"
    for x y by(auto simp add: c_simps)

  interpret Graph c .
  note [simp del] = zero_cap_simp

  interpret Network c s t
  proof
    have "(s, x0)  E" using x0 by(simp add: E_def)
    thus "s  V" by(auto simp add: V_def)
    
    have "(y0 + n + 1, t)  E" using y0 by(simp add: E_def)
    thus "t  V" by(auto simp add: V_def)

    show "s  t" by simp
    show "u v. 0  c (u, v)" by(simp add: c_simps f_nonneg g_nonneg max_def)
    show "u. (u, s)  E" by(simp add: E_def c_simps)
    show "u. (t, u)  E" by(simp add: E_def c_simps)
    show "u v. (u, v)  E  (v, u)  E" by(simp add: E_def c_simps)

    have "isPath s [(s, x0), (x0, y0 + n + 1), (y0 + n + 1, t)] t" 
      using x0 y0 by(auto simp add: E_def c_simps)
    hence st: "connected s t" by(auto simp add: connected_def simp del: isPath.simps)

    show "vV. connected s v  connected v t"
    proof(intro strip)
      fix v
      assume v: "v  V"
      hence "v  2 * n + 3" by(auto simp add: V_def E_def c_simps t_def s_def split: if_split_asm)
      then consider (left) "v  n" | (right) "n < v" "v  2 * n + 1" | (s) "v = s" | (t) "v = t"
        by(fastforce simp add: s_def t_def)
      then show "connected s v  connected v t"
      proof cases
        case left
        have sv: "(s, v)  E" using v left
          by(fastforce simp add: E_def V_def c_simps max_def R'_def split: if_split_asm)
        hence "connected s v" by(auto simp add: connected_def intro!: exI[where x="[(s, v)]"])
        moreover from sv left f_nonneg[of v] have "f v > 0" by(simp add: E_def)
        from hop[OF left this] obtain v' where "(v, v')  R" "v'  n" "g v' > 0" by auto
        hence "isPath v [(v, v' + n + 1), (v' + n + 1, t)] t" using left f v > 0
          by(auto simp add: E_def c_simps R'_def)
        hence "connected v t" by(auto simp add: connected_def simp del: isPath.simps)
        ultimately show ?thesis ..
      next
        case right
        hence vt: "(v, t)  E" using v by(auto simp add: V_def E_def c_simps max_def R'_def split: if_split_asm)
        hence "connected v t" by(auto simp add: connected_def intro!: exI[where x="[(v, t)]"])
        moreover
        have *: "g (v - n - 1) > 0" using vt g_nonneg[of "v - n - 1"] right by(simp add: E_def )
        have "v'n. (v', v - n - 1)  R  f v' > 0"
        proof(rule ccontr)
          assume "¬ ?thesis"
          hence zero: " v'  n; (v', v - n - 1)  R   f v' = 0" for v' using f_nonneg[of v'] by auto
          have "sum f {..n} = sum f {x. x  n  x  R^-1 `` {v - n - 1}}"
            by(rule sum.mono_neutral_cong_right)(auto dest: zero)
          also have "  sum g (R `` {x. x  n  x  R^-1 `` {v - n - 1}})" by(rule le) auto
          also have "  sum g ({..n} - {v - n - 1})"
            by(rule sum_mono2)(use R in auto simp add: g_nonneg)
          also have " = sum g {..n} - g (v - n - 1)" using right by(subst sum_diff) auto
          also have " < sum g {..n}" using * by simp
          also have " = sum f {..n}" by(simp add: eq_sum)
          finally show False by simp
        qed
        then obtain v' where "v'  n" "(v', v - n - 1)  R" "f v' > 0" by auto
        with right * have "isPath s [(s, v'), (v', v)] v" by(auto simp add: E_def c_simps R'_def)
        hence "connected s v" by(auto simp add: connected_def simp del: isPath.simps)
        ultimately show ?thesis by blast
      qed(simp_all add: st)
    qed

    have "reachableNodes s  V" using s  V by(rule reachable_ss_V)
    also have "V  {..2 * n + 3}"
      by(clarsimp simp add: V_def E_def)(simp_all add: c_simps s_def t_def split: if_split_asm)
    finally show "finite (reachableNodes s)" by(rule finite_subset) simp
  qed

  interpret h: NFlow c s t max_flow by(rule max_flow)
  let ?h = "λx y. max_flow (x, y + n + 1)"
  from max_flow(2)[THEN h.fofu_II_III] obtain C where C: "NCut c s t C" 
    and eq: "h.val = NCut.cap c C" by blast
  interpret C: NCut c s t C using C .

  have "sum c (outgoing s) = sum (λ(_, x). f x) (Pair s ` {..n})"
    by(rule sum.mono_neutral_cong_left)(auto simp add: outgoing_def E_def)
  also have " = sum f {..n}" by(subst sum.reindex)(auto simp add: inj_on_def)
  finally have out: "sum c (outgoing s) = sum f {..n}" .

  have no_leaving: "(λy. y + n + 1) ` (R' `` (C  {..n}))  C"
  proof(rule ccontr)
    assume "¬ ?thesis"
    then obtain x y where *: "(x, y)  R'" "x  n" "x  C" "y + n + 1  C" by auto
    then have xy: "(x, y + n + 1)  E" "y  n" "c (x, y + n + 1) = LARGE"
      using R by(auto simp add: E_def c_simps R'_def)

    have "h.val  sum f {..n}" using h.val_bounded(2) out by simp
    also have " < sum c {(x, y + n + 1)}" using xy * by(simp add: LARGE_def)
    also have "  C.cap" using * xy unfolding C.cap_def
      by(intro sum_mono2[OF finite_outgoing'])(auto simp add: outgoing'_def cap_non_negative)
    also have " = h.val" by(simp add: eq)
    finally show False by simp
  qed

  have "C.cap  sum f {..n}" using out h.val_bounded(2) eq by simp
  then have cap: "C.cap = sum f {..n}"
  proof(rule antisym)
    let ?L = "{x. x  n  x  C  f x > 0}"
    let ?R = "(λy. y + n + 1) ` (R' `` ?L)"
    have "sum f {..n} = sum f ?L + sum f ({..n} - ?L)" by(subst sum_diff) auto
    also have "sum f ?L  sum g (R `` ?L)" by(rule le) auto
    also have " = sum g (R' `` ?L)" using g_nonneg
      by(intro sum.mono_neutral_cong_right)(auto 4 3 simp add: R'_def Image_iff intro: antisym)
    also have " = sum c ((λy. (y + n + 1, t)) ` (R' `` ?L))" using R
      by(subst sum.reindex)(auto intro!: sum.cong simp add: inj_on_def R'_def)
    also have "sum f ({..n} - ?L) = sum c (Pair s ` ({..n} - ?L))" by(simp add: sum.reindex inj_on_def)
    also have "sum c ((λy. (y + n + 1, t)) ` (R' `` ?L)) +  = 
      sum c (((λy. (y + n + 1, t)) ` (R' `` ?L))  (Pair s ` ({..n} - ?L)))"
      by(subst sum.union_disjoint) auto
    also have "  sum c (((λy. (y + n + 1, t)) ` (R' `` ?L))  (Pair s ` ({..n} - ?L))  {(x, t) | x. x  C  n < x  x  2 * n + 1})"
      by(rule sum_mono2)(auto simp add: g_nonneg)
    also have "(((λy. (y + n + 1, t)) ` (R' `` ?L))  (Pair s ` ({..n} - ?L))  {(x, t) | x. x  C  n < x  x  2 * n + 1}) = (Pair s ` ({..n} - ?L))   {(x, t) | x. x  C  n < x  x  2 * n + 1}"
      using no_leaving R' by(fastforce simp add: Image_iff intro: rev_image_eqI)
    also have "sum c  = sum c (outgoing' C)" using C.s_in_cut C.t_ni_cut f_nonneg no_leaving
      apply(intro sum.mono_neutral_cong_right)
      apply(auto simp add: outgoing'_def E_def intro: le_neq_trans)
      apply(fastforce simp add: c_simps Image_iff intro: rev_image_eqI split: if_split_asm)+
      done
    also have " = C.cap" by(simp add: C.cap_def)
    finally show "sum f {..n}  " by simp
  qed

  show thesis
  proof
    show "0  ?h x y" for x y by(rule h.f_non_negative)
    show "(x, y)  R" if "0 < ?h x y" "x  n" "y  n" for x y
      using h.capacity_const[rule_format, of "(x, y + n + 1)"] that
      by(simp add: c_simps R'_def split: if_split_asm)

    have sum_h: "sum (?h x) {..n} = max_flow (s, x)" if "x  n" for x
    proof -
      have "sum (?h x) {..n} = sum max_flow (Pair x ` ((+) (n + 1)) ` {..n})"
        by(simp add: sum.reindex add.commute inj_on_def)
      also have " = sum max_flow (outgoing x)" using that
        apply(intro sum.mono_neutral_cong_right)
        apply(auto simp add: outgoing_def E_def)
        subgoal for y by(auto 4 3 simp add: c_simps max_def split: if_split_asm intro: rev_image_eqI[where x="y - n - 1"])
        done
      also have " = sum max_flow (incoming x)" using that by(subst h.conservation) auto
      also have " = sum max_flow {(s, x)}" using that
        by(intro sum.mono_neutral_cong_left; auto simp add: incoming_def E_def; simp add: c_simps split: if_split_asm)
      finally show ?thesis by simp
    qed
    hence le: "sum (?h x) {..n}  f x" if "x  n" for x
      using sum_h[OF that] h.capacity_const[rule_format, of "(s, x)"] that by simp
    moreover have "f x  sum (?h x) {..n}" if "x  n" for x
    proof(rule ccontr)
      assume "¬ ?thesis"
      hence "sum (?h x) {..n} < f x" by simp
      hence "sum (λx. (sum (?h x) {..n})) {..n} < sum f {..n}"
        using le that by(intro sum_strict_mono_ex1) auto
      also have "sum (λx. (sum (?h x) {..n})) {..n} = sum max_flow (Pair s ` {..n})"
        using sum_h by(simp add: sum.reindex inj_on_def)
      also have " = sum max_flow (outgoing s)"
        by(rule sum.mono_neutral_right)(auto simp add: outgoing_def E_def)
      also have " = sum f {..n}" using eq cap by(simp add: h.val_alt)
      finally show False by simp
    qed
    ultimately show "f x = sum (?h x) {..n}" if "x  n" for x using that by(auto intro: antisym)

    have sum_h': "sum (λx. ?h x y) {..n} = max_flow (y + n + 1, t)" if "y  n" for y
    proof -
      have "sum (λx. ?h x y) {..n} = sum max_flow ((λx. (x, y + n + 1)) ` {..n})" 
        by(simp add: sum.reindex inj_on_def)
      also have " = sum max_flow (incoming (y + n + 1))" using that
        apply(intro sum.mono_neutral_cong_right)
        apply(auto simp add: incoming_def E_def)
        apply(auto simp add: c_simps t_def split: if_split_asm)
        done
      also have " = sum max_flow (outgoing (y + n + 1))"
        using that by(subst h.conservation)(auto simp add: s_def t_def)
      also have " = sum max_flow {(y + n + 1, t)}" using that
        by(intro sum.mono_neutral_cong_left; auto simp add: outgoing_def E_def; simp add: s_def c_simps split: if_split_asm)
      finally show ?thesis by simp
    qed
    hence le': "sum (λx. ?h x y) {..n}  g y" if "y  n" for y
      using sum_h'[OF that] h.capacity_const[rule_format, of "(y + n + 1, t)"] that by simp
    moreover have "g y  sum (λx. ?h x y) {..n}" if "y  n" for y
    proof(rule ccontr)
      assume "¬ ?thesis"
      hence "sum (λx. ?h x y) {..n} < g y" by simp
      hence "sum (λy. (sum (λx. ?h x y) {..n})) {..n} < sum g {..n}"
        using le' that by(intro sum_strict_mono_ex1) auto
      also have "sum (λy. (sum (λx. ?h x y) {..n})) {..n} = sum max_flow ((λy. (y + n + 1, t)) ` {..n})"
        using sum_h' by(simp add: sum.reindex inj_on_def)
      also have " = sum max_flow (incoming t)"
        apply(rule sum.mono_neutral_right)
        apply simp
        apply(auto simp add: incoming_def E_def cong: rev_conj_cong)
        subgoal for u by(auto intro: rev_image_eqI[where x="u - n - 1"])
        done
      also have " = sum max_flow (outgoing s)" by(rule h.inflow_t_outflow_s)
      also have " = sum f {..n}" using eq cap by(simp add: h.val_alt)
      finally show False using eq_sum by simp
    qed
    ultimately show "g y = sum (λx. ?h x y) {..n}" if "y  n" for y using that by(auto intro: antisym)
  qed
qed

lemma convergent_bounded_family_nat:
  fixes f :: "nat  nat  real"
  assumes bounded: "x. bounded (range (λn. f n x))"
  obtains k where "strict_mono k" "x. convergent (λn. f (k n) x)"
proof -
  interpret subseqs "λx k. convergent (λn. f (k n) x)"
  proof(unfold_locales)
    fix x and s :: "nat  nat"
    have "bounded (range (λn. f (s n) x))" using bounded by(rule bounded_subset) auto
    from bounded_imp_convergent_subsequence[OF this]
    show "r. strict_mono r  convergent (λn. f ((s  r) n) x)"
      by(auto simp add: o_def convergent_def)
  qed
  { fix k
    have "convergent (λn. f ((diagseq  (+) (Suc k)) n) k)"
      by(rule diagseq_holds)(auto dest: convergent_subseq_convergent simp add: o_def)
    hence "convergent (λn. f (diagseq n) k)" unfolding o_def
      by(subst (asm) add.commute)(simp only: convergent_ignore_initial_segment[where f="λx. f (diagseq x) k"])
  } with subseq_diagseq show ?thesis ..
qed

lemma convergent_bounded_family:
  fixes f :: "nat  'a  real"
  assumes bounded: "x. x  A  bounded (range (λn. f n x))"
  and A: "countable A"
  obtains k where "strict_mono k" "x. x  A  convergent (λn. f (k n) x)"
proof(cases "A = {}")
  case False
  define f' where "f' n x = f n (from_nat_into A x)" for n x
  have "bounded (range (λn. f' n x))" for x
    unfolding f'_def using from_nat_into[OF False] by(rule bounded)
  then obtain k where k: "strict_mono k" 
    and conv: "convergent (λn. f' (k n) x)" for x
    by(rule convergent_bounded_family_nat) iprover
  have "convergent (λn. f (k n) x)" if "x  A" for x
    using conv[of "to_nat_on A x"] A that by(simp add: f'_def)
  with k show thesis ..
next
  case True
  with strict_mono_id show thesis by(blast intro!: that)
qed

abbreviation zero_on :: "('a  'b :: zero)  'a set  'a  'b"
where "zero_on f  override_on f (λ_. 0)"

lemma zero_on_le [simp]: fixes f :: "'a  'b :: {preorder, zero}" shows
  "zero_on f X x  f x  (x  X  0  f x)"
by(auto simp add: override_on_def)

lemma zero_on_nonneg: fixes f :: "'a  'b :: {preorder, zero}" shows
  "0  zero_on f X x  (x  X  0  f x)"
by(auto simp add: override_on_def)

lemma sums_zero_on:
  fixes f :: "nat  'a::real_normed_vector"
  assumes f: "f sums s"
    and X: "finite X"
  shows "zero_on f X sums (s - sum f X)"
proof -
  have "(λx. if x  X then f x else 0) sums sum (λx. f x) X" using X by(rule sums_If_finite_set)
  from sums_diff[OF f this] show ?thesis
    by(simp add: sum_negf override_on_def if_distrib cong del: if_weak_cong)
qed

lemma 
  fixes f :: "nat  'a::real_normed_vector"
  assumes f: "summable f"
  and X: "finite X"
  shows summable_zero_on [simp]: "summable (zero_on f X)" (is ?thesis1)
  and suminf_zero_on: "suminf (zero_on f X) = suminf f - sum f X" (is ?thesis2)
proof -
  from f obtain s where "f sums s" unfolding summable_def ..
  with sums_zero_on[OF this X] show ?thesis1 ?thesis2 
    by(auto simp add: summable_def sums_unique[symmetric])
qed

lemma summable_zero_on_nonneg:
  fixes f :: "nat  'a :: {ordered_comm_monoid_add,linorder_topology,conditionally_complete_linorder}"
  assumes f: "summable f"
  and nonneg: "x. 0  f x"
  shows "summable (zero_on f X)"
proof(rule summableI_nonneg_bounded)
  fix n
  have "sum (zero_on f X) {..<n}  sum f {..<n}" by(rule sum_mono)(simp add: nonneg)
  also have "  suminf f" using f by(rule sum_le_suminf)(auto simp add: nonneg)
  finally show "sum (zero_on f X) {..<n}  suminf f" .
qed(simp add: zero_on_nonneg nonneg)

lemma zero_on_ennreal [simp]: "zero_on (λx. ennreal (f x)) A = (λx. ennreal (zero_on f A x))"
by(simp add: override_on_def fun_eq_iff)
 
lemma sum_lessThan_conv_atMost_nat:
  fixes f :: "nat  'b :: ab_group_add"
  shows "sum f {..<n} = sum f {..n} - f n"
by (metis Groups.add_ac(2) add_diff_cancel_left' lessThan_Suc_atMost sum.lessThan_Suc)

lemma Collect_disjoint_atLeast:
  "Collect P  {x..} = {}  (yx. ¬ P y)"
by(auto simp add: atLeast_def)

lemma bounded_matrix_for_marginals_nat:
  fixes f g :: "nat  real"
    and R :: "(nat × nat) set"
    and s :: real
  assumes sum_f: "f sums s" and sum_g: "g sums s"
    and f_nonneg: "x. 0  f x" and g_nonneg: "y. 0  g y"
    and f_le_g: "X. suminf (zero_on f (- X))  suminf (zero_on g (- R `` X))"
  obtains h :: "nat  nat  real"
    where "x y. 0  h x y"
    and "x y. 0 < h x y  (x, y)  R"
    and "x. h x sums f x"
    and "y. (λx. h x y) sums g y"
proof -
  have summ_f: "summable f" and summ_g: "summable g" and sum_fg: "suminf f = suminf g"
    using sum_f sum_g by(auto simp add: summable_def sums_unique[symmetric])

  have summ_zf: "summable (zero_on f A)" for A
    using summ_f f_nonneg by(rule summable_zero_on_nonneg)
  have summ_zg: "summable (zero_on g A)" for A
    using summ_g g_nonneg by(rule summable_zero_on_nonneg)

  define f' :: "nat  nat  real"
    where "f' n x = (if x  n then f x else if x = Suc n then  k. f (k + (n + 1)) else 0)" for n x
  define g' :: "nat  nat  real"
    where "g' n y = (if y  n then g y else if y = Suc n then  k. g (k + (n + 1)) else 0)" for n y
  define R' :: "nat  (nat × nat) set"
    where "R' n = 
      R  {..n} × {..n}  
      {(n + 1, y) | y. x'>n. (x', y)  R  y  n} 
      {(x, n + 1) | x. y'>n. (x, y')  R  x  n} 
      (if x>n. y>n. (x, y)  R then {(n + 1, n + 1)} else {})"
    for n
  have R'_simps [simplified, simp]:
    " x  n; y  n   (x, y)  R' n  (x, y)  R"
    "y  n  (n + 1, y)  R' n  (x'>n. (x', y)  R)"
    "x  n  (x, n + 1)  R' n  (y'>n. (x, y')  R)"
    "(n + 1, n + 1)  R' n  (x'>n. y'>n. (x', y')  R)"
    "x > n + 1  y > n + 1  (x, y)  R' n"
    for x y n by(auto simp add: R'_def)

  have R'_cases: thesis if "(x, y)  R' n"
    and " x  n; y  n; (x, y)  R   thesis"
    and "x'.  x = n + 1; y  n; n < x'; (x', y)  R   thesis"
    and "y'.  x  n; y = n + 1; n < y'; (x, y')  R   thesis"
    and "x' y'.  x = n + 1; y = n + 1; n < x'; n < y'; (x', y')  R   thesis"
    for thesis x y n using that by(auto simp add: R'_def split: if_split_asm)

  have R'_intros:
    " (x, y)  R; x  n; y  n   (x, y)  R' n"
    " (x', y)  R; n < x'; y  n   (n + 1, y)  R' n"
    " (x, y')  R; x  n; n < y'   (x, n + 1)  R' n"
    " (x', y')  R; n < x'; n < y'   (n + 1, n + 1)  R' n"
    for n x y x' y' by(auto)

  have Image_R':
    "R' n `` X = (R `` (X  {..n}))  {..n}  
    (if n + 1  X then (R `` {n+1..})  {..n} else {}) 
    (if (R `` (X  {..n}))  {n+1..} = {} then {} else {n + 1}) 
    (if n + 1  X  (R `` {n+1..})  {n+1..}  {} then {n + 1} else {})" for n X
    apply(simp add: Image_def)
    apply(safe elim!: R'_cases; auto simp add: Collect_disjoint_atLeast intro: R'_intros simp add: Suc_le_eq dest: Suc_le_lessD)
    apply(metis R'_simps(4) R'_intros(3) Suc_eq_plus1)+
    done

  { fix n
    have "sum (f' n) {..n + 1} = sum (g' n) {..n + 1}" using sum_fg
      unfolding f'_def g'_def suminf_minus_initial_segment[OF summ_f] suminf_minus_initial_segment[OF summ_g]
      by(simp)(metis (no_types, opaque_lifting) add.commute add.left_inverse atLeast0AtMost atLeast0LessThan atLeastLessThanSuc_atLeastAtMost minus_add_distrib sum.lessThan_Suc uminus_add_conv_diff)
    moreover have "sum (f' n) X  sum (g' n) (R' n `` X)" if "X  {..n + 1}" for X
    proof -
      from that have [simp]: "finite X" by(rule finite_subset) simp
      define X' where "X'  if n + 1  X then X  {n+1..} else X"

      define Y' where "Y'  if R `` X'  {n+1..} = {} then R `` X' else R `` X'  {n+1..}"
        
      have "sum (f' n) X = sum (f' n) (X - {n + 1}) + (if n + 1  X then f' n (n + 1) else 0)"
        by(simp add: sum.remove)
      also have "sum (f' n) (X - {n + 1}) = sum f (X - {n + 1})" using that
        by(intro sum.cong)(auto simp add: f'_def)
      also have " + (if n + 1  X then f' n (n + 1) else 0) = suminf (zero_on f (- X'))"
      proof(cases "n + 1  X")
        case True
        with sum_f that show ?thesis
          apply(simp add: summable_def X'_def f'_def suminf_zero_on[OF sums_summable] del: One_nat_def)
          apply(subst suminf_minus_initial_segment[OF summ_f])
          apply(simp add: algebra_simps)
          apply(subst sum.union_disjoint[symmetric])
          apply(auto simp add: sum_lessThan_conv_atMost_nat intro!: sum.cong)
          done
      next
        case False
        with sum_f show ?thesis
          by(simp add: X'_def suminf_finite[where N=X])
      qed
      also have "  suminf (zero_on g (- R `` X'))" by(rule f_le_g)
      also have "  suminf (zero_on g (- Y'))"
        by(rule suminf_le[OF _ summ_zg summ_zg])(clarsimp simp add: override_on_def g_nonneg Y'_def summ_zg)
      also have " = suminf (λk. zero_on g (- Y') (k + (n + 1))) + sum (zero_on g (- Y')) {..n}"
        by(subst suminf_split_initial_segment[OF summ_zg, of _ "n + 1"])(simp add: sum_lessThan_conv_atMost_nat)
      also have "sum (zero_on g (- Y')) {..n} = sum g (Y'  {..n})"
        by(rule sum.mono_neutral_cong_right)(auto simp add: override_on_def)
      also have " = sum (g' n) (Y'  {..n})"
        by(rule sum.cong)(auto simp add: g'_def)
      also have "suminf (λk. zero_on g (- Y') (k + (n + 1)))  (if R `` X'  {n+1..} = {} then 0 else g' n (n + 1))"
        apply(clarsimp simp add: Y'_def g'_def simp del: One_nat_def)
        apply(subst suminf_eq_zero_iff[THEN iffD2])
        apply(auto simp del: One_nat_def simp add: summable_iff_shift summ_zg zero_on_nonneg g_nonneg)
        apply(auto simp add: override_on_def)
        done
      also have " + sum (g' n) (Y'  {..n}) = sum (g' n) (R' n `` X)"
        using that by(fastforce simp add: Image_R' Y'_def X'_def atMost_Suc intro: sum.cong[OF _ refl])
      finally show ?thesis by simp
    qed
    moreover have "x. 0  f' n x" "y. 0  g' n y"
      by(auto simp add: f'_def g'_def f_nonneg g_nonneg summable_iff_shift summ_f summ_g intro!: suminf_nonneg simp del: One_nat_def)
    moreover have "R' n  {..n+1} × {..n+1}"
      by(auto elim!: R'_cases)
    ultimately obtain h 
      where "x y.  x  n + 1; y  n + 1  0  h x y"
      and "x y.  0 < h x y; x  n + 1; y  n + 1  (x, y)  R' n"
      and "x. x  n + 1  f' n x = sum (h x) {..n + 1}"
      and "y. y  n + 1  g' n y = sum (λx. h x y) {..n + 1}"
      by(rule bounded_matrix_for_marginals_finite) blast+
    hence "h. (x y. x  n + 1  y  n + 1  0  h x y) 
       (x y. 0 < h x y  x  n + 1  y  n + 1  (x, y)  R' n) 
       (x. x  n + 1  f' n x = sum (h x) {..n + 1}) 
       (y. y  n + 1  g' n y = sum (λx. h x y) {..n + 1})" by blast }
  hence "h. n. (x y. x  n + 1  y  n + 1  0  h n x y) 
       (x y. 0 < h n x y  x  n + 1  y  n + 1  (x, y)  R' n) 
       (x. x  n + 1  f' n x = sum (h n x) {..n + 1}) 
       (y. y  n + 1  g' n y = sum (λx. h n x y) {..n + 1})"
    by(subst choice_iff[symmetric]) blast
  then obtain h where h_nonneg: "n x y.  x  n + 1; y  n + 1  0  h n x y"
    and h_R: "n x y.  0 < h n x y; x  n + 1; y  n + 1  (x, y)  R' n"
    and h_f: "n x. x  n + 1  f' n x = sum (h n x) {..n + 1}"
    and h_g: "n y. y  n + 1  g' n y = sum (λx. h n x y) {..n + 1}"
    apply(rule exE)
    subgoal for h by(erule meta_allE[of _ h]) blast
    done
  
  define h' :: "nat  nat × nat  real"
    where "h' n = (λ(x, y). if x  n  y  n then h n x y else 0)" for n
  have h'_nonneg: "h' n xy  0" for n xy by(simp add: h'_def h_nonneg split: prod.split)
  
  have "h' n xy  s" for n xy
  proof(cases xy)
    case [simp]: (Pair x y)
    consider (le) "x  n" "y  n" | (beyond) "x > n  y > n" by fastforce
    then show ?thesis
    proof cases
      case le
      have "h' n xy = h n x y" by(simp add: h'_def le)
      also have "  h n x y + sum (h n x) {..<y} + sum (h n x) {y<..n + 1}"
        using h_nonneg le by(auto intro!: sum_nonneg add_nonneg_nonneg)
      also have " = sum (h n x) {..y} +  sum (h n x) {y<..n + 1}"
        by(simp add: sum_lessThan_conv_atMost_nat)
      also have " = sum (h n x) {..n+1}" using le
        by(subst sum.union_disjoint[symmetric])(auto simp del: One_nat_def intro!: sum.cong)
      also have " = f' n x" using le by(simp add: h_f)
      also have " = f x" using le by(simp add: f'_def)
      also have " = suminf (zero_on f (- {x}))"
        by(subst suminf_finite[where N="{x}"]) simp_all
      also have "  suminf f"
        by(rule suminf_le)(auto simp add: f_nonneg summ_zf summ_f)
      also have " = s" using sum_f by(simp add: sums_unique[symmetric])
      finally show ?thesis .
    next
      case beyond
      then have "h' n xy = 0" by(auto simp add: h'_def)
      also have "0  s" using summ_f by(simp add: sums_unique[OF sum_f] suminf_nonneg f_nonneg)
      finally show ?thesis .
    qed
  qed
  then have "bounded (range (λn. h' n x))" for x unfolding bounded_def
    by(intro exI[of _ 0] exI[of _ s]; simp add: h'_nonneg)
  from convergent_bounded_family[OF this, of UNIV "%x. x"] obtain k 
    where k: "strict_mono k" and conv: "xy. convergent (λn. h' (k n) xy)" by auto

  define H :: "nat  nat  real"
    where "H x y = lim (λn. h' (k n) (x, y))" for x y

  have H: "(λn. h' (k n) (x, y))  H x y" for x y
    unfolding H_def using conv[of "(x, y)"] by(simp add: convergent_LIMSEQ_iff)

  show thesis
  proof(rule that)
    show H_nonneg: "0  H x y" for x y using H[of x y] by(rule LIMSEQ_le_const)(simp add: h'_nonneg)
    show "(x, y)  R" if "0 < H x y" for x y 
    proof(rule ccontr)
      assume "(x, y)  R"
      hence "h' n (x, y) = 0" for n using h_nonneg[of x n y] h_R[of n x y]
        by(fastforce simp add: h'_def)
      hence "H x y = 0" using H[of x y] by(simp add: LIMSEQ_const_iff)
      with that show False by simp
   qed
   show "H x sums f x" for x unfolding sums_iff
   proof
     have sum_H: "sum (H x) {..<m}  f x" for m
     proof -
       have "sum (λy. h' (k n) (x, y)) {..<m}  f x" for n
       proof(cases "x  k n")
         case True
         from k have "n  k n" by(rule seq_suble)
         have "sum (λy. h' (k n) (x, y)) {..<m} = sum (λy. h' (k n) (x, y)) {..<min m (k n + 1)}"
           by(rule sum.mono_neutral_right)(auto simp add: h'_def min_def)
         also have "  sum (λy. h (k n) x y) {..k n + 1}" using True
           by(intro sum_le_included[where i="id"])(auto simp add: h'_def h_nonneg)
         also have " = f' (k n) x" using h_f True by simp
         also have " = f x" using True by(simp add: f'_def)
         finally show ?thesis .
       qed(simp add: f_nonneg h'_def)
       then show ?thesis by -((rule LIMSEQ_le_const2 tendsto_sum H)+, simp)
     qed
     with H_nonneg show summ_H: "summable (H x)" by(rule summableI_nonneg_bounded)
     hence "suminf (H x)  f x" using sum_H by(rule suminf_le_const)
     moreover
     have "(λm. sum (H x) {..<m} + suminf (λn. g (n + m)))  suminf (H x) + 0"
       by(rule tendsto_intros summable_LIMSEQ summ_H suminf_exist_split2 summ_g)+
     hence "f x  suminf (H x) + 0"
     proof(rule LIMSEQ_le_const)
       have "f x  sum (H x) {..<m} + suminf (λn. g (n + m))" for m
       proof -
         have "(λn. sum (λy. h' (k n) (x, y)) {..<m} + suminf (λi. g (i + m)))  sum (H x) {..<m} + suminf (λi. g (i + m))"
           by(rule tendsto_intros H)+
         moreover have "N. nN. f x  sum (λy. h' (k n) (x, y)) {..<m} + suminf (λi. g (i + m))"
         proof(intro exI strip)
           fix n
           assume "max x m  n"
           with seq_suble[OF k, of n] have x: "x  k n" and m: "m  k n" by auto
           have "f x = f' (k n) x" using x by(simp add: f'_def)
           also have " = sum (h (k n) x) {..k n + 1}" using x by(simp add: h_f)
           also have " = sum (h (k n) x) {..<m} + sum (h (k n) x) {m..k n + 1}"
             using x m by(subst sum.union_disjoint[symmetric])(auto intro!: sum.cong simp del: One_nat_def)
           also have "sum (h (k n) x) {..<m} = sum (λy. h' (k n) (x, y)) {..<m}"
             using x m by(auto simp add: h'_def)
           also have "sum (h (k n) x) {m..k n + 1} = sum (λy. sum (λx. h (k n) x y) {x}) {m..k n + 1}" by simp
           also have "  sum (λy. sum (λx. h (k n) x y) {..k n + 1}) {m..k n + 1}" using x
             by(intro sum_mono sum_mono2)(auto simp add: h_nonneg)
           also have " = sum (g' (k n)) {m..k n + 1}" by(simp add: h_g del: One_nat_def)
           also have " = sum g {m..k n} + suminf (λi. g (i + (k n + 1)))" using m by(simp add: g'_def)
           also have " = suminf (λi. g (i + m))" using m
             apply(subst (2) suminf_split_initial_segment[where k="k n + 1 - m"])
             apply(simp_all add: summable_iff_shift summ_g)
             apply(rule sum.reindex_cong[OF _ _ refl])
             apply(simp_all add: Suc_diff_le lessThan_Suc_atMost)
             apply(safe; clarsimp)
             subgoal for x by(rule image_eqI[where x="x - m"]) auto
             subgoal by auto 
             done
           finally show "f x  sum (λy. h' (k n) (x, y)) {..<m} + " by simp
         qed
         ultimately show ?thesis by(rule LIMSEQ_le_const)
       qed
       thus "N. nN. f x  sum (H x) {..<n} + (i. g (i + n))" by auto
     qed
     ultimately show "suminf (H x) = f x" by simp
   qed
   show "(λx. H x y) sums g y" for y unfolding sums_iff
   proof
     have sum_H: "sum (λx. H x y) {..<m}  g y" for m
     proof -
       have "sum (λx. h' (k n) (x, y)) {..<m}  g y" for n
       proof(cases "y  k n")
         case True
         from k have "n  k n" by(rule seq_suble)
         have "sum (λx. h' (k n) (x, y)) {..<m} = sum (λx. h' (k n) (x, y)) {..<min m (k n + 1)}"
           by(rule sum.mono_neutral_right)(auto simp add: h'_def min_def)
         also have "  sum (λx. h (k n) x y) {..k n + 1}" using True
           by(intro sum_le_included[where i="id"])(auto simp add: h'_def h_nonneg)
         also have " = g' (k n) y" using h_g True by simp
         also have " = g y" using True by(simp add: g'_def)
         finally show ?thesis .
       qed(simp add: g_nonneg h'_def)
       then show ?thesis by -((rule LIMSEQ_le_const2 tendsto_sum H)+, simp)
     qed
     with H_nonneg show summ_H: "summable (λx. H x  y)" by(rule summableI_nonneg_bounded)
     hence "suminf (λx. H x y)  g y" using sum_H by(rule suminf_le_const)
     moreover
     have "(λm. sum (λx. H x y) {..<m} + suminf (λn. f (n + m)))  suminf (λx. H x y) + 0"
       by(rule tendsto_intros summable_LIMSEQ summ_H suminf_exist_split2 summ_f)+
     hence "g y  suminf (λx. H x y) + 0"
     proof(rule LIMSEQ_le_const)
       have "g y  sum (λx. H x y) {..<m} + suminf (λn. f (n + m))" for m
       proof -
         have "(λn. sum (λx. h' (k n) (x, y)) {..<m} + suminf (λi. f (i + m)))  sum (λx. H x y) {..<m} + suminf (λi. f (i + m))"
           by(rule tendsto_intros H)+
         moreover have "N. nN. g y  sum (λx. h' (k n) (x, y)) {..<m} + suminf (λi. f (i + m))"
         proof(intro exI strip)
           fix n
           assume "max y m  n"
           with seq_suble[OF k, of n] have y: "y  k n" and m: "m  k n" by auto
           have "g y = g' (k n) y" using y by(simp add: g'_def)
           also have " = sum (λx. h (k n) x y) {..k n + 1}" using y by(simp add: h_g)
           also have " = sum (λx. h (k n) x y) {..<m} + sum (λx. h (k n) x y) {m..k n + 1}"
             using y m by(subst sum.union_disjoint[symmetric])(auto intro!: sum.cong simp del: One_nat_def)
           also have "sum (λx. h (k n) x y) {..<m} = sum (λx. h' (k n) (x, y)) {..<m}"
             using y m by(auto simp add: h'_def)
           also have "sum (λx. h (k n) x y) {m..k n + 1} = sum (λx. sum (λy. h (k n) x y) {y}) {m..k n + 1}" by simp
           also have "  sum (λx. sum (λy. h (k n) x y) {..k n + 1}) {m..k n + 1}" using y
             by(intro sum_mono sum_mono2)(auto simp add: h_nonneg)
           also have " = sum (f' (k n)) {m..k n + 1}" by(simp add: h_f del: One_nat_def)
           also have " = sum f {m..k n} + suminf (λi. f (i + (k n + 1)))" using m by(simp add: f'_def)
           also have " = suminf (λi. f (i + m))" using m
             apply(subst (2) suminf_split_initial_segment[where k="k n + 1 - m"])
             apply(simp_all add: summable_iff_shift summ_f)
             apply(rule sum.reindex_cong[OF _ _ refl])
             apply(simp_all add: Suc_diff_le lessThan_Suc_atMost)
             apply(safe; clarsimp)
             subgoal for x by(rule image_eqI[where x="x - m"]) auto
             subgoal by auto 
             done
           finally show "g y  sum (λx. h' (k n) (x, y)) {..<m} + " by simp
         qed
         ultimately show ?thesis by(rule LIMSEQ_le_const)
       qed
       thus "N. nN. g y  sum (λx. H x y) {..<n} + (i. f (i + n))" by auto
     qed
     ultimately show "suminf (λx. H x y) = g y" by simp
   qed
  qed
qed

lemma bounded_matrix_for_marginals_ennreal:
  assumes sum_eq: "(+ xA. f x) = (+ yB. g y)"
    and finite: "(+ xB. g x)  "
    and le: "X. X  A  (+ xX. f x)  (+ yR `` X. g y)"
    and countable [simp]: "countable A" "countable B"
    and R: "R  A × B"
  obtains h where "x y. 0 < h x y  (x, y)  R"
    and "x y. h x y  "
    and "x. x  A  (+ yB. h x y) = f x"
    and "y. y  B  (+ xA. h x y) = g y"
proof -
  have fin_g [simp]: "g y  " if "y  B" for y using finite
    by(rule neq_top_trans)(rule nn_integral_ge_point[OF that])
  have fin_f [simp]: "f x  " if "x  A" for x using finite unfolding sum_eq[symmetric]
    by(rule neq_top_trans)(rule nn_integral_ge_point[OF that])

  define f' where "f' x = (if x  to_nat_on A ` A then enn2real (f (from_nat_into A x)) else 0)" for x
  define g' where "g' y = (if y  to_nat_on B ` B then enn2real (g (from_nat_into B y)) else 0)" for y
  define s where "s = enn2real (+ xB. g x)"
  define R' where "R' = map_prod (to_nat_on A) (to_nat_on B) ` R" 

  have f'_nonneg: "f' x  0" for x by(simp add: f'_def)
  have g'_nonneg: "g' y  0" for y by(simp add: g'_def)

  have "(+ x. f' x) = (+ xto_nat_on A ` A. f' x)"
    by(auto simp add: nn_integral_count_space_indicator f'_def intro!: nn_integral_cong)
  also have " = (+ xA. f x)"
    by(subst nn_integral_count_space_reindex)(auto simp add: inj_on_to_nat_on f'_def ennreal_enn2real_if intro!: nn_integral_cong)
  finally have sum_f': "(+ x. f' x) = (+ xA. f x)" .

  have "(+ y. g' y) = (+ yto_nat_on B ` B. g' y)"
    by(auto simp add: nn_integral_count_space_indicator g'_def intro!: nn_integral_cong)
  also have " = (+ yB. g y)"
    by(subst nn_integral_count_space_reindex)(auto simp add: inj_on_to_nat_on g'_def ennreal_enn2real_if intro!: nn_integral_cong)
  finally have sum_g': "(+ y. g' y) = (+ yB. g y)" .

  have summ_f': "summable f'"
  proof(rule summableI_nonneg_bounded)
    show "sum f' {..<n}  enn2real (+ x. f' x)" for n
    proof -
      have "sum f' {..<n} = enn2real (+ x{..<n}. f' x)" 
        by(simp add: nn_integral_count_space_finite f'_nonneg sum_nonneg)
      also have "enn2real (+ x{..<n}. f' x)  enn2real (+ x. f' x)" using finite sum_eq[symmetric]
        by(auto simp add: nn_integral_count_space_indicator sum_f'[symmetric] less_top intro!: nn_integral_mono enn2real_mono split: split_indicator)
      finally show ?thesis .
    qed
  qed(rule f'_nonneg)
  have suminf_f': "suminf f' = enn2real (+ y. f' y)"
    by(simp add: nn_integral_count_space_nat suminf_ennreal2[OF f'_nonneg summ_f'] suminf_nonneg[OF summ_f'] f'_nonneg)
  with summ_f' sum_f' sum_eq have sums_f: "f' sums s" by(simp add: s_def sums_iff)
  moreover
  have summ_g': "summable g'"
  proof(rule summableI_nonneg_bounded)
    show "sum g' {..<n}  enn2real (+ y. g' y)" for n
    proof -
      have "sum g' {..<n} = enn2real (+ y{..<n}. g' y)" 
        by(simp add: nn_integral_count_space_finite g'_nonneg sum_nonneg)
      also have "enn2real (+ y{..<n}. g' y)  enn2real (+ y. g' y)" using finite
        by(auto simp add: nn_integral_count_space_indicator sum_g'[symmetric] less_top intro!: nn_integral_mono enn2real_mono split: split_indicator)
      finally show ?thesis .
    qed
  qed(rule g'_nonneg)
  have suminf_g': "suminf g' = enn2real (+ y. g' y)"
    by(simp add: nn_integral_count_space_nat suminf_ennreal2[OF g'_nonneg summ_g'] suminf_nonneg[OF summ_g'] g'_nonneg)
  with summ_g' sum_g' have sums_g: "g' sums s" by(simp add: s_def sums_iff)
  moreover note f'_nonneg g'_nonneg
  moreover have "suminf (zero_on f' (- X))  suminf (zero_on g' (- R' `` X))" for X
  proof -
    define X' where "X' = from_nat_into A ` (X  to_nat_on A ` A)"
    have X': "to_nat_on A ` X' = X  (to_nat_on A ` A)"
      by(auto 4 3 simp add: X'_def intro: rev_image_eqI)

    have "ennreal (suminf (zero_on f' (- X))) = suminf (zero_on (λx. ennreal (f' x)) (- X))"
      by(simp add: suminf_ennreal2 zero_on_nonneg f'_nonneg summable_zero_on_nonneg summ_f')
    also have " = (+ xX. f' x)"
      by(auto simp add: nn_integral_count_space_nat[symmetric] nn_integral_count_space_indicator intro!: nn_integral_cong split: split_indicator)
    also have " = (+ xto_nat_on A ` X'. f' x)" using X'
      by(auto simp add: nn_integral_count_space_indicator f'_def intro!: nn_integral_cong split: split_indicator)
    also have " = (+ x  X'. f x)"
      by(subst nn_integral_count_space_reindex)(auto simp add: X'_def inj_on_def f'_def ennreal_enn2real_if intro!: nn_integral_cong)
    also have "  (+ y  R `` X'. g y)" by(rule le)(auto simp add: X'_def)
    also have " = (+ y  to_nat_on B ` (R `` X'). g' y)" using R fin_g
      by(subst nn_integral_count_space_reindex)(auto 4 3 simp add: X'_def inj_on_def g'_def ennreal_enn2real_if simp del: fin_g intro!: nn_integral_cong from_nat_into dest: to_nat_on_inj[THEN iffD1, rotated -1])
    also have "to_nat_on B ` (R `` X') = R' `` X" using R 
      by(auto 4 4 simp add: X'_def R'_def Image_iff intro: rev_image_eqI rev_bexI intro!: imageI)
    also have "(+ y. g' y) = suminf (zero_on (λy. ennreal (g' y)) (- ))"
      by(auto simp add: nn_integral_count_space_nat[symmetric] nn_integral_count_space_indicator intro!: nn_integral_cong split: split_indicator)
    also have " = ennreal (suminf (zero_on g' (- R' `` X)))"
      by(simp add: suminf_ennreal2 zero_on_nonneg g'_nonneg summable_zero_on_nonneg summ_g')
    finally show ?thesis 
      by(simp add: suminf_nonneg summable_zero_on_nonneg[OF summ_g' g'_nonneg] zero_on_nonneg g'_nonneg)
  qed
  ultimately obtain h' where h'_nonneg: "x y. 0  h' x y"
    and dom_h': "x y. 0 < h' x y  (x, y)  R'"
    and h'_f: "x. h' x sums f' x"
    and h'_g: "y. (λx. h' x y) sums g' y"
    by(rule bounded_matrix_for_marginals_nat) blast

  define h where "h x y = ennreal (if x  A  y  B then h' (to_nat_on A x) (to_nat_on B y) else 0)" for x y
  show ?thesis
  proof
    show "(x, y)  R" if "0 < h x y" for x y
      using that dom_h'[of "to_nat_on A x" "to_nat_on B y"] R 
      by(auto simp add: h_def R'_def dest: to_nat_on_inj[THEN iffD1, rotated -1] split: if_split_asm)
    show "h x y  " for x y by(simp add: h_def)

    fix x
    assume x: "x  A"
    have "(+ yB. h x y) = (+ yto_nat_on B ` B. h' (to_nat_on A x) y)"
      by(subst nn_integral_count_space_reindex)(auto simp add: inj_on_to_nat_on h_def x intro!: nn_integral_cong)
    also have " = (+ y. h' (to_nat_on A x) y)" using dom_h'[of "to_nat_on A x"] h'_nonneg R
      by(fastforce intro!: nn_integral_cong intro: rev_image_eqI simp add: nn_integral_count_space_indicator R'_def less_le split: split_indicator)
    also have " = ennreal (suminf (h' (to_nat_on A x)))" 
      by(simp add: nn_integral_count_space_nat suminf_ennreal_eq[OF _ h'_f] h'_nonneg) 
    also have " = ennreal (f' (to_nat_on A x))" using h'_f[of "to_nat_on A x"] by(simp add: sums_iff)
    also have " = f x" using x by(simp add: f'_def ennreal_enn2real_if)
    finally show "(+ yB. h x y) = f x" .
  next
    fix y
    assume y: "y  B"
    have "(+ xA. h x y) = (+ xto_nat_on A ` A. h' x (to_nat_on B y))"
      by(subst nn_integral_count_space_reindex)(auto simp add: inj_on_to_nat_on h_def y intro!: nn_integral_cong)
    also have " = (+ x. h' x (to_nat_on B y))" using dom_h'[of _ "to_nat_on B y"] h'_nonneg R
      by(fastforce intro!: nn_integral_cong intro: rev_image_eqI simp add: nn_integral_count_space_indicator R'_def less_le split: split_indicator)
    also have " = ennreal (suminf (λx. h' x (to_nat_on B y)))" 
      by(simp add: nn_integral_count_space_nat suminf_ennreal_eq[OF _ h'_g] h'_nonneg) 
    also have " = ennreal (g' (to_nat_on B y))" using h'_g[of "to_nat_on B y"] by(simp add: sums_iff)
    also have " = g y" using y by(simp add: g'_def ennreal_enn2real_if)
    finally show "(+ xA. h x y) = g y" .
  qed
qed

end