Theory Enumeration

(*
 * Copyright 2020, Data61, CSIRO (ABN 41 687 119 230)
 *
 * SPDX-License-Identifier: BSD-2-Clause
Proofs tidied by LCP, 2024-09
 *)

section "Enumeration extensions and alternative definition"

theory Enumeration
imports Main
begin

abbreviation
  "enum  enum_class.enum"
abbreviation
  "enum_all  enum_class.enum_all"
abbreviation
  "enum_ex  enum_class.enum_ex"

primrec (nonexhaustive)
  the_index :: "'a list  'a  nat"
where
  "the_index (x # xs) y = (if x = y then 0 else Suc (the_index xs y))"

lemma the_index_bounded:
  "x  set xs  the_index xs x < length xs"
  by (induct xs, clarsimp+)

lemma nth_the_index:
  "x  set xs  xs ! the_index xs x = x"
  by (induct xs, clarsimp+)

lemma distinct_the_index_is_index[simp]:
  " distinct xs ; n < length xs   the_index xs (xs ! n) = n"
  by (meson nth_eq_iff_index_eq nth_mem nth_the_index the_index_bounded)

lemma the_index_last_distinct:
  "distinct xs  xs  []  the_index xs (last xs) = length xs - 1"
  by (simp add: last_conv_nth)

context enum begin

(* These two are added for historical reasons. *)
lemmas enum_surj[simp] = enum_UNIV
declare enum_distinct[simp]

lemma enum_nonempty[simp]: "(enum :: 'a list)  []"
  using enum_surj by fastforce

definition
  maxBound :: 'a where
  "maxBound  last enum"

definition
  minBound :: 'a where
  "minBound  hd enum"

definition
  toEnum :: "nat  'a" where
  "toEnum n  if n < length (enum :: 'a list) then enum ! n else the None"

definition
  fromEnum :: "'a  nat" where
  "fromEnum x  the_index enum x"


lemma maxBound_is_length:
  "fromEnum maxBound = length (enum :: 'a list) - 1"
  by (simp add: maxBound_def fromEnum_def the_index_last_distinct)

lemma maxBound_less_length:
  "(x  fromEnum maxBound) = (x < length (enum :: 'a list))"
  unfolding maxBound_is_length by (cases "length enum") auto

lemma maxBound_is_bound [simp]:
  "fromEnum x  fromEnum maxBound"
  unfolding maxBound_less_length
  by (fastforce simp: fromEnum_def intro: the_index_bounded)

lemma to_from_enum [simp]:
  fixes x :: 'a
  shows "toEnum (fromEnum x) = x"
proof -
  have "x  set enum" by simp
  then show ?thesis by (simp add: toEnum_def fromEnum_def nth_the_index the_index_bounded)
qed

lemma from_to_enum [simp]:
  "x  fromEnum maxBound  fromEnum (toEnum x) = x"
  unfolding maxBound_less_length by (simp add: toEnum_def fromEnum_def)

lemma map_enum:
  fixes x :: 'a
  shows "map f enum ! fromEnum x = f x"
proof -
  have "fromEnum x  fromEnum (maxBound :: 'a)"
    by (rule maxBound_is_bound)
  then have "fromEnum x < length (enum::'a list)"
    by (simp add: maxBound_less_length)
  then have "map f enum ! fromEnum x = f (enum ! fromEnum x)" by simp
  also
  have "x  set enum" by simp
  then have "enum ! fromEnum x = x"
    by (simp add: fromEnum_def nth_the_index)
  finally
  show ?thesis .
qed

definition
  assocs :: "('a  'b)  ('a × 'b) list" where
 "assocs f  map (λx. (x, f x)) enum"

end

(* For historical naming reasons. *)
lemmas enum_bool = enum_bool_def

lemma fromEnumTrue [simp]: "fromEnum True = 1"
  by (simp add: fromEnum_def enum_bool)

lemma fromEnumFalse [simp]: "fromEnum False = 0"
  by (simp add: fromEnum_def enum_bool)


class enum_alt =
  fixes enum_alt :: "nat  'a option"

class enumeration_alt = enum_alt +
  assumes enum_alt_one_bound:
    "enum_alt x = (None :: 'a option)  enum_alt (Suc x) = (None :: 'a option)"
  assumes enum_alt_surj:
    "range enum_alt  {None} = UNIV"
  assumes enum_alt_inj:
    "(enum_alt x :: 'a option) = enum_alt y  (x = y)  (enum_alt x = (None :: 'a option))"
begin

lemma enum_alt_inj_2:
  assumes "enum_alt x = (enum_alt y :: 'a option)"
          "enum_alt x  (None :: 'a option)"
  shows "x = y"
proof -
  from assms
  have "(x = y)  (enum_alt x = (None :: 'a option))" by (fastforce intro!: enum_alt_inj)
  with assms show ?thesis by clarsimp
qed

lemma enum_alt_surj_2:
  "x. enum_alt x = Some y"
proof -
  have "Some y  range enum_alt  {None}" by (subst enum_alt_surj) simp
  then have "Some y  range enum_alt" by simp
  then show ?thesis by auto
qed

end

definition
  alt_from_ord :: "'a list  nat  'a option"
where
  "alt_from_ord L  λn. if (n < length L) then Some (L ! n) else None"

(*Seemingly redundant, but heavily used elsewhere*)
lemma handy_if_lemma: "((if P then Some A else None) = Some B) = (P  (A = B))"
  by simp

class enumeration_both = enum_alt + enum +
  assumes enum_alt_rel: "enum_alt = alt_from_ord enum"

lemma the_index_less_length: "the_index (enum::'a::enum list) x < length (enum::'a::enum list)"
  by (rule the_index_bounded, simp)

lemma enum_if_enum:
  defines "(e::'a::enum list)  enum"
  shows
    "(if x < length e then Some (e ! x) else None) = Some (e ! y) 
           y < length e  x = y"
  by (simp add: e_def split: if_split_asm flip: nth_eq_iff_index_eq [where xs=e])

instance enumeration_both < enumeration_alt
  apply (intro_classes)
    apply (simp_all add: enum_alt_rel alt_from_ord_def enum_if_enum split: if_split_asm)
  apply (safe; simp)[1]
  apply (intro rev_image_eqI; simp)
   apply (rule the_index_less_length)
  apply (subst nth_the_index; simp)
  done

instantiation bool :: enumeration_both
begin
  definition enum_alt_bool: "enum_alt  alt_from_ord [False, True]"
  instance by (intro_classes, simp add: enum_bool_def enum_alt_bool)
end

definition
  toEnumAlt :: "nat  ('a :: enum_alt)" where
 "toEnumAlt n  the (enum_alt n)"

definition
  fromEnumAlt :: "('a :: enum_alt)  nat" where
 "fromEnumAlt x  THE n. enum_alt n = Some x"

definition
  upto_enum :: "('a :: enumeration_alt)  'a  'a list"
    ((‹indent=1 notation=‹mixfix upto_enum››[_ .e. _]))
  where "[n .e. m]  map toEnumAlt [fromEnumAlt n ..< Suc (fromEnumAlt m)]"

lemma fromEnum_alt_red[simp]:
  "fromEnumAlt = (fromEnum :: ('a :: enumeration_both)  nat)"
  apply (rule ext)
  apply (simp add: fromEnumAlt_def fromEnum_def enum_alt_rel alt_from_ord_def)
  apply (rule theI2)
    apply (rule conjI)
     apply (clarify, rule nth_the_index)
    apply (auto simp: enum_if_enum the_index_less_length)
  done

lemma toEnum_alt_red[simp]:
  "toEnumAlt = (toEnum :: nat  'a :: enumeration_both)"
  by (rule ext) (simp add: enum_alt_rel alt_from_ord_def toEnum_def toEnumAlt_def)

lemma upto_enum_red:               
  "[(n :: ('a :: enumeration_both)) .e. m] = map toEnum [fromEnum n ..< Suc (fromEnum m)]"
  unfolding upto_enum_def by simp

instantiation nat :: enumeration_alt
begin
  definition enum_alt_nat: "enum_alt  Some"
  instance by (intro_classes; simp add: enum_alt_nat UNIV_option_conv)
end

lemma toEnumAlt_nat[simp]: "toEnumAlt = id"
  by (rule ext) (simp add: toEnumAlt_def enum_alt_nat)

lemma fromEnumAlt_nat[simp]: "fromEnumAlt = id"
  by (rule ext) (simp add: fromEnumAlt_def enum_alt_nat)

lemma upto_enum_nat[simp]: "[n .e. m] = [n ..< Suc m]"
  by (subst upto_enum_def) simp

definition
  zipE1 :: "'a :: enum_alt  'b list  ('a × 'b) list"
where
  "zipE1 x L  zip (map toEnumAlt [fromEnumAlt x ..< fromEnumAlt x + length L]) L"

definition
  zipE2 :: "'a :: enum_alt  'a  'b list  ('a × 'b) list"
where
  "zipE2 x xn L  zip (map (λn. toEnumAlt (fromEnumAlt x + (fromEnumAlt xn - fromEnumAlt x) * n))
                      [0 ..< length L]) L"

definition
  zipE3 :: "'a list  'b :: enum_alt  ('a × 'b) list"
where
  "zipE3 L x  zip L (map toEnumAlt [fromEnumAlt x ..< fromEnumAlt x + length L])"

definition
  zipE4 :: "'a list  'b :: enum_alt  'b  ('a × 'b) list"
where
  "zipE4 L x xn  zip L (map (λn. toEnumAlt (fromEnumAlt x + (fromEnumAlt xn - fromEnumAlt x) * n))
                         [0 ..< length L])"


lemma to_from_enum_alt[simp]:
  "toEnumAlt (fromEnumAlt x) = (x :: 'a :: enumeration_alt)"
proof -
  have rl: "a b. a = Some b  the a = b" by simp
  show ?thesis
    unfolding fromEnumAlt_def toEnumAlt_def
    by (rule rl, rule theI') (metis enum_alt_inj enum_alt_surj_2 not_None_eq)
qed

lemma upto_enum_triv [simp]: "[x .e. x] = [x]"
  unfolding upto_enum_def by simp

lemma toEnum_eq_to_fromEnum_eq:
  fixes v :: "'a :: enum"
  shows "n  fromEnum (maxBound :: 'a)  (toEnum n = v) = (n = fromEnum v)"
  by auto

lemma le_imp_diff_le:
  "(j::nat)  k  j - n  k"
  by simp

lemma fromEnum_upto_nth:
  fixes start :: "'a :: enumeration_both"
  assumes "n < length [start .e. end]"
  shows "fromEnum ([start .e. end] ! n) = fromEnum start + n"
proof -
  have less_sub: "m k m' n.  (n::nat) < m - k ; m  m'   n < m' - k" by fastforce
  note upt_Suc[simp del]
  show ?thesis using assms
  by (fastforce simp: upto_enum_red
                dest: less_sub[where m'="Suc (fromEnum maxBound)"] intro: maxBound_is_bound)
qed

lemma length_upto_enum_le_maxBound:
  fixes start :: "'a :: enumeration_both"
  shows "length [start .e. end]  Suc (fromEnum (maxBound :: 'a))"
  by (simp add: le_imp_diff_le upto_enum_red)

lemma less_length_upto_enum_maxBoundD:
  fixes start :: "'a :: enumeration_both"
  assumes "n < length [start .e. end]"
  shows "n  fromEnum (maxBound :: 'a)"
  using assms
  by (simp add: upto_enum_red less_Suc_eq_le
                le_trans[OF _ le_imp_diff_le[OF maxBound_is_bound[of "end"]]]
           split: if_splits)

lemma fromEnum_eq_iff:
  "(fromEnum e = fromEnum f) = (e = f)"
proof -
  have a: "e  set enum" by auto
  have b: "f  set enum" by auto
  from nth_the_index[OF a] nth_the_index[OF b] show ?thesis unfolding fromEnum_def by metis
qed

lemma maxBound_is_bound':
  "i = fromEnum (e::('a::enum))  i  fromEnum (maxBound::('a::enum))"
  by clarsimp

end