Theory Syntactic_Ordinal

theory Syntactic_Ordinal
imports Hereditary_Multiset Product_Order Extended_Nat
(*  Title:       Syntactic Ordinals in Cantor Normal Form
    Author:      Jasmin Blanchette <jasmin.blanchette at inria.fr>, 2016
    Author:      Mathias Fleury <mfleury at mpi-inf.mpg.de>, 2016
    Author:      Dmitriy Traytel <traytel at inf.ethz.ch>, 2016
    Maintainer:  Jasmin Blanchette <jasmin.blanchette at inria.fr>
*)

section ‹Syntactic Ordinals in Cantor Normal Form›

theory Syntactic_Ordinal
imports Hereditary_Multiset "HOL-Library.Product_Order" "HOL-Library.Extended_Nat"
begin


subsection ‹Natural (Hessenberg) Product›

instantiation hmultiset :: comm_semiring_1
begin

abbreviation ω_exp :: "hmultiset ⇒ hmultiset" ("ω^") where
  "ω^ ≡ λm. HMSet {#m#}"

definition one_hmultiset :: hmultiset where
  "1 = ω^0"

abbreviation ω :: hmultiset where
  "ω ≡ ω^1"

definition times_hmultiset :: "hmultiset ⇒ hmultiset ⇒ hmultiset"  where
  "A * B = HMSet (image_mset (case_prod (+)) (hmsetmset A ×# hmsetmset B))"

lemma hmsetmset_times:
  "hmsetmset (m * n) = image_mset (case_prod (+)) (hmsetmset m ×# hmsetmset n)"
  unfolding times_hmultiset_def by simp

instance
proof (intro_classes, goal_cases assoc comm one distrib_plus zeroL zeroR zero_one)
  case (assoc a b c)
  thus ?case
    by (auto simp: times_hmultiset_def Times_mset_image_mset1 Times_mset_image_mset2
      Times_mset_assoc ac_simps intro: multiset.map_cong)
next
  case (comm a b)
  thus ?case
    unfolding times_hmultiset_def
    by (subst product_swap_mset[symmetric]) (auto simp: ac_simps intro: multiset.map_cong)
next
  case (one a)
  thus ?case
    by (auto simp: one_hmultiset_def times_hmultiset_def Times_mset_single_left)
next
  case (distrib_plus a b c)
  thus ?case
    by (auto simp: plus_hmultiset_def times_hmultiset_def)
next
  case (zeroL a)
  thus ?case
    by (auto simp: times_hmultiset_def)
next
  case (zeroR a)
  thus ?case
    by (auto simp: times_hmultiset_def)
next
  case zero_one
  thus ?case
    by (auto simp: one_hmultiset_def)
qed

end

lemma empty_times_left_hmset[simp]: "HMSet {#} * M = 0"
  by (simp add: times_hmultiset_def)

lemma empty_times_right_hmset[simp]: "M * HMSet {#} = 0"
  by (metis mult_zero_right zero_hmultiset_def)

lemma singleton_times_left_hmset[simp]: "ω^M * N = HMSet (image_mset ((+) M) (hmsetmset N))"
  by (simp add: times_hmultiset_def Times_mset_single_left)

lemma singleton_times_right_hmset[simp]: "N * ω^M = HMSet (image_mset ((+) M) (hmsetmset N))"
  by (metis mult.commute singleton_times_left_hmset)


subsection ‹Inequalities›

definition plus_nmultiset :: "unit nmultiset ⇒ unit nmultiset ⇒ unit nmultiset"  where
  "plus_nmultiset X Y = Rep_hmultiset (Abs_hmultiset X + Abs_hmultiset Y)"

lemma plus_nmultiset_mono:
  assumes less: "(X, Y) < (X', Y')" and no_elem: "no_elem X" "no_elem Y" "no_elem X'" "no_elem Y'"
  shows "plus_nmultiset X Y < plus_nmultiset X' Y'"
  using less[unfolded less_le_not_le] no_elem
  by (auto simp: plus_nmultiset_def plus_hmultiset_def less_multiset_extDM_less less_eq_nmultiset_def
          union_less_mono type_definition.Abs_inverse[OF type_definition_hmultiset, simplified]
        elim!: no_elem.cases)

lemma plus_hmultiset_transfer[transfer_rule]:
  "(rel_fun pcr_hmultiset (rel_fun pcr_hmultiset pcr_hmultiset)) plus_nmultiset (+)"
  unfolding rel_fun_def plus_nmultiset_def pcr_hmultiset_def nmultiset.rel_eq eq_OO cr_hmultiset_def
  by (auto simp: type_definition.Rep_inverse[OF type_definition_hmultiset])

lemma Times_mset_monoL:
  assumes less: "M < N" and Z_nemp: "Z ≠ {#}"
  shows "M ×# Z < N ×# Z"
proof -
  obtain Y X where
    Y_nemp: "Y ≠ {#}" and Y_sub_N: "Y ⊆# N" and M_eq: "M = N - Y + X" and
    ex_Y: "∀x. x ∈# X ⟶ (∃y. y ∈# Y ∧ x < y)"
    using less[unfolded less_multisetDM] by blast

  let ?X = "X ×# Z"
  let ?Y = "Y ×# Z"

  show ?thesis
    unfolding less_multisetDM
  proof (intro exI conjI)
    show "M ×# Z = N ×# Z - ?Y + ?X"
      unfolding M_eq by (auto simp: Sigma_mset_Diff_distrib1)
  next
    obtain y where y: "∀x. x ∈# X ⟶ y x ∈# Y ∧ x < y x"
      using ex_Y by moura

    show "∀x. x ∈# ?X ⟶ (∃y. y ∈# ?Y ∧ x < y)"
    proof (intro allI impI)
      fix x
      assume "x ∈# ?X"
      thus "∃y. y ∈# ?Y ∧ x < y"
        using y by (intro exI[of _ "(y (fst x), snd x)"]) (auto simp: less_le_not_le)
    qed
  qed (auto simp: Z_nemp Y_nemp Y_sub_N Sigma_mset_mono)
qed

lemma times_hmultiset_monoL:
  "a < b ⟹ 0 < c ⟹ a * c < b * c" for a b c :: hmultiset
  by (cases a, cases b, cases c, hypsubst_thin,
    unfold times_hmultiset_def zero_hmultiset_def hmultiset.sel, transfer,
    auto simp: less_multiset_extDM_less multiset.pred_set
      intro!: image_mset_strict_mono Times_mset_monoL elim!: plus_nmultiset_mono)

instance hmultiset :: linordered_semiring_strict
  by intro_classes (subst (1 2) mult.commute, rule times_hmultiset_monoL)

lemma mult_le_mono1_hmset: "i ≤ j ⟹ i * k ≤ j * k" for i j k :: hmultiset
  by (simp add: mult_right_mono)

lemma mult_le_mono2_hmset: "i ≤ j ⟹ k * i ≤ k * j" for i j k :: hmultiset
  by (simp add: mult_left_mono)

lemma mult_le_mono_hmset: "i ≤ j ⟹ k ≤ l ⟹ i * k ≤ j * l" for i j k l :: hmultiset
  by (simp add: mult_mono)

lemma less_iff_add1_le_hmset: "m < n ⟷ m + 1 ≤ n" for m n :: hmultiset
proof (cases m n rule: hmultiset.exhaust[case_product hmultiset.exhaust])
  case (HMSet_HMSet m0 n0)
  note m = this(1) and n = this(2)

  show ?thesis
  proof (simp add: m n one_hmultiset_def plus_hmultiset_def order.order_iff_strict
      less_multiset_extDM_less, intro iffI)
    assume m0_lt_n0: "m0 < n0"
    note
      m0_ne_n0 = m0_lt_n0[unfolded less_multisetHO, THEN conjunct1] and
      ex_n0_gt_m0 = m0_lt_n0[unfolded less_multisetHO, THEN conjunct2, rule_format]

    {
      assume zero_m0_gt_n0: "add_mset 0 m0 > n0"
      note
        n0_ne_0m0 = zero_m0_gt_n0[unfolded less_multisetHO, THEN conjunct1] and
        ex_0m0_gt_n0 = zero_m0_gt_n0[unfolded less_multisetHO, THEN conjunct2, rule_format]

      {
        fix y
        assume m0y_lt_n0y: "count m0 y < count n0 y"

        have "∃x > y. count n0 x < count m0 x"
        proof (cases "count (add_mset 0 m0) y < count n0 y")
          case True
          then obtain aa where
            aa_gt_y: "aa > y" and
            count_n0aa_lt_count_0m0aa: "count n0 aa < count (add_mset 0 m0) aa"
            using ex_0m0_gt_n0 by blast
          have "aa ≠ 0"
            by (rule gr_implies_not_zero_hmset[OF aa_gt_y])
          hence "count (add_mset 0 m0) aa = count m0 aa"
            by simp
          thus ?thesis
            using count_n0aa_lt_count_0m0aa aa_gt_y by auto
        next
          case not_0m0_y_lt_n0y: False
          hence y_eq_0: "y = 0"
            by (metis count_add_mset m0y_lt_n0y)
          have sm0y_eq_n0y: "Suc (count m0 y) = count n0 y"
            using m0y_lt_n0y not_0m0_y_lt_n0y count_add_mset[of 0 _ 0] unfolding y_eq_0 by simp

          obtain bb where "count n0 bb < count (add_mset 0 m0) bb"
            using lt_imp_ex_count_lt[OF zero_m0_gt_n0] by blast
          hence n0bb_lt_m0bb: "count n0 bb < count m0 bb"
            unfolding count_add_mset by (metis (full_types) less_irrefl_nat sm0y_eq_n0y y_eq_0)
          hence "bb ≠ 0"
            using sm0y_eq_n0y y_eq_0 by auto
          thus ?thesis
            unfolding y_eq_0 using n0bb_lt_m0bb not_gr_zero_hmset by blast
        qed
      }
      hence "n0 < m0"
        unfolding less_multisetHO using m0_ne_n0 by blast
      hence False
        using m0_lt_n0 by simp
    }
    thus "add_mset 0 m0 < n0 ∨ add_mset 0 m0 = n0"
      using antisym_conv3 by blast
  next
    assume "add_mset 0 m0 < n0 ∨ add_mset 0 m0 = n0"
    thus "m0 < n0"
      using dual_order.strict_trans le_multiset_right_total by blast
  qed
qed

lemma zero_less_iff_1_le_hmset: "0 < n ⟷ 1 ≤ n" for n :: hmultiset
  by (rule less_iff_add1_le_hmset[of 0, simplified])

lemma less_add_1_iff_le_hmset: "m < n + 1 ⟷ m ≤ n" for m n :: hmultiset
  by (rule less_iff_add1_le_hmset[of m "n + 1", simplified])

instance hmultiset :: ordered_cancel_comm_semiring
  by intro_classes (simp add: mult_le_mono2_hmset)

instance hmultiset :: zero_less_one
  by intro_classes (simp add: zero_less_iff_neq_zero_hmset)

instance hmultiset :: linordered_semiring_1_strict
  by intro_classes

instance hmultiset :: bounded_lattice_bot
  by intro_classes

instance hmultiset :: linordered_nonzero_semiring
  by intro_classes simp

instance hmultiset :: semiring_no_zero_divisors
  by intro_classes (use mult_pos_pos not_gr_zero_hmset in blast)

lemma lt_1_iff_eq_0_hmset: "M < 1 ⟷ M = 0" for M :: hmultiset
  by (simp add: less_iff_add1_le_hmset)

lemma zero_less_mult_iff_hmset[simp]: "0 < m * n ⟷ 0 < m ∧ 0 < n" for m n :: hmultiset
  using mult_eq_0_iff not_gr_zero_hmset by blast

lemma one_le_mult_iff_hmset[simp]: "1 ≤ m * n ⟷ 1 ≤ m ∧ 1 ≤ n" for m n :: hmultiset
  by (metis lt_1_iff_eq_0_hmset mult_eq_0_iff not_le)

lemma mult_less_cancel2_hmset[simp]: "m * k < n * k ⟷ 0 < k ∧ m < n" for k m n :: hmultiset
  by (metis gr_zeroI_hmset leD leI le_cases mult_right_mono mult_zero_right times_hmultiset_monoL)

lemma mult_less_cancel1_hmset[simp]: "k * m < k * n ⟷ 0 < k ∧ m < n" for k m n :: hmultiset
  by (simp add: mult.commute[of k])

lemma mult_le_cancel1_hmset[simp]: "k * m ≤ k * n ⟷ (0 < k ⟶ m ≤ n)" for k m n :: hmultiset
  by (simp add: linorder_not_less[symmetric], auto)

lemma mult_le_cancel2_hmset[simp]: "m * k ≤ n * k ⟷ (0 < k ⟶ m ≤ n)" for k m n :: hmultiset
  by (simp add: linorder_not_less[symmetric], auto)

lemma mult_le_cancel_left1_hmset: "y > 0 ⟹ x ≤ x * y" for x y :: hmultiset
  by (metis zero_less_iff_1_le_hmset mult.commute mult.left_neutral mult_le_cancel2_hmset)

lemma mult_le_cancel_left2_hmset: "y ≤ 1 ⟹ x * y ≤ x" for x y :: hmultiset
  by (metis mult.commute mult.left_neutral mult_le_cancel2_hmset)

lemma mult_le_cancel_right1_hmset: "y > 0 ⟹ x ≤ y * x" for x y :: hmultiset
  by (subst mult.commute) (fact mult_le_cancel_left1_hmset)

lemma mult_le_cancel_right2_hmset: "y ≤ 1 ⟹ y * x ≤ x" for x y :: hmultiset
  by (subst mult.commute) (fact mult_le_cancel_left2_hmset)

lemma le_square_hmset: "m ≤ m * m" for m :: hmultiset
  using mult_le_cancel_left1_hmset by force

lemma le_cube_hmset: "m ≤ m * (m * m)" for m :: hmultiset
  using mult_le_cancel_left1_hmset by force

lemma
  less_imp_minus_plus_hmset: "m < n ⟹ k < k - m + n" and
  le_imp_minus_plus_hmset: "m ≤ n ⟹ k ≤ k - m + n" for k m n :: hmultiset
  by (meson add_less_cancel_left leD le_minus_plus_same_hmset less_le_trans not_le_imp_less)+

lemma gt_0_lt_mult_gt_1_hmset:
  fixes m n :: hmultiset
  assumes "m > 0" and "n > 1"
  shows "m < m * n"
  using assms by (metis mult.right_neutral mult_less_cancel1_hmset)

instance hmultiset :: linordered_comm_semiring_strict
  by intro_classes simp


subsection ‹Embedding of Natural Numbers›

lemma of_nat_hmset: "of_nat n = HMSet (replicate_mset n 0)"
  by (induct n) (auto simp: zero_hmultiset_def one_hmultiset_def plus_hmultiset_def)

lemma of_nat_inject_hmset[simp]: "(of_nat m :: hmultiset) = of_nat n ⟷ m = n"
  unfolding of_nat_hmset by simp

lemma of_nat_minus_hmset: "of_nat (m - n) = (of_nat m :: hmultiset) - of_nat n"
  unfolding of_nat_hmset minus_hmultiset_def by simp

lemma plus_of_nat_plus_of_nat_hmset:
  "k + of_nat m + of_nat n = k + of_nat (m + n)" for k :: hmultiset
  by simp

lemma plus_of_nat_minus_of_nat_hmset:
  fixes k :: hmultiset
  assumes "n ≤ m"
  shows "k + of_nat m - of_nat n = k + of_nat (m - n)"
  using assms by (metis add.left_commute add_diff_cancel_left' le_add_diff_inverse of_nat_add)

lemma of_nat_lt_ω[simp]: "of_nat n < ω"
  by (auto simp: of_nat_hmset zero_less_iff_neq_zero_hmset less_multiset_extDM_less)

lemma of_nat_ne_ω[simp]: "of_nat n ≠ ω"
  by (simp add: neq_iff)

lemma of_nat_less_hmset[simp]: "(of_nat M :: hmultiset) < of_nat N ⟷ M < N"
  unfolding of_nat_hmset less_multiset_extDM_less by simp

lemma of_nat_le_hmset[simp]: "(of_nat M :: hmultiset) ≤ of_nat N ⟷ M ≤ N"
  unfolding of_nat_hmset order_le_less less_multiset_extDM_less by simp

lemma of_nat_times_ω_exp: "of_nat n * ω^m = HMSet (replicate_mset n m)"
  by (induct n) (simp_all add: hmsetmset_plus one_hmultiset_def)

lemma ω_exp_times_of_nat: "ω^m * of_nat n = HMSet (replicate_mset n m)"
  using of_nat_times_ω_exp by simp


subsection ‹Embedding of Extended Natural Numbers›

primrec hmset_of_enat :: "enat ⇒ hmultiset" where
  "hmset_of_enat (enat n) = of_nat n"
| "hmset_of_enat ∞ = ω"

lemma hmset_of_enat_0[simp]: "hmset_of_enat 0 = 0"
  by (simp add: zero_enat_def)

lemma hmset_of_enat_1[simp]: "hmset_of_enat 1 = 1"
  by (simp add: one_enat_def del: One_nat_def)

lemma hmset_of_enat_of_nat[simp]: "hmset_of_enat (of_nat n) = of_nat n"
  using of_nat_eq_enat by auto

lemma hmset_of_enat_numeral[simp]: "hmset_of_enat (numeral n) = numeral n"
  by (simp add: numeral_eq_enat)

lemma hmset_of_enat_le_ω[simp]: "hmset_of_enat n ≤ ω"
  using of_nat_lt_ω[THEN less_imp_le] by (cases n) auto

lemma hmset_of_enat_eq_ω_iff[simp]: "hmset_of_enat n = ω ⟷ n = ∞"
  by (cases n) auto


subsection ‹Head Omega›

definition head_ω :: "hmultiset ⇒ hmultiset" where
  "head_ω M = (if M = 0 then 0 else ω^(Max (set_mset (hmsetmset M))))"

lemma head_ω_subseteq: "hmsetmset (head_ω M) ⊆# hmsetmset M"
  unfolding head_ω_def by simp

lemma head_ω_eq_0_iff[simp]: "head_ω m = 0 ⟷ m = 0"
  unfolding head_ω_def zero_hmultiset_def by simp

lemma head_ω_0[simp]: "head_ω 0 = 0"
  by simp

lemma head_ω_1[simp]: "head_ω 1 = 1"
  unfolding head_ω_def one_hmultiset_def by simp

lemma head_ω_of_nat[simp]: "head_ω (of_nat n) = (if n = 0 then 0 else 1)"
  unfolding head_ω_def one_hmultiset_def of_nat_hmset by simp

lemma head_ω_numeral[simp]: "head_ω (numeral n) = 1"
  by (metis head_ω_of_nat of_nat_numeral zero_neq_numeral)

lemma head_ω_ω[simp]: "head_ω ω = ω"
  unfolding head_ω_def by simp

lemma le_imp_head_ω_le:
  assumes m_le_n: "m ≤ n"
  shows "head_ω m ≤ head_ω n"
proof -
  have le_in_le_max: "⋀a M N. M ≤ N ⟹ a ∈# M ⟹ a ≤ Max (set_mset N)"
    by (metis (no_types) Max_ge finite_set_mset le_less less_eq_multisetHO linorder_not_less
      mem_Collect_eq neq0_conv order_trans set_mset_def)
  show ?thesis
    using m_le_n unfolding head_ω_def
    by (cases m, cases n,
      auto simp del: hmsetmset_le simp: head_ω_def hmsetmset_le[symmetric] zero_hmultiset_def,
      metis Max_in dual_order.antisym finite_set_mset le_in_le_max le_less set_mset_eq_empty_iff)
qed

lemma head_ω_lt_imp_lt: "head_ω m < head_ω n ⟹ m < n"
  unfolding head_ω_def hmsetmset_less[symmetric]
  by (rule all_lt_Max_imp_lt_mset, auto simp: zero_hmultiset_def split: if_splits)

lemma head_ω_plus[simp]: "head_ω (m + n) = sup (head_ω m) (head_ω n)"
proof (cases m n rule: hmultiset.exhaust[case_product hmultiset.exhaust])
  case m_n: (HMSet_HMSet M N)
  show ?thesis
  proof (cases "Max_mset M < Max_mset N")
    case True
    thus ?thesis
      unfolding m_n head_ω_def sup_hmultiset_def zero_hmultiset_def plus_hmultiset_def
      by (simp add: Max.union max_def dual_order.strict_implies_order)
  next
    case False
    thus ?thesis
      unfolding m_n head_ω_def sup_hmultiset_def zero_hmultiset_def plus_hmultiset_def
      by simp (metis False Max.union finite_set_mset leI max_def set_mset_eq_empty_iff sup.commute)
  qed
qed

lemma head_ω_times[simp]: "head_ω (m * n) = head_ω m * head_ω n"
proof (cases "m = 0 ∨ n = 0")
  case False
  hence m_nz: "m ≠ 0" and n_nz: "n ≠ 0"
    by simp+

  define δ where "δ = hmsetmset m"
  define ε where "ε = hmsetmset n"

  have δ_nemp: "δ ≠ {#}"
    unfolding δ_def using m_nz by simp
  have ε_nemp: "ε ≠ {#}"
    unfolding ε_def using n_nz by simp

  let ?D = "set_mset δ"
  let ?E = "set_mset ε"
  let ?DE = "{z. ∃x ∈ ?D. ∃y ∈ ?E. z = x + y}"

  have max_D_in: "Max ?D ∈ ?D"
    using δ_nemp by simp
  have max_E_in: "Max ?E ∈ ?E"
    using ε_nemp by simp

  have "Max ?DE = Max ?D + Max ?E"
  proof (rule order_antisym, goal_cases le ge)
    case le
    have "⋀x y. x ∈ ?D ⟹ y ∈ ?E ⟹ x + y ≤ Max ?D + Max ?E"
      by (simp add: add_mono)
    hence mem_imp_le: "⋀z. z ∈ ?DE ⟹ z ≤ Max ?D + Max ?E"
      by auto
    show ?case
      by (intro mem_imp_le Max_in, simp, use δ_nemp ε_nemp in fast)
  next
    case ge
    have "{z. ∃x ∈ {Max ?D}. ∃y ∈ {Max ?E}. z = x + y} ⊆ {z. ∃x ∈# δ. ∃y ∈# ε. z = x + y}"
      using max_D_in max_E_in by fast
    thus ?case
      by simp
  qed
  thus ?thesis
    unfolding δ_def ε_def by (auto simp: head_ω_def image_def times_hmultiset_def)
qed auto


subsection ‹More Inequalities and Some Equalities›

lemma zero_lt_ω[simp]: "0 < ω"
  by (metis of_nat_lt_ω of_nat_0)

lemma one_lt_ω[simp]: "1 < ω"
  by (metis enat_defs(2) hmset_of_enat.simps(1) hmset_of_enat_1 of_nat_lt_ω)

lemma numeral_lt_ω[simp]: "numeral n < ω"
  using hmset_of_enat_numeral[symmetric] hmset_of_enat.simps(1) of_nat_lt_ω numeral_eq_enat
  by presburger

lemma one_le_ω[simp]: "1 ≤ ω"
  by (simp add: less_imp_le)

lemma of_nat_le_ω[simp]: "of_nat n ≤ ω"
  by (simp add: le_less)

lemma numeral_le_ω[simp]: "numeral n ≤ ω"
  by (simp add: less_imp_le)

lemma not_ω_lt_1[simp]: "¬ ω < 1"
  by (simp add: not_less)

lemma not_ω_lt_of_nat[simp]: "¬ ω < of_nat n"
  by (simp add: not_less)

lemma not_ω_lt_numeral[simp]: "¬ ω < numeral n"
  by (simp add: not_less)

lemma not_ω_le_1[simp]: "¬ ω ≤ 1"
  by (simp add: not_le)

lemma not_ω_le_of_nat[simp]: "¬ ω ≤ of_nat n"
  by (simp add: not_le)

lemma not_ω_le_numeral[simp]: "¬ ω ≤ numeral n"
  by (simp add: not_le)

lemma zero_ne_ω[simp]: "0 ≠ ω"
  by (metis not_ω_le_1 zero_le_hmset)

lemma one_ne_ω[simp]: "1 ≠ ω"
  using not_ω_le_1 by force

lemma numeral_ne_ω[simp]: "numeral n ≠ ω"
  by (metis not_ω_le_numeral numeral_le_ω)

lemma
  ω_ne_0[simp]: "ω ≠ 0" and
  ω_ne_1[simp]: "ω ≠ 1" and
  ω_ne_of_nat[simp]: "ω ≠ of_nat m" and
  ω_ne_numeral[simp]: "ω ≠ numeral n"
  using zero_ne_ω one_ne_ω of_nat_ne_ω numeral_ne_ω by metis+

lemma
  hmset_of_enat_inject[simp]: "hmset_of_enat m = hmset_of_enat n ⟷ m = n" and
  hmset_of_enat_less[simp]: "hmset_of_enat m < hmset_of_enat n ⟷ m < n" and
  hmset_of_enat_le[simp]: "hmset_of_enat m ≤ hmset_of_enat n ⟷ m ≤ n"
  by (cases m; cases n; simp)+

lemma lt_ω_imp_ex_of_nat:
  assumes M_lt_ω: "M < ω"
  shows "∃n. M = of_nat n"
proof -
  have M_lt_single_1: "hmsetmset M < {#1#}"
    by (rule M_lt_ω[unfolded hmsetmset_less[symmetric] less_multiset_extDM_less hmultiset.sel])

  have "N = 0" if "N ∈# hmsetmset M" for N
  proof -
    have "0 < count (hmsetmset M) N"
      using that by auto
    hence "N < 1"
      by (metis (no_types) M_lt_single_1 count_single gr_implies_not0 less_eq_multisetHO less_one
        neq_iff not_le)
    thus ?thesis
      by (simp add: lt_1_iff_eq_0_hmset)
  qed
  then obtain n where hmmM: "M = HMSet (replicate_mset n 0)"
    using ex_replicate_mset_if_all_elems_eq by (metis hmultiset.collapse)
  show ?thesis
    unfolding hmmM of_nat_hmset by blast
qed

lemma le_ω_imp_ex_hmset_of_enat:
  assumes M_le_ω: "M ≤ ω"
  shows "∃n. M = hmset_of_enat n"
proof (cases "M = ω")
  case True
  thus ?thesis
    by (metis hmset_of_enat.simps(2))
next
  case False
  thus ?thesis
    using M_le_ω lt_ω_imp_ex_of_nat by (metis hmset_of_enat.simps(1) le_less)
qed

lemma lt_ω_lt_ω_imp_times_lt_ω: "M < ω ⟹ N < ω ⟹ M * N < ω"
  by (metis lt_ω_imp_ex_of_nat of_nat_lt_ω of_nat_mult)

lemma times_ω_minus_of_nat[simp]: "m * ω - of_nat n = m * ω"
  by (auto intro!: Diff_triv_mset simp: times_hmultiset_def minus_hmultiset_def
    Times_mset_single_right of_nat_hmset disjunct_not_in image_def)

lemma times_ω_minus_numeral[simp]: "m * ω - numeral n = m * ω"
  by (metis of_nat_numeral times_ω_minus_of_nat)

lemma ω_minus_of_nat[simp]: "ω - of_nat n = ω"
  using times_ω_minus_of_nat[of 1] by (metis mult.left_neutral)

lemma ω_minus_1[simp]: "ω - 1 = ω"
  using ω_minus_of_nat[of 1] by simp

lemma ω_minus_numeral[simp]: "ω - numeral n = ω"
  using times_ω_minus_numeral[of 1] by (metis mult.left_neutral)

lemma hmset_of_enat_minus_enat[simp]: "hmset_of_enat (m - enat n) = hmset_of_enat m - of_nat n"
  by (cases m) (auto simp: of_nat_minus_hmset)

lemma of_nat_lt_hmset_of_enat_iff: "of_nat m < hmset_of_enat n ⟷ enat m < n"
  by (metis hmset_of_enat.simps(1) hmset_of_enat_less)

lemma of_nat_le_hmset_of_enat_iff: "of_nat m ≤ hmset_of_enat n ⟷ enat m ≤ n"
  by (metis hmset_of_enat.simps(1) hmset_of_enat_le)

lemma hmset_of_enat_lt_iff_ne_infinity: "hmset_of_enat x < ω ⟷ x ≠ ∞"
  by (cases x; simp)

lemma minus_diff_sym_hmset: "m - (m - n) = n - (n - m)" for m n :: hmultiset
  unfolding minus_hmultiset_def by simp (metis multiset_inter_def subset_mset.inf_aci(1))

lemma diff_plus_sym_hmset: "(c - b) + b = (b - c) + c" for b c :: hmultiset
proof -
  have f1: "⋀h ha :: hmultiset. h - (ha + h) = 0"
    by (simp add: add.commute)
  have f2: "⋀h ha hb :: hmultiset. h + ha - (h - hb) = hb + ha - (hb - h)"
    by (metis (no_types) add_diff_cancel_right minus_diff_sym_hmset)
  have "⋀h ha hb :: hmultiset. h + (ha + hb) - hb = h + ha"
    by (metis (no_types) add.assoc add_diff_cancel_right')
  then show ?thesis
    using f2 f1 by (metis (no_types) add.commute add.right_neutral diff_diff_add_hmset)
qed

lemma times_diff_plus_sym_hmset: "a * (c - b) + a * b = a * (b - c) + a * c" for a b c :: hmultiset
  by (metis distrib_left diff_plus_sym_hmset)

lemma times_of_nat_minus_left:
  "(of_nat m - of_nat n) * l = of_nat m * l - of_nat n * l" for l :: hmultiset
  by (induct n m rule: diff_induct) (auto simp: ring_distribs)

lemma times_of_nat_minus_right:
  "l * (of_nat m - of_nat n) = l * of_nat m - l * of_nat n" for l :: hmultiset
  by (metis times_of_nat_minus_left mult.commute)

lemma lt_ω_imp_times_minus_left: "m < ω ⟹ n < ω ⟹ (m - n) * l = m * l - n * l"
  by (metis lt_ω_imp_ex_of_nat times_of_nat_minus_left)

lemma lt_ω_imp_times_minus_right: "m < ω ⟹ n < ω ⟹ l * (m - n) = l * m - l * n"
  by (metis lt_ω_imp_ex_of_nat times_of_nat_minus_right)

lemma hmset_pair_decompose:
  "∃k n1 n2. m1 = k + n1 ∧ m2 = k + n2 ∧ (head_ω n1 ≠ head_ω n2 ∨ n1 = 0 ∧ n2 = 0)"
proof -
  define n1 where n1: "n1 = m1 - m2"
  define n2 where n2: "n2 = m2 - m1"
  define k where k1: "k = m1 - n1"

  have k2: "k = m2 - n2"
    using k1 unfolding n1 n2 by (simp add: minus_diff_sym_hmset)

  have "m1 = k + n1"
    unfolding k1
    by (metis (no_types) n1 add_diff_cancel_left add.commute add_diff_cancel_right' diff_add_zero
      diff_diff_add minus_diff_sym_hmset)
  moreover have "m2 = k + n2"
    unfolding k2
    by (metis n2 add.commute add_diff_cancel_left add_diff_cancel_left' add_diff_cancel_right'
      diff_add_zero diff_diff_add diff_zero k2 minus_diff_sym_hmset)
  moreover have hd_n: "head_ω n1 ≠ head_ω n2" if n1_or_n2_nz: "n1 ≠ 0 ∨ n2 ≠ 0"
  proof (cases "n1 = 0" "n2 = 0" rule: bool.exhaust[case_product bool.exhaust])
    case False_False
    note n1_nz = this(1)[simplified] and n2_nz = this(2)[simplified]

    define δ1 where "δ1 = hmsetmset n1"
    define δ2 where "δ2 = hmsetmset n2"

    have δ1_inter_δ2: "δ1 ∩# δ2 = {#}"
      unfolding δ1_def δ2_def n1 n2 minus_hmultiset_def by (simp add: diff_intersect_sym_diff)

    have δ1_ne: "δ1 ≠ {#}"
      unfolding δ1_def using n1_nz by simp
    have δ2_ne: "δ2 ≠ {#}"
      unfolding δ2_def using n2_nz by simp

    have max_δ1: "Max (set_mset δ1) ∈# δ1"
      using δ1_ne by simp
    have max_δ2: "Max (set_mset δ2) ∈# δ2"
      using δ2_ne by simp
    have max_δ1_ne_δ2: "Max (set_mset δ1) ≠ Max (set_mset δ2)"
      using δ1_inter_δ2 disjunct_not_in max_δ1 max_δ2 by force

    show ?thesis
      using n1_nz n2_nz
      by (cases n1 rule: hmultiset.exhaust_sel, cases n2 rule: hmultiset.exhaust_sel,
        auto simp: head_ω_def zero_hmultiset_def max_δ1_ne_δ2[unfolded δ1_def δ2_def])
  qed (use n1_or_n2_nz in ‹auto simp: head_ω_def›)
  ultimately show ?thesis
    by blast
qed

lemma hmset_pair_decompose_less:
  assumes m1_lt_m2: "m1 < m2"
  shows "∃k n1 n2. m1 = k + n1 ∧ m2 = k + n2 ∧ head_ω n1 < head_ω n2"
proof -
  obtain k n1 n2 where
    m1: "m1 = k + n1" and
    m2: "m2 = k + n2" and
    hds: "head_ω n1 ≠ head_ω n2 ∨ n1 = 0 ∧ n2 = 0"
    using hmset_pair_decompose[of m1 m2] by blast

  {
    assume "n1 = 0" and "n2 = 0"
    hence "m1 = m2"
      unfolding m1 m2 by simp
    hence False
      using m1_lt_m2 by simp
  }
  moreover
  {
    assume "head_ω n1 > head_ω n2"
    hence "n1 > n2"
      by (rule head_ω_lt_imp_lt)
    hence "m1 > m2"
      unfolding m1 m2 by simp
    hence False
      using m1_lt_m2 by simp
  }
  ultimately show ?thesis
    using m1 m2 hds by (blast elim: neqE)
qed

lemma hmset_pair_decompose_less_eq:
  assumes "m1 ≤ m2"
  shows "∃k n1 n2. m1 = k + n1 ∧ m2 = k + n2 ∧ (head_ω n1 < head_ω n2 ∨ n1 = 0 ∧ n2 = 0)"
  using assms
  by (metis add_cancel_right_right hmset_pair_decompose_less order.not_eq_order_implies_strict)

lemma mono_cross_mult_less_hmset:
  fixes Aa A Ba B :: hmultiset
  assumes A_lt: "A < Aa" and B_lt: "B < Ba"
  shows "A * Ba + B * Aa < A * B + Aa * Ba"
proof -
  obtain j m1 m2 where A: "A = j + m1" and Aa: "Aa = j + m2" and hd_m: "head_ω m1 < head_ω m2"
    by (metis hmset_pair_decompose_less[OF A_lt])
  obtain k n1 n2 where B: "B = k + n1" and Ba: "Ba = k + n2" and hd_n: "head_ω n1 < head_ω n2"
    by (metis hmset_pair_decompose_less[OF B_lt])

  have hd_lt: "head_ω (m1 * n2 + m2 * n1) < head_ω (m1 * n1 + m2 * n2)"
  proof simp
    have "⋀h ha :: hmultiset. 0 < h ∨ ¬ ha < h"
      by force
    hence "¬ head_ω m2 * head_ω n2 ≤ sup (head_ω m1 * head_ω n2) (head_ω m2 * head_ω n1)"
      using hd_m hd_n sup_hmultiset_def by auto
    thus "sup (head_ω m1 * head_ω n2) (head_ω m2 * head_ω n1)
      < sup (head_ω m1 * head_ω n1) (head_ω m2 * head_ω n2)"
      by (meson leI sup.bounded_iff)
  qed
  show ?thesis
    unfolding A Aa B Ba ring_distribs by (simp add: algebra_simps head_ω_lt_imp_lt[OF hd_lt])
qed

lemma triple_cross_mult_hmset:
  "An * (Bn * Cn + Bp * Cp - (Bn * Cp + Cn * Bp))
   + (Cn * (An * Bp + Bn * Ap - (An * Bn + Ap * Bp))
      + (Ap * (Bn * Cp + Cn * Bp - (Bn * Cn + Bp * Cp))
         + Cp * (An * Bn + Ap * Bp - (An * Bp + Bn * Ap)))) =
   An * (Bn * Cp + Cn * Bp - (Bn * Cn + Bp * Cp))
   + (Cn * (An * Bn + Ap * Bp - (An * Bp + Bn * Ap))
      + (Ap * (Bn * Cn + Bp * Cp - (Bn * Cp + Cn * Bp))
         + Cp * (An * Bp + Bn * Ap - (An * Bn + Ap * Bp))))"
  for Ap An Bp Bn Cp Cn Dp Dn :: hmultiset
  apply (simp add: algebra_simps)
  apply (unfold add.assoc[symmetric])

  apply (rule add_right_cancel[THEN iffD1, of _ "Cp * (An * Bp + Ap * Bn)"])
  apply (unfold add.assoc)
  apply (subst times_diff_plus_sym_hmset)
  apply (unfold add.assoc[symmetric])
  apply (subst (12) add.commute)
  apply (subst (11) add.commute)
  apply (unfold add.assoc[symmetric])

  apply (rule add_right_cancel[THEN iffD1, of _ "Cn * (An * Bn + Ap * Bp)"])
  apply (unfold add.assoc)
  apply (subst times_diff_plus_sym_hmset)
  apply (unfold add.assoc[symmetric])
  apply (subst (14) add.commute)
  apply (subst (13) add.commute)
  apply (unfold add.assoc[symmetric])

  apply (rule add_right_cancel[THEN iffD1, of _ "Ap * (Bn * Cn + Bp * Cp)"])
  apply (unfold add.assoc)
  apply (subst times_diff_plus_sym_hmset)
  apply (unfold add.assoc[symmetric])
  apply (subst (16) add.commute)
  apply (subst (15) add.commute)
  apply (unfold add.assoc[symmetric])

  apply (rule add_right_cancel[THEN iffD1, of _ "An * (Bn * Cp + Bp * Cn)"])
  apply (unfold add.assoc)
  apply (subst times_diff_plus_sym_hmset)
  apply (unfold add.assoc[symmetric])
  apply (subst (18) add.commute)
  apply (subst (17) add.commute)
  apply (unfold add.assoc[symmetric])

  by (simp add: algebra_simps)


subsection ‹Conversions to Natural Numbers›

definition offset_hmset :: "hmultiset ⇒ nat" where
  "offset_hmset M = count (hmsetmset M) 0"

lemma offset_hmset_of_nat[simp]: "offset_hmset (of_nat n) = n"
  unfolding offset_hmset_def of_nat_hmset by simp

lemma offset_hmset_numeral[simp]: "offset_hmset (numeral n) = numeral n"
  unfolding offset_hmset_def by (metis offset_hmset_def offset_hmset_of_nat of_nat_numeral)

definition sum_coefs :: "hmultiset ⇒ nat" where
  "sum_coefs M = size (hmsetmset M)"

lemma sum_coefs_distrib_plus[simp]: "sum_coefs (M + N) = sum_coefs M + sum_coefs N"
  unfolding plus_hmultiset_def sum_coefs_def by simp

lemma sum_coefs_gt_0: "sum_coefs M > 0 ⟷ M > 0"
  by (auto simp: sum_coefs_def zero_hmultiset_def hmsetmset_less[symmetric] less_multiset_extDM_less
    nonempty_has_size[symmetric])


subsection ‹An Example›

text ‹
The following proof is based on an informal proof by Uwe Waldmann, inspired by
a similar argument by Michel Ludwig.
›

lemma ludwig_waldmann_less:
  fixes α1 α2 β1 β2 γ δ :: hmultiset
  assumes
    αβ2γ_lt_αβ1γ: "α2 + β2 * γ < α1 + β1 * γ" and
    β2_le_β1: "β2 ≤ β1" and
    γ_lt_δ: "γ < δ"
  shows "α2 + β2 * δ < α1 + β1 * δ"
proof -
  obtain β0 β2a β1a where
    β1: "β1 = β0 + β1a" and
    β2: "β2 = β0 + β2a" and
    hd_β2a_vs_β1a: "head_ω β2a < head_ω β1a ∨ β2a = 0 ∧ β1a = 0"
    using hmset_pair_decompose_less_eq[OF β2_le_β1] by blast

  obtain η γa δa where
    γ: "γ = η + γa" and
    δ: "δ = η + δa" and
    hd_γa_lt_δa: "head_ω γa < head_ω δa"
    using hmset_pair_decompose_less[OF γ_lt_δ] by blast

  have "α2 + β0 * γ + β2a * γ = α2 + β2 * γ"
    unfolding β2 by (simp add: add.commute add.left_commute distrib_left mult.commute)
  also have "… < α1 + β1 * γ"
    by (rule αβ2γ_lt_αβ1γ)
  also have "… = α1 + β0 * γ + β1a * γ"
    unfolding β1 by (simp add: add.commute add.left_commute distrib_left mult.commute)
  finally have *: "α2 + β2a * γ < α1 + β1a * γ"
    by (metis add_less_cancel_right semiring_normalization_rules(23))

  have "α2 + β2 * δ = α2 + β0 * δ + β2a * δ"
    unfolding β2 by (simp add: ab_semigroup_add_class.add_ac(1) distrib_right)
  also have "… = α2 + β0 * δ + β2a * η + β2a * δa"
    unfolding δ by (simp add: distrib_left semiring_normalization_rules(25))
  also have "… ≤ α2 + β0 * δ + β2a * η + β2a * δa + β2a * γa"
    by simp
  also have "… = α2 + β2a * γ + β0 * δ + β2a * δa"
    unfolding γ distrib_left add.assoc[symmetric] by (simp add: semiring_normalization_rules(23))
  also have "… < α1 + β1a * γ + β0 * δ + β2a * δa"
    using * by simp
  also have "… = α1 + β1a * η + β1a * γa + β0 * η + β0 * δa + β2a * δa"
    unfolding γ δ distrib_left add.assoc[symmetric] by (rule refl)
  also have "… ≤ α1 + β1a * η + β0 * η + β0 * δa + β1a * δa"
  proof -
    have "β1a * γa + β2a * δa ≤ β1a * δa"
    proof (cases "β2a = 0 ∧ β1a = 0")
      case False
      hence "head_ω β2a < head_ω β1a"
        using hd_β2a_vs_β1a by blast
      hence "head_ω (β1a * γa + β2a * δa) < head_ω (β1a * δa)"
        using hd_γa_lt_δa by (auto intro: gr_zeroI_hmset simp: sup_hmultiset_def)
      hence "β1a * γa + β2a * δa < β1a * δa"
        by (rule head_ω_lt_imp_lt)
      thus ?thesis
        by simp
    qed simp
    thus ?thesis
      by simp
  qed
  finally show ?thesis
    unfolding β1 δ
    by (simp add: distrib_left distrib_right add.assoc[symmetric] semiring_normalization_rules(23))
qed

end