Theory Goodstein_Sequence

theory Goodstein_Sequence
imports Syntactic_Ordinal
(*  Title:       Termination of the Goodstein Sequence
    Author:      Jasmin Blanchette <jasmin.blanchette at inria.fr>, 2017
    Maintainer:  Jasmin Blanchette <jasmin.blanchette at inria.fr>
*)

section ‹Termination of the Goodstein Sequence›

theory Goodstein_Sequence
imports Multiset_More Syntactic_Ordinal
begin

text ‹
The ‹goodstein› function returns the successive values of the Goodstein
sequence. It is defined in terms of ‹encode› and ‹decode› functions,
which convert between natural numbers and ordinals. The development culminates
with a proof of Goodstein's theorem.
›


subsection ‹Lemmas about Division›

lemma div_mult_le: "m div n * n ≤ m" for m n :: nat
  by (fact div_times_less_eq_dividend)

lemma power_div_same_base:
  "b ^ y ≠ 0 ⟹ x ≥ y ⟹ b ^ x div b ^ y = b ^ (x - y)" for b :: "'a::semidom_divide"
  by (metis add_diff_inverse leD nonzero_mult_div_cancel_left power_add)


subsection ‹Hereditary and Nonhereditary Base-‹n› Systems›

context
  fixes base :: nat
  assumes base_ge_2: "base ≥ 2"
begin

inductive well_base :: "'a multiset ⇒ bool" where
  "(∀n. count M n < base) ⟹ well_base M"

lemma well_base_filter: "well_base M ⟹ well_base {#m ∈# M. p m#}"
  by (auto simp: well_base.simps)

lemma well_base_image_inj: "well_base M ⟹ inj_on f (set_mset M) ⟹ well_base (image_mset f M)"
  unfolding well_base.simps by (metis count_image_mset_le_count_inj_on le_less_trans)

lemma well_base_bound:
  assumes
    "well_base M" and
    "∀m ∈# M. m < n"
  shows "(∑m ∈# M. base ^ m) < base ^ n"
  using assms
proof (induct n arbitrary: M)
  case (Suc n)
  note ih = this(1) and well_M = this(2) and in_M_lt_Sn = this(3)

  let ?Meq = "{#m ∈# M. m = n#}"
  let ?Mne = "{#m ∈# M. m ≠ n#}"
  let ?K = "{#base ^ m. m ∈# M#}"

  have M: "M = ?Meq + ?Mne"
    by (simp)

  have well_Mne: "well_base ?Mne"
    by (rule well_base_filter[OF well_M])

  have in_Mne_lt_n: "∀m ∈# ?Mne. m < n"
    using in_M_lt_Sn by auto

  have "sum_mset (image_mset ((^) base) ?Meq) ≤ (base - 1) * base ^ n"
    unfolding filter_eq_replicate_mset using base_ge_2
    by simp (metis Suc_pred diff_self_eq_0 le_SucE less_imp_le less_le_trans less_numeral_extra(3)
      pos2 well_M well_base.cases zero_less_diff)
  moreover have "base * base ^ n = base ^ n + (base - Suc 0) * base ^ n"
    using base_ge_2 mult_eq_if by auto
  ultimately show ?case
    using ih[OF well_Mne in_Mne_lt_n] by (subst M) (simp del: union_filter_mset_complement)
qed simp

inductive well_baseh :: "hmultiset ⇒ bool" where
  "(∀N ∈# hmsetmset M. well_baseh N) ⟹ well_base (hmsetmset M) ⟹ well_baseh M"

lemma well_baseh_mono_hmset: "well_baseh M ⟹ hmsetmset N ⊆# hmsetmset M ⟹ well_baseh N"
  by (induct rule: well_baseh.induct, rule well_baseh.intros, blast)
    (meson leD leI order_trans subseteq_mset_def well_base.simps)

lemma well_baseh_imp_well_base: "well_baseh M ⟹ well_base (hmsetmset M)"
  by (erule well_baseh.cases) simp


subsection ‹Encoding of Natural Numbers into Ordinals›

function encode :: "nat ⇒ nat ⇒ hmultiset" where
  "encode e n =
   (if n = 0 then 0 else of_nat (n mod base) * ω^(encode 0 e) + encode (e + 1) (n div base))"
  by pat_completeness auto
termination
  using base_ge_2
proof (relation "measure (λ(e, n). n * (base ^ e + 1))"; simp)
  fix e n :: nat
  assume n_ge_0: "n > 0"

  have "e + e ≤ 2 ^ e"
    by (induct e; simp) (metis add_diff_cancel_left' add_leD1 diff_is_0_eq' double_not_eq_Suc_double
      le_antisym mult_2 not_less_eq_eq power_eq_0_iff zero_neq_numeral)
  also have "… ≤ base ^ e"
    using base_ge_2 by (simp add: power_mono)
  also have "… ≤ n * base ^ e"
    using n_ge_0 by (simp add: Suc_leI)
  also have "… < n + n * base ^ e"
    using n_ge_0 by simp
  finally show "e + e < n + n * base ^ e"
    by assumption

  have "n div base * (base * base ^ e) ≤ n * base ^ e"
    using base_ge_2 by (auto intro: div_mult_le)
  moreover have "n div base < n"
    using n_ge_0 base_ge_2 by simp
  ultimately show "n div base + n div base * (base * base ^ e) < n + n * base ^ e"
    by linarith
qed

declare encode.simps[simp del]

lemma encode_0[simp]: "encode e 0 = 0"
  by (subst encode.simps) simp

lemma encode_Suc:
  "encode e (Suc n) = of_nat (Suc n mod base) * ω^(encode 0 e) + encode (e + 1) (Suc n div base)"
  by (subst encode.simps) simp

lemma encode_0_iff: "encode e n = 0 ⟷ n = 0"
proof (induct n arbitrary: e rule: less_induct)
  case (less n)
  note ih = this

  show ?case
  proof (cases n)
    case 0
    thus ?thesis
      by simp
  next
    case n: (Suc m)
    show ?thesis
    proof (cases "n mod base = 0")
      case True
      hence "n div base ≠ 0"
        using div_eq_0_iff n by fastforce
      thus ?thesis
        using ih[of "Suc m div base"] n
        by (simp add: encode_Suc) (metis One_nat_def base_ge_2 div_eq_dividend_iff div_le_dividend
          leD lessI nat_neq_iff numeral_2_eq_2)
    next
      case False
      thus ?thesis
        using n plus_hmultiset_def by (simp add: encode_Suc[unfolded of_nat_times_ω_exp])
    qed
  qed
qed

lemma encode_Suc_exp: "encode (Suc e) n = encode e (base * n)"
  using base_ge_2
  by (subst (1 2) encode.simps, subst (4) encode.simps, simp add: zero_hmultiset_def[symmetric])

lemma encode_exp_0: "encode e n = encode 0 (base ^ e * n)"
  by (induct e arbitrary: n) (simp_all add: encode_Suc_exp mult.assoc mult.commute)

lemma mem_hmsetmset_encodeD: "M ∈# hmsetmset (encode e n) ⟹ ∃e' ≥ e. M = encode 0 e'"
proof (induct e n rule: encode.induct)
  case (1 e n)
  note ih = this(1-2) and M_in = this(3)

  show ?case
  proof (cases n)
    case 0
    thus ?thesis
      using M_in by simp
  next
    case n: (Suc m)

    {
      assume "M ∈# replicate_mset (n mod base) (encode 0 e)"
      hence ?thesis
        by (meson in_replicate_mset order_refl)
    }
    moreover
    {
      assume "M ∈# hmsetmset (encode (e + 1) (n div base))"
      hence ?thesis
        using ih(2) le_add1 n order_trans by blast
    }
    ultimately show ?thesis
      using M_in[unfolded n encode_Suc[unfolded of_nat_times_ω_exp], folded n]
      unfolding hmsetmset_plus by auto
  qed
qed

lemma less_imp_encode_less: "n < p ⟹ encode e n < encode e p"
proof (induct e n arbitrary: p rule: encode.induct)
  case (1 e n)
  note ih = this(1-2) and n_lt_p = this(3)

  show ?case
  proof (cases "n = 0")
    case True
    thus ?thesis
      using n_lt_p base_ge_2 encode_0_iff[of e p] le_less by fastforce
  next
    case n_nz: False

    let ?Ma = "replicate_mset (n mod base) (encode 0 e)"
    let ?Na = "replicate_mset (p mod base) (encode 0 e)"
    let ?Pa = "replicate_mset (n mod base - p mod base) (encode 0 e)"

    have "HMSet ?Ma + encode (Suc e) (n div base) < HMSet ?Na + encode (Suc e) (p div base)"
    proof (cases "n mod base < p mod base")
      case mod_lt: True
      show ?thesis
        by (rule add_less_le_mono, simp add: mod_lt,
          metis ih(2)[of "p div base", OF n_nz] Suc_eq_plus1 div_le_mono le_less n_lt_p)
    next
      case mod_ge: False
      hence div_lt: "n div base < p div base"
        by (metis add_le_cancel_left div_le_mono div_mult_mod_eq le_neq_implies_less less_imp_le
          n_lt_p nat_neq_iff)

      let ?M = "hmsetmset (encode (Suc e) (n div base))"
      let ?N = "hmsetmset (encode (Suc e) (p div base))"

      have "?M < ?N"
        by (auto intro!: ih(2)[folded Suc_eq_plus1] n_nz div_lt)
      then obtain X Y where
        X_nemp: "X ≠ {#}" and
        X_sub: "X ⊆# ?N" and
        M: "?M = ?N - X + Y" and
        ex_gt: "∀y. y ∈# Y ⟶ (∃x. x ∈# X ∧ x > y)"
        using less_multisetDM by metis

      {
        fix x
        assume x_in_X: "x ∈# X"
        hence x_in_N: "x ∈# ?N"
          using X_sub by blast
        then obtain e' where
          e'_gt: "e' > e" and
          x: "x = encode 0 e'"
          by (auto simp: Suc_le_eq dest: mem_hmsetmset_encodeD)

        have "x > encode 0 e"
          unfolding x using ih(1)[OF n_nz] e'_gt by (blast dest: Suc_lessD)
      }
      hence ex_gt_e: "∃x ∈# X. x > encode 0 e"
        using X_nemp by auto

      have X_sub': "X ⊆# ?Na + ?N"
        using X_sub by (simp add: subset_mset.add_increasing)
      have mam_eq: "?Ma + ?M = ?Na + ?N - X + (Y + ?Pa)"
      proof -
        from mod_ge have "?Ma = ?Na + ?Pa"
          by (simp add: replicate_mset_plus [symmetric])
        moreover have "?Na + ?N - X = ?Na + (?N - X)"
          by (meson X_sub multiset_diff_union_assoc)
        ultimately show ?thesis
          by (simp add: M)
      qed
      have max_X: "⋀k. k ∈# Y + ?Pa ⟹ ∃a. a ∈# X ∧ k < a"
        using ex_gt mod_ge ex_gt_e by (metis in_replicate_mset union_iff)

      show ?thesis
        by (subst (4 8) hmultiset.collapse[symmetric],
          unfold HMSet_plus[symmetric] HMSet_less less_multisetDM,
          rule exI[of _ X], rule exI[of _ "Y + ?Pa"],
          intro conjI impI allI X_nemp X_sub' mam_eq, elim max_X)
    qed
    thus ?thesis
      using n_nz n_lt_p by (subst (1 2) encode.simps[unfolded of_nat_times_ω_exp]) auto
  qed
qed

inductive alignede :: "nat ⇒ hmultiset ⇒ bool" where
  "(∀m ∈# hmsetmset M. m ≥ encode 0 e) ⟹ alignede e M"

lemma alignede_encode: "alignede e (encode e M)"
  by (subst encode_exp_0, rule alignede.intros,
    metis encode_exp_0 leD leI lessI less_imp_encode_less lift_Suc_mono_less_iff
      mem_hmsetmset_encodeD)

lemma well_baseh_encode: "well_baseh (encode e n)"
proof (induct e n rule: encode.induct)
  case (1 e n)
  note ih = this

  have well2: "∀M ∈# hmsetmset (encode (Suc e) (n div base)). well_baseh M"
    using ih(2) well_baseh.cases by (metis Suc_eq_plus1 Zero_not_Suc count_empty div_0
      encode_0_iff hmsetmset_empty_iff in_countE)

  have cnt1: "count (hmsetmset (encode (Suc e) (n div base))) (encode 0 e) = 0"
    using alignede_encode[unfolded alignede.simps]
      less_imp_encode_less[of n "Suc n" for n, simplified]
    by (meson count_inI leD)

  show ?case
  proof (rule well_baseh.intros)
    show "∀M ∈# hmsetmset (encode e n). well_baseh M"
      by (subst encode.simps[unfolded of_nat_times_ω_exp],
        simp add: zero_hmultiset_def hmsetmset_plus, use ih(1) well2 in blast)
  next
    show "well_base (hmsetmset (encode e n))"
      using cnt1 base_ge_2
      by (subst encode.simps[unfolded of_nat_times_ω_exp],
        simp add: well_base.simps zero_hmultiset_def hmsetmset_plus,
        metis ih(2) well_baseh.simps Suc_eq_plus1 less_numeral_extra(3) well_base.simps)
  qed
qed


subsection ‹Decoding of Natural Numbers from Ordinals›

primrec decode :: "nat ⇒ hmultiset ⇒ nat" where
  "decode e (HMSet M) = (∑m ∈# M. base ^ decode 0 m) div base ^ e"

lemma decode_unfold: "decode e M = (∑m ∈# hmsetmset M. base ^ decode 0 m) div base ^ e"
  by (cases M) simp

lemma decode_0[simp]: "decode e 0 = 0"
  unfolding zero_hmultiset_def by simp

inductive alignedd :: "nat ⇒ hmultiset ⇒ bool" where
  "(∀m ∈# hmsetmset M. decode 0 m ≥ e) ⟹ alignedd e M"

lemma alignedd_0[simp]: "alignedd 0 M"
  by (rule alignedd.intros) simp

lemma alignedd_mono_exp_Suc: "alignedd (Suc e) M ⟹ alignedd e M"
  by (auto simp: alignedd.simps)

lemma alignedd_mono_hmset:
  assumes "alignedd e M" and "hmsetmset M' ⊆# hmsetmset M"
  shows "alignedd e M'"
  using assms by (auto simp: alignedd.simps)

lemma decode_exp_shift_Suc:
  assumes alignd: "alignedd (Suc e) M"
  shows "decode e M = base * decode (Suc e) M"
proof (subst (1 2) decode_unfold, subst (1 2) sum_mset_distrib_div_if_dvd)
  note align' = alignd[unfolded alignedd.simps, simplified, unfolded Suc_le_eq]

  show "∀m ∈# hmsetmset M. base ^ Suc e dvd base ^ decode 0 m"
    using align' Suc_leI le_imp_power_dvd by blast

  show "∀m ∈# hmsetmset M. base ^ e dvd base ^ decode 0 m"
    using align' by (simp add: le_imp_power_dvd le_less)

  have base_e_nz: "base ^ e ≠ 0"
    using base_ge_2 by simp

  have mult_base:
    "base ^ decode 0 m div base ^ e = base * (base ^ decode 0 m div (base * base ^ e))"
    if m_in: "m ∈# hmsetmset M" for m
    using m_in align'
    by (subst power_div_same_base[OF base_e_nz], force,
      metis Suc_diff_Suc Suc_leI mult_is_0 power_Suc power_div_same_base power_not_zero)

  show "(∑m∈#hmsetmset M. base ^ decode 0 m div base ^ e) =
    base * (∑m∈#hmsetmset M. base ^ decode 0 m div base ^ Suc e)"
    by (auto simp: sum_mset_distrib_left intro!: arg_cong[of _ _ sum_mset] image_mset_cong
      elim!: mult_base)
qed

lemma decode_exp_shift:
  assumes "alignedd e M"
  shows "decode 0 M = base ^ e * decode e M"
  using assms by (induct e) (auto simp: decode_exp_shift_Suc dest: alignedd_mono_exp_Suc)

lemma decode_plus:
  assumes alignd_M: "alignedd e M"
  shows "decode e (M + N) = decode e M + decode e N"
  using alignd_M[unfolded alignedd.simps, simplified]
  by (subst (1 2 3) decode_unfold) (auto simp: hmsetmset_plus
    intro!: le_imp_power_dvd div_plus_div_distrib_dvd_left[OF sum_mset_dvd])

lemma less_imp_decode_less:
  assumes
    "well_baseh M" and
    "alignedd e M" and
    "alignedd e N" and
    "M < N"
  shows "decode e M < decode e N"
  using assms
proof (induct M arbitrary: N e rule: less_induct)
  case (less M)
  note ih = this(1) and wellh_M = this(2) and alignd_M = this(3) and alignd_N = this(4) and
    M_lt_N = this(5)

  obtain K Ma Na where
    M: "M = K + Ma" and
    N: "N = K + Na" and
    hds: "head_ω Ma < head_ω Na"
    using hmset_pair_decompose_less[OF M_lt_N] by blast

  obtain H where
    H: "head_ω Na = ω^H"
    using hds head_ω_def by fastforce
  have H_in: "H ∈# hmsetmset Na"
    by (metis (no_types) H Max_in add_mset_eq_single add_mset_not_empty finite_set_mset head_ω_def
      hmsetmset_empty_iff hmultiset.simps(1) set_mset_eq_empty_iff zero_hmultiset_def)

  have wellh_Ma: "well_baseh Ma"
    by (rule well_baseh_mono_hmset[OF wellh_M]) (simp add: M hmsetmset_plus)
  have alignd_K: "alignedd e K"
    using M alignd_M alignedd_mono_hmset hmsetmset_plus by auto
  have alignd_Ma: "alignedd e Ma"
    using M alignd_M alignedd_mono_hmset hmsetmset_plus by auto
  have alignd_Na: "alignedd e Na"
    using N alignd_N alignedd_mono_hmset hmsetmset_plus by auto

  have "inj_on (decode 0) (set_mset (hmsetmset Ma))"
    unfolding inj_on_def
  proof clarify
    fix x y
    assume
      x_in: "x ∈# hmsetmset Ma" and
      y_in: "y ∈# hmsetmset Ma" and
      dec_eq: "decode 0 x = decode 0 y"

    {
      fix x y
      assume
        x_in: "x ∈# hmsetmset Ma" and
        y_in: "y ∈# hmsetmset Ma" and
        x_lt_y: "x < y"

      have x_lt_M: "x < M"
        unfolding M using mem_hmsetmset_imp_less[OF x_in] by (simp add: trans_less_add2_hmset)
      have wellh_x: "well_baseh x"
        using wellh_Ma well_baseh.simps x_in by blast

      have "decode 0 x < decode 0 y"
        by (rule ih[OF x_lt_M wellh_x alignedd_0 alignedd_0 x_lt_y])
    }
    thus "x = y"
      using x_in y_in dec_eq by (metis leI less_irrefl_nat order.not_eq_order_implies_strict)
  qed
  hence well_dec_Ma: "well_base (image_mset (decode 0) (hmsetmset Ma))"
    by (rule well_base_image_inj[OF well_baseh_imp_well_base[OF wellh_Ma]])

  have H_bound: "∀m ∈# hmsetmset Ma. decode 0 m < decode 0 H"
  proof
    fix m
    assume m_in: "m ∈# hmsetmset Ma"

    have "∀m ∈# hmsetmset (head_ω Ma). m < H"
      using hds[unfolded H] using head_ω_def by auto
    hence m_lt_H: "m < H"
      using m_in
      by (metis Max_less_iff empty_iff finite_set_mset head_ω_def hmultiset.sel insert_iff
        set_mset_add_mset_insert)

    have m_lt_M: "m < M"
      using mem_hmsetmset_imp_less[OF m_in] by (simp add: M trans_less_add2_hmset)

    have wellh_m: "well_baseh m"
      using m_in wellh_Ma well_baseh.cases by blast

    show "decode 0 m < decode 0 H"
      by (rule ih[OF m_lt_M wellh_m alignedd_0 alignedd_0 m_lt_H])
  qed

  have "decode 0 Ma < base ^ decode 0 H"
    using well_base_bound[OF well_dec_Ma, simplified, OF H_bound] by (subst decode_unfold) simp
  also have "… ≤ decode 0 Na"
    by (subst (2) decode_unfold, simp, rule sum_image_mset_mono_mem[OF H_in])
  finally have "decode e Ma < decode e Na"
    using decode_exp_shift[OF alignd_Ma] decode_exp_shift[OF alignd_Na] by simp
  thus "decode e M < decode e N"
    unfolding M N by (simp add: decode_plus[OF alignd_K])
qed

lemma inj_decode: "inj_on (decode e) {M. well_baseh M ∧ alignedd e M}"
  unfolding inj_on_def Ball_def mem_Collect_eq
  by (metis less_imp_decode_less less_irrefl_nat neqE)

lemma decode_0_iff: "well_baseh M ⟹ alignedd e M ⟹ decode e M = 0 ⟷ M = 0"
  by (metis alignedd_0 decode_0 decode_exp_shift encode_0 less_imp_decode_less mult_0_right neqE
    not_less_zero well_baseh_encode)

lemma decode_encode: "decode e (encode e n) = n"
proof (induct e n rule: encode.induct)
  case (1 e n)
  note ih = this

  show ?case
  proof (cases "n = 0")
    case n_nz: False

    have alignd1: "alignedd e (of_nat (n mod base) * ω^(encode 0 e))"
      unfolding of_nat_times_ω_exp using n_nz by (auto simp: ih(1) alignedd.simps)
    have alignd2: "alignedd (Suc e) (encode (Suc e) (n div base))"
      by (safe intro!: alignedd.intros, subst ih(1)[OF n_nz, symmetric],
        auto dest: mem_hmsetmset_encodeD intro!: Suc_le_eq[THEN iffD2]
          less_imp_decode_less[OF well_baseh_encode alignedd_0 alignedd_0] less_imp_encode_less)

    show ?thesis
      using ih base_ge_2
      by (subst encode.simps[unfolded of_nat_times_ω_exp])
        (simp add: decode_plus[OF alignd1[unfolded of_nat_times_ω_exp]]
           decode_exp_shift_Suc[OF alignd2])
  qed simp
qed

lemma encode_decode_exp_0: "well_baseh M ⟹ encode 0 (decode 0 M) = M"
  by (auto intro: inj_onD[OF inj_decode] decode_encode well_baseh_encode)

end

lemma well_baseh_mono_base:
  assumes
    wellh: "well_baseh base M" and
    two: "2 ≤ base" and
    bases: "base ≤ base'"
  shows "well_baseh base' M"
  using two wellh
  by (induct rule: well_baseh.induct)
    (meson two bases less_le_trans order_trans well_baseh.intros well_base.simps)


subsection ‹The Goodstein Sequence and Goodstein's Theorem›

context
  fixes start :: nat
begin

primrec goodstein :: "nat ⇒ nat" where
  "goodstein 0 = start"
| "goodstein (Suc i) = decode (i + 3) 0 (encode (i + 2) 0 (goodstein i)) - 1"

lemma goodstein_step:
  assumes gi_gt_0: "goodstein i > 0"
  shows "encode (i + 2) 0 (goodstein i) > encode (i + 3) 0 (goodstein (i + 1))"
proof -
  let ?Ei = "encode (i + 2) 0 (goodstein i)"
  let ?reencode = "encode (i + 3) 0"
  let ?decoded_Ei = "decode (i + 3) 0 ?Ei"

  have two_le: "2 ≤ i + 3"
    by simp

  have "well_baseh (i + 2) ?Ei"
    by (rule well_baseh_encode) simp
  hence wellh: "well_baseh (i + 3) ?Ei"
    by (rule well_baseh_mono_base) simp_all

  have decoded_Ei_gt_0: "?decoded_Ei > 0"
    by (metis gi_gt_0 gr0I encode_0_iff le_add2 decode_0_iff[OF _ wellh alignedd_0] two_le)

  have "?reencode (?decoded_Ei - 1) < ?reencode ?decoded_Ei"
    by (rule less_imp_encode_less[OF two_le]) (use decoded_Ei_gt_0 in linarith)
  also have "… = ?Ei"
    by (simp only: encode_decode_exp_0[OF two_le wellh])
  finally show ?thesis
    by simp
qed

theorem goodsteins_theorem: "∃i. goodstein i = 0"
proof -
  let ?G = "λi. encode (i + 2) 0 (goodstein i)"

  obtain i where
    "¬ ?G i > ?G (i + 1)"
    using wf_iff_no_infinite_down_chain[THEN iffD1, OF wf,
      unfolded not_ex not_all mem_Collect_eq prod.case, rule_format, of ?G]
    by auto
  hence "goodstein i = 0"
    using goodstein_step by (metis add.assoc gr0I one_plus_numeral semiring_norm(3))
  thus ?thesis
    by blast
qed

end

end