File ‹Tools/primrec_package.ML›
signature PRIMREC_PACKAGE =
sig
val primrec: ((binding * string) * Token.src list) list -> theory -> thm list * theory
val primrec_i: ((binding * term) * attribute list) list -> theory -> thm list * theory
end;
structure PrimrecPackage : PRIMREC_PACKAGE =
struct
exception RecError of string;
val dest_eqn = FOLogic.dest_eq o \<^dest_judgment>;
fun primrec_err s = error ("Primrec definition error:\n" ^ s);
fun primrec_eq_err sign s eq =
primrec_err (s ^ "\nin equation\n" ^ Syntax.string_of_term_global sign eq);
fun process_eqn thy (eq, rec_fn_opt) =
let
val (lhs, rhs) =
if null (Term.add_vars eq []) then
dest_eqn eq handle TERM _ => raise RecError "not a proper equation"
else raise RecError "illegal schematic variable(s)";
val (recfun, args) = strip_comb lhs;
val (fname, ftype) = dest_Const recfun handle TERM _ =>
raise RecError "function is not declared as constant in theory";
val (ls_frees, rest) = chop_prefix is_Free args;
val (middle, rs_frees) = chop_suffix is_Free rest;
val (constr, cargs_frees) =
if null middle then raise RecError "constructor missing"
else strip_comb (hd middle);
val cname = dest_Const_name constr
handle TERM _ => raise RecError "ill-formed constructor";
val con_info = the (Symtab.lookup (ConstructorsData.get thy) cname)
handle Option.Option =>
raise RecError "cannot determine datatype associated with function"
val (ls, cargs, rs) = (map dest_Free ls_frees,
map dest_Free cargs_frees,
map dest_Free rs_frees)
handle TERM _ => raise RecError "illegal argument in pattern";
val lfrees = ls @ rs @ cargs;
val new_eqn = (cname, (rhs, cargs, eq))
in
if has_duplicates (op =) lfrees then
raise RecError "repeated variable name in pattern"
else if not (subset (op =) (Term.add_frees rhs [], lfrees)) then
raise RecError "extra variables on rhs"
else if length middle > 1 then
raise RecError "more than one non-variable in pattern"
else case rec_fn_opt of
NONE => SOME (fname, ftype, ls, rs, con_info, [new_eqn])
| SOME (fname', _, ls', rs', con_info': constructor_info, eqns) =>
if AList.defined (op =) eqns cname then
raise RecError "constructor already occurred as pattern"
else if (ls <> ls') orelse (rs <> rs') then
raise RecError "non-recursive arguments are inconsistent"
else if #big_rec_name con_info <> #big_rec_name con_info' then
raise RecError ("Mixed datatypes for function " ^ fname)
else if fname <> fname' then
raise RecError ("inconsistent functions for datatype " ^
#big_rec_name con_info)
else SOME (fname, ftype, ls, rs, con_info, new_eqn::eqns)
end
handle RecError s => primrec_eq_err thy s eq;
fun inst_recursor ((_ $ constr, rhs), cargs') =
subst_atomic (#2 (strip_comb constr) ~~ map Free cargs') rhs;
fun process_fun thy (fname, ftype, ls, rs, con_info: constructor_info, eqns) =
let
val fconst = Const(fname, ftype)
val fabs = list_comb (fconst, map Free ls @ [Bound 0] @ map Free rs)
and {big_rec_name, constructors, rec_rewrites, ...} = con_info
fun use_fabs (_ $ t) = subst_bound (t, fabs)
| use_fabs t = t
val cnames = map dest_Const_name constructors
and recursor_pairs = map (dest_eqn o Thm.concl_of) rec_rewrites
fun absterm (Free x, body) = absfree x body
| absterm (t, body) = Abs("rec", \<^Type>‹i›, abstract_over (t, body))
fun add_case ((cname, recursor_pair), cases) =
let val (rhs, recursor_rhs, eq) =
case AList.lookup (op =) eqns cname of
NONE => (warning ("no equation for constructor " ^ cname ^
"\nin definition of function " ^ fname);
(\<^Const>‹zero›, #2 recursor_pair, \<^Const>‹zero›))
| SOME (rhs, cargs', eq) =>
(rhs, inst_recursor (recursor_pair, cargs'), eq)
val allowed_terms = map use_fabs (#2 (strip_comb recursor_rhs))
val abs = List.foldr absterm rhs allowed_terms
in
if !Ind_Syntax.trace then
writeln ("recursor_rhs = " ^
Syntax.string_of_term_global thy recursor_rhs ^
"\nabs = " ^ Syntax.string_of_term_global thy abs)
else();
if Logic.occs (fconst, abs) then
primrec_eq_err thy
("illegal recursive occurrences of " ^ fname)
eq
else abs :: cases
end
val recursor = head_of (#1 (hd recursor_pairs))
val rec_arg =
Free (singleton (Name.variant_list (map #1 (ls@rs))) (Long_Name.base_name big_rec_name),
\<^Type>‹i›)
val def_tm = Logic.mk_equals
(subst_bound (rec_arg, fabs),
list_comb (recursor,
List.foldr add_case [] (cnames ~~ recursor_pairs))
$ rec_arg)
in
if !Ind_Syntax.trace then
writeln ("primrec def:\n" ^
Syntax.string_of_term_global thy def_tm)
else();
(Long_Name.base_name fname ^ "_" ^ Long_Name.base_name big_rec_name ^ "_def",
def_tm)
end;
fun primrec_i args thy =
let
val ((eqn_names, eqn_terms), eqn_atts) = apfst split_list (split_list args);
val SOME (fname, ftype, ls, rs, con_info, eqns) =
List.foldr (process_eqn thy) NONE eqn_terms;
val def = process_fun thy (fname, ftype, ls, rs, con_info, eqns);
val (def_thm, thy1) = thy
|> Sign.add_path (Long_Name.base_name fname)
|> Global_Theory.add_def (apfst Binding.name def);
val rewrites = def_thm :: map (mk_meta_eq o Thm.transfer thy1) (#rec_rewrites con_info)
val eqn_thms0 =
eqn_terms |> map (fn t =>
Goal.prove_global thy1 [] [] (Ind_Syntax.traceIt "next primrec equation = " thy1 t)
(fn {context = ctxt, ...} =>
EVERY [rewrite_goals_tac ctxt rewrites, resolve_tac ctxt @{thms refl} 1]));
in
thy1
|> Global_Theory.add_thms ((eqn_names ~~ eqn_thms0) ~~ eqn_atts)
|-> (fn eqn_thms =>
Global_Theory.add_thmss [((Binding.name "simps", eqn_thms), [Simplifier.simp_add])])
|>> the_single
||> Sign.parent_path
end;
fun primrec args thy =
primrec_i (map (fn ((name, s), srcs) =>
((name, Syntax.read_prop_global thy s), map (Attrib.attribute_cmd_global thy) srcs))
args) thy;
val _ =
Outer_Syntax.command \<^command_keyword>‹primrec› "define primitive recursive functions on datatypes"
(Scan.repeat1 (Parse_Spec.opt_thm_name ":" -- Parse.prop)
>> (Toplevel.theory o (#2 oo (primrec o map (fn ((x, y), z) => ((x, z), y))))));
end;