File ‹partial_function_mr.ML›
signature PARTIAL_FUNCTION_MR =
sig
val init: string ->
(term * term list * typ * typ * typ list -> term) ->
(typ list * typ list -> typ) ->
(typ -> typ list * typ list) ->
thm list ->
thm list -> Morphism.declaration
val add_partial_function_mr: string -> (binding * typ option * mixfix) list ->
Specification.multi_specs -> local_theory -> thm list * local_theory
val add_partial_function_mr_cmd: string -> (binding * string option * mixfix) list ->
Specification.multi_specs_cmd -> local_theory -> thm list * local_theory
end;
structure Partial_Function_MR: PARTIAL_FUNCTION_MR =
struct
val partial_function_mr_trace =
Attrib.setup_config_bool @{binding partial_function_mr_trace} (K false);
fun trace ctxt msg =
if Config.get ctxt partial_function_mr_trace
then tracing msg
else ()
datatype setup_data = Setup_Data of
{mk_monad_map: term * term list * typ * typ * typ list -> term,
mk_monadT: typ list * typ list -> typ,
dest_monadT: typ -> typ list * typ list,
monad_map_comp: thm list,
monad_map_id: thm list};
structure Modes = Generic_Data
(
type T = setup_data Symtab.table;
val empty = Symtab.empty;
fun merge data = Symtab.merge (K true) data;
)
val known_modes = Symtab.keys o Modes.get o Context.Proof;
val lookup_mode = Symtab.lookup o Modes.get o Context.Proof;
fun curry_const (A, B, C) =
Const (@{const_name Product_Type.curry},
[HOLogic.mk_prodT (A, B) --> C, A, B] ---> C);
fun mk_curry f =
case fastype_of f of
Type ("fun", [Type (_, [S, T]), U]) =>
curry_const (S, T, U) $ f
| T => raise TYPE ("mk_curry", [T], [f]);
fun curry_n arity = funpow (arity - 1) mk_curry;
fun uncurry_n arity = funpow (arity - 1) HOLogic.mk_case_prod;
fun init mode mk_monad_map mk_monadT dest_monadT monad_map_comp monad_map_id phi =
let
val thm = Morphism.thm phi;
val data' = Setup_Data
{mk_monad_map=mk_monad_map, mk_monadT=mk_monadT, dest_monadT=dest_monadT,
monad_map_comp=map thm monad_map_comp,monad_map_id=map thm monad_map_id};
in
Modes.map (Symtab.update (mode, data'))
end
fun mk_sumT (T1,T2) = Type (@{type_name sum}, [T1,T2])
fun mk_choiceT [ty] = ty
| mk_choiceT (ty :: more) = mk_sumT (ty,mk_choiceT more)
| mk_choiceT _ = error "mk_choiceT []"
fun mk_choice_resT mk_monadT dest_monadT mTs = let
val (commonTs,argTs) = map dest_monadT mTs |> split_list |> apfst hd;
val n = length (hd argTs);
val new = map (fn i => mk_choiceT (map (fn xs => nth xs i) argTs)) (0 upto (n - 1))
in mk_monadT (commonTs,new) end;
fun mk_inj [_] t _ = t
| mk_inj (ty :: more) t n = let
val moreT = mk_choiceT more;
val allT = mk_sumT (ty,moreT)
in
if n = 0 then Const (@{const_name Inl}, ty --> allT) $ t
else Const (@{const_name Inr}, moreT --> allT) $ mk_inj more t (n-1)
end
| mk_inj _ _ _ = error "mk_inj [] _ _"
fun mk_proj [_] t _ = t
| mk_proj (ty :: more) t n = let
val moreT = mk_choiceT more;
val allT = mk_sumT (ty,moreT)
in
if n = 0 then Const (@{const_name Sum_Type.projl}, allT --> ty) $ t
else mk_proj more (Const (@{const_name Sum_Type.projr}, allT --> moreT) $ t) (n-1)
end
| mk_proj _ _ _ = error "mk_proj [] _ _"
fun get_head ctxt (_,(_,eqn)) =
let
val ((_, plain_eqn), _) = Variable.focus NONE eqn ctxt;
val lhs = HOLogic.dest_eq (HOLogic.dest_Trueprop plain_eqn) |> #1;
val head = strip_comb lhs |> #1;
in
head
end;
fun get_infos lthy heads (fix,(_,eqn)) = let
val ((_, plain_eqn), _) = Variable.focus NONE eqn lthy;
val ((f_binding, fT), mixfix) = fix;
val fname = Binding.name_of f_binding;
val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop plain_eqn);
val (_, args) = strip_comb lhs;
val F = fold_rev lambda (heads @ args) rhs;
val arity = length args;
val (aTs, bTs) = chop arity (binder_types fT);
val tupleT = foldl1 HOLogic.mk_prodT aTs;
val fT_uc = tupleT :: bTs ---> body_type fT;
val (inT,resT) = dest_funT fT_uc;
val f_uc = Free (fname, fT_uc);
val f_cuc = curry_n arity f_uc
in
(fname, f_cuc, f_uc, inT, resT, ((f_binding,mixfix),fT), F, arity, args)
end;
fun fresh_var ctxt name = Name.variant name (Variable.names_of ctxt) |> #1
fun gen_add_partial_function_mr prep mode fixes_raw eqns_raw lthy =
let
val setup_data = the (lookup_mode lthy mode)
handle Option.Option => error (cat_lines ["Unknown mode " ^ quote mode ^ ".",
"Known modes are " ^ commas_quote (known_modes lthy) ^ "."]);
val Setup_Data {mk_monad_map, mk_monadT, dest_monadT, monad_map_comp, monad_map_id} = setup_data;
val _ = if length eqns_raw < 2 then error "require at least two function definitions" else ();
val ((fixes, eq_abinding_eqns), _) = prep fixes_raw eqns_raw lthy;
val _ = if length eqns_raw = length fixes then () else error "# of eqns does not match # of constants";
val fix_eq_abinding_eqns = fixes ~~ eq_abinding_eqns;
val heads = map (get_head lthy) fix_eq_abinding_eqns;
val fnames = map (Binding.name_of o #1 o #1) fixes
val fnames' = map (#1 o Term.dest_Free) heads
val f_f = fnames ~~ fnames'
val _ = case find_first (fn (f,g) => not (f = g)) f_f of NONE => () | SOME _ =>
error ("list of function symbols does not match list of equations:\n"
^ @{make_string} fnames ^ "\nvs\n" ^ @{make_string} fnames')
val all = map (get_infos lthy heads) fix_eq_abinding_eqns
val f_cucs = map #2 all
val f_ucs = map #3 all
val inTs = map #4 all
val resTs = map #5 all
val bindings_types = map #6 all
val Fs = map #7 all
val arities = map #8 all
val all_args = map #9 all
val glob_inT = mk_choiceT inTs
val glob_resT = mk_choice_resT mk_monadT dest_monadT resTs
val inj = mk_inj inTs
val glob_fname = fresh_var lthy (foldl1 (fn (a,b) => a ^ "_" ^ b) (fnames @ [serial_string ()]))
val glob_constT = glob_inT --> glob_resT;
val glob_const = Free (glob_fname, glob_constT)
val nums = 0 upto (length all - 1)
fun mk_res_inj_proj n = let
val resT = nth resTs n
val glob_Targs = dest_monadT glob_resT |> #2
val res_Targs = dest_monadT resT |> #2
val m = length res_Targs
fun inj_proj m = let
val resTs_m = map (fn resT => nth (dest_monadT resT |> #2) m) resTs
val resT_arg = nth resTs_m n
val globT_arg = nth glob_Targs m
val x = Free ("x",resT_arg)
val y = Free ("x",globT_arg)
val inj = lambda x (mk_inj resTs_m x n)
val proj = lambda y (mk_proj resTs_m y n)
in ((inj, resT_arg --> globT_arg), (proj, globT_arg --> resT_arg))
end;
val (inj,proj) = map inj_proj (0 upto (m - 1)) |> split_list
val (t_to_ss_inj,t_to_sTs_inj) = split_list inj;
val (t_to_ss_proj,t_to_sTs_proj) = split_list proj;
in (fn mt => mk_monad_map (mt, t_to_ss_inj, resT, glob_resT, t_to_sTs_inj),
fn mt => mk_monad_map (mt, t_to_ss_proj, glob_resT, resT, t_to_sTs_proj))
end;
val (res_inj, res_proj) = map mk_res_inj_proj nums |> split_list
fun mk_global_fun n = let
val fname = nth fnames n
val inT = nth inTs n
val xs = Free (fresh_var lthy ("x_" ^ fname), inT)
val inj_xs = inj xs n
val glob_inj_xs = glob_const $ inj_xs
val glob_inj_xs_map = nth res_proj n glob_inj_xs
val res = lambda xs glob_inj_xs_map
in
(xs,res)
end
val (xss,global_funs) = map mk_global_fun nums |> split_list
fun mk_cases n = let
val xs = nth xss n
val F = nth Fs n;
val arity = nth arities n;
val F_uc =
fold_rev lambda f_ucs (uncurry_n arity (list_comb (F, f_cucs)));
val F_uc_inst = Term.betapplys (F_uc,global_funs)
val res = lambda xs (nth res_inj n (F_uc_inst $ xs))
in res end;
val all_cases = map mk_cases nums;
fun combine_cases [cs] [_] = cs
| combine_cases (cs :: more) (inT :: moreTy) =
let
val moreT = mk_choiceT moreTy
val sumT = mk_sumT (inT, moreT)
val case_const = Const
(@{const_name case_sum},
(inT --> glob_resT) --> (moreT --> glob_resT) --> sumT --> glob_resT)
in case_const $ cs $ combine_cases more moreTy end
| combine_cases _ _ = error "combine_cases with incompatible argument lists";
val glob_x_name = fresh_var lthy ("x_" ^ glob_fname)
val glob_x = Free (glob_x_name,glob_inT)
val rhs = combine_cases all_cases inTs $ glob_x;
val lhs = glob_const $ glob_x
val eq = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs,rhs))
val glob_binding = Binding.name (glob_fname) |> Binding.concealed
val glob_attrib_binding = Binding.empty_atts
val _ = trace lthy "invoking partial_function on global function"
val priv_lthy = lthy
|> Proof_Context.private_scope (Binding.new_scope())
val ((glob_const, glob_simp_thm),priv_lthy') = priv_lthy
|> Partial_Function.add_partial_function mode
[(glob_binding,SOME glob_constT,NoSyn)] (glob_attrib_binding,eq)
val glob_lthy = priv_lthy'
|> Proof_Context.restore_naming lthy
val _ = trace lthy "deriving simp rules for separate functions from global function"
fun define_f n (fs, fdefs,rhss,lthy) = let
val ((fbinding,mixfix),_) = nth bindings_types n
val fname = nth fnames n
val inT = nth inTs n;
val arity = nth arities n;
val x = Free (fresh_var lthy ("x_" ^ fname), inT)
val inj_argsProd = inj x n
val call = glob_const $ inj_argsProd
val post = nth res_proj n call
val rhs = curry_n arity (lambda x post)
val ((f, (_, f_def)),lthy') =
Local_Theory.define_internal ((fbinding,mixfix), (Binding.empty_atts, rhs)) lthy
in
(f :: fs, f_def :: fdefs,rhs :: rhss,lthy')
end
val (fs,fdefs,f_rhss,local_lthy) = fold_rev define_f nums ([],[],[],glob_lthy)
val glob_simp_thm' =
let
fun mk_case_new n =
let
val F = nth Fs n
val arity = nth arities n
val Finst = uncurry_n arity (Term.betapplys (F,fs))
val xs = nth xss n
val res = lambda xs (nth res_inj n (Finst $ xs))
in
res
end;
val new_cases = map mk_case_new nums;
val rhs = combine_cases new_cases inTs $ glob_x;
val lhs = glob_const $ glob_x
val eq = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs,rhs))
in
Goal.prove local_lthy [glob_x_name] [] eq (fn {prems = _, context = ctxt} =>
Thm.instantiate' [] [SOME (Thm.cterm_of ctxt glob_x)] glob_simp_thm
|> (fn simp_thm => unfold_tac ctxt [simp_thm] THEN unfold_tac ctxt fdefs))
end
fun mk_simp_thm n =
let
val args = nth all_args n
val arg_names = map (dest_Free #> fst) args
val f = nth fs n
val F = nth Fs n
val fdef = nth fdefs n
val lhs = list_comb (f,args);
val mhs = Term.betapplys (nth f_rhss n, args)
val rhs = list_comb (list_comb (F,fs), args);
val eq1 = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs,mhs))
val eq2 = HOLogic.mk_Trueprop (HOLogic.mk_eq (mhs,rhs))
val simp_thm1 = Goal.prove local_lthy arg_names [] eq1
(fn {prems = _, context = ctxt} => unfold_tac ctxt [fdef])
val simp_thm2 = Goal.prove local_lthy arg_names [] eq2 (fn {prems = _, context = ctxt} =>
unfold_tac ctxt [glob_simp_thm']
THEN unfold_tac ctxt @{thms sum.simps curry_def split}
THEN unfold_tac ctxt (@{thm o_def} :: monad_map_comp)
THEN unfold_tac ctxt (monad_map_id @ @{thms sum.sel}))
in
@{thm trans} OF [simp_thm1,simp_thm2]
end
val simp_thms = map mk_simp_thm nums
fun register n lthy =
let
val simp_thm = nth simp_thms n
val eq_abinding = nth eq_abinding_eqns n |> fst
val fname = nth fnames n
val f = nth fs n
in
lthy
|> Local_Theory.note (eq_abinding, [simp_thm])
|-> (fn (_, simps) =>
Spec_Rules.add Binding.empty Spec_Rules.equational_recdef [f] simps
#> Local_Theory.note
((Binding.qualify true fname (Binding.name "simps"),
@{attributes [code]}), simps) #>> snd #>> hd)
end
in
fold (fn i => fn (simps, lthy) => case register i lthy of
(simp, lthy') => (simps @ [simp], lthy')) nums ([], local_lthy)
end;
val add_partial_function_mr = gen_add_partial_function_mr Specification.check_multi_specs;
val add_partial_function_mr_cmd = gen_add_partial_function_mr Specification.read_multi_specs;
val mode = @{keyword "("} |-- Parse.name --| @{keyword ")"};
val _ =
Outer_Syntax.local_theory @{command_keyword partial_function_mr}
"define mutually recursive partial functions"
(mode -- Parse_Spec.specification
>> (fn (mode, (fixes, specs)) => add_partial_function_mr_cmd mode fixes specs #> #2));
end