File ‹stream_fusion.ML›
signature STREAM_FUSION =
sig
val get_rules: Proof.context -> thm list
val get_conspats: Proof.context -> (term * thm) list
val match_consumer: Proof.context -> term -> bool
val add_fusion_rule: thm -> Context.generic -> Context.generic
val del_fusion_rule: thm -> Context.generic -> Context.generic
val add_unstream: string -> Context.generic -> Context.generic
val del_unstream: string -> Context.generic -> Context.generic
val get_unstream: Proof.context -> string list
val fusion_add: attribute
val fusion_del: attribute
val fusion_conv: Proof.context -> conv
val fusion_simproc: Proof.context -> cterm -> thm option
val trace: bool Config.T
end;
structure Stream_Fusion : STREAM_FUSION =
struct
type fusion_rules =
{ rules : thm Item_Net.T,
conspats : (term * thm) Item_Net.T,
unstream : string list
}
fun map_fusion_rules f1 f2 f3
{rules, conspats, unstream}
=
{rules = f1 rules,
conspats = f2 conspats,
unstream = f3 unstream};
fun map_rules f = map_fusion_rules f I I;
fun map_conspats f = map_fusion_rules I f I;
fun map_unstream f = map_fusion_rules I I f;
structure Fusion_Rules = Generic_Data
(
type T = fusion_rules;
val empty = {rules = Thm.item_net,
conspats = Item_Net.init (Thm.eq_thm_prop o apply2 snd) (single o fst),
unstream = []};
fun merge
({rules = r, conspats = cp, unstream = u},
{rules = r', conspats = cp', unstream = u'}) =
{rules = Item_Net.merge (r, r'),
conspats = Item_Net.merge (cp, cp'),
unstream = Library.merge (op =) (u, u')}
);
val get_rules = Item_Net.content o #rules o Fusion_Rules.get o Context.Proof;
val get_conspats = Item_Net.content o #conspats o Fusion_Rules.get o Context.Proof;
val get_unstream = #unstream o Fusion_Rules.get o Context.Proof;
fun match_consumer ctxt t =
Context.Proof ctxt
|> Fusion_Rules.get
|> #conspats
|> (fn net => Item_Net.retrieve_matching net t)
|> not o null
datatype classification = ProducerTransformer | Consumer of term
fun occur_in ts ((Const (c, _)) $ t) =
member (op =) ts c orelse occur_in ts t
| occur_in ts (op $ (u, t)) = occur_in ts u orelse occur_in ts t
| occur_in ts (Abs (_, _, t)) = occur_in ts t
| occur_in _ _ = false;
fun first_depth (t1 $ _) = let val (f,d) = first_depth t1 in (f,d+1) end |
first_depth t1 = (t1,0)
fun mk_conspat rhs ctxt =
let
val (f,d) = first_depth rhs
val types = binder_types (fastype_of f)
val (vfixes, ctxt1) = Variable.variant_fixes (replicate d "x") ctxt
in
(hd o Variable.export_terms ctxt1 ctxt o single) (list_comb (f, map Free (vfixes ~~ types)))
end
fun classify ctxt thm = case Thm.full_prop_of thm
of (@{const Trueprop} $ (Const (@{const_name "HOL.eq"}, _) $ lhs $ rhs)) =>
let val unstream = get_unstream ctxt in
if occur_in unstream lhs then SOME ProducerTransformer
else if occur_in unstream rhs then SOME (Consumer (mk_conspat rhs ctxt))
else NONE
end
| _ => NONE;
fun sym thm = thm RS @{thm sym}
fun format_error ctxt thm =
warning (Pretty.string_of (Pretty.block [
Pretty.str "Wrong format for fusion rule: ",
Pretty.brk 2,
Syntax.pretty_term (Context.proof_of ctxt) (Thm.prop_of thm)]))
fun register thm NONE = (fn ctxt =>
let
val _ = format_error ctxt thm
in
ctxt
end)
| register thm (SOME ProducerTransformer) = Fusion_Rules.map (
map_rules (Item_Net.update (sym thm)))
| register thm (SOME (Consumer cp)) = Fusion_Rules.map (
map_rules (Item_Net.update (sym thm)) o map_conspats (Item_Net.update (cp, thm)));
fun unregister thm NONE = (fn ctxt =>
let
val _ = format_error ctxt thm
in
ctxt
end)
| unregister thm (SOME ProducerTransformer) = Fusion_Rules.map (
map_rules (Item_Net.remove (sym thm)))
| unregister thm (SOME (Consumer cp)) = Fusion_Rules.map (
map_rules (Item_Net.remove (sym thm)) o map_conspats (Item_Net.remove (cp, thm)));
fun add_fusion_rule thm ctxt = register thm (classify (Context.proof_of ctxt) thm) ctxt
fun del_fusion_rule thm ctxt = unregister thm (classify (Context.proof_of ctxt) thm) ctxt
fun add_unstream c = Fusion_Rules.map (map_unstream (insert (op =) c))
fun del_unstream c = Fusion_Rules.map (map_unstream (remove (op =) c))
val fusion_add = Thm.declaration_attribute add_fusion_rule;
val fusion_del = Thm.declaration_attribute del_fusion_rule;
val _ =
Theory.setup
(Attrib.setup @{binding "stream_fusion"} (Attrib.add_del fusion_add fusion_del)
"declaration of a rule for stream fusion" #>
Global_Theory.add_thms_dynamic
(@{binding "stream_fusion"}, Item_Net.content o #rules o Fusion_Rules.get));
val trace = Attrib.setup_config_bool @{binding "stream_fusion_trace"} (K false)
fun tracing ctxt msg = if Config.get ctxt trace then Output.tracing (msg ()) else ()
fun fusion_conv ctxt = Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps get_rules ctxt)
fun fusion_simproc ctxt ct =
let
val matches = match_consumer ctxt (Thm.term_of ct)
in
if matches then
let
val _ = tracing ctxt (fn _ => Pretty.string_of (Pretty.block
[Pretty.str "Trying stream fusion on ",
Pretty.brk 2,
Syntax.pretty_term ctxt (Thm.term_of ct)]))
val thm = fusion_conv ctxt ct
val failed = Thm.is_reflexive thm orelse occur_in (get_unstream ctxt) (Thm.term_of (Thm.rhs_of thm))
val _ = tracing ctxt (fn _ => Pretty.string_of (Pretty.block
[Pretty.str (if failed then "FAILED: " else "SUCCEEDED: "),
Pretty.brk 2,
Syntax.pretty_term ctxt (Thm.prop_of thm)]))
in
if failed then NONE else SOME thm
end
else NONE
end
end;