Theory AutoCorres2.Tuple_Tools
chapter "Proof Tools"
theory Tuple_Tools
imports Main
"Misc_Antiquotation"
TermPatternAntiquote
ML_Fun_Cache
begin
section ‹Tools for handling Tuples›
text ‹
The tools are supposed to help simplifiying expressions containing @{const case_prod}
constructions, like @{term "λ(w, x, y, z). f w x y z"}.
▪ Simprocs for single step splitting and simplification of tuples / ‹case_prods› instead of repeated
application of the built-in variants for pairs
▪ Looper for the simplifier to instantiate @{term f} in @{term "f (a,b,c)"} with
@{term "λ(a,b,c). g a b c"}
›
subsection ‹Antiquotations for terms with schematic variables›
setup ‹
let
val parser = Args.context -- Scan.lift Parse.embedded_inner_syntax
fun pattern (ctxt, str) =
str |> Proof_Context.read_term_pattern ctxt
|> ML_Syntax.print_term
|> ML_Syntax.atomic
in
ML_Antiquotation.inline @{binding "pattern"} (parser >> pattern)
end
›
ML ‹@{pattern "λ(a,b,c). ?g a b c"}›
setup ‹
let
type style = term -> term;
fun pretty_term_style ctxt (style: style, t) =
Document_Output.pretty_term ctxt (style t);
val basic_entity = Document_Output.antiquotation_pretty_source;
in
(basic_entity \<^binding>‹pattern› (Term_Style.parse -- Args.term_pattern) pretty_term_style)
end
›
text ‹Unification only works modulo ordinary beta reduction but does not succeed on
tupled-beta reduction. For example the simplifier will not find an instatiation for the variable
@{pattern "?c"}
›
schematic_goal "⋀a b. (case (a, b) of (x, y) ⇒ f x y) ≡ ?c (a, b)"
apply simp
oops
text ‹Such equations can typically occur in congruence rules when a pair is splitted by
@{thm split_paired_all}.›
lemma tuple_up_eq_trivial: "f r ≡ (λx. f x) r"
by simp
thm tuple_up_eq_trivial
lemma tuple_up_eq2: "f (fst r) (snd r) ≡ (case r of (x1, x2) ⇒ f x1 x2)"
by (simp add: split_def)
text ‹Constant to explictly trigger splitting by simproc ▩‹SPLIT_simproc› below›
definition SPLIT :: "'a::{} ⇒ 'a"
where "SPLIT P ≡ P"
lemma SPLIT_cong: "PROP SPLIT P ≡ PROP SPLIT P"
by simp
lemma case_prod_out: "(λr. f ((λ(a,b). (g a b)) r)) = (λ(a, b). f (g a b))"
apply (rule ext)
apply (simp add: split_paired_all)
done
lemma case_prod_eta_contract: "(λx. (case_prod s) x) ≡ (case_prod s)"
by simp
definition ETA_TUPLED :: "('a ⇒ 'b) ⇒ ('a ⇒ 'b)"
where "ETA_TUPLED f ≡ f"
lemma ETA_TUPLED_trans: "f ≡ g ⟹ ETA_TUPLED f ≡ g"
by (simp add: ETA_TUPLED_def)
text ‹Add the congruence rule @{thm SPLIT_cong} to the simpset together with the simproc
to avoid descending into @{term P} before the split.›
ML ‹
structure Bin =
struct
type Bin = bool list
fun bin_of_int n =
if n < 0 then error ("bin: " ^ "cannot convert negativ number '" ^ string_of_int n ^ "'")
else map (fn 0 => false | _ => true) (radixpand (2,n));
fun int_of_bin bs =
let
fun int_of pos [] = 0
| int_of pos (false::bs) = int_of (2 * pos) bs
| int_of pos (true::bs) = pos + int_of (2 * pos) bs
in int_of 1 (rev bs) end
fun string_of_bin bs =
let
fun bit true = #"1"
| bit false = #"0"
in
"0b" ^ String.implode (map bit bs)
end
val bin_ord = list_ord bool_ord;
structure Bintab = Table(type key = Bin val ord = bin_ord);
val _ = @{assert} (int_of_bin (bin_of_int 1) = 1)
val _ = @{assert} (int_of_bin (bin_of_int 22) = 22)
val _ = @{assert} (int_of_bin (bin_of_int 43) = 43)
val _ = @{assert} (is_none (try bin_of_int ~1))
val _ = @{assert} (string_of_bin (bin_of_int 0) = "0b0")
val _ = @{assert} (string_of_bin (bin_of_int 1) = "0b1")
val _ = @{assert} (string_of_bin (bin_of_int 7) = "0b111")
val _ = @{assert} (string_of_bin (bin_of_int 23) = "0b10111")
end
›
ML ‹
structure Tuple_Tools = struct
fun assert_cterm' ctxt t ct =
let
val t1 = Envir.beta_eta_contract (Thm.term_of (Thm.cterm_of ctxt t))
val t2 = Envir.beta_eta_contract (Thm.term_of ct)
in @{assert} (Term.aconv_untyped (t1, t2)) end
val assert_cterm = assert_cterm' @{context}
fun assert_term t1 t2 = @{assert} (t1 = t2)
fun gen_mk_prod (t1, T1) (t2, T2) =
(HOLogic.pair_const T1 T2 $ t1 $ t2, HOLogic.mk_prodT (T1, T2))
fun gen_mk_tuple [] = (@{term "()"}, @{typ "unit"})
| gen_mk_tuple [(t, T)] = (t, T)
| gen_mk_tuple (t::ts) = gen_mk_prod t (gen_mk_tuple ts)
fun mk_elT_named n i = TFree (n ^ string_of_int i, @{sort "type"});
val mk_elT = mk_elT_named "'a"
fun mk_elT' Ts i = nth Ts (i-1);
fun mk_tupleT' Ts i = HOLogic.mk_tupleT (drop (i - 1) Ts);
fun mk_tupleT n i = mk_tupleT' (map mk_elT (1 upto n)) i;
fun mk_el_name_named x i = x ^ string_of_int i;
val mk_el_name = mk_el_name_named "x"
fun mk_el' Ts i = Free (mk_el_name i, mk_elT' Ts i);
fun mk_el_named' x Ts i = Free (mk_el_name_named x i, mk_elT' Ts i);
fun mk_el_named x i = Free (mk_el_name_named x i, mk_elT i);
fun mk_el i = Free (mk_el_name i, mk_elT i);
fun mk_tuple_bounds n = (1 upto n) |> map (fn i => (Bound (n - i), mk_elT i)) |> gen_mk_tuple |> fst
val mk_tuple_packed_name = "r";
fun mk_tuple n i = HOLogic.mk_tuple (map mk_el (i upto n));
fun mk_tuple_named x n i = HOLogic.mk_tuple (map (mk_el_named x) (i upto n));
fun mk_tuple' Ts n i = HOLogic.mk_tuple (map (mk_el' Ts) (i upto n))
fun mk_tuple_named' x Ts n i = HOLogic.mk_tuple (map (mk_el_named' x Ts) (i upto n))
fun mk_tuple_packed n = Free (mk_tuple_packed_name, mk_tupleT n 1);
fun mk_P n = Free ("P", map mk_elT (1 upto n) ---> @{typ "bool"});
fun mk_P_frees n = list_comb (mk_P n, map mk_el (1 upto n));
fun mk_P_bounds n = list_comb (mk_P n, map Bound (n-1 downto 0));
fun mk_Prop_P n = Free ("P", mk_tupleT n 1 --> @{typ "prop"});
fun mk_fun (name,resT) i n = Free (name, map mk_elT (i upto (i+n-1)) ---> resT);
fun mk_fun_bounds (name, resT) i n = list_comb (mk_fun (name, resT) i n, map Bound (n-1 downto 0));
fun mk_fun_with_types Ts (name,resT) = Free (name, map (mk_elT' Ts) (1 upto (length Ts)) ---> resT);
fun mk_fun_bounds_with_types Ts (name, resT) = list_comb (mk_fun_with_types Ts (name, resT), map Bound ((length Ts - 1) downto 0));
val mk_f_resT = @{typ "'b"};
val mk_f_name = "f";
fun mk_f n = Free (mk_f_name, map mk_elT (1 upto n) ---> mk_f_resT);
fun mk_f_frees n = list_comb (mk_f n, map mk_el (1 upto n));
fun mk_f_bounds n = list_comb (mk_f n, map Bound (n-1 downto 0));
fun mk_f_tupled n = Free (mk_f_name, mk_tupleT n 1 --> mk_f_resT)
fun mk_selT n i = mk_tupleT n 1 --> mk_elT i;
fun mk_selT' Ts i = mk_tupleT' Ts i --> mk_elT' Ts i;
fun mk_snd n i = Const (@{const_name ‹snd›}, mk_tupleT n i --> mk_tupleT n (i+1));
fun mk_snd' Ts i = Const (@{const_name ‹snd›}, mk_tupleT' Ts i --> mk_tupleT' Ts (i+1));
fun mk_fst n i = Const (@{const_name ‹fst›}, mk_tupleT n i --> mk_elT i);
fun mk_fst' Ts i = Const (@{const_name ‹fst›}, mk_tupleT' Ts i --> mk_elT' Ts i);
fun mk_snds' Ts i r =
let
val n = length Ts;
in
if n = 1 then r
else
if i = (n-1) then mk_snd' Ts 1 $ r
else mk_snd' Ts (n - i) $ (mk_snds' Ts (i+1) r)
end;
fun mk_snds n i r =
if i = (n-1) then mk_snd n 1 $ r
else mk_snd n (n - i) $ (mk_snds n (i+1) r)
fun mk_sel' Ts r i =
let
val n = length Ts;
in
if n = 1 then r
else
if i = 1 then mk_fst' Ts 1 $ r
else if i = n then mk_snds' Ts 1 r
else mk_fst' Ts i $ mk_snds' Ts (n - i + 1) r
end;
val _ = assert_term (mk_sel' ([@{typ 'a}, @{typ 'b}, @{typ 'c}, @{typ 'd}]) (Bound 0) 4)
(Const (@{const_name snd}, @{typ "'c × 'd ⇒ 'd"}) $
(Const (@{const_name snd}, @{typ "'b × 'c × 'd ⇒ 'c × 'd"}) $
(Const (@{const_name snd}, @{typ "'a × 'b × 'c × 'd ⇒ 'b × 'c × 'd"}) $ Bound 0)))
val _ = assert_term (mk_sel' ([@{typ 'a}, @{typ 'b}, @{typ 'c}, @{typ 'd}]) (Bound 0) 3)
(Const (@{const_name fst}, @{typ "'c × 'd ⇒ 'c"}) $
(Const (@{const_name snd}, @{typ "'b × 'c × 'd ⇒ 'c × 'd"}) $
(Const (@{const_name snd}, @{typ "'a × 'b × 'c × 'd ⇒ 'b × 'c × 'd"}) $ Bound 0)))
fun mk_sel n r i =
if i = 1 then mk_fst n 1 $ r
else if i = n then mk_snds n 1 r
else mk_fst n i $ mk_snds n (n - i + 1) r
fun mk_eq n r i =
HOLogic.mk_eq (mk_sel n r i, mk_el i);
fun mk_case_prod_from T t i n =
let
fun mk_cp n i =
if (n - i) <= 1 then
HOLogic.case_prod_const (mk_elT i, mk_elT n, T) $
(Abs (mk_el_name i, mk_elT i, Abs (mk_el_name n, mk_elT n, t)))
else
HOLogic.case_prod_const (mk_elT i, mk_tupleT n (i+1), T) $
(Abs (mk_el_name i, mk_elT i, mk_cp n (i+1)))
in mk_cp (n+i-1) i end;
fun mk_case_prod T t n =
mk_case_prod_from T t 1 n
fun mk_case_prod_with_types Ts T t =
let
val n = length Ts
fun mk_cp n i =
if (n - i) <= 1 then
HOLogic.case_prod_const (mk_elT' Ts i, mk_elT' Ts n, T) $
(Abs (mk_el_name i, mk_elT' Ts i, Abs (mk_el_name n, mk_elT' Ts n, t)))
else
HOLogic.case_prod_const (mk_elT' Ts i, mk_tupleT' Ts (i+1), T) $
(Abs (mk_el_name i, mk_elT' Ts i, mk_cp n (i+1)))
in mk_cp n 1 end;
fun mk_case_prod_P n = mk_case_prod @{typ bool} (mk_P_bounds n) n;
fun mk_case_prod_f n = mk_case_prod mk_f_resT (mk_f_bounds n) n;
fun mk_case_prod_fun (name, resT) i n = mk_case_prod_from resT (mk_fun_bounds (name, resT) i n) i n;
fun mk_case_prod_fun_with_types Ts (name, resT) = mk_case_prod_with_types Ts resT (mk_fun_bounds_with_types Ts (name, resT));
fun mk_case_prod_f_tupled_bounds n = mk_case_prod mk_f_resT (mk_f_tupled n $ mk_tuple_bounds n) n
fun argTs_of_TVar n [(TVar ((x,_),_))] = map (mk_elT_named x) (1 upto n)
| argTs_of_TVar _ Ts = Ts
fun mk_case_prod' name T n =
if n <= 1 then Free (name, T)
else
let
val (domT, rangeT) = dest_funT T
val argTs = HOLogic.strip_tupleT domT |> argTs_of_TVar n
in
if length argTs = n
then mk_case_prod_fun_with_types argTs (name, rangeT)
else mk_case_prod_fun (name, rangeT) 1 n
end
fun mk_tuple_from_type name T n =
if n <= 1 then Free (name, T)
else
let
val Ts = HOLogic.strip_tupleT T
in
if length Ts = n
then mk_tuple_named' name Ts n 1
else mk_tuple_named name n 1
end
fun mk_tuple_or_case_prod name T n =
if is_some (try dest_funT T)
then mk_case_prod' name T n
else mk_tuple_from_type name T n
fun mk_f_sels n r = list_comb (mk_f n, map (mk_sel n r) (1 upto n));
fun mk_tuple_up_eq n =
let
val r = mk_tuple_packed n;
val lhs = mk_f_sels n r;
val rhs = mk_case_prod_f n $ r;
in Logic.mk_equals (lhs, rhs) end
val _ = assert_cterm (mk_tuple_up_eq 3)
@{cterm "f (fst r) (fst (snd r)) (snd (snd r)) ≡ case r of (x1, x2, x3) ⇒ f x1 x2 x3"}
fun mk_tuple_up_eq_thm ctxt n =
let
val fixes = [mk_tuple_packed n,mk_f n]
|> map (fn (Free (x, _)) => x);
in
Goal.prove ctxt fixes [] (mk_tuple_up_eq n) (fn _ =>
simp_tac (put_simpset HOL_basic_ss ctxt addsimps [mk_meta_eq @{thm split_def}]) 1)
end
fun mk_tuple_case_eq n =
let
val lhs = mk_case_prod_f n $ mk_tuple n 1
val rhs = mk_f_frees n
in
Logic.mk_equals (lhs, rhs)
end
val _ = assert_cterm (mk_tuple_case_eq 3)
@{cterm "case (x1, x2, x3) of (x1, x2, x3) ⇒ f x1 x2 x3 ≡ f x1 x2 x3"}
fun mk_tuple_case_eq_thm ctxt n =
let
val fixes = mk_f n::map mk_el (1 upto n)
|> map (fn (Free (x, _)) => x)
in
Goal.prove ctxt fixes [] (mk_tuple_case_eq n) (fn _ =>
simp_tac (put_simpset HOL_basic_ss ctxt addsimps [@{thm Product_Type.prod.case}]) 1)
end
fun num_case_prod t =
case strip_comb t of
(Const (@{const_name case_prod},_) , (Abs (_, _, cp))::_) => 1 + num_case_prod cp
| (Const (@{const_name case_prod},_) , _) => 1
| _ => 0;
val strip_uu = prefix "strip_" Name.uu_
fun strip_case_prod t =
case strip_comb t of
(Const (@{const_name case_prod},_), Abs (x, xT, (Abs (y, yT, _)))::_) =>
[(x,xT), (y, yT)]
| (Const (@{const_name case_prod},cpT), Abs (x, xT, cp)::_) =>
(case strip_comb cp of
(Const (@{const_name case_prod},_), _) => (x,xT)::strip_case_prod cp
| _ =>
([(x,xT), (strip_uu, nth (binder_types cpT) 1)]))
| (Const (@{const_name case_prod}, cpT), cp::_) =>
let
val [xT, yT] = take 2 (binder_types (domain_type cpT))
in [(strip_uu, xT), (strip_uu, yT)] end
| _ => [];
fun strip_case_prod' (Abs (x, xT, _)) = [(x, xT)]
| strip_case_prod' t = strip_case_prod t
fun mk_split_tupled_all_lhs n =
Logic.list_all ([(mk_tuple_packed_name, mk_tupleT n 1)], (mk_Prop_P n) $ Bound 0)
fun mk_prod_bound (i,iT) (t,T) =
HOLogic.pair_const iT T $ Bound i $ t
fun mk_prod_bounds [(i, iT), (j, jT)] = mk_prod_bound (i,iT) (Bound j, jT)
| mk_prod_bounds ((i,iT)::xs) =
let
val t = mk_prod_bounds xs
in mk_prod_bound (i,iT) (t, fastype_of t) end
fun mk_split_tupled_all_rhs n =
let
val _ = @{assert} (n >= 2)
in
Logic.list_all (map (fn i => (mk_el_name i, mk_elT i)) (1 upto n),
(mk_Prop_P n) $ mk_prod_bounds (map (fn i => (n-i, mk_elT i)) (1 upto n)))
end
fun mk_split_tupled_all n =
Logic.mk_equals (mk_split_tupled_all_lhs n, mk_split_tupled_all_rhs n)
val _ = assert_cterm (mk_split_tupled_all 3)
@{cterm "(⋀r. PROP P r) ≡ (⋀x1 x2 x3. PROP P (x1, x2, x3))"}
fun mk_split_tupled_all_thm ctxt n =
Goal.prove ctxt ["P"] [] (mk_split_tupled_all n) (fn _ => simp_tac
(put_simpset HOL_basic_ss ctxt addsimps @{thms split_paired_all}) 1)
fun mk_eta_tupled_eq n =
let
val lhs = mk_f_tupled n
val rhs = mk_case_prod_f_tupled_bounds n
val eq = Logic.mk_equals (lhs, rhs)
in
eq
end
fun mk_eta_tupled_eq_thm ctxt n =
Goal.prove @{context} [mk_f_name] [] (mk_eta_tupled_eq n) (fn _ =>
simp_tac @{context} 1) |> Thm.transfer' ctxt
structure Data = Generic_Data(
type T = (thm * thm * thm * thm) list;
val empty = [];
fun merge (thms1, thms2) = if length thms1 >= length thms2 then thms1 else thms2;
);
fun get_tuple_up_eq_thm ctxt n =
if n <= 1 then @{thm tuple_up_eq_trivial}
else
Thm.transfer' ctxt (#1 (nth (Data.get (Context.Proof ctxt)) (n - 2)))
handle Subscript => mk_tuple_up_eq_thm ctxt n
fun get_split_tupled_all_thm ctxt n =
if n <= 1 then error ("get_split_tupled_all only makes sense for n >= 2")
else
Thm.transfer' ctxt (#2 (nth (Data.get (Context.Proof ctxt)) (n - 2)))
handle Subscript => mk_split_tupled_all_thm ctxt n
fun get_tuple_case_eq_thm ctxt n =
if n <= 1 then error ("get_tuple_case_eq_thm only makes sense for n >= 2")
else
Thm.transfer' ctxt (#3 (nth (Data.get (Context.Proof ctxt)) (n - 2)))
handle Subscript => mk_tuple_case_eq_thm ctxt n
fun get_eta_tupled_eq_thm ctxt n =
if n <= 1 then @{thm reflexive}
else
Thm.transfer' ctxt (#4 (nth (Data.get (Context.Proof ctxt)) (n - 2)))
handle Subscript => mk_eta_tupled_eq_thm ctxt n
fun liberal_zip (x::xs) (y::ys) = (x,y)::liberal_zip xs ys
| liberal_zip _ [] = []
| liberal_zip [] _ = []
fun get_first_subterm P t =
if P (Term.head_of t)
then SOME t
else
case t of
(t1 $ t2) => (case get_first_subterm P t1 of
NONE => get_first_subterm P t2
| some => some)
| Abs (_, _ , t) => get_first_subterm P t
| _ => NONE
val split_tupled_all_simproc =
Simplifier.make_simproc @{context}{name = "split_tupled_all_simproc", kind = Simproc, identifier = [],
lhss = [Proof_Context.read_term_pattern @{context} "⋀r. PROP ?P r"],
proc = fn _ => fn ctxt => fn ct =>
let
fun get_tuple_arity (Const (@{const_name "Pure.all"},_) $ (Abs (_, T, _))) =
length (HOLogic.flatten_tupleT T)
| get_tuple_arity _ = 0;
val t = Thm.term_of ct
val n = get_tuple_arity t
in
if n >= 2
then
let
fun is_case_prod (Const (@{const_name case_prod}, _)) = true
| is_case_prod _ = false
fun guess_names t = case get_first_subterm is_case_prod t of
SOME t' => map fst (strip_case_prod t') |> filter_out (fn n => n = strip_uu)
| _ => []
val thm = get_split_tupled_all_thm ctxt n
|> Drule.rename_bvars (liberal_zip (map mk_el_name (1 upto n)) (guess_names t))
in SOME thm end
else NONE
end
}
val tuple_case_simproc =
Simplifier.make_simproc @{context} {name = "tuple_simproc", kind = Simproc, identifier = [],
lhss = [Proof_Context.read_term_pattern @{context} "case_prod ?X (?x,?y)"],
proc = fn _ => fn ctxt => fn ct =>
let
fun get_tuple_arity
(t as (Const (@{const_name Product_Type.prod.case_prod}, _) $ _ $
(p as (Const (@{const_name Pair}, _) $ _ $ _)))) =
Int.min (num_case_prod t + 1, length (HOLogic.strip_tuple p))
| get_tuple_arity _ = 0;
val n = get_tuple_arity (Thm.term_of ct)
in
if n >= 2
then SOME (get_tuple_case_eq_thm ctxt n)
else NONE
end
}
local
fun strip_all (Const (@{const_name "Pure.all"}, _ ) $ Abs (x,T,t)) =
let
val (vTs, bdy) = strip_all t
in ((x,T)::vTs, bdy) end
| strip_all t = ([], t)
fun dest_pair_bound (Const (@{const_name "Product_Type.Pair"}, _) $ Bound i $ p) =
i::dest_pair_bound p
| dest_pair_bound (Bound i) = [i]
fun lookup env i = nth env i
fun mk_name env (Bound i) = fst (lookup env i)
| mk_name env (Var ((x,_),_)) = x
| mk_name env (Free (x,_)) = x
| mk_name env _ = Name.uu_;
fun mk_var env (Bound i) = SOME (lookup env i)
| mk_var _ _ = NONE
fun is_Pair (Const (@{const_name "Product_Type.Pair"}, _) $ _ $ _) = true
| is_Pair _ = false;
fun dest_tuple env (Const (@{const_name "Product_Type.Pair"}, _) $ t $ p)
= mk_name env t :: (if is_Pair p then dest_tuple env p else [mk_name env p])
fun dest_tuple' env (Const (@{const_name "Product_Type.Pair"}, _) $ t $ p)
= mk_var env t :: (if is_Pair p then dest_tuple' env p else [mk_var env p])
fun mk_case_prod' T seed [(n1,T1), (n2, T2)] =
(Const (@{const_name "case_prod"}, (T1 --> T2 --> T) --> HOLogic.mk_prodT (T1, T2) --> T)$
Abs (n1, T1, Abs (n2, T2, seed)), HOLogic.mk_prodT (T1, T2))
| mk_case_prod' T seed ((n1,T1)::bs) =
let
val (t',T2) = mk_case_prod' T seed bs
in (Const (@{const_name "case_prod"}, (T1 --> T2 --> T) --> HOLogic.mk_prodT (T1, T2) --> T)$
Abs (n1, T1, t'), HOLogic.mk_prodT (T1, T2))
end
fun get_tuple_var env (Var v $
(p as (Const (@{const_name "Product_Type.Pair"}, _)$ _$ _))) =
[(env, v, dest_tuple env p)]
| get_tuple_var env (t1 $ t2) =
(case get_tuple_var env t1 of
[] => get_tuple_var env t2
| xs => xs)
| get_tuple_var env (Abs (x, T, t)) =
get_tuple_var ((x,T)::env) t
| get_tuple_var env _ = [];
fun dest_tuple'' env t = case try (dest_tuple' env) t of
SOME xs => xs
| NONE => [NONE]
fun flatten_upto_first_tuple xxs =
let
val (upto_one, more) = chop_prefix (fn xs => length xs <= 1) xxs
in
if null more then ([],[])
else (flat (upto_one), hd more)
end
fun get_tuple_vars env (t as (_ $ _)) =
(case strip_comb t of
(Var v, ts) =>
let
val (args, tuple_args) = map (dest_tuple'' env) ts |> flatten_upto_first_tuple
val rest = map (get_tuple_vars env) ts |> flat
in
if null tuple_args then rest
else (env, v, (args, tuple_args)) :: rest
end
| (t, ts) => map (get_tuple_vars env) (t::ts) |> flat)
| get_tuple_vars env (Abs (x, T, t)) =
get_tuple_vars ((x,T)::env) t
| get_tuple_vars _ _ = [];
fun get_distinct_tuple_vars env t =
t |> get_tuple_vars env |> distinct (fn ((_,v1,_),(_,v2,_)) => v1 = v2)
in
val list_abs = fold_rev Term.abs
fun map_filter2 _ [] [] = []
| map_filter2 f (x :: xs) (y :: ys) =
let
val vs = map_filter2 f xs ys
in case f x y of SOME v => v :: vs | _ => vs end
| map_filter2 _ _ _ = raise ListPair.UnequalLengths;
fun gen_calc_inst bound get_tuple_vars ctxt max_idx t = \<^try>‹
get_tuple_vars [] t
|> map (fn (_,((x,i), pT), (args, bs)) =>
let
val (all_argTs, finalT) = strip_type pT
val (argTs, ([tupleT], rest_argTs)) = all_argTs |> chop (length args) ||> chop 1
fun mk_vars bound vs Ts = (vs ~~ Ts)
|> map (fn (SOME (x, _), T) => (true, (x, T))
| (NONE, T) => if bound then (false, (Name.uu_, T)) else (true, (Name.uu_, T)))
fun filter_tagged tags xs = map_filter2 (fn b => fn x => if b then SOME x else NONE) tags xs
val tagged_args = mk_vars false args argTs
val args = map snd tagged_args
val tagged_tuple_args = mk_vars bound bs (HOLogic.flatten_tupleT tupleT)
val tuple_args = map snd tagged_tuple_args
val tagged_vars = tagged_args @ tagged_tuple_args
val tags = map fst tagged_vars
val vars = map snd tagged_vars
val bnds = length vars - 1 downto 0 |> filter_tagged tags
val Ts = map snd vars |> filter_tagged tags
val rT = rest_argTs ---> finalT
val newVar = Var ((x,max_idx+1), Ts ---> rT)
val seed = Term.list_comb (newVar, map Bound bnds)
val case_prod = fst (mk_case_prod' rT seed tuple_args)
val inst = list_abs args case_prod
in ((x,i), Thm.cterm_of ctxt inst) end
) catch _ => []›
fun calc_first_inst bound = gen_calc_inst bound (fn env => take 1 o get_tuple_vars env)
val calc_inst = calc_first_inst true
val calc_inst' = calc_first_inst false
val calc_insts = gen_calc_inst false get_distinct_tuple_vars
end
fun cond_trace ctxt s =
if Config.get ctxt Simplifier.simp_trace then tracing s else ()
fun tuple_inst_tac ctxt i = SUBGOAL (fn (trm, _) => fn st =>
let
val max_idx = Thm.maxidx_of_cterm (Thm.cprop_of st)
in
case calc_inst ctxt max_idx trm of
[] => (cond_trace ctxt "tuple_inst_tac: no instantiation found"; no_tac st)
| inst => let
val _ = cond_trace ctxt ("tuple_inst_tac: " ^ @{make_string} inst);
in PRIMITIVE (Drule.infer_instantiate ctxt inst) st end
end) i;
fun tuple_insts ctxt thm =
let
val max_idx = Thm.maxidx_of_cterm (Thm.cprop_of thm)
in
case calc_insts ctxt max_idx (Thm.prop_of thm) of
[] => thm
| inst => let
val _ = cond_trace ctxt ("tuple_inst_tac: " ^ @{make_string} inst);
in Drule.infer_instantiate ctxt inst thm end
end;
fun is_case_prod_app_Pair t =
case Term.strip_comb t of
(Const (@{const_name "case_prod"}, _), _ :: maybe_pair :: _) => (case Term.strip_comb maybe_pair of
(Const (@{const_name "Pair"}, _), _) => true | _ => false)
| _ => false
fun mksimps ctxt thm =
let
val thm' = (if exists_subterm is_case_prod_app_Pair (Thm.prop_of thm)
then
let
val thm' = Simplifier.simplify (put_simpset HOL_ss ctxt addsimps @{thms Product_Type.prod.case}) thm
val _ = cond_trace ctxt ("Tuple_Tools.mksimps: " ^ @{make_string} thm')
in thm' end
else thm)
in (Simpdata.mksimps Simpdata.mksimps_pairs) ctxt thm'
end
val SPLIT_simproc =
Simplifier.make_simproc @{context} {name = "SPLIT_simproc", kind = Simproc, identifier = [],
lhss = [Proof_Context.read_term_pattern @{context} "PROP SPLIT ?P",
Proof_Context.read_term_pattern @{context} "SPLIT ?P"],
proc = fn _ => fn ctxt => fn ct =>
let
fun get_tuple_arity (Const (@{const_name "SPLIT"}, _)$
(Const (@{const_name "Pure.all"},_) $ (Abs (_, T, _)))) =
length (HOLogic.flatten_tupleT T)
| get_tuple_arity _ = 0;
val t = Thm.term_of ct
val n = get_tuple_arity t
in
if n >= 2
then
let
fun is_case_prod (Const (@{const_name case_prod}, _)) = true
| is_case_prod _ = false
fun guess_names t = case get_first_subterm is_case_prod t of
SOME t' => map fst (strip_case_prod t') |> filter_out (fn n => n = strip_uu)
| _ => []
val split_thm = get_split_tupled_all_thm ctxt n
|> Drule.rename_bvars (liberal_zip (map mk_el_name (1 upto n)) (guess_names t))
val case_eq_thm = get_tuple_case_eq_thm ctxt n
val conv = Conv.rewr_conv @{thm SPLIT_def}
then_conv Conv.rewr_conv split_thm
then_conv (Conv.bottom_conv (K (Conv.try_conv (Conv.rewr_conv case_eq_thm))) ctxt)
in SOME (conv ct) end
else SOME @{thm SPLIT_def}
end
}
fun asserts [] x = x
| asserts ((cond, msg)::xs) x = if cond x then asserts xs x else error (msg x)
structure Intlisttab = Table(type key = int list val ord = list_ord int_ord);
fun comb_product [] = [[]]
| comb_product (xs:: yss) =
let
val prods = comb_product yss
val prodss = map (fn x => map (fn ps => x::ps) prods) xs
in
flat prodss
end
val _ = @{assert} (comb_product [[1,2],[3,4],[5,6]] =
[[1, 3, 5], [1, 3, 6], [1, 4, 5], [1, 4, 6], [2, 3, 5], [2, 3, 6], [2, 4, 5], [2, 4, 6]])
fun split_rule_simproc ctxt name pattern rule =
let
val pat = Proof_Context.read_term_pattern ctxt pattern
fun dest_eq thm =
let
val @{term_pat "Trueprop (?lhs = ?rhs)"} = thm |> Thm.concl_of
in (lhs, rhs) end
val lhs = rule |> dest_eq |> fst
val (tenv, term_env) = Pattern.match (Proof_Context.theory_of ctxt) (pat, lhs) (Vartab.empty, Vartab.empty)
fun get_vars env (Abs (x, T, t)) = get_vars (env @ [(x,T)]) t
| get_vars env t =
(case Term.strip_comb t of
(Var x, args) => (env, x, args) :: flat (map (get_vars env) args)
| (_, args) => flat (map (get_vars env) args))
fun msg s x = "split_rule_simproc: " ^ s
fun is_dummy ((n, _)) = (n = "_dummy_")
val rule_subterms = Vartab.dest term_env |> filter_out (is_dummy o fst) |> map (snd o snd)
val rule_vars = flat (map (get_vars []) rule_subterms)
|> map (asserts [
(fn (env, var, args) => not (null env), fn _ => "matched subterm for splitting has to be abstraction, use dummy pattern '_' to skip subterms" ),
(fn (env, var, args) => not (null args), fn (env, var, args) =>
("Variable '" ^ Term.string_of_vname' (fst var) ^
"' has to be applied to bound variable, about to be splitted")),
(fn (env, var, args) => hd args = Bound (length env - 1), fn (env, var, args) =>
("first argument of variable '" ^ Term.string_of_vname' (fst var) ^
"'expected to be bound variable '" ^ (fst (hd env)) ^ "'"))
])
|> map (#2)
fun split_rule arities =
let
val (names, ctxt') = Variable.add_fixes (map (fst o fst) rule_vars) ctxt
fun mk_inst ((var, T), (n, name)) i = ((var, Thm.cterm_of ctxt (mk_case_prod_fun (name, range_type T) i n)), i+n)
val insts = fold_map mk_inst (rule_vars ~~ (arities ~~ names)) 1 |> fst
val [rule'] = Drule.infer_instantiate ctxt' insts rule
|> Simplifier.simplify (put_simpset HOL_basic_ss ctxt' addsimps @{thms case_prod_out})
|> mk_meta_eq
|> single
|> Proof_Context.export ctxt' ctxt
in
rule'
end
val (split_rule, handler) = Fun_Cache.create_handler (Binding.map_name (fn name => name ^ "-cache") name)
(fn tab => @{make_string} (Intlisttab.dest tab))
Intlisttab.empty Intlisttab.lookup Intlisttab.update split_rule
fun split_rule_names arity_names =
let
val arities = map fst arity_names
val names = map snd arity_names
val rule = split_rule arities
fun sum ns = fold (fn n => fn m => n + m) ns 0
val bound_names = map (mk_el_name) (1 upto (sum arities))
val renamings = (bound_names ~~ flat (names)) |> filter_out (fn (_, n) => n = strip_uu)
in
Drule.rename_bvars renamings rule
end
fun trace_cache_info () =
tracing (@{make_string} handler)
val _ = map split_rule (comb_product (map (K (1 upto 3)) rule_vars))
val _ = trace_cache_info ()
in
Simplifier.make_simproc ctxt {name = Binding.name_of name, kind = Simproc, identifier = [],
lhss = [pat],
proc = fn _ => fn ctxt => fn ct =>
let
val t = Thm.term_of ct
val (tenv, term_env) = Pattern.match (Proof_Context.theory_of ctxt) (pat, t) (Vartab.empty, Vartab.empty)
val subterms = Vartab.dest term_env |> filter_out (is_dummy o fst) |> map (snd o snd)
val arity_names = map ((fn vars => (length vars, map fst vars)) o strip_case_prod) subterms
val _ = asserts [(fn _ => length arity_names = length rule_vars, fn _ => "split_rule_simproc: unexpected number of matched subterms")]
in
if exists (fn (n,_) => n > 1) arity_names
then
let
val splitted_rule = split_rule_names arity_names
in SOME splitted_rule end
else NONE
end
}
end
fun tuple_inst_simp_tac ctxt i =
tuple_inst_tac ctxt i THEN
simp_tac (put_simpset HOL_ss ctxt addsimps @{thms Product_Type.prod.case}) i
fun tuple_refl_tac ctxt i =
resolve_tac ctxt @{thms refl} i
ORELSE (
tuple_inst_simp_tac ctxt i THEN
resolve_tac ctxt @{thms refl} i
)
fun split_rule_simprocs ctxt = map (fn (name, rule, pattern) => split_rule_simproc ctxt name pattern rule)
fun gen_split_rule ctxt0 name_arities thm =
let
val ctxt = Variable.set_body false ctxt0
val case_prod_eta_contract_thm = @{thm case_prod_eta_contract}
val case_prod_conv = Conv.bottom_conv (
K (Conv.try_conv (Conv.rewr_conv case_prod_eta_contract_thm)))
val name_arities = name_arities |> sort_by fst
val names = map fst name_arities
val arities = map snd name_arities
fun eta_contract ctxt thm = Conv.fconv_rule (case_prod_conv ctxt) thm
val vars = Term.add_vars (Thm.prop_of thm) [] |>
filter (fn ((n,_) ,_) => member (op =) names n) |> distinct (op =) |> sort_by (fst o fst)
val names' = map (fst o fst) vars
val insts = ((names' ~~ arities) ~~ vars)
|> map (fn ((f, n), (var,T)) =>
(var, mk_tuple_or_case_prod f T n))
val new_names = [] |> fold Term.add_frees (map snd insts) |> map fst
val (new_names', ctxt') = Variable.add_fixes new_names ctxt
val _ = @{assert} (new_names = new_names')
val (_, ctxt'') = ctxt' |> Variable.import_terms false (map snd insts)
|> apsnd (Context_Position.set_visible false)
val insts = map (apsnd (Thm.cterm_of ctxt'')) insts
val split_arities = arities |> filter (fn n => n > 1)
val splits = map (get_split_tupled_all_thm ctxt) split_arities
val case_eqs = map (get_tuple_case_eq_thm ctxt) split_arities
in
thm
|> Drule.infer_instantiate ctxt'' insts
|> Simplifier.asm_full_simplify (put_simpset HOL_basic_ss ctxt'' addsimps
(@{thms case_prod_out} @ case_eqs @ splits))
|> eta_contract ctxt''
|> single
|> Proof_Context.export ctxt'' ctxt
|> hd
|> Drule.zero_var_indexes
end
fun split_rule ctxt names thm n =
if n <= 1 then thm
else gen_split_rule ctxt (map (fn name => (name, n)) names) thm
fun split_rule_bin ctxt names thm bin =
split_rule ctxt names thm (Bin.int_of_bin bin)
val case_prod_conv = mk_meta_eq @{thm case_prod_conv}
fun print_conv msg (ct:cterm) =
let
val _ = tracing (msg ^ ": " ^ @{make_string} ct);
in Conv.all_conv ct end
fun bounded_top_conv i conv ctxt ct =
if i <= 0 then Conv.all_conv ct
else
(conv ctxt then_conv Conv.sub_conv (bounded_top_conv (i - 1) conv) ctxt) ct;
fun bounded_bottom_conv i conv ctxt ct =
if i <= 0 then Conv.all_conv ct
else
(Conv.sub_conv (bounded_bottom_conv (i - 1) conv) ctxt then_conv conv ctxt) ct;
fun bounded_top_rewrs_conv i rewrs = bounded_top_conv i (K (Conv.try_conv (Conv.rewrs_conv rewrs)));
fun bounded_bottom_rewrs_conv i rewrs = bounded_bottom_conv i (K
(Conv.try_conv (Conv.rewrs_conv rewrs)));
fun eta_expand_tupled_conv ctxt ct =
let
val arity_of = HOLogic.flatten_tupleT #> length
val domT = Thm.typ_of_cterm ct |> domain_type
val arity = domT |> arity_of
val name_types = strip_case_prod' (Thm.term_of ct)
val case_prod_arity = length name_types
val renamings = map (fn (n, T) => if arity_of T = 1 then SOME n else NONE) name_types
val base_conv =
if case_prod_arity > 0 andalso case_prod_arity < arity then
let
val eta_contract_rule = get_eta_tupled_eq_thm ctxt case_prod_arity
|> Thm.symmetric
in fn ct => Conv.rewr_conv eta_contract_rule ct handle CTERM _ => Conv.all_conv ct end
else Conv.all_conv
val conv =
if arity > 1 andalso case_prod_arity < arity then
base_conv then_conv
Conv.rewr_conv (Drule.rename_bvars' renamings (get_eta_tupled_eq_thm ctxt arity)) then_conv
bounded_bottom_rewrs_conv (arity * 2) [case_prod_conv] ctxt
else Conv.all_conv
in
conv ct
end
handle Match => Conv.all_conv ct
fun eta_expand_tuple ctxt = eta_expand_tupled_conv ctxt #> Thm.rhs_of
end
›
simproc_setup ETA_TUPLED (‹ETA_TUPLED f›) = ‹fn _ => fn ctxt => fn ct =>
let
val arity = Thm.typ_of_cterm ct |> domain_type |> HOLogic.flatten_tupleT |> length
in
if arity <= 1 then
SOME @{thm ETA_TUPLED_def}
else
SOME (@{thm ETA_TUPLED_trans} OF [Tuple_Tools.get_eta_tupled_eq_thm ctxt arity])
end›
declare [[simproc del: ETA_TUPLED]]
text ‹Set up the theorem cache.
We could use "dynamic programming" here and rewrite "thm n" with one step of "thm (n - 1)" and @{thm split_def}›
setup ‹
Context.theory_map (Tuple_Tools.Data.map (fn _ => map (fn i =>
(Thm.trim_context (Tuple_Tools.mk_tuple_up_eq_thm @{context} i),
Thm.trim_context (Tuple_Tools.mk_split_tupled_all_thm @{context} i),
Thm.trim_context (Tuple_Tools.mk_tuple_case_eq_thm @{context} i),
Thm.trim_context (Tuple_Tools.mk_eta_tupled_eq_thm @{context} i))) (2 upto 50)))
›
attribute_setup split_tuple = ‹
let
val parse_arity = Args.$$$ "arity" |-- Args.colon |-- Parse.nat
val parse_names = Parse.and_list1 Parse.short_ident
in
Scan.lift (parse_names -- parse_arity) >> (fn (names, arity) =>
let
in
Thm.rule_attribute [] (fn context => fn thm =>
Tuple_Tools.split_rule (Context.proof_of context) names thm arity)
end)
end
›
thm HOL.ext [split_tuple f and g arity: 3]
ML_val ‹
val test = @{pattern "⋀a b c s. f a b c ≡ ?c (a,b,k) s"}
val x = Tuple_Tools.calc_inst' @{context} 2 test
›
text ‹Especially in the case of congruence rules we prefer the more restrictive instantiation, where
the freshly introduced variable does not depend on 'non-bound' tuple components ›
ML_val ‹
val test = @{pattern "⋀a b c s. f a b c ≡ ?c (a,b,k) s"}
val x = Tuple_Tools.calc_inst @{context} 2 test
›
ML_val ‹
val thm = Tuple_Tools.get_tuple_up_eq_thm @{context} 1
›
ML_val ‹
val thm = Tuple_Tools.get_tuple_up_eq_thm @{context} 0
›
ML_val ‹
val thm = Tuple_Tools.get_tuple_up_eq_thm @{context} 3
›
ML_val ‹
val thm = Tuple_Tools.get_split_tupled_all_thm @{context} 30
›
ML_val ‹
val thm = Tuple_Tools.get_split_tupled_all_thm @{context} 2
›
ML_val ‹
val thm = Tuple_Tools.get_tuple_case_eq_thm @{context} 2
›
ML_val ‹
val thm = Tuple_Tools.get_eta_tupled_eq_thm @{context} 2
›
declare [[simp_trace=false]]
text ‹Variable names should be preserved here, since all abstractions stay in place.›
lemma "⋀r. (λ(x,y). (x::nat) < y ∧ y < 200) r"
apply (tactic ‹
simp_tac (put_simpset HOL_basic_ss @{context} addsimprocs [Tuple_Tools.split_tupled_all_simproc, Tuple_Tools.tuple_case_simproc]
delsimps @{thms Product_Type.prod.case Product_Type.case_prod_conv}) 1
›)
oops
text ‹Variable names are not preserved due to eta contraction:
@{lemma "(λ(x,y). (x::nat) < y ) = case_prod (<)" by (rule refl)}›
lemma "⋀r. (λ(x,y). (x::nat) < y ) r"
apply (tactic ‹
simp_tac (put_simpset HOL_basic_ss @{context} addsimprocs [Tuple_Tools.split_tupled_all_simproc, Tuple_Tools.tuple_case_simproc]
delsimps @{thms Product_Type.prod.case Product_Type.case_prod_conv}) 1
›)
oops
lemma "(λ(x,y,p,z,f). (x::nat) < y) (a,b,c) = (case c of (p,z,f) ⇒ a < b)"
apply (tactic ‹
simp_tac (put_simpset HOL_basic_ss @{context} addsimprocs [Tuple_Tools.split_tupled_all_simproc, Tuple_Tools.tuple_case_simproc]
delsimps @{thms Product_Type.prod.case Product_Type.case_prod_conv}) 1
›)
done
lemma "⋀r. (λ(x,y,p,z,f). (p::nat) < f) r"
apply (tactic ‹
simp_tac (put_simpset HOL_basic_ss @{context} addsimprocs [Tuple_Tools.split_tupled_all_simproc, Tuple_Tools.tuple_case_simproc]
delsimps @{thms Product_Type.prod.case Product_Type.case_prod_conv}) 1
›)
oops
lemma "(λ(x,y,p). (x::nat) < y) (a,b,c,d,e) = (a < b)"
apply (tactic ‹
simp_tac (put_simpset HOL_basic_ss @{context} addsimprocs [Tuple_Tools.split_tupled_all_simproc, Tuple_Tools.tuple_case_simproc]
delsimps @{thms Product_Type.prod.case Product_Type.case_prod_conv}) 1
›)
done
ML_val ‹
val test = @{pattern "⋀a b c d e s. f a b c ≡ ?c a b (c,d,e) s"}
val x = Tuple_Tools.calc_inst @{context} 2 test
›
ML_val ‹
val test = @{pattern "⋀a b c s. P x ⟹ f a b c ≡ ?c (a,b,c) s"}
val x = Tuple_Tools.calc_inst @{context} 2 test
›
ML_val ‹
val test = @{pattern "⋀a b c s. P x =simp=> f a b c ≡ ?c (a,b,c) s"}
val x = Tuple_Tools.calc_inst @{context} 2 test
›
ML_val ‹
val test = @{pattern "⋀a b c s. f a b c ≡ ?c (a,b,k) s"}
val x = Tuple_Tools.calc_inst @{context} 2 test
›
ML_val ‹
val test = @{pattern "⋀a b c s. f a b c ≡ ?c (a,b,k) s"}
val x = Tuple_Tools.calc_inst @{context} 2 test
›
ML_val ‹
val test = @{pattern "⋀a b c s. f a b c ≡ ?c (a,x+y,x + y + z) s"}
val x = Tuple_Tools.calc_inst @{context} 2 test
›
ML_val ‹
val test = @{pattern "⋀a b c s. f a b c ≡ ?c (a,x+y,x + y + z) s"}
val x = Tuple_Tools.calc_inst @{context} 2 test
›
ML_val ‹
val test = @{pattern "(⋀a b c s. P x ⟹ f a b c ≡ ?c (a,b,c) s) ⟹ (⋀a b c s. Q x)"}
val x = Tuple_Tools.calc_inst @{context} 2 test
›
lemma "SPLIT (⋀r. ((λz. Q z ∧ (Q z ⟶ P z)) r))
≡ XXX"
apply (tactic ‹
simp_tac ( @{context}
addsimprocs [Tuple_Tools.SPLIT_simproc, Tuple_Tools.tuple_case_simproc]
|> Simplifier.add_cong @{thm SPLIT_cong}
|> Simplifier.add_cong @{thm conj_cong}) 1
›)
oops
lemma "PROP SPLIT (⋀r. ((λ(x,y,z). y < z ∧ z=s) r) ⟹ P r)
≡ (⋀x y z. y < z ∧ z = s ⟹ P (x, y, s))"
apply (tactic ‹
asm_full_simp_tac (put_simpset HOL_basic_ss @{context}
addsimprocs [Tuple_Tools.SPLIT_simproc, Tuple_Tools.tuple_case_simproc] |> Simplifier.add_cong @{thm SPLIT_cong}) 1
›)
done
schematic_goal "ETA_TUPLED (λx::('a ×'b × 'c). f x) = ?XXX"
supply [[simp_trace]]
thm Product_Type.cond_case_prod_eta
supply [[simproc add: ETA_TUPLED]]
apply (simp)
done
context
fixes f::"('a × 'b × 'c × 'd) ⇒ nat"
fixes c::"('a × 'b × 'c × 'd) ⇒ 's ⇒ bool"
begin
ML_val ‹
val t1 = Tuple_Tools.eta_expand_tupled_conv @{context} @{cterm "f"}
val t2 = Tuple_Tools.eta_expand_tupled_conv @{context} @{cterm "λ(a, b). f (a, b)"}
val t3 = Tuple_Tools.eta_expand_tupled_conv @{context} @{cterm "λ(a, b, c). f (a, b, c)"}
val c1 = Tuple_Tools.eta_expand_tupled_conv @{context} @{cterm "c"}
val c2 = Tuple_Tools.eta_expand_tupled_conv @{context} @{cterm "λ(a, b) s. c (a,b) s"}
val d1 = Tuple_Tools.eta_expand_tupled_conv @{context} @{cterm "λ(a, (b::nat × nat)) s. a < 2"}
val d2 = Tuple_Tools.eta_expand_tupled_conv @{context} @{cterm "λ(a, (b::nat × nat × nat)) s. a < 2"}
val d3 = Tuple_Tools.eta_expand_tupled_conv @{context} @{cterm "λ(a, (b::nat × nat × nat × nat)) s. a < 2"}
›
end
end