Theory Ix
theory Ix
imports
Main
begin
subsection‹ A Haskell-like ∗‹Ix› class\label{sec:Ix} ›
text‹
We allow arbitrary indexing schemes for user-facing arrays via the
‹Ix› class, which essentially represents a bijection
between a subset of an arbitrary type and an initial segment of the
naturals.
Source materials:
▪ Haskell 2010 report: 🌐‹https://www.haskell.org/onlinereport/haskell2010/haskellch19.html›
▪ GHC implementation: 🌐‹https://hackage.haskell.org/package/base-4.16.0.0/docs/src/GHC.Ix.html›
▪ Haskell pure arrays (just for colour): 🌐‹https://www.haskell.org/onlinereport/haskell2010/haskellch14.html›
▪ SML 2D arrays: 🌐‹https://smlfamily.github.io/Basis/array2.html›
Observations:
▪ follow Haskell convention here: include the bounds
▪ could alternatively use an array of one-dimensional arrays but those are not necessarily rectangular
▪ we can't use \<^class>‹enum› as that requires the whole type to be enumerable
Limitations:
▪ the basic design assumes laziness; we don't ever want to build the list of indices
▪ can be improved either by tweaking the code generator setup or changing the constants here
▪ array indices typically have partial predecessor and successor operations and are totally ordered on their domain
▪ no guarantee the ‹interval› is correct (does not prevent off-by-one errors in instances)
›
class Ix =
fixes interval :: "'a × 'a ⇒ 'a list"
fixes index :: "'a × 'a ⇒ 'a ⇒ nat"
assumes index: "i ∈ set (interval b) ⟹ interval b ! index b i = i"
assumes interval: "map (index b) (interval b) = [0..<length (interval b)]"
lemma index_length:
assumes "i ∈ set (interval b)"
shows "index b i < length (interval b)"
proof -
from assms[unfolded in_set_conv_nth]
obtain j where "j < length (interval b)" and "interval b ! j = i"
by blast
with arg_cong[where f="λx. List.nth x j", OF interval[of b]] show ?thesis
by simp
qed
lemma distinct_interval:
shows "distinct (interval b)"
by (metis distinct_map distinct_upt interval)
lemma inj_on_index:
shows "inj_on (index b) (set (interval b))"
by (metis distinct_map distinct_upt interval)
lemma index_eq_conv:
assumes "i ∈ set (interval b)"
assumes "j ∈ set (interval b)"
shows "index b i = index b j ⟷ i = j"
by (metis assms index)
lemma index_inv_into:
assumes "i < length (interval b)"
shows "inv_into (set (interval b)) (index b) i ∈ set (interval b)"
by (metis assms add.left_neutral inv_into_into length_map list.set_map interval nth_mem nth_upt)
lemma linear_order_on:
shows "linear_order_on (set (interval b)) {(i, j). {i, j} ⊆ set (interval b) ∧ index b i ≤ index b j}"
by (force simp: linear_order_on_def partial_order_on_def preorder_on_def refl_on_def total_on_def
intro: transI antisymI
dest: index)
lemma interval_map:
shows "map (λi. f (interval b ! i)) [0..<length (interval b)] = map f (interval b)"
by (simp add: map_equality_iff)
lemma index_forE:
assumes "i < length (interval b)"
obtains j where "j ∈ set (interval b)" and "index b j = i"
using assms index index_length nth_eq_iff_index_eq[OF distinct_interval] nth_mem[OF assms] by blast
instantiation unit :: Ix
begin
definition "interval_unit = (λ(x::unit, y::unit). [()])"
definition "index_unit = (λ(x::unit, y::unit) _::unit. 0::nat)"
instance by standard (auto simp: interval_unit_def index_unit_def)
end
instantiation nat :: Ix
begin
definition "interval_nat = (λ(l, u::nat). [l..<Suc u])"
definition "index_nat = (λ(l, u::nat) i::nat. i - l)"
lemma upt_minus:
shows "map (λi. i - l) [l..<u] = [0..<u - l]"
by (induct u) (auto simp: Suc_diff_le)
instance by standard (auto simp: interval_nat_def index_nat_def upt_minus nth_append)
end
instantiation int :: Ix
begin
definition "interval_int = (λ(l, u::int). [l..u])"
definition "index_int = (λ(l, u::int) i::int. nat (i - l))"
lemma upto_minus:
shows "map (λi. nat (i - l)) [l..u] = [0..<nat (u - l + 1)]"
proof(induct "nat(u - l + 1)" arbitrary: u)
case (Suc i)
from Suc.hyps(1)[of "u - 1"] Suc.hyps(2) show ?case
by (simp add: upto_rec2 ac_simps Suc_nat_eq_nat_zadd1 flip: upt_Suc_append)
qed simp
instance by standard (auto simp: interval_int_def index_int_def upto_minus)
end
type_synonym ('i, 'j) two_dim = "('i × 'j) × ('i × 'j)"
instantiation prod :: (Ix, Ix) Ix
begin
definition "interval_prod = (λ((l, l'), (u, u')). List.product (interval (l, u)) (interval (l', u')))"
definition "index_prod = (λ((l, l'), (u, u')) (i, i'). index (l, u) i * length (interval (l', u')) + index (l', u') i')"
abbreviation (input) fst_bounds :: "('a × 'b) × ('a × 'b) ⇒ ('a × 'a)" where
"fst_bounds b ≡ (fst (fst b), fst (snd b))"
abbreviation (input) snd_bounds :: "('a × 'b) × ('a × 'b) ⇒ ('b × 'b)" where
"snd_bounds b ≡ (snd (fst b), snd (snd b))"
lemma inj_on_index_prod:
shows "inj_on (index ((l, l'), (u, u'))) (set (interval ((l, l'), (u, u'))))"
by (clarsimp simp: inj_on_def interval_prod_def index_prod_def)
(metis index index_length length_pos_if_in_set add_diff_cancel_right'
div_mult_self_is_m mod_less mod_mult_self3)
instance
proof
show "interval b ! index b i = i" if "i ∈ set (interval b)" for b and i :: "'a × 'b"
proof -
have *: "i * n + j < m * n" if "i < m" and "j < n"
for i j m n :: nat
using that by (metis bot_nat_0.extremum_strict div_less div_less_iff_less_mult div_mult_self3 nat_arith.rule0 not_gr_zero)
from that
have "index (fst_bounds b) (fst i) * length (interval (snd_bounds b))
+ index (snd_bounds b) (snd i)
< length (interval (fst_bounds b)) * length (interval (snd_bounds b))"
by (clarsimp simp: interval_prod_def index_prod_def * dest!: index_length)
then show ?thesis
using that length_pos_if_in_set
by (fastforce simp: interval_prod_def index_prod_def List.product_nth index index_length)
qed
show "map (index b) (interval b) = [0..<length (interval b)]" for b :: "('a × 'b) × ('a × 'b)"
by (rule iffD2[OF list_eq_iff_nth_eq])
(clarsimp simp: interval_prod_def index_prod_def split_def product_nth ac_simps;
metis (no_types, lifting) distinct_interval index index_length length_pos_if_in_set nth_mem
less_mult_imp_div_less mod_div_mult_eq mod_less_divisor mult.commute nth_eq_iff_index_eq)
qed
end
setup ‹Sign.mandatory_path "Ix"›
setup ‹Sign.mandatory_path "prod"›
lemma interval_conv:
shows "(x, y) ∈ set (interval b) ⟷ x ∈ set (interval (fst_bounds b)) ∧ y ∈ set (interval (snd_bounds b))"
by (force simp: interval_prod_def)
setup ‹Sign.parent_path›
type_synonym 'i square = "('i, 'i) two_dim"
definition square :: "'i::Ix Ix.square ⇒ bool" where
"square = (λ((l, l'), (u, u')). Ix.interval (l, u) = Ix.interval (l', u'))"
setup ‹Sign.mandatory_path "square"›
lemma conv:
assumes "Ix.square b"
shows "i ∈ set (Ix.interval (fst_bounds b)) ⟷ i ∈ set (Ix.interval (snd_bounds b))"
using assms by (clarsimp simp: Ix.square_def)
setup ‹Sign.parent_path›
setup ‹Sign.parent_path›
hide_const (open) interval index
end