Theory Card_Datatype

(*  Title:      Containers/Card_Datatype.thy
    Author:     Andreas Lochbihler, ETH Zurich *)

theory Card_Datatype
imports "HOL-Library.Cardinality"
begin

section ‹Definitions to prove equations about the cardinality of data types›

subsection ‹Specialised @{term range} constants›

definition rangeIt :: "'a  ('a  'a)  'a set"
where "rangeIt x f = range (λn. (f ^^ n) x)"

definition rangeC :: "('a  'b) set  'b set"
where "rangeC F = (f  F. range f)"

lemma infinite_rangeIt: 
  assumes inj: "inj f"
  and x: "y. x  f y"
  shows "¬ finite (rangeIt x f)"
proof -
  have "inj (λn. (f ^^ n) x)"
  proof(rule injI)
    fix n m
    assume "(f ^^ n) x = (f ^^ m) x"
    thus "n = m"
    proof(induct n arbitrary: m)
      case 0
      thus ?case using x by(cases m)(auto intro: sym)
    next
      case (Suc n)
      thus ?case using x by(cases m)(auto intro: sym dest: injD[OF inj])
    qed
  qed
  thus ?thesis
    by(auto simp add: rangeIt_def dest: finite_imageD)
qed

lemma in_rangeC: "f  A  f x  rangeC A"
by(auto simp add: rangeC_def)

lemma in_rangeCE: assumes "y  rangeC A"
  obtains f x where "f  A" "y = f x"
using assms by(auto simp add: rangeC_def)

lemma in_rangeC_singleton: "f x  rangeC {f}"
by(auto simp add: rangeC_def)

lemma in_rangeC_singleton_const: "x  rangeC {λ_. x}"
by(rule in_rangeC_singleton)

lemma rangeC_rangeC: "f  rangeC A  f x  rangeC (rangeC A)"
by(auto simp add: rangeC_def)

lemma rangeC_eq_empty: "rangeC A = {}  A = {}"
by(auto simp add: rangeC_def)

lemma Ball_rangeC_iff:
  "(x  rangeC A. P x)  (f  A. x. P (f x))"
by(auto intro: in_rangeC elim: in_rangeCE)

lemma Ball_rangeC_singleton:
  "(x  rangeC {f}. P x)  (x. P (f x))"
by(simp add: Ball_rangeC_iff)

lemma Ball_rangeC_rangeC:
  "(x  rangeC (rangeC A). P x)  (f  rangeC A. x. P (f x))"
by(simp add: Ball_rangeC_iff)

lemma finite_rangeC:
  assumes inj: "f  A. inj f"
  and disjoint: "f  A. g  A. f  g  (x y. f x  g y)"
  shows "finite (rangeC (A :: ('a  'b) set))  finite A  (A  {}  finite (UNIV :: 'a set))"
  (is "?lhs  ?rhs")
proof
  assume ?lhs
  thus ?rhs using inj disjoint
  proof(induct "rangeC A" arbitrary: A rule: finite_psubset_induct)
    case (psubset A)
    show ?case
    proof(cases "A = {}")
      case True thus ?thesis by simp
    next
      case False
      then obtain f A' where A: "A = insert f A'" and f: "f  A" "f  A'"
        by(fastforce dest: mk_disjoint_insert)
      from A have rA: "rangeC A = rangeC A'  range f"
        by(auto simp add: rangeC_def)

      have "¬ range f  rangeC A'" 
      proof
        assume "range f  rangeC A'"
        moreover obtain x where x: "x  range f" by auto
        ultimately have "x  rangeC A'" by auto
        then obtain g where "g  A'" "x  range g" by(auto simp add: rangeC_def)
        with f  A' A have "g  A" "f  g" by auto
        with f  A have "x y. f x  g y" by(rule psubset.prems[rule_format])
        thus False using x x  range g by auto
      qed
      hence "rangeC A'  rangeC A" unfolding rA by auto
      hence "finite A'  (A'  {}  finite (UNIV :: 'a set))"
        using psubset.prems
        by -(erule psubset.hyps, auto simp add: A)
      with A have "finite A" by simp
      moreover with finite (rangeC A) A f  A. inj f
      have "finite (UNIV :: 'a set)"
        by(auto simp add: rangeC_def dest: finite_imageD)
      ultimately show ?thesis by blast
    qed
  qed
qed(auto simp add: rangeC_def)

lemma finite_rangeC_singleton_const:
  "finite (rangeC {λ_. x})"
by(auto simp add: rangeC_def image_def)

lemma card_Un:
  " finite A; finite B   card (A  B) = card (A) + card (B) - card(A  B)"
by(subst card_Un_Int) simp_all

lemma card_rangeC_singleton_const:
  "card (rangeC {λ_. f}) = 1"
by(simp add: rangeC_def image_def)

lemma card_rangeC:
  assumes inj: "f  A. inj f"
  and disjoint: "f  A. g  A. f  g  (x y. f x  g y)"
  shows "card (rangeC (A :: ('a  'b) set)) = CARD('a) * card A"
  (is "?lhs = ?rhs")
proof(cases "finite (UNIV :: 'a set)  finite A")
  case False
  thus ?thesis using False finite_rangeC[OF assms]
    by(auto simp add: card_eq_0_iff rangeC_eq_empty)
next
  case True
  { fix f
    assume "f  A"
    hence "card (range f) = CARD('a)" using inj by(simp add: card_image) }
  thus ?thesis using disjoint True unfolding rangeC_def
    by(subst card_UN_disjoint) auto
qed

lemma rangeC_Int_rangeC:
  " f  A. g  B. x y. f x  g y   rangeC A  rangeC B = {}"
by(auto simp add: rangeC_def)

lemmas rangeC_simps =
  in_rangeC_singleton
  in_rangeC_singleton_const
  rangeC_rangeC
  rangeC_eq_empty
  Ball_rangeC_singleton
  Ball_rangeC_rangeC
  finite_rangeC
  finite_rangeC_singleton_const
  card_rangeC_singleton_const
  card_rangeC
  rangeC_Int_rangeC

bundle card_datatype =
  rangeC_simps [simp]
  card_Un [simp]
  fun_eq_iff [simp]
  Int_Un_distrib [simp]
  Int_Un_distrib2 [simp]
  card_eq_0_iff [simp]
  imageI [simp] image_eqI [simp del]
  conj_cong [cong]
  infinite_rangeIt [simp]

subsection ‹Cardinality primitives for polymorphic HOL types›

ML structure Card_Simp_Rules = Named_Thms
(
  val name = @{binding card_simps}
  val description = "Simplification rules for cardinality of types"
)
setup Card_Simp_Rules.setup

definition card_fun :: "nat  nat  nat"
where "card_fun a b = (if a  0  b  0  b = 1 then b ^ a else 0)"

lemma CARD_fun [card_simps]:
  "CARD('a  'b) = card_fun CARD('a) CARD('b)"
by(simp add: card_fun card_fun_def)

definition card_sum :: "nat  nat  nat"
where "card_sum a b = (if a = 0  b = 0 then 0 else a + b)"

lemma CARD_sum [card_simps]:
  "CARD('a + 'b) = card_sum CARD('a) CARD('b)"
by(simp add: card_UNIV_sum card_sum_def)

definition card_option :: "nat  nat"
where "card_option n = (if n = 0 then 0 else Suc n)"

lemma CARD_option [card_simps]:
  "CARD('a option) = card_option CARD('a)"
by(simp add: card_option_def card_UNIV_option)

definition card_prod :: "nat  nat  nat"
where "card_prod a b = a * b"

lemma CARD_prod [card_simps]:
  "CARD('a * 'b) = card_prod CARD('a) CARD('b)"
by(simp add: card_prod_def)

definition card_list :: "nat  nat"
where "card_list _ = 0"

lemma CARD_list [card_simps]: "CARD('a list) = card_list CARD('a)"
by(simp add: card_list_def infinite_UNIV_listI)

end