File ‹Transform_Term.ML›
signature TRANSFORM_TERM =
sig
val repeat_sweep_conv: (Proof.context -> conv) -> Proof.context -> conv
val rewrite_pureapp_beta_conv: conv
val wrap_head: Transform_Const.MONAD_CONSTS -> term -> int -> term
val lift_equation: Transform_Const.MONAD_CONSTS -> Proof.context ->
term * term -> term option -> (Proof.context -> conv) * term * int
end
structure Transform_Term : TRANSFORM_TERM =
struct
fun list_conv (head_conv, arg_convs) ctxt =
Library.foldl (uncurry Conv.combination_conv)
(head_conv ctxt, map (fn conv => conv ctxt) arg_convs)
fun eta_conv1 ctxt =
(Conv.abs_conv (K Conv.all_conv) ctxt)
else_conv
(Thm.eta_long_conversion then_conv Conv.abs_conv (K Thm.eta_conversion) ctxt)
fun eta_conv_n n =
funpow n (fn conv => fn ctxt => eta_conv1 ctxt then_conv Conv.abs_conv (conv o #2) ctxt) (K Conv.all_conv)
fun conv_changed conv ctm =
let val eq = conv ctm
in if Thm.is_reflexive eq then Conv.no_conv ctm else eq end
fun repeat_sweep_conv conv =
Conv.repeat_conv o conv_changed o Conv.top_sweep_conv conv
val app_mark_conv = Conv.rewr_conv @{thm App_def[symmetric]}
val app_unmark_conv = Conv.rewr_conv @{thm App_def}
val wrap_mark_conv = Conv.rewr_conv @{thm Wrap_def[symmetric]}
fun eta_expand tm =
let
val n_args = Integer.min 1 (length (binder_types (fastype_of tm)))
val (args, body) = Term.strip_abs_eta n_args tm
in
Library.foldr (uncurry Term.absfree) (args, body)
end
fun is_ctr_sugar ctxt tp_name =
is_some (Ctr_Sugar.ctr_sugar_of ctxt tp_name)
fun type_nargs tp = tp |> strip_type |> fst |> length
fun term_nargs tm = type_nargs (fastype_of tm)
fun lift_type (monad_consts: Transform_Const.MONAD_CONSTS) ctxt tp =
#mk_stateT monad_consts (lift_type' monad_consts ctxt tp)
and lift_type' monad_consts ctxt (tp as Type (@{type_name fun}, _)) =
lift_type' monad_consts ctxt (domain_type tp) --> lift_type monad_consts ctxt (range_type tp)
| lift_type' monad_consts ctxt (tp as Type (tp_name, tp_args)) =
if is_ctr_sugar ctxt tp_name then Type (tp_name, map (lift_type' monad_consts ctxt) tp_args)
else if null tp_args then tp
else raise TYPE("not a ctr_sugar", [tp], [])
| lift_type' _ _ tp = tp
fun is_atom_type monad_consts ctxt tp =
tp = lift_type' monad_consts ctxt tp
fun is_1st_type monad_consts ctxt tp =
body_type tp :: binder_types tp
|> forall (is_atom_type monad_consts ctxt)
fun orig_atom ctxt atom_name =
Proof_Context.read_term_pattern ctxt atom_name
fun is_1st_term monad_consts ctxt tm =
is_1st_type monad_consts ctxt (fastype_of tm)
fun is_1st_atom monad_consts ctxt atom_name =
is_1st_term monad_consts ctxt (orig_atom ctxt atom_name)
fun wrap_1st_term monad_consts tm n_args_opt inner_wrap =
let
val n_args = the_default (term_nargs tm) n_args_opt
val (vars_name_typ, body) = Term.strip_abs_eta n_args tm
fun wrap (name_typ, (conv, tm)) =
(fn ctxt => eta_conv1 ctxt then_conv Conv.abs_conv (conv o #2) ctxt then_conv wrap_mark_conv,
#return monad_consts (Term.absfree name_typ tm))
val (conv, result) = Library.foldr wrap (vars_name_typ, (
if inner_wrap then (K wrap_mark_conv, #return monad_consts body) else (K Conv.all_conv, body)
))
in
(conv, result)
end
fun lift_1st_atom monad_consts ctxt atom (name, tp) =
let
val (arg_typs, body_typ) = strip_type tp
val n_args = term_nargs (orig_atom ctxt name)
val (arg_typs, body_arg_typs) = chop n_args arg_typs
val arg_typs' = map (lift_type' monad_consts ctxt) arg_typs
val body_typ' = lift_type' monad_consts ctxt (body_arg_typs ---> body_typ)
val tm' = atom (name, arg_typs' ---> body_typ')
in
wrap_1st_term monad_consts tm' (SOME n_args) true
end
fun fixed_args head_n_args tm =
let
val (tm_head, tm_args) = strip_comb tm
val n_tm_args = length tm_args
in
head_n_args tm_head
|> Option.mapPartial (fn n_args =>
if n_tm_args > n_args then NONE
else if n_tm_args < n_args then raise TERM("need " ^ string_of_int n_args ^ " args", [tm])
else SOME (tm_head, tm_args))
end
fun lift_abs' monad_consts ctxt (name, typ) cont lift_dict body =
let
val free = Free (name, typ)
val typ' = lift_type' monad_consts ctxt typ
val freeT' = Free (name, typ')
val freeT = #return monad_consts (freeT')
val lift_dict' = if is_atom_type monad_consts ctxt typ
then lift_dict
else (free, (K wrap_mark_conv, freeT))::lift_dict
val (conv_free, body_free) = (cont (lift_dict') body)
val body' = lambda freeT' body_free
fun conv ctxt' = eta_conv1 ctxt' then_conv Conv.abs_conv (conv_free o #2) ctxt'
in (conv, body') end
fun lift_arg monad_consts ctxt lift_dict tm =
lift_term monad_consts ctxt lift_dict tm
and lift_term monad_consts ctxt lift_dict tm = let
val case_terms = Ctr_Sugar.ctr_sugars_of ctxt |> map #casex
fun lookup_case_term tm =
find_first (fn x => Term.aconv_untyped (x, tm)) case_terms
val check_cont = lift_term monad_consts ctxt
val check_cont_arg = lift_arg monad_consts ctxt
fun check_const tm =
case tm of
Const (_, typ) => (
case Transform_Data.get_dp_info (#monad_name monad_consts) ctxt tm of
SOME {new_headT=Const (name, _), ...} => SOME (K Conv.all_conv, Const (name, lift_type monad_consts ctxt typ))
| SOME {new_headT=tm', ...} => raise TERM("not a constant", [tm'])
| NONE => NONE)
| _ => NONE
fun check_1st_atom tm =
case tm of
Const (name, typ) =>
if is_1st_atom monad_consts ctxt name then SOME (lift_1st_atom monad_consts ctxt Const (name, typ)) else NONE
| Free (name, typ) =>
if is_1st_atom monad_consts ctxt name then SOME (lift_1st_atom monad_consts ctxt Free (name, typ)) else NONE
| _ => NONE
fun check_dict tm =
AList.lookup Term.aconv_untyped lift_dict tm
fun check_if tm =
fixed_args (fn head => if Term.aconv_untyped (head, @{term If}) then SOME 3 else NONE) tm
|> Option.map (fn (_, args) =>
let
val (arg_convs, args') = map (check_cont lift_dict) args |> split_list
val conv = list_conv (K Conv.all_conv, arg_convs)
val tm' = list_comb (Const (#if_termN monad_consts, dummyT), args')
in
(conv, tm')
end)
fun check_abs tm =
case tm of
Abs _ =>
let
val ((name', typ), body') = Term.dest_abs_global tm
val (conv, tm') = lift_abs' monad_consts ctxt (name', typ) check_cont lift_dict body'
fun convT ctxt' = conv ctxt' then_conv wrap_mark_conv
val tmT = #return monad_consts tm'
in SOME (convT, tmT) end
| _ => NONE
fun check_case tm =
fixed_args (lookup_case_term #> Option.map (fn cs => term_nargs cs - 1)) tm
|> Option.map (fn (head, args) =>
let
val (case_name, case_type) = lookup_case_term head |> the |> dest_Const
val ((clause_typs, _), _) =
strip_type case_type |>> split_last
val clase_nparams = clause_typs |> map type_nargs
fun lift_clause n_param clause =
let
val (vars_name_typ, body) = Term.strip_abs_eta n_param clause
val abs_lift_wraps = map (lift_abs' monad_consts ctxt) vars_name_typ
val lift_wrap = Library.foldr (op o) (abs_lift_wraps, I) check_cont
val (conv, clauseT) = lift_wrap lift_dict body
in
(conv, clauseT)
end
val head' = Const (case_name, dummyT)
val (convs, clauses') = map2 lift_clause clase_nparams args |> split_list
fun conv ctxt' = list_conv (K Conv.all_conv, convs) ctxt' then_conv wrap_mark_conv
val tm' = #return monad_consts (list_comb (head', clauses'))
in
(conv, tm')
end)
fun check_app tm =
case tm of
f $ x =>
let
val (f_conv, tmf) = check_cont lift_dict f
val (x_conv, tmx) = check_cont_arg lift_dict x
val tm' = #app monad_consts (tmf, tmx)
fun conv ctxt' = Conv.combination_conv (f_conv ctxt' then_conv app_mark_conv) (x_conv ctxt')
in
SOME (conv, tm')
end
| _ => NONE
fun check_pure tm =
if tm |> exists_subterm (AList.defined (op aconv) lift_dict)
orelse not (is_1st_term monad_consts ctxt tm)
then NONE
else SOME (wrap_1st_term monad_consts tm NONE true)
fun choke tm =
raise TERM("cannot process term", [tm])
val checks = [
check_pure,
check_const,
check_case,
check_if,
check_abs,
check_app,
check_dict,
check_1st_atom,
choke
]
in get_first (fn check => check tm) checks |> the end
fun rewrite_pureapp_beta_conv ctm =
case Thm.term_of ctm of
Const (@{const_name Pure_Monad.App}, _)
$ (Const (@{const_name Pure_Monad.Wrap}, _) $ Abs _)
$ (Const (@{const_name Pure_Monad.Wrap}, _) $ _)
=> Conv.rewr_conv @{thm Wrap_App_Wrap} ctm
| _ => Conv.no_conv ctm
fun monadify monad_consts ctxt tm =
let
val (_, tm) = lift_term monad_consts ctxt [] tm
val tm = Syntax.check_term ctxt tm
in
tm
end
fun wrap_head (monad_consts: Transform_Const.MONAD_CONSTS) head n_args =
Library.foldr
(fn (typ, tm) => #return monad_consts (absdummy typ tm))
(replicate n_args dummyT, list_comb (head, rev (map_range Bound n_args)))
fun lift_head monad_consts ctxt head n_args =
let
val dest_head = if is_Const head then dest_Const else dest_Free
val (head_name, head_typ) = dest_head head
val (arg_typs, body_typ) = strip_type head_typ
val (arg_typs0, arg_typs1) = chop n_args arg_typs
val arg_typs0' = map (lift_type' monad_consts ctxt) arg_typs0
val arg_typs1T = lift_type monad_consts ctxt (arg_typs1 ---> body_typ)
val head_typ' = arg_typs0' ---> arg_typs1T
val head' = Free (head_name, head_typ')
val (head_conv, headT) = wrap_1st_term monad_consts head' (SOME n_args) false
in
(head', (head_conv, headT))
end
fun lift_equation monad_consts ctxt (lhs, rhs) memoizer_opt =
let
val (head, args) = strip_comb lhs
val n_args = length args
val (head', (head_conv, headT)) = lift_head monad_consts ctxt head n_args
val args' = args |> map (map_aterms (fn tm => tm |> map_types
(if is_Const tm then K dummyT else lift_type' monad_consts ctxt)))
val lhs' = list_comb (head', args')
val frees = fold Term.add_frees args []
|> filter_out (is_atom_type monad_consts ctxt o snd)
val lift_dict_args = frees |> map (fn (name, typ) => (
Free (name, typ),
(K wrap_mark_conv,
#return monad_consts (Free (name, lift_type' monad_consts ctxt typ)))
))
val lift_dict = (head, (head_conv, headT)) :: lift_dict_args
val (rhs_conv, rhsT) = lift_term monad_consts ctxt lift_dict rhs
val rhsT_memoized = case memoizer_opt of
SOME memoizer =>
memoizer
$ HOLogic.mk_tuple args
$ rhsT
| NONE => rhsT
val eqs' = (lhs', rhsT_memoized) |> HOLogic.mk_eq |> HOLogic.mk_Trueprop
in
(rhs_conv, eqs', n_args)
end
end