Theory Nested_Multiset

(*  Title:       Nested Multisets
    Author:      Dmitriy Traytel <traytel at inf.ethz.ch>, 2016
    Author:      Jasmin Blanchette <jasmin.blanchette at inria.fr>, 2016
    Maintainer:  Dmitriy Traytel <traytel at inf.ethz.ch>
*)

section ‹Nested Multisets›

theory Nested_Multiset
imports "HOL-Library.Multiset_Order"
begin

declare multiset.map_comp [simp]
declare multiset.map_cong [cong]


subsection ‹Type Definition›

datatype 'a nmultiset =
  Elem 'a
| MSet "'a nmultiset multiset"

inductive no_elem :: "'a nmultiset  bool" where
  "(X. X ∈# M  no_elem X)  no_elem (MSet M)"

inductive_set sub_nmset :: "('a nmultiset × 'a nmultiset) set" where
  "X ∈# M  (X, MSet M)  sub_nmset"

lemma wf_sub_nmset[simp]: "wf sub_nmset"
proof (rule wfUNIVI)
  fix P :: "'a nmultiset  bool" and M :: "'a nmultiset"
  assume IH: "M. (N. (N, M)  sub_nmset  P N)  P M"
  show "P M"
    by (induct M; rule IH[rule_format]) (auto simp: sub_nmset.simps)
qed

primrec depth_nmset :: "'a nmultiset  nat" ("|_|") where
  "|Elem a| = 0"
| "|MSet M| = (let X = set_mset (image_mset depth_nmset M) in if X = {} then 0 else Suc (Max X))"

lemma depth_nmset_MSet: "x ∈# M  |x| < |MSet M|"
  by (auto simp: less_Suc_eq_le)

declare depth_nmset.simps(2)[simp del]


subsection ‹Dershowitz and Manna's Nested Multiset Order›

text ‹The Dershowitz--Manna extension:›

definition less_multiset_extDM :: "('a  'a  bool)  'a multiset  'a multiset  bool" where
  "less_multiset_extDM R M N 
   (X Y. X  {#}  X ⊆# N  M = (N - X) + Y  (k. k ∈# Y  (a. a ∈# X  R k a)))"

lemma less_multiset_extDM_imp_mult:
  assumes
    N_A: "set_mset N  A" and M_A: "set_mset M  A" and less: "less_multiset_extDM R M N"
  shows "(M, N)  mult {(x, y). x  A  y  A  R x y}"
proof -
  from less obtain X Y where
    "X  {#}" and "X ⊆# N" and "M = N - X + Y" and "k. k ∈# Y  (a. a ∈# X  R k a)"
    unfolding less_multiset_extDM_def by blast
  with N_A M_A have "(N - X + Y, N - X + X)  mult {(x, y). x  A  y  A  R x y}"
    by (intro one_step_implies_mult, blast,
      metis (mono_tags, lifting) case_prodI mem_Collect_eq mset_subset_eqD mset_subset_eq_add_right
        subsetCE)
  with M = N - X + Y X ⊆# N show "(M, N)  mult {(x, y). x  A  y  A  R x y}"
    by (simp add: subset_mset.diff_add)
qed

lemma mult_imp_less_multiset_extDM:
  assumes
    N_A: "set_mset N  A" and M_A: "set_mset M  A" and
    trans: "x  A. y  A. z  A. R x y  R y z  R x z" and
    in_mult: "(M, N)  mult {(x, y). x  A  y  A  R x y}"
  shows "less_multiset_extDM R M N"
  using in_mult N_A M_A unfolding mult_def less_multiset_extDM_def
proof induct
  case (base N)
  then obtain y M0 X where "N = add_mset y M0" and "M = M0 + X" and "a. a ∈# X  R a y"
    unfolding mult1_def by auto
  thus ?case
    by (auto intro: exI[of _ "{#y#}"])
next
  case (step N N')
  note N_N'_in_mult1 = this(2) and ih = this(3) and N'_A = this(4) and M_A = this(5)

  have N_A: "set_mset N  A"
    using N_N'_in_mult1 N'_A unfolding mult1_def by auto

  obtain Y X where y_nemp: "Y  {#}" and y_sub_N: "Y ⊆# N" and M_eq: "M = N - Y + X" and
    ex_y: "x. x ∈# X  (y. y ∈# Y  R x y)"
    using ih[OF N_A M_A] by blast

  obtain z M0 Ya where N'_eq: "N' = M0 + {#z#}" and N_eq: "N = M0 + Ya" and
    z_gt: "y. y ∈# Ya  y  A  z  A  R y z"
    using N_N'_in_mult1[unfolded mult1_def] by auto

  let ?Za = "Y - Ya + {#z#}"
  let ?Xa = "X + Ya + (Y - Ya) - Y"

  have xa_sub_x_ya: "set_mset ?Xa  set_mset (X + Ya)"
    by (metis diff_subset_eq_self in_diffD subsetI subset_mset.diff_diff_right)

  have x_A: "set_mset X  A"
    using M_A M_eq by auto
  have ya_A: "set_mset Ya  A"
    by (simp add: subsetI z_gt)

  have ex_y': "y. y ∈# Y - Ya + {#z#}  R x y" if x_in: "x ∈# X + Ya" for x
  proof (cases "x ∈# X")
    case True
    then obtain y where y_in: "y ∈# Y" and y_gt_x: "R x y"
      using ex_y by blast
    show ?thesis
    proof (cases "y ∈# Ya")
      case False
      hence "y ∈# Y - Ya + {#z#}"
        using y_in count_greater_zero_iff in_diff_count by fastforce
      thus ?thesis
        using y_gt_x by blast
    next
      case True
      hence "y  A" and "z  A" and "R y z"
        using z_gt by blast+
      hence "R x z"
        using trans y_gt_x x_A ya_A x_in by (meson subsetCE union_iff)
      thus ?thesis
        by auto
    qed
  next
    case False
    hence "x ∈# Ya"
      using x_in by auto
    hence "x  A" and "z  A" and "R x z"
      using z_gt by blast+
    thus ?thesis
      by auto
  qed

  show ?case
  proof (rule exI[of _ ?Za], rule exI[of _ ?Xa], intro conjI)
    show "Y - Ya + {#z#} ⊆# N'"
      using mset_subset_eq_mono_add subset_eq_diff_conv y_sub_N N_eq N'_eq
      by (simp add: subset_eq_diff_conv)
  next
    show "M = N' - (Y - Ya + {#z#}) + (X + Ya + (Y - Ya) - Y)"
      unfolding M_eq N_eq N'_eq by (auto simp: multiset_eq_iff)
  next
    show "x. x ∈# X + Ya + (Y - Ya) - Y  (y. y ∈# Y - Ya + {#z#}  R x y)"
      using ex_y' xa_sub_x_ya by blast
  qed auto
qed

lemma less_multiset_extDM_iff_mult:
  assumes
    N_A: "set_mset N  A" and M_A: "set_mset M  A" and
    trans: "x  A. y  A. z  A. R x y  R y z  R x z"
  shows "less_multiset_extDM R M N  (M, N)  mult {(x, y). x  A  y  A  R x y}"
  using mult_imp_less_multiset_extDM[OF assms] less_multiset_extDM_imp_mult[OF N_A M_A] by blast

instantiation nmultiset :: (preorder) preorder
begin

lemma less_multiset_extDM_cong[fundef_cong]:
  "(X Y k a. X  {#}  X ⊆# N  M = (N - X) + Y  k ∈# Y  R k a = S k a) 
  less_multiset_extDM R M N = less_multiset_extDM S M N"
  unfolding less_multiset_extDM_def by metis

function less_nmultiset :: "'a nmultiset  'a nmultiset  bool" where
  "less_nmultiset (Elem a) (Elem b)  a < b"
| "less_nmultiset (Elem a) (MSet M)  True"
| "less_nmultiset (MSet M) (Elem a)  False"
| "less_nmultiset (MSet M) (MSet N)  less_multiset_extDM less_nmultiset M N"
  by pat_completeness auto
termination
  by (relation "sub_nmset <*lex*> sub_nmset", fastforce,
    metis sub_nmset.simps in_lex_prod mset_subset_eqD mset_subset_eq_add_right)

lemmas less_nmultiset_induct =
  less_nmultiset.induct[case_names Elem_Elem Elem_MSet MSet_Elem MSet_MSet]

lemmas less_nmultiset_cases =
  less_nmultiset.cases[case_names Elem_Elem Elem_MSet MSet_Elem MSet_MSet]

lemma trans_less_nmultiset: "X < Y  Y < Z  X < Z" for X Y Z :: "'a nmultiset"
proof (induct "Max {|X|, |Y|, |Z|}" arbitrary: X Y Z
    rule: less_induct)
  case less
  from less(2,3) show ?case
  proof (cases X; cases Y; cases Z)
    fix M N N' :: "'a nmultiset multiset"
    define A where "A = set_mset M  set_mset N  set_mset N'"
    assume XYZ: "X = MSet M" "Y = MSet N" "Z = MSet N'"
    then have trans: "x  A. y  A. z  A. x < y  y < z  x < z"
      by (auto elim!: less(1)[rotated -1] dest!: depth_nmset_MSet simp add: A_def)
    have "set_mset M  A" "set_mset N  A" "set_mset N'  A"
      unfolding A_def by auto
    with less(2,3) XYZ show "X < Z"
      by (auto simp: less_multiset_extDM_iff_mult[OF _ _ trans] mult_def)
  qed (auto elim: less_trans)
qed

lemma irrefl_less_nmultiset:
  fixes X :: "'a nmultiset"
  shows "X < X  False"
proof (induct X)
  case (MSet M)
  from MSet(2) show ?case
  unfolding less_nmultiset.simps less_multiset_extDM_def
  proof safe
    fix X Y :: "'a nmultiset multiset"
    define XY where "XY = {(x, y). x ∈# X  y ∈# Y  y < x}"
    then have fin: "finite XY" and trans: "trans XY"
      by (auto simp: trans_def intro: trans_less_nmultiset
        finite_subset[OF _ finite_cartesian_product])
    assume "X  {#}" "X ⊆# M" "M = M - X + Y"
    then have "X = Y"
      by (auto simp: mset_subset_eq_exists_conv)
    with MSet(1) X ⊆# M have "irrefl XY"
      unfolding XY_def by (force dest: mset_subset_eqD simp: irrefl_def)
    with trans have "acyclic XY"
      by (simp add: acyclic_irrefl)
    moreover
    assume "k. k ∈# Y  (a. a ∈# X  k < a)"
    with X = Y X  {#} have "¬ acyclic XY"
      by (intro notI, elim finite_acyclic_wf[OF fin, elim_format])
        (auto dest!: spec[of _ "set_mset Y"] simp: wf_eq_minimal XY_def)
    ultimately show False by blast
  qed
qed simp

lemma antisym_less_nmultiset:
  fixes X Y :: "'a nmultiset"
  shows "X < Y  Y < X  False"
  using trans_less_nmultiset irrefl_less_nmultiset by blast

definition less_eq_nmultiset :: "'a nmultiset  'a nmultiset  bool" where
  "less_eq_nmultiset X Y = (X < Y  X = Y)"

instance
proof (intro_classes, goal_cases less_def refl trans)
  case (less_def x y)
  then show ?case
    unfolding less_eq_nmultiset_def by (metis irrefl_less_nmultiset antisym_less_nmultiset)
next
  case (refl x)
  then show ?case
    unfolding less_eq_nmultiset_def by blast
next
  case (trans x y z)
  then show ?case
    unfolding less_eq_nmultiset_def by (metis trans_less_nmultiset)
qed

lemma less_multiset_extDM_less: "less_multiset_extDM (<) = (<)"
  unfolding fun_eq_iff less_multiset_extDM_def less_multisetDM by blast

end

instantiation nmultiset :: (order) order
begin

instance
proof (intro_classes, goal_cases antisym)
  case (antisym x y)
  then show ?case
    unfolding less_eq_nmultiset_def by (metis trans_less_nmultiset irrefl_less_nmultiset)
qed

end

instantiation nmultiset :: (linorder) linorder
begin

lemma total_less_nmultiset:
  fixes X Y :: "'a nmultiset"
  shows "¬ X < Y  Y  X  Y < X"
proof (induct X Y rule: less_nmultiset_induct)
  case (MSet_MSet M N)
  then show ?case
    unfolding nmultiset.inject less_nmultiset.simps less_multiset_extDM_less less_multisetHO
    by (metis add_diff_cancel_left' count_inI diff_add_zero in_diff_count less_imp_not_less
      mset_subset_eq_multiset_union_diff_commute subset_mset.refl)
qed auto

instance
proof (intro_classes, goal_cases total)
  case (total x y)
  then show ?case
    unfolding less_eq_nmultiset_def by (metis total_less_nmultiset)
qed

end

lemma less_depth_nmset_imp_less_nmultiset: "|X| < |Y|  X < Y"
proof (induct X Y rule: less_nmultiset_induct)
  case (MSet_MSet M N)
  then show ?case
  proof (cases "M = {#}")
    case False
    with MSet_MSet show ?thesis
      by (auto 0 4 simp: depth_nmset.simps(2) less_multiset_extDM_def not_le Max_gr_iff
        intro: exI[of _ N] split: if_splits)
  qed (auto simp: depth_nmset.simps(2) less_multiset_extDM_less split: if_splits)
qed simp_all

lemma less_nmultiset_imp_le_depth_nmset: "X < Y  |X|  |Y|"
proof (induct X Y rule: less_nmultiset_induct)
  case (MSet_MSet M N)
  then have "M < N" by (simp add: less_multiset_extDM_less)
  then show ?case
  proof (cases "M = {#}" "N = {#}" rule: bool.exhaust[case_product bool.exhaust])
    case [simp]: False_False
    show ?thesis
    unfolding depth_nmset.simps(2) Let_def False_False Suc_le_mono set_image_mset image_is_empty
      set_mset_eq_empty_iff if_False
    proof (intro iffD2[OF Max_le_iff] ballI iffD2[OF Max_ge_iff]; (elim imageE)?; simp)
      fix X
      assume [simp]: "X ∈# M"
      with MSet_MSet(1)[of N M X, simplified] M < N show "Y∈#N. |X|  |Y|"
        by (meson ex_gt_imp_less_multiset less_asym' less_depth_nmset_imp_less_nmultiset
          not_le_imp_less)
    qed
  qed (auto simp: depth_nmset.simps(2))
qed simp_all

lemma eq_mlex_I:
  fixes f :: "'a  nat" and R :: "'a  'a  bool"
  assumes "X Y. f X < f Y  R X Y" and "antisymp R"
  shows "{(X, Y). R X Y} = f <*mlex*> {(X, Y). f X = f Y  R X Y}"
proof safe
  fix X Y
  assume "R X Y"
  show "(X, Y)  f <*mlex*> {(X, Y). f X = f Y  R X Y}"
  proof (cases "f X" "f Y" rule: linorder_cases)
    case less
    with R X Y show ?thesis
      by (elim mlex_less)
  next
    case equal
    with R X Y show ?thesis
      by (intro mlex_leq) auto
  next
    case greater
    from R X Y assms(1)[OF greater] antisymp R greater show ?thesis
      unfolding antisymp_def by auto
  qed
next
  fix X Y
  assume "(X, Y)  f <*mlex*> {(X, Y). f X = f Y  R X Y}"
  then show "R X Y"
    unfolding mlex_prod_def by (auto simp: assms(1))
qed

instantiation nmultiset :: (wellorder) wellorder
begin

lemma depth_nmset_eq_0[simp]: "|X| = 0  (X = MSet {#}  (x. X = Elem x))"
  by (cases X; simp add: depth_nmset.simps(2))

lemma depth_nmset_eq_Suc[simp]: "|X| = Suc n 
  (N. X = MSet N  (Y ∈# N. |Y| = n)  (Y ∈# N. |Y|  n))"
  by (cases X; auto simp add: depth_nmset.simps(2) intro!: Max_eqI)
    (metis (no_types, lifting) Max_in finite_imageI finite_set_mset imageE image_is_empty
      set_mset_eq_empty_iff)

lemma wf_less_nmultiset_depth:
  "wf {(X :: 'a nmultiset, Y). |X| = i  |Y| = i  X < Y}"
proof (induct i rule: less_induct)
  case (less i)
  define A :: "'a nmultiset set" where "A = {X. |X| < i}"
  from less have "wf ((depth_nmset :: 'a nmultiset  nat) <*mlex*>
      (j < i. {(X, Y). |X| = j  |Y| = j  X < Y}))"
    by (intro wf_UN wf_mlex) auto
  then have *: "wf (mult {(X :: 'a nmultiset, Y). X  A  Y  A  X < Y})"
    by (intro wf_mult, elim wf_subset) (force simp: A_def mlex_prod_def not_less_iff_gr_or_eq
      dest!: less_depth_nmset_imp_less_nmultiset)
  show ?case
  proof (cases i)
    case 0
    then show ?thesis
      by (auto simp: inj_on_def intro!: wf_subset[OF
        wf_Un[OF wf_map_prod_image[OF wf, of Elem] wf_UN[of "Elem ` UNIV" "λx. {(x, MSet {#})}"]]])
  next
    case (Suc n)
    then show ?thesis
      by (intro wf_subset[OF wf_map_prod_image[OF *, of MSet]])
        (auto 0 4 simp: map_prod_def image_iff inj_on_def A_def
          dest!: less_multiset_extDM_imp_mult[of _ A, rotated -1] split: prod.splits)
  qed
qed

lemma wf_less_nmultiset: "wf {(X :: 'a nmultiset, Y :: 'a nmultiset). X < Y}" (is "wf ?R")
proof -
  have "?R = depth_nmset <*mlex*> {(X, Y). |X| = |Y|  X < Y}"
    by (rule eq_mlex_I) (auto simp: antisymp_def less_depth_nmset_imp_less_nmultiset)
  also have "{(X, Y). |X| = |Y|  X < Y} = (i. {(X, Y). |X| = i  |Y| = i  X < Y})"
    by auto
  finally show ?thesis
    by (fastforce intro: wf_mlex wf_Union wf_less_nmultiset_depth)
qed

instance using wf_less_nmultiset unfolding wf_def mem_Collect_eq prod.case by intro_classes metis

end

end