Theory Multiset_Extra

theory Multiset_Extra
  imports
    "HOL-Library.Multiset"
    "HOL-Library.Multiset_Order"
    Nested_Multisets_Ordinals.Multiset_More
begin

lemma one_le_countE:
  assumes "1  count M x"
  obtains M' where "M = add_mset x M'"
  using assms by (meson count_greater_eq_one_iff multi_member_split)

lemma two_le_countE:
  assumes "2  count M x"
  obtains M' where "M = add_mset x (add_mset x M')"
  using assms
  by (metis Suc_1 Suc_eq_plus1_left Suc_leD add.right_neutral count_add_mset multi_member_split
      not_in_iff not_less_eq_eq)

lemma three_le_countE:
  assumes "3  count M x"
  obtains M' where "M = add_mset x (add_mset x (add_mset x M'))"
  using assms
  by (metis One_nat_def Suc_1 Suc_leD add_le_cancel_left count_add_mset numeral_3_eq_3 plus_1_eq_Suc
      two_le_countE)

lemma one_step_implies_multpHO_strong:
  fixes A B J K :: "_ multiset"
  defines "J  B - A" and "K  A - B"
  assumes "J  {#}" and "k ∈# K. x ∈# J. R k x"
  shows "multpHO R A B"
  unfolding multpHO_def
proof (intro conjI allI impI)
  show "A  B"
    using assms by force
next
  show "y. count B y < count A y  x. R y x  count A x < count B x"
    using assms by (metis in_diff_count)
qed

lemma Uniq_antimono: "Q  P  Uniq Q  Uniq P"
  unfolding le_fun_def le_bool_def
  by (rule impI) (simp only: Uniq_I Uniq_D)

lemma Uniq_antimono': "(x. Q x  P x)  Uniq P  Uniq Q"
  by (fact Uniq_antimono[unfolded le_fun_def le_bool_def, rule_format])

lemma multp_singleton_right[simp]:
  assumes "transp R"
  shows "multp R M {#x#}  (y ∈# M. R y x)"
proof (rule iffI)
  show "y ∈# M. R y x  multp R M {#x#}"
    using one_step_implies_multp[of "{#x#}" _ R "{#}", simplified] .
next
  show "multp R M {#x#}  y∈#M. R y x"
    using multp_implies_one_step[OF transp R]
    by (smt (verit, del_insts) add_0 set_mset_add_mset_insert set_mset_empty single_is_union
        singletonD)
qed

lemma multp_singleton_left[simp]:
  assumes "transp R"
  shows "multp R {#x#} M  ({#x#} ⊂# M  (y ∈# M. R x y))"
proof (rule iffI)
  show "{#x#} ⊂# M  (y∈#M. R x y)  multp R {#x#} M"
  proof (elim disjE bexE)
    show "{#x#} ⊂# M  multp R {#x#} M"
      by (simp add: subset_implies_multp)
  next
    show "y. y ∈# M  R x y  multp R {#x#} M"
      using one_step_implies_multp[of M "{#x#}" R "{#}", simplified] by force
  qed
next
  show "multp R {#x#} M  {#x#} ⊂# M  (y∈#M. R x y)"
    using multp_implies_one_step[OF transp R, of "{#x#}" M]
    by (metis (no_types, opaque_lifting) add_cancel_right_left subset_mset.gr_zeroI
        subset_mset.less_add_same_cancel2 union_commute union_is_single union_single_eq_member)
qed

lemma multp_singleton_singleton[simp]: "transp R  multp R {#x#} {#y#}  R x y"
  using multp_singleton_right[of R "{#x#}" y] by simp

lemma multp_subset_supersetI: "transp R  multp R A B  C ⊆# A  B ⊆# D  multp R C D"
  by (metis subset_implies_multp subset_mset.antisym_conv2 transpE transp_multp)

lemma multp_double_doubleI:
  assumes "transp R" "multp R A B"
  shows "multp R (A + A) (B + B)"
  using multp_repeat_mset_repeat_msetI[OF transp R multp R A B, of 2]
  by (simp add: numeral_Bit0)

lemma multp_implies_one_step_strong:
  fixes A B I J K :: "_ multiset"
  assumes "transp R" and "asymp R" and "multp R A B"
  defines "J  B - A" and "K  A - B"
  shows "J  {#}" and "k ∈# K. x ∈# J. R k x"
proof -
  from assms have "multpHO R A B"
    by (simp add: multp_eq_multpHO)

  thus "J  {#}" and "k ∈# K. x ∈# J. R k x"
    using multpHO_implies_one_step_strong[OF multpHO R A B]
    by (simp_all add: J_def K_def)
qed

lemma multp_double_doubleD:
  assumes "transp R" and "asymp R" and "multp R (A + A) (B + B)"
  shows "multp R A B"
proof -
  from assms have
    "B + B - (A + A)  {#}" and
    "k∈#A + A - (B + B). x∈#B + B - (A + A). R k x"
    using multp_implies_one_step_strong[OF assms] by simp_all

  have "multp R (A ∩# B + (A - B)) (A ∩# B + (B - A))"
  proof (rule one_step_implies_multp[of "B - A" "A - B" R "A ∩# B"])
    show "B - A  {#}"
      using B + B - (A + A)  {#}
      by (meson Diff_eq_empty_iff_mset mset_subset_eq_mono_add)
  next
    show "k∈#A - B. j∈#B - A. R k j"
    proof (intro ballI)
      fix x assume "x ∈# A - B"
      hence "x ∈# A + A - (B + B)"
        by (simp add: in_diff_count)
      then obtain y where "y ∈# B + B - (A + A)" and "R x y"
        using k∈#A + A - (B + B). x∈#B + B - (A + A). R k x by auto
      then show "j∈#B - A. R x j"
        by (auto simp add: in_diff_count)
    qed
  qed

  moreover have "A = A ∩# B + (A - B)"
    by (simp add: inter_mset_def)

  moreover have "B = A ∩# B + (B - A)"
    by (metis diff_intersect_right_idem subset_mset.add_diff_inverse subset_mset.inf.cobounded2)

  ultimately show ?thesis
    by argo
qed

lemma multp_double_double:
  "transp R  asymp R  multp R (A + A) (B + B)  multp R A B"
  using multp_double_doubleD multp_double_doubleI by metis

lemma multp_doubleton_doubleton[simp]:
  "transp R  asymp R  multp R {#x, x#} {#y, y#}  R x y"
  using multp_double_double[of R "{#x#}" "{#y#}", simplified] by simp

lemma multp_single_doubleI: "M  {#}  multp R M (M + M)"
  using one_step_implies_multp[of M "{#}" _ M, simplified] by simp

lemma mult1_implies_one_step_strong:
  assumes "trans r" and "asym r" and "(A, B)  mult1 r"
  shows "B - A  {#}" and "k ∈# A - B. j ∈# B - A. (k, j)  r"
proof -
  from (A, B)  mult1 r obtain b B' A' where
    B_def: "B = add_mset b B'" and
    A_def: "A = B' + A'" and
    "a. a ∈# A'  (a, b)  r"
    unfolding mult1_def by auto

  have "b ∉# A'"
    by (meson a. a ∈# A'  (a, b)  r assms(2) asym_onD iso_tuple_UNIV_I)
  then have "b ∈# B - A"
    by (simp add: A_def B_def)
  thus "B - A  {#}"
    by auto

  show "k ∈# A - B. j ∈# B - A. (k, j)  r"
    by (metis A_def B_def a. a ∈# A'  (a, b)  r b ∈# B - A b ∉# A' add_diff_cancel_left'
        add_mset_add_single diff_diff_add_mset diff_single_trivial)
qed

lemma asymp_multp:
  assumes "asymp R" and "transp R"
  shows "asymp (multp R)"
  using asymp_multpHO[OF assms]
  unfolding multp_eq_multpHO[OF assms].

lemma multp_doubleton_singleton: "transp R  multp R {# x, x #} {# y #}  R x y"
  by (cases "x = y") auto

lemma image_mset_remove1_mset: 
  assumes "inj f"  
  shows "remove1_mset (f a) (image_mset f X) = image_mset f (remove1_mset a X)"
  using image_mset_remove1_mset_if
  unfolding image_mset_remove1_mset_if inj_image_mem_iff[OF assms, symmetric]
  by simp

lemma multpDM_map_strong:
  assumes
    f_mono: "monotone_on (set_mset (M1 + M2)) R S f" and
    M1_lt_M2: "multpDM R M1 M2"
  shows "multpDM S (image_mset f M1) (image_mset f M2)"
proof -
  obtain Y X where
    "Y  {#}" and "Y ⊆# M2" and M1_eq: "M1 = M2 - Y + X" and
    ex_y: "x. x ∈# X  (y. y ∈# Y  R x y)"
    using M1_lt_M2[unfolded multpDM_def Let_def mset_map] by blast


  let ?fY = "image_mset f Y"
  let ?fX = "image_mset f X"

  show ?thesis
    unfolding multpDM_def 
  proof (intro exI conjI)
    show "image_mset f Y  {#}"
      using Y  {#} unfolding image_mset_is_empty_iff .
  next
    show "image_mset f Y ⊆# image_mset f M2"
      using Y ⊆# M2 image_mset_subseteq_mono by metis
  next
    show "image_mset f M1 = image_mset f M2 - ?fY + ?fX"
      using M1_eq[THEN arg_cong, of "image_mset f"] Y ⊆# M2
      by (metis image_mset_Diff image_mset_union)
  next
    obtain g where y: "x. x ∈# X  g x ∈# Y  R x (g x)"
      using ex_y by moura

    show "fx. fx ∈# ?fX  (fy. fy ∈# ?fY  S fx fy)"
    proof (intro allI impI)
      fix x' assume "x' ∈# ?fX"
      then obtain x where x': "x' = f x" and x_in: "x ∈# X"
        by auto
      hence y_in: "g x ∈# Y" and y_gt: "R x (g x)"
        using y[rule_format, OF x_in] by blast+

      moreover have "X ⊆# M1"
        using M1_eq by simp

      ultimately have "f (g x) ∈# ?fY  S (f x)(f (g x)) "
        using f_mono[THEN monotone_onD, of x "g x"] Y ⊆# M2 X ⊆# M1 x_in
        by (metis imageI in_image_mset mset_subset_eqD union_iff)
      thus "fy. fy ∈# ?fY  S x' fy"
        unfolding x' by auto
    qed
  qed
qed

lemma multp_map_strong:
  assumes
    transp: "transp R" and
    f_mono: "monotone_on (set_mset (M1 + M2)) R S f" and
    M1_lt_M2: "multp R M1 M2"
  shows "multp S (image_mset f M1) (image_mset f M2)"
  using monotone_on_multp_multp_image_mset[THEN monotone_onD, OF f_mono transp _ _ M1_lt_M2]
  by simp

lemma multpHO_add_mset:
  assumes "asymp R" "transp R" "R x y" "multpHO R X Y"
  shows "multpHO R (add_mset x X) (add_mset y Y)"
  unfolding multpHO_def
proof(intro allI conjI impI)
  show "add_mset x X  add_mset y Y"
    using assms(1, 3, 4)
    unfolding multpHO_def
    by (metis asympD count_add_mset lessI less_not_refl)
next
  fix x' 
  assume count_x': "count (add_mset y Y) x' < count (add_mset x X) x'"
  show "y'. R x' y'  count (add_mset x X) y' < count (add_mset y Y) y'"
  proof(cases "x' = x")
      case True
      then show ?thesis 
        using assms
        unfolding multpHO_def
        by (metis count_add_mset irreflpD irreflp_on_if_asymp_on not_less_eq transpE)
    next
      case x'_neq_x: False
      show ?thesis
      proof(cases "y = x'")
        case True
        then show ?thesis 
          using assms(1, 3, 4) count_x' x'_neq_x
          unfolding multpHO_def count_add_mset
          by (smt (verit) Suc_lessD asympD)
      next
        case False
        then show ?thesis
          using assms count_x' x'_neq_x
          unfolding multpHO_def count_add_mset
          by (smt (verit, del_insts) irreflpD irreflp_on_if_asymp_on not_less_eq transpE)
      qed     
    qed
qed

lemma multp_add_mset:
  assumes "asymp R" "transp R" "R x y" "multp R X Y"
  shows "multp R (add_mset x X) (add_mset y Y)"
  using multpHO_add_mset[OF assms(1-3)] assms(4)
  unfolding multp_eq_multpHO[OF assms(1, 2)] 
  by simp

lemma multp_add_mset':
  assumes "R x y"  
  shows "multp R (add_mset x X) (add_mset y X)"
  using assms
  by (metis add_mset_add_single empty_iff insert_iff one_step_implies_multp set_mset_add_mset_insert 
        set_mset_empty)

lemma multp_add_mset_reflclp:
  assumes "asymp R" "transp R" "R x y" "(multp R)== X Y"
  shows "multp R (add_mset x X) (add_mset y Y)"
  using 
    assms(4)
    multp_add_mset'[of R, OF assms(3)]
    multp_add_mset[OF assms(1-3)]
  by blast

lemma multp_add_same:
  assumes "asymp R" "transp R" "multp R X Y"
  shows "multp R (add_mset x X) (add_mset x Y)"
  by (meson assms asymp_on_subset irreflp_on_if_asymp_on multp_cancel_add_mset top_greatest)

end