File ‹applicative.ML›
signature APPLICATIVE =
sig
  type afun
  val intern: Context.generic -> xstring -> string
  val extern: Context.generic -> string -> xstring
  val afun_of_generic: Context.generic -> string -> afun
  val afun_of: Proof.context -> string -> afun
  val afuns_of_term_generic: Context.generic -> term -> afun list
  val afuns_of_term: Proof.context -> term -> afun list
  val afuns_of_typ_generic: Context.generic -> typ -> afun list
  val afuns_of_typ: Proof.context -> typ -> afun list
  val name_of_afun: afun -> binding
  val unfolds_of_afun: afun -> thm list
  type afun_inst
  val match_afun_inst: Proof.context -> afun -> term * int -> afun_inst
  val import_afun_inst: afun -> Proof.context -> afun_inst * Proof.context
  val inner_sort_of: afun_inst -> sort
  val mk_type: afun_inst -> typ -> typ
  val mk_pure: afun_inst -> typ -> term
  val lift_term: afun_inst -> term -> term
  val mk_ap: afun_inst -> typ * typ -> term
  val mk_comb: afun_inst -> typ -> term * term -> term
  val mk_set: afun_inst -> typ -> term
  val dest_type: Proof.context -> afun_inst -> typ -> typ option
  val dest_type': Proof.context -> afun_inst -> typ -> typ
  val dest_pure: Proof.context -> afun_inst -> term -> term
  val dest_comb: Proof.context -> afun_inst -> term -> term * term
  val infer_comb: Proof.context -> afun_inst -> term * term -> term
  val subst_lift_term: afun_inst -> (term * term) list -> term -> term
  val generalize_lift_terms: afun_inst -> term list -> Proof.context -> term list * Proof.context
  val afun_unfold_tac: Proof.context -> afun -> int -> tactic
  val afun_fold_tac: Proof.context -> afun -> int -> tactic
  val unfold_all_tac: Proof.context -> int -> tactic
  val normalform_conv: Proof.context -> afun -> conv
  val normalize_rel_tac: Proof.context -> afun -> int -> tactic
  val general_normalform_conv: Proof.context -> afun -> cterm * cterm -> thm * thm
  val general_normalize_rel_tac: Proof.context -> afun -> int -> tactic
  val forward_lift_rule: Proof.context -> afun -> thm -> thm
  val unfold_wrapper_tac: Proof.context -> afun option -> int -> tactic
  val fold_wrapper_tac: Proof.context -> afun option -> int -> tactic
  val normalize_wrapper_tac: Proof.context -> afun option -> int -> tactic
  val lifting_wrapper_tac: Proof.context -> afun option -> int -> tactic
  val setup_combinators: (string * thm) list -> local_theory -> local_theory
  val combinator_rule_attrib: string list option -> attribute
  val parse_opt_afun: afun option context_parser
  val applicative_cmd: (((((binding * string list) * string) * string) * string option) * string option) ->
    local_theory -> Proof.state
  val print_afuns: Proof.context -> unit
  val add_unfold_attrib: xstring option -> attribute
  val forward_lift_attrib: xstring -> attribute
end;
structure Applicative : APPLICATIVE =
struct
open Ctr_Sugar_Util
fun fold_options xs = fold (fn x =>
  (case x of
    SOME x' => cons x'
  | NONE => I)) xs [];
fun the_pair [x, y] = (x, y)
  | the_pair _ = raise General.Size;
fun strip_comb2 (f $ x $ y) = (f, (x, y))
  | strip_comb2 t = raise TERM ("strip_comb2", [t]);
fun mk_comb_pattern (t, n) =
  let
    val Ts = take n (binder_types (fastype_of t));
    val maxidx = maxidx_of_term t;
    val vars = map (fn (T, i) => ((Name.uu, maxidx + i), T)) (Ts ~~ (1 upto n));
  in (vars, Term.betapplys (t, map Var vars)) end;
fun match_comb_pattern ctxt tn u =
  let
    val thy = Proof_Context.theory_of ctxt;
    val (vars, pat) = mk_comb_pattern tn;
    val envs = Pattern.match thy (pat, u) (Vartab.empty, Vartab.empty)
      handle Pattern.MATCH => raise TERM ("match_comb_pattern", [u, pat]);
  in (vars, envs) end;
fun dest_comb_pattern ctxt tn u =
  let val (vars, (_, env)) = match_comb_pattern ctxt tn u;
  in map (the o Envir.lookup1 env) vars end;
val norm_term_types = Term.map_types o Envir.norm_type_same;
val mk_TFrees_of = mk_TFrees' oo replicate;
fun mk_Free name typ ctxt = yield_singleton Variable.variant_fixes name ctxt
  |>> (fn name' => Free (name', typ));
fun mk_tuple' ts = fold_rev (curry HOLogic.mk_prod) ts HOLogic.unit;
fun strip_tuple' (Const (@{const_name Unity}, _)) = []
  | strip_tuple' (Const (@{const_name Pair}, _) $ t1 $ t2) = t1 :: strip_tuple' t2
  | strip_tuple' t = raise TERM ("strip_tuple'", [t]);
fun mk_eq_on S =
  let val (SA, ST) = `HOLogic.dest_setT (fastype_of S);
  in Const (@{const_name eq_on}, ST --> BNF_Util.mk_pred2T SA SA) $ S end;
type poly_type = typ list * typ;
type poly_term = typ list * term;
fun instantiate_poly_type (tvars, T) insts = typ_subst_atomic (tvars ~~ insts) T;
fun instantiate_poly_term (tvars, t) insts = subst_atomic_types (tvars ~~ insts) t;
fun dest_poly_type ctxt (tvars, T) U =
  let
    val thy = Proof_Context.theory_of ctxt;
    val tyenv = Sign.typ_match thy (T, U) Vartab.empty
      handle Type.TYPE_MATCH => raise TYPE ("dest_poly_type", [U, T], []);
  in map (Type.lookup tyenv o dest_TVar) tvars end;
fun poly_type_to_term (tvars, T) = (tvars, Logic.mk_type T);
fun poly_type_of_term (tvars, t) = (tvars, Logic.dest_type t);
fun pack_poly_term (tvars, t) = HOLogic.mk_prod (mk_tuple' (map Logic.mk_type tvars), t);
fun unpack_poly_term t =
  let val (tvars, t') = HOLogic.dest_prod t;
  in (map Logic.dest_type (strip_tuple' tvars), t') end;
val pack_poly_terms = mk_tuple' o map pack_poly_term;
val unpack_poly_terms = map unpack_poly_term o strip_tuple';
fun match_poly_terms_type ctxt (pt, i) (U, maxidx) =
  let
    val thy = Proof_Context.theory_of ctxt;
    val pt' = Logic.incr_indexes ([], maxidx + 1) pt;
    val (tvars, T) = poly_type_of_term (nth (unpack_poly_terms pt') i);
    val tyenv = Sign.typ_match thy (T, U) Vartab.empty
      handle Type.TYPE_MATCH => raise TYPE ("match_poly_terms", [U, T], []);
    val tyenv' = fold Vartab.delete_safe (map (#1 o dest_TVar) tvars) tyenv;
    val pt'' = Envir.subst_term_types tyenv' pt';
  in unpack_poly_terms pt'' end;
fun match_poly_terms ctxt (pt, i) (t, maxidx) =
  match_poly_terms_type ctxt (pt, i) (fastype_of t, maxidx);
fun import_poly_terms pt ctxt =
  let
    fun insert_paramTs (tvars, t) = fold_types (fold_atyps
      (fn TVar v => if member (op =) tvars (TVar v) then I else insert (op =) v
        | _ => I)) t;
    val paramTs = rev (fold insert_paramTs (unpack_poly_terms pt) []);
    val (tfrees, ctxt') = Variable.invent_types (map #2 paramTs) ctxt;
    val instT = TVars.make (paramTs ~~ map TFree tfrees);
    val params = map (apsnd (Term_Subst.instantiateT instT)) (rev (Term.add_vars pt []));
    val (frees, ctxt'') = Variable.variant_fixes (map (Name.clean o #1 o #1) params) ctxt';
    val inst = Vars.make (params ~~ map Free (frees ~~ map #2 params));
    val pt' = Term_Subst.instantiate (instT, inst) pt;
  in (unpack_poly_terms pt', ctxt'') end;
type rel_thms = {
  pure_transfer: thm,
  ap_rel_fun: thm
};
fun map_rel_thms f {pure_transfer, ap_rel_fun} =
  {pure_transfer = f pure_transfer, ap_rel_fun = f ap_rel_fun};
type afun_thms = {
  hom: thm,
  ichng: thm,
  reds: thm Symtab.table,
  rel_thms: rel_thms option,
  rel_intros: thm list,
  pure_comp_conv: thm
};
fun map_afun_thms f {hom, ichng, reds, rel_thms, rel_intros, pure_comp_conv} =
  {hom = f hom, ichng = f ichng, reds = Symtab.map (K f) reds,
    rel_thms = Option.map (map_rel_thms f) rel_thms, rel_intros = map f rel_intros,
    pure_comp_conv = f pure_comp_conv};
datatype afun = AFun of {
  name: binding,
  terms: term,
  rel: term option,
  thms: afun_thms,
  unfolds: thm Item_Net.T
};
fun rep_afun (AFun af) = af;
val name_of_afun = #name o rep_afun;
val terms_of_afun = #terms o rep_afun;
val rel_of_afun = #rel o rep_afun;
val thms_of_afun = #thms o rep_afun;
val unfolds_of_afun = Item_Net.content o #unfolds o rep_afun;
val red_of_afun = Symtab.lookup o #reds o thms_of_afun;
val has_red_afun = is_some oo red_of_afun;
fun mk_afun name terms rel thms =
  AFun {name = name, terms = terms, rel = rel, thms = thms, unfolds = Thm.item_net};
fun map_afun f1 f2 f3 f4 f5 (AFun {name, terms, rel, thms, unfolds}) =
  AFun {name = f1 name, terms = f2 terms, rel = f3 rel, thms = f4 thms, unfolds = f5 unfolds};
fun map_unfolds f thms = fold Item_Net.update (map f (Item_Net.content thms)) Thm.item_net;
fun morph_afun phi =
  let
    val binding = Morphism.binding phi;
    val term = Morphism.term phi;
    val thm = Morphism.thm phi;
  in map_afun binding term (Option.map term) (map_afun_thms thm) (map_unfolds thm) end;
val transfer_afun = morph_afun o Morphism.transfer_morphism;
fun add_unfolds_afun thms = map_afun I I I I (fold Item_Net.update thms);
fun patterns_of_afun af =
  let
    val [Tt, (_, pure), (_, ap), _] = unpack_poly_terms (terms_of_afun af);
    val (_, T) = poly_type_of_term Tt;
  in [#2 (mk_comb_pattern (pure, 1)), #2 (mk_comb_pattern (ap, 2)), Net.encode_type T] end;
datatype combinator_rule = Combinator_Rule of {
  strong_premises: string Ord_List.T,
  weak_premises: bool,
  conclusion: string,
  eq_thm: thm
};
fun rep_combinator_rule (Combinator_Rule rule) = rule;
val conclusion_of_rule = #conclusion o rep_combinator_rule;
val thm_of_rule = #eq_thm o rep_combinator_rule;
fun eq_combinator_rule (rule1, rule2) =
  pointer_eq (rule1, rule2) orelse Thm.eq_thm (thm_of_rule rule1, thm_of_rule rule2);
fun is_applicable_rule rule have_weak have_premises =
  let val {strong_premises, weak_premises, ...} = rep_combinator_rule rule;
  in (have_weak orelse not weak_premises) andalso have_premises strong_premises end;
fun map_combinator_rule f1 f2 f3 f4
  (Combinator_Rule {strong_premises, weak_premises, conclusion, eq_thm}) =
    Combinator_Rule {strong_premises = f1 strong_premises, weak_premises = f2 weak_premises,
      conclusion = f3 conclusion, eq_thm = f4 eq_thm};
fun transfer_combinator_rule thy = map_combinator_rule I I I (Thm.transfer thy);
fun mk_combinator_rule comb_names weak_premises thm =
  let
    val (lhs, rhs) = Logic.dest_equals (Thm.prop_of thm);
    val conclusion = the (Symtab.lookup comb_names (#1 (dest_Const lhs)));
    val premises = Ord_List.make fast_string_ord
      (fold_options (map (Symtab.lookup comb_names o #1) (Term.add_consts rhs [])));
    val weak_premises' = Ord_List.make fast_string_ord (these weak_premises);
    val strong_premises = Ord_List.subtract fast_string_ord weak_premises' premises;
  in Combinator_Rule {strong_premises = strong_premises, weak_premises = is_some weak_premises,
    conclusion = conclusion, eq_thm = thm} end;
fun merge_afuns _ (af1, af2) = if pointer_eq (af1, af2)
  then raise Change_Table.SAME
  else map_afun I I I I (fn thms1 => Item_Net.merge (thms1, #unfolds (rep_afun af2))) af1;
structure Data = Generic_Data
(
  type T = {
    combinators: thm Symtab.table * combinator_rule list,
    afuns: afun Name_Space.table,
    patterns: (string * term list) Item_Net.T
  };
  val empty = {
    combinators = (Symtab.empty, []),
    afuns = Name_Space.empty_table "applicative functor",
    patterns = Item_Net.init (op = o apply2 #1) #2
  };
  fun merge ({combinators = (cd1, cr1), afuns = a1, patterns = p1},
             {combinators = (cd2, cr2), afuns = a2, patterns = p2}) =
    {combinators = (Symtab.merge (K true) (cd1, cd2), Library.merge eq_combinator_rule (cr1, cr2)),
      afuns = Name_Space.join_tables merge_afuns (a1, a2),
      patterns = Item_Net.merge (p1, p2)};
);
fun get_combinators context =
  let
    val thy = Context.theory_of context;
    val {combinators = (defs, rules), ...} = Data.get context;
  in (Symtab.map (K (Thm.transfer thy)) defs, map (transfer_combinator_rule thy) rules) end;
val get_afun_table = #afuns o Data.get;
val get_afun_space = Name_Space.space_of_table o get_afun_table;
val get_patterns = #patterns o Data.get;
fun map_data f1 f2 f3 {combinators, afuns, patterns} =
  {combinators = f1 combinators, afuns = f2 afuns, patterns = f3 patterns};
val intern = Name_Space.intern o get_afun_space;
fun extern context = Name_Space.extern (Context.proof_of context) (get_afun_space context);
local fun undeclared name = error ("Undeclared applicative functor " ^ quote name);
in
fun afun_of_generic context name = case Name_Space.lookup (get_afun_table context) name of
    SOME af => transfer_afun (Context.theory_of context) af
  | NONE => undeclared name;
val afun_of = afun_of_generic o Context.Proof;
fun update_afun name f context = if Name_Space.defined (get_afun_table context) name
  then Data.map (map_data I (Name_Space.map_table_entry name f) I) context
  else undeclared name;
end;
fun match_term context = map #1 o Item_Net.retrieve_matching (get_patterns context);
fun match_typ context = match_term context o Net.encode_type;
fun afuns_of_term_generic context = map (afun_of_generic context) o match_term context;
val afuns_of_term = afuns_of_term_generic o Context.Proof;
fun afuns_of_typ_generic context = map (afun_of_generic context) o match_typ context;
val afuns_of_typ = afuns_of_typ_generic o Context.Proof;
fun all_unfolds_of_generic context =
  let val unfolds_of = map (Thm.transfer'' context) o unfolds_of_afun;
  in Name_Space.fold_table (fn (_, af) => append (unfolds_of af)) (get_afun_table context) [] end;
val all_unfolds_of = all_unfolds_of_generic o Context.Proof;
type afun_inst = {
  T: poly_type,
  pure: poly_term,
  ap: poly_term,
  set: poly_term
};
fun mk_afun_inst [T, pure, ap, set] = {T = poly_type_of_term T, pure = pure, ap = ap, set = set};
fun pack_afun_inst {T, pure, ap, set} = pack_poly_terms [poly_type_to_term T, pure, ap, set];
fun match_afun_inst ctxt af = match_poly_terms ctxt (terms_of_afun af, 0) #> mk_afun_inst;
fun import_afun_inst_raw terms = import_poly_terms terms #>> mk_afun_inst;
val import_afun_inst = import_afun_inst_raw o terms_of_afun;
fun inner_sort_of {T = (tvars, _), ...} = Type.sort_of_atyp (the_single tvars);
fun mk_type {T, ...} = instantiate_poly_type T o single;
fun mk_pure {pure, ...} = instantiate_poly_term pure o single;
fun mk_ap {ap, ...} (T1, T2) = instantiate_poly_term ap [T1, T2];
fun mk_set {set, ...} = instantiate_poly_term set o single;
fun lift_term af_inst t = Term.betapply (mk_pure af_inst (Term.fastype_of t), t);
fun mk_comb af_inst funT (t1, t2) = Term.betapplys (mk_ap af_inst (dest_funT funT), [t1, t2]);
fun dest_type ctxt {T, ...} = the_single o dest_poly_type ctxt T;
val dest_type' = the_default HOLogic.unitT ooo dest_type;
fun dest_pure ctxt {pure = (_, pure), ...} = the_single o dest_comb_pattern ctxt (pure, 1);
fun dest_comb ctxt {ap = (_, ap), ...} = the_pair o dest_comb_pattern ctxt (ap, 2);
fun infer_comb ctxt af_inst (t1, t2) =
  let val funT = the_default (dummyT --> dummyT) (dest_type ctxt af_inst (fastype_of t1));
  in mk_comb af_inst funT (t1, t2) end;
fun subst_lift_term af_inst subst tm =
  let
    fun subst_lift (s $ t) =
        (case (subst_lift s, subst_lift t) of
          (NONE, NONE) => NONE
        | (SOME s', NONE) => SOME (mk_comb af_inst (fastype_of s) (s', lift_term af_inst t))
        | (NONE, SOME t') => SOME (mk_comb af_inst (fastype_of s) (lift_term af_inst s, t'))
        | (SOME s', SOME t') => SOME (mk_comb af_inst (fastype_of s) (s', t')))
      | subst_lift t = AList.lookup (op aconv) subst t;
  in
    (case subst_lift tm of
      NONE => lift_term af_inst tm
    | SOME tm' => tm')
  end;
fun add_lifted_vars (s $ t) = add_lifted_vars s #> add_lifted_vars t
  | add_lifted_vars (Abs (_, _, t)) = Term.add_vars t
  | add_lifted_vars _ = I;
fun generalize_lift_terms af_inst ts ctxt =
  let
    val vars = subtract (op =) (fold add_lifted_vars ts []) (fold Term.add_vars ts []);
    val (var_names, Ts) = split_list vars;
    val (free_names, ctxt') = Variable.variant_fixes (map #1 var_names) ctxt;
    val Ts' = map (mk_type af_inst) Ts;
    val subst = map Var vars ~~ map Free (free_names ~~ Ts');
  in (map (subst_lift_term af_inst subst) ts, ctxt') end;
val clean_name = perhaps (perhaps_apply [try Name.dest_skolem, try Name.dest_internal]);
fun term_to_vname (Const (x, _)) = Long_Name.base_name x
  | term_to_vname (Free (x, _)) = clean_name x
  | term_to_vname (Var ((x, _), _)) = clean_name x
  | term_to_vname _ = "x";
fun afuns_of_rel precise ctxt t =
  let val (_, (lhs, rhs)) = Variable.focus NONE t ctxt
    |> #1 |> #2
    |> Logic.strip_imp_concl
    |> Envir.beta_eta_contract
    |> HOLogic.dest_Trueprop
    |> strip_comb2;
  in if precise
    then (case afuns_of_term ctxt lhs of
        [] => afuns_of_term ctxt rhs
      | afs => afs)
    else afuns_of_typ ctxt (fastype_of lhs) end;
fun AUTO_AFUNS precise tac ctxt opt_af = case opt_af of
    SOME af => tac [af]
  | NONE => SUBGOAL (fn (goal, i) => (case afuns_of_rel precise ctxt goal of
      [] => no_tac
    | afs => tac afs i) handle TERM _ => no_tac);
fun AUTO_AFUN precise tac = AUTO_AFUNS precise (tac o hd);
fun binop_par_conv cv ct =
  let
    val ((binop, arg1), arg2) = Thm.dest_comb ct |>> Thm.dest_comb;
    val (th1, th2) = cv (arg1, arg2);
  in Drule.binop_cong_rule binop th1 th2 end;
fun binop_par_conv_tac cv = CONVERSION (HOLogic.Trueprop_conv (binop_par_conv cv));
val fold_goal_tac = SELECT_GOAL oo Simplifier.fold_goals_tac;
fun afun_unfold_tac ctxt af = Simplifier.rewrite_goal_tac ctxt (unfolds_of_afun af);
fun afun_fold_tac ctxt af = fold_goal_tac ctxt (unfolds_of_afun af);
fun unfold_all_tac ctxt = Simplifier.rewrite_goal_tac ctxt (all_unfolds_of ctxt);
fun pure_conv ctxt {pure = (_, pure), ...} cv ct =
  let
    val ([var], (tyenv, env)) = match_comb_pattern ctxt (pure, 1) (Thm.term_of ct);
    val arg = the (Envir.lookup1 env var);
    val thm = cv (Thm.cterm_of ctxt arg);
  in
    if Thm.is_reflexive thm then Conv.all_conv ct
    else
      let val pure_inst = Envir.subst_term_types tyenv pure;
      in Drule.arg_cong_rule (Thm.cterm_of ctxt pure_inst) thm end
  end;
fun ap_conv ctxt {ap = (_, ap), ...} cv1 cv2 ct =
  let
    val ([var1, var2], (tyenv, env)) = match_comb_pattern ctxt (ap, 2) (Thm.term_of ct);
    val (arg1, arg2) = apply2 (the o Envir.lookup1 env) (var1, var2);
    val thm1 = cv1 (Thm.cterm_of ctxt arg1);
    val thm2 = cv2 (Thm.cterm_of ctxt arg2);
  in
    if Thm.is_reflexive thm1 andalso Thm.is_reflexive thm2 then Conv.all_conv ct
    else
      let val ap_inst = Envir.subst_term_types tyenv ap;
      in Drule.binop_cong_rule (Thm.cterm_of ctxt ap_inst) thm1 thm2 end
  end;
fun normalform_conv ctxt af ct =
  let
    val {hom, ichng, pure_comp_conv, ...} = thms_of_afun af;
    val the_red = the o red_of_afun af;
    val leaf_conv = Conv.rewr_conv (mk_meta_eq (the_red "I") |> Thm.symmetric);
    val merge_conv = Conv.rewr_conv (mk_meta_eq hom);
    val swap_conv = Conv.rewr_conv (mk_meta_eq ichng);
    val rotate_conv = Conv.rewr_conv (mk_meta_eq (the_red "B") |> Thm.symmetric);
    val pure_rotate_conv = Conv.rewr_conv (mk_meta_eq pure_comp_conv);
    val af_inst = match_afun_inst ctxt af (Thm.term_of ct, Thm.maxidx_of_cterm ct);
    fun left_conv cv = ap_conv ctxt af_inst cv Conv.all_conv;
    fun norm_pure_nf ct =
      ((pure_rotate_conv then_conv left_conv norm_pure_nf) else_conv merge_conv) ct;
    val norm_nf_pure = swap_conv then_conv norm_pure_nf;
    fun norm_nf_nf ct = ((rotate_conv then_conv
        left_conv (left_conv norm_pure_nf then_conv norm_nf_nf)) else_conv
      norm_nf_pure) ct;
    fun normalize ct = ((ap_conv ctxt af_inst normalize normalize then_conv norm_nf_nf) else_conv
      pure_conv ctxt af_inst Conv.all_conv else_conv
      leaf_conv) ct;
  in normalize ct end;
val normalize_rel_tac = binop_par_conv_tac o apply2 oo normalform_conv;
datatype apterm =
    Pure of term  
  | ApVar of int * term  
  | Ap of apterm * apterm;
fun apterm_vars (Pure _) = I
  | apterm_vars (ApVar v) = cons v
  | apterm_vars (Ap (t1, t2)) = apterm_vars t1 #> apterm_vars t2;
fun occurs_any _ (Pure _) = false
  | occurs_any vs (ApVar (i, _)) = exists (fn j => i = j) vs
  | occurs_any vs (Ap (t1, t2)) = occurs_any vs t1 orelse occurs_any vs t2;
fun term_of_apterm ctxt af_inst t =
  let
    fun tm_of (Pure t) = t
      | tm_of (ApVar (_, t)) = t
      | tm_of (Ap (t1, t2)) = infer_comb ctxt af_inst (tm_of t1, tm_of t2);
  in tm_of t end;
fun apterm_of_term ctxt af_inst t =
  let
    fun aptm_of t i = case try (dest_comb ctxt af_inst) t of
        SOME (t1, t2) => i |> aptm_of t1 ||>> aptm_of t2 |>> Ap
      | NONE => if can (dest_pure ctxt af_inst) t
          then (Pure t, i)
          else (ApVar (i, t), i + 1);
  in aptm_of t end;
fun consolidate ctxt af (t1, t2) =
  let
    fun common_inst (i, t) (j, insts) = case Termtab.lookup insts t of
        SOME k => (((i, t), k), (j, insts))
      | NONE => (((i, t), j), (j + 1, Termtab.update (t, j) insts));
    val (vars, _) = (0, Termtab.empty)
      |> fold_map common_inst (apterm_vars t1 [])
      ||>> fold_map common_inst (apterm_vars t2 []);
    fun merge_adjacent (([], _), _) [] = []
      | merge_adjacent ((is, t), d) [] = [((is, t), d)]
      | merge_adjacent (([], _), _) (((i, t), d)::xs) = merge_adjacent (([i], t), d) xs
      | merge_adjacent ((is, t), d) (((i', t'), d')::xs) = if d = d'
          then merge_adjacent ((i'::is, t), d) xs
          else ((is, t), d) :: merge_adjacent (([i'], t'), d') xs;
    fun align _ [] = NONE
      | align ((i, t), d) (((i', t'), d')::xs) = if d = d'
          then SOME ([((i @ i', t), d)], xs)
          else Option.map (apfst (cons ((i', t'), d'))) (align ((i, t), d) xs);
    fun merge ([], ys) = ys
      | merge (xs, []) = xs
      | merge ((xs as ((is1, t1), d1)::xs'), ys as (((is2, t2), d2)::ys')) = if d1 = d2
          then ((is1 @ is2, t1), d1) :: merge (xs', ys')
          else case (align ((is2, t2), d2) xs, align ((is1, t1), d1) ys) of
              (SOME (zs, xs''), NONE) => zs @ merge (xs'', ys')
            | (NONE, SOME (zs, ys'')) => zs @ merge (xs', ys'')
            | _ => ((is1, t1), d1) :: ((is2, t2), d2) :: merge (xs', ys');
    fun unbalanced vs = error ("Unbalanced opaque terms " ^
      commas_quote (map (Syntax.string_of_term ctxt o #2 o #1) vs));
    fun mismatch (t1, t2) = error ("Mismatched opaque terms " ^
      quote (Syntax.string_of_term ctxt t1) ^ " and " ^ quote (Syntax.string_of_term ctxt t2));
    fun same ([], []) = []
      | same ([], ys) = unbalanced ys
      | same (xs, []) = unbalanced xs
      | same ((((i1, t1), d1)::xs), (((i2, t2), d2)::ys)) = if d1 = d2
          then ((i1 @ i2, t1), d1) :: same (xs, ys)
          else mismatch (t1, t2);
  in vars
    |> has_red_afun af "C" ? apply2 (sort (int_ord o apply2 #2))
    |> apply2 (if has_red_afun af "W"
        then merge_adjacent (([], Term.dummy), 0)
        else map (apfst (apfst single)))
    |> (if has_red_afun af "K" then merge else same)
    |> map #1
  end;
fun ap_cong ctxt af_inst thm1 thm2 =
  let
    val funT = the_default (dummyT --> dummyT)
      (dest_type ctxt af_inst (Thm.typ_of_cterm (Thm.lhs_of thm1)));
    val ap_inst = Thm.cterm_of ctxt (mk_ap af_inst (dest_funT funT));
  in Drule.binop_cong_rule ap_inst thm1 thm2 end;
fun rewr_subst_ap ctxt af_inst rewr thm1 thm2 =
  let
    val rule1 = ap_cong ctxt af_inst thm1 thm2;
    val rule2 = Conv.rewr_conv rewr (Thm.rhs_of rule1);
  in Thm.transitive rule1 rule2 end;
fun merge_pures ctxt af_inst merge_thm tt =
  let
    fun merge (Pure t) = SOME (Thm.reflexive (Thm.cterm_of ctxt t))
      | merge (ApVar _) = NONE
      | merge (Ap (tt1, tt2)) = case merge tt1 of
          NONE => NONE
        | SOME thm1 => case merge tt2 of
            NONE => NONE
          | SOME thm2 => SOME (rewr_subst_ap ctxt af_inst merge_thm thm1 thm2);
  in merge tt end;
exception ASSERT of string;
fun eliminate ctxt (af, af_inst) tt (v, v_tm) =
  let
    val {hom, ichng, ...} = thms_of_afun af;
    val the_red = the o red_of_afun af;
    val hom_conv = mk_meta_eq hom;
    val ichng_conv = mk_meta_eq ichng;
    val mk_combI = Thm.symmetric o mk_meta_eq;
    val id_conv = mk_combI (the_red "I");
    val comp_conv = mk_combI (the_red "B");
    val flip_conv = Option.map mk_combI (red_of_afun af "C");
    val const_conv = Option.map mk_combI (red_of_afun af "K");
    val dup_conv = Option.map mk_combI (red_of_afun af "W");
    val rewr_subst_ap = rewr_subst_ap ctxt af_inst;
    fun extract_comb n thm = Pure (thm |> Thm.rhs_of |> funpow n Thm.dest_arg1 |> Thm.term_of);
    fun refl_step tt = (tt, Thm.reflexive (Thm.cterm_of ctxt (term_of_apterm ctxt af_inst tt)));
    fun comb2_step def (tt1, thm1) (tt2, thm2) =
      let val thm = rewr_subst_ap def thm1 thm2;
      in (Ap (Ap (extract_comb 3 thm, tt1), tt2), thm) end;
    val B_step = comb2_step comp_conv;
    fun swap_B_step (tt1, thm1) thm2 =
      let
        val thm3 = rewr_subst_ap ichng_conv thm1 thm2;
        val thm4 = Thm.transitive thm3 (Conv.rewr_conv comp_conv (Thm.rhs_of thm3));
      in (Ap (Ap (extract_comb 3 thm4, extract_comb 1 thm3), tt1), thm4) end;
    fun I_step tm =
      let val thm = Conv.rewr_conv id_conv (Thm.cterm_of ctxt tm)
      in (extract_comb 1 thm, thm) end;
    fun W_step s1 s2 =
      let
        val (Ap (Ap (tt1, tt2), tt3), thm1) = B_step s1 s2;
        val thm2 = Conv.rewr_conv comp_conv (Thm.rhs_of thm1 |> funpow 2 Thm.dest_arg1);
        val thm3 = merge_pures ctxt af_inst hom_conv tt3 |> the;
        val (tt4, thm4) = swap_B_step (Ap (Ap (extract_comb 3 thm2, tt1), tt2), thm2) thm3;
        val var = Thm.rhs_of thm1 |> Thm.dest_arg;
        val thm5 = rewr_subst_ap (the dup_conv) thm4 (Thm.reflexive var);
        val thm6 = Thm.transitive thm1 thm5;
      in (Ap (extract_comb 2 thm6, tt4), thm6) end;
    fun S_step s1 s2 =
      let
        val (Ap (Ap (tt1, tt2), tt3), thm1) = comb2_step (the flip_conv) s1 s2;
        val thm2 = Conv.rewr_conv comp_conv (Thm.rhs_of thm1 |> Thm.dest_arg1);
        val var = Thm.rhs_of thm1 |> Thm.dest_arg;
        val thm3 = rewr_subst_ap (the dup_conv) thm2 (Thm.reflexive var);
        val thm4 = Thm.transitive thm1 thm3;
        val tt = Ap (extract_comb 2 thm4, Ap (Ap (extract_comb 3 thm2, Ap (tt1, tt2)), tt3));
      in (tt, thm4) end;
    fun K_step tt tm =
      let
        val ct = Thm.cterm_of ctxt tm;
        val T_opt = Term.fastype_of tm |> dest_type ctxt af_inst |> Option.map (Thm.ctyp_of ctxt);
        val thm = Thm.instantiate' [T_opt] [SOME ct]
          (Conv.rewr_conv (the const_conv) (term_of_apterm ctxt af_inst tt |> Thm.cterm_of ctxt))
      in (Ap (extract_comb 2 thm, tt), thm) end;
    fun unreachable _ = raise ASSERT "eliminate: assertion failed";
    fun elim (Pure _) = unreachable ()
      | elim (ApVar (i, t)) = if exists (fn x => x = i) v then I_step t else unreachable ()
      | elim (Ap (t1, t2)) = (case (occurs_any v t1, occurs_any v t2) of
            (false, false) => unreachable ()
          | (false, true) => B_step (refl_step t1) (elim t2)
          | (true, false) => (case merge_pures ctxt af_inst hom_conv t2 of
                SOME thm => swap_B_step (elim t1) thm
              | NONE => comb2_step (the flip_conv) (elim t1) (refl_step t2))
          | (true, true) => if is_some flip_conv
              then S_step (elim t1) (elim t2)
              else W_step (elim t1) (elim t2));
  in if occurs_any v tt
    then elim tt
    else K_step tt v_tm
  end;
fun general_normalform_conv ctxt af cts =
  let
    val (t1, t2) = apply2 (Thm.term_of) cts;
    val maxidx = Int.max (apply2 Thm.maxidx_of_cterm cts);
    
    val af_inst = match_afun_inst ctxt af (t1, maxidx);
    val ((apt1, apt2), _) = 0 |> apterm_of_term ctxt af_inst t1 ||>> apterm_of_term ctxt af_inst t2;
    val vs = consolidate ctxt af (apt1, apt2);
    val merge_thm = mk_meta_eq (#hom (thms_of_afun af));
    fun elim_all tt [] = the (merge_pures ctxt af_inst merge_thm tt)
      | elim_all tt (v::vs) =
          let
            val (tt', rule1) = eliminate ctxt (af, af_inst) tt v;
            val rule2 = elim_all tt' vs;
            val (_, vartm) = dest_comb ctxt af_inst (Thm.term_of (Thm.rhs_of rule1));
            val rule3 = ap_cong ctxt af_inst rule2 (Thm.reflexive (Thm.cterm_of ctxt vartm));
          in Thm.transitive rule1 rule3 end;
  in (elim_all apt1 vs, elim_all apt2 vs) end;
val general_normalize_rel_tac = binop_par_conv_tac oo general_normalform_conv;
fun rename_params names i st =
  let
    val (_, Bs, Bi, C) = Thm.dest_state (st, i);
    val Bi' = Logic.list_rename_params names Bi;
  in Thm.renamed_prop (Logic.list_implies (Bs @ [Bi'], C)) st end;
fun head_cong_tac ctxt af renames =
  let
    val {rel_intros, ...} = thms_of_afun af;
    fun term_name tm = case AList.lookup (op aconv) renames tm of
        SOME n => n
      | NONE => term_to_vname tm;
    fun gather_vars' af_inst tm = case try (dest_comb ctxt af_inst) tm of
        SOME (t1, t2) => term_name t2 :: gather_vars' af_inst t1
      | NONE => [];
    fun gather_vars prop = case prop of
        Const (@{const_name Trueprop}, _) $ (_ $ rhs) =>
          rev (gather_vars' (match_afun_inst ctxt af (rhs, maxidx_of_term prop)) rhs)
      | _ => [];
  in SUBGOAL (fn (subgoal, i) =>
    (REPEAT_DETERM (resolve_tac ctxt rel_intros i) THEN
      REPEAT_DETERM (resolve_tac ctxt [ext, @{thm rel_fun_eq_onI}] i ORELSE
        eresolve_tac ctxt [@{thm UNIV_E}] i) THEN
      PRIMITIVE (rename_params (gather_vars subgoal) i)))
  end;
fun forward_lift_rule ctxt af thm =
  let
    val thm = Object_Logic.rulify ctxt thm;
    val (af_inst, ctxt_inst) = import_afun_inst af ctxt;
    val (prop, ctxt_Ts) = yield_singleton Variable.importT_terms (Thm.prop_of thm) ctxt_inst;
    val (lhs, rhs) = prop |> HOLogic.dest_Trueprop |> HOLogic.dest_eq;
    val ([lhs', rhs'], ctxt_lifted) = generalize_lift_terms af_inst [lhs, rhs] ctxt_Ts;
    val lifted = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs', rhs'));
    val (lifted', ctxt') = yield_singleton (Variable.import_terms true) lifted ctxt_lifted;
    fun tac {prems, context} = HEADGOAL (general_normalize_rel_tac context af THEN'
      head_cong_tac context af [] THEN'
      resolve_tac context [prems MRS thm]);
    val thm' = singleton (Variable.export ctxt' ctxt)
      (Goal.prove ctxt' [] [] lifted' tac);
    val thm'' = Simplifier.fold_rule ctxt (unfolds_of_afun af) thm';
  in thm'' end;
fun forward_lift_attrib name =
  Thm.rule_attribute [] (fn context => fn thm =>
    let val af = afun_of_generic context (intern context name)  
    in forward_lift_rule (Context.proof_of context) af thm end);
fun unfold_wrapper_tac ctxt = AUTO_AFUNS false (fn afs =>
  Simplifier.safe_asm_full_simp_tac (ctxt addsimps flat (map unfolds_of_afun afs))) ctxt;
fun fold_wrapper_tac ctxt = AUTO_AFUN true (fold_goal_tac ctxt o unfolds_of_afun) ctxt;
fun WRAPPER tac ctxt opt_af =
  REPEAT_DETERM o resolve_tac ctxt [@{thm allI}] THEN'
  Subgoal.FOCUS (fn {context = ctxt, params, ...} =>
    let val renames = map (swap o apsnd Thm.term_of) params
    in
      AUTO_AFUNS false (EVERY' o map (afun_unfold_tac ctxt)) ctxt opt_af 1 THEN
      AUTO_AFUN true (fn af =>
        afun_unfold_tac ctxt af THEN'
        CONVERSION Drule.beta_eta_conversion THEN'
        tac ctxt af THEN'
        head_cong_tac ctxt af renames) ctxt opt_af 1
    end) ctxt THEN'
  Simplifier.rewrite_goal_tac ctxt [Drule.triv_forall_equality];
val normalize_wrapper_tac = WRAPPER normalize_rel_tac;
val lifting_wrapper_tac = WRAPPER general_normalize_rel_tac;
val parse_opt_afun = Scan.peek (fn context =>
  Scan.option Parse.name >> Option.map (intern context #> afun_of_generic context));
fun declare_combinators combs phi =
  let
    val (names, thms) = split_list combs;
    val thms' = map (Morphism.thm phi) thms;
    fun add_combs (defs, rules) = (fold (Symtab.insert (K false)) (names ~~ thms') defs, rules);
  in Data.map (map_data add_combs I I) end;
val setup_combinators =
  Local_Theory.declaration {syntax = false, pervasive = false, pos = ⌂} o declare_combinators;
fun combinator_of_red thm =
  let
    val (lhs, _) = Logic.dest_equals (Thm.prop_of thm);
    val (head, _) = strip_comb lhs;
  in #1 (dest_Const head) end;
fun register_combinator_rule weak_premises thm context =
  let
    val (lhs, rhs) = Logic.dest_equals (Thm.prop_of thm);
    val ltvars = Term.add_tvars lhs [];
    val rtvars = Term.add_tvars rhs [];
    val _ = if exists (not o member op = ltvars) rtvars
      then Pretty.breaks
         [Pretty.str "Combinator equation",
          Pretty.quote (Syntax.pretty_term (Context.proof_of context) (Thm.prop_of thm)),
          Pretty.str "has additional type variables on right-hand side."]
        |> Pretty.block |> Pretty.string_of |> error
      else ();
    val (defs, _) = #combinators (Data.get context);
    val comb_names =
      Symtab.make (map (fn (name, thm) => (combinator_of_red thm, name)) (Symtab.dest defs));
    val rule = mk_combinator_rule comb_names weak_premises thm;
    fun add_rule (defs, rules) = (defs, insert eq_combinator_rule rule rules);
  in Data.map (map_data add_rule I I) context end;
val combinator_rule_attrib = Thm.declaration_attribute o register_combinator_rule;
fun combinator_closure rules have_weak combs =
  let
    fun apply rule (cs, changed) =
      if not (Ord_List.member fast_string_ord cs (conclusion_of_rule rule)) andalso
        is_applicable_rule rule have_weak (fn prems => Ord_List.subset fast_string_ord (prems, cs))
      then (Ord_List.insert fast_string_ord (conclusion_of_rule rule) cs, true)
      else (cs, changed);
    fun loop cs =
      (case fold apply rules (cs, false) of
        (cs', true) => loop cs'
      | (_, false) => cs);
  in loop combs end;
fun derive_combinator_red ctxt af_inst red_thms (base_thm, eq_thm) =
  let
    val base_prop = Thm.prop_of base_thm;
    val tvars = Term.add_tvars base_prop [];
    val (Ts, ctxt_Ts) = mk_TFrees_of (length tvars) (inner_sort_of af_inst) ctxt;
    val base_prop' = base_prop |> Term_Subst.instantiate (TVars.make (tvars ~~ Ts), Vars.empty);
    val (lhs, rhs) = Logic.dest_equals base_prop';
    val ([lhs', rhs'], ctxt') = generalize_lift_terms af_inst [lhs, rhs] ctxt_Ts;
    val lifted_prop = (lhs', rhs') |> HOLogic.mk_eq |> HOLogic.mk_Trueprop;
    val unfold_comb_conv = HOLogic.Trueprop_conv
      (HOLogic.eq_conv (Conv.top_sweep_rewrs_conv [eq_thm] ctxt') Conv.all_conv);
    fun tac goal_ctxt =
      HEADGOAL (CONVERSION unfold_comb_conv THEN'
      Simplifier.rewrite_goal_tac goal_ctxt red_thms THEN'
      resolve_tac goal_ctxt [@{thm refl}]);
  in
    singleton (Variable.export ctxt' ctxt) (Goal.prove ctxt' [] [] lifted_prop (tac o #context))
  end;
fun weak_red_closure ctxt (af_inst, merge_thm) strong_red =
  let
    val (lhs, _) = Thm.prop_of strong_red |> Logic.dest_equals;
    val vars = rev (Term.add_vars lhs []);
    fun closure [] prev thms = (prev::thms)
      | closure ((v, af_T)::vs) prev thms =
          (case try (dest_type ctxt af_inst) af_T of
            NONE => closure vs prev thms
          | SOME T_opt =>
            let
              val (T, ctxt') = (case T_opt of
                  NONE => yield_singleton Variable.invent_types (inner_sort_of af_inst) ctxt
                    |>> TFree
                | SOME T => (T, ctxt));
              val (v', ctxt'') = mk_Free (#1 v) T ctxt';
              val pure_v = Thm.cterm_of ctxt'' (lift_term af_inst v');
              val next = Drule.instantiate_normalize (TVars.empty, Vars.make [((v, af_T), pure_v)]) prev;
              val next' = Simplifier.rewrite_rule ctxt'' [merge_thm] next;
              val next'' = singleton (Variable.export ctxt'' ctxt) next';
            in closure vs next'' (prev::thms) end);
   in closure vars strong_red [] end;
fun combinator_red_closure ctxt (comb_defs, rules) (af_inst, merge_thm) weak_reds combs =
  let
    val have_weak = not (null weak_reds);
    val red_thms0 = Symtab.fold (fn (_, thm) => cons (mk_meta_eq thm)) combs weak_reds;
    val red_thms = flat (map (weak_red_closure ctxt (af_inst, merge_thm)) red_thms0);
    fun apply rule ((cs, rs), changed) =
      if not (Symtab.defined cs (conclusion_of_rule rule)) andalso
        is_applicable_rule rule have_weak (forall (Symtab.defined cs))
      then
        let
          val conclusion = conclusion_of_rule rule;
          val def = the (Symtab.lookup comb_defs conclusion);
          val new_red_thm = derive_combinator_red ctxt af_inst rs (def, thm_of_rule rule);
          val new_red_thms = weak_red_closure ctxt (af_inst, merge_thm) (mk_meta_eq new_red_thm);
        in ((Symtab.update (conclusion, new_red_thm) cs, new_red_thms @ rs), true) end
      else ((cs, rs), changed);
    fun loop xs =
      (case fold apply rules (xs, false) of
        (xs', true) => loop xs'
      | (_, false) => xs);
  in #1 (loop (combs, red_thms)) end;
fun mk_terms ctxt (raw_pure, raw_ap, raw_rel, raw_set) =
  let
    val thy = Proof_Context.theory_of ctxt;
    val show_typ = quote o Syntax.string_of_typ ctxt;
    val show_term = quote o Syntax.string_of_term ctxt;
    fun closed_poly_term t =
      let val poly_t = singleton (Variable.polymorphic ctxt) t;
      in case Term.add_vars (singleton
          (Variable.export_terms (Proof_Context.augment t ctxt) ctxt) t) [] of
          [] => (case (Term.hidden_polymorphism poly_t) of
              [] => poly_t
            | _ => error ("Hidden type variables in term " ^ show_term t))
        | _ => error ("Locally free variables in term " ^ show_term t)
      end;
    val pure = closed_poly_term raw_pure;
    val (tvar, T1) = fastype_of pure |> dest_funT |>> dest_TVar
      handle TYPE _ => error ("Bad type for pure: " ^ show_typ (fastype_of pure));
    val maxidx_pure = maxidx_of_term pure;
    val ap = Logic.incr_indexes ([], maxidx_pure + 1) (closed_poly_term raw_ap);
    fun bad_ap _ = error ("Bad type for ap: " ^ show_typ (fastype_of ap));
    val (T23, (T2, T3)) = fastype_of ap |> dest_funT ||> dest_funT
      handle TYPE _ => bad_ap ();
    val maxidx_common = Term.maxidx_term ap maxidx_pure;
    
    fun no_unifier (T, U) = error ("Unable to infer common functor type from " ^
      commas (map show_typ [T, U]));
    fun unify_ap_type T (tyenv, maxidx) =
      let
        val argT = TVar ((Name.aT, maxidx + 1), []);
        val T1' = Term_Subst.instantiateT (TVars.make [(tvar, argT)]) T1;
        val (tyenv', maxidx') = Sign.typ_unify thy (T1', T) (tyenv, maxidx + 1)
          handle Type.TUNIFY => no_unifier (T1', T);
      in (argT, (tyenv', maxidx')) end;
    val (ap_args, (ap_env, maxidx_env)) =
      fold_map unify_ap_type [T2, T3, T23] (Vartab.empty, maxidx_common);
    val [T2_arg, T3_arg, T23_arg] = map (Envir.norm_type ap_env) ap_args;
    val (tvar2, tvar3) = (dest_TVar T2_arg, dest_TVar T3_arg) handle TYPE _ => bad_ap ();
    val _ = if T23_arg = T2_arg --> T3_arg then () else bad_ap ();
    val sort = foldl1 (Sign.inter_sort thy) (map #2 [tvar, tvar2, tvar3]);
    val _ = Sign.of_sort thy (Term.aT sort --> Term.aT sort, sort) orelse
      error ("Sort constraint " ^ quote (Syntax.string_of_sort ctxt sort) ^
        " not closed under function types");
    fun update_sort (v, S) (tyenv, maxidx) =
      (Vartab.update_new (v, (S, TVar ((Name.aT, maxidx + 1), sort))) tyenv, maxidx + 1);
    val (common_env, _) = fold update_sort [tvar, tvar2, tvar3] (ap_env, maxidx_env);
    val tvar' = Envir.norm_type common_env (TVar tvar);
    val pure' = norm_term_types common_env pure;
    val (tvar2', tvar3') = apply2 (Envir.norm_type common_env) (T2_arg, T3_arg);
    val ap' = norm_term_types common_env ap;
    fun bad_set set = error ("Bad type for set: " ^ show_typ (fastype_of set));
    fun mk_set set =
      let
        val tyenv = Sign.typ_match thy (domain_type (fastype_of set), range_type (fastype_of pure'))
          Vartab.empty
          handle Type.TYPE_MATCH => bad_set set;
        val set' = Envir.subst_term_types tyenv set;
        val set_tvar = fastype_of set' |> range_type |> HOLogic.dest_setT |> dest_TVar
          handle TYPE _ => bad_set set;
        val _ = if Term.eq_tvar (dest_TVar tvar', set_tvar) then () else bad_set set;
      in ([tvar'], set') end
    val set = (case raw_set of
        NONE => ([tvar'], Abs ("x", tvar', HOLogic.mk_UNIV tvar'))
      | SOME t => mk_set (closed_poly_term t));
    val terms = Term_Subst.zero_var_indexes (pack_poly_terms
      [poly_type_to_term ([tvar'], range_type (fastype_of pure')),
      ([tvar'], pure'), ([tvar2', tvar3'], ap'), set]);
    
    fun bad_rel rel = error ("Bad type for rel: " ^ show_typ (fastype_of rel));
    fun mk_rel rel =
      let
        val ((T1, T2), (T1_af, T2_af)) = fastype_of rel
          |> dest_funT
          |>> BNF_Util.dest_pred2T
          ||> BNF_Util.dest_pred2T;
        val _ = (dest_TVar T1; dest_TVar T2);
        val _ = if T1 = T2 then bad_rel rel else ();
        val af_inst = mk_afun_inst (match_poly_terms_type ctxt (terms, 0) (T1_af, maxidx_of_term rel));
        val (T1', T2') = apply2 (dest_type ctxt af_inst) (T1_af, T2_af);
        val _ = if (is_none T1' andalso is_none T2') orelse (T1' = SOME T1 andalso T2' = SOME T2)
          then () else bad_rel rel;
      in Term_Subst.zero_var_indexes (pack_poly_terms [([T1, T2], rel)]) end
      handle TYPE _ => bad_rel rel;
    val rel = Option.map (mk_rel o closed_poly_term) raw_rel;
  in (terms, rel) end;
fun mk_rel_intros {pure_transfer, ap_rel_fun} =
  let val pure_rel_intro = pure_transfer RS @{thm rel_funD};
  in [pure_rel_intro, ap_rel_fun] end;
fun mk_afun_thms ctxt af_inst (hom_thm, ichng_thm, reds, rel_axioms) =
  let
    val pure_comp_conv =
      let
        val ([T1, T2, T3], ctxt_Ts) = mk_TFrees_of 3 (inner_sort_of af_inst) ctxt;
        val (((g, f), x), ctxt') = ctxt_Ts
          |> mk_Free "g" (T2 --> T3)
          ||>> mk_Free "f" (mk_type af_inst (T1 --> T2))
          ||>> mk_Free "x" (mk_type af_inst T1);
        val comb = mk_comb af_inst;
        val lhs = comb (T2 --> T3) (lift_term af_inst g, comb (T1 --> T2) (f, x));
        val B_g = Abs ("f", T1 --> T2, Abs ("x", T1, Term.betapply (g, Bound 1 $ Bound 0)));
        val rhs = comb (T1 --> T3)
          (comb ((T1 --> T2) --> T1 --> T3) (lift_term af_inst B_g, f), x);
        val prop = HOLogic.mk_eq (lhs, rhs) |> HOLogic.mk_Trueprop;
        val merge_rule = mk_meta_eq hom_thm;
        val B_intro = the (Symtab.lookup reds "B") |> mk_meta_eq |> Thm.symmetric;
        fun tac goal_ctxt =
          HEADGOAL (Simplifier.rewrite_goal_tac goal_ctxt [B_intro, merge_rule] THEN'
          resolve_tac goal_ctxt [@{thm refl}]);
      in
        singleton (Variable.export ctxt' ctxt) (Goal.prove ctxt' [] [] prop (tac o #context))
      end;
    val eq_intros =
      let
        val ([T1, T2], ctxt_Ts) = mk_TFrees_of 2 (inner_sort_of af_inst) ctxt;
        val T12 = mk_type af_inst (T1 --> T2);
        val (((((x, y), x'), f), g), ctxt') = ctxt_Ts
          |> mk_Free "x" T1
          ||>> mk_Free "y" T1
          ||>> mk_Free "x" (mk_type af_inst T1)
          ||>> mk_Free "f" T12
          ||>> mk_Free "g" T12;
        val pure_fun = mk_pure af_inst T1;
        val pure_cong = Drule.infer_instantiate' ctxt'
          (map (SOME o Thm.cterm_of ctxt') [x, y, pure_fun]) @{thm arg_cong};
        val ap_fun = mk_ap af_inst (T1, T2);
        val ap_cong1 = Drule.infer_instantiate' ctxt'
          (map (SOME o Thm.cterm_of ctxt')  [f, g, ap_fun, x']) @{thm arg1_cong};
      in Variable.export ctxt' ctxt [pure_cong, ap_cong1] end;
    val rel_intros = case rel_axioms of
        NONE => []
      | SOME axioms => mk_rel_intros axioms;
  in
    {hom = hom_thm,
      ichng = ichng_thm,
      reds = reds,
      rel_thms = rel_axioms,
      rel_intros = eq_intros @ rel_intros,
      pure_comp_conv = pure_comp_conv}
  end;
fun reuse_TFrees n S (ctxt, Ts) =
  let
    val have_n = Int.min (n, length Ts);
    val (more_Ts, ctxt') = mk_TFrees_of (n - have_n) S ctxt;
  in (take have_n Ts @ more_Ts, (ctxt', Ts @ more_Ts)) end;
fun mk_comb_prop lift_pos thm af_inst ctxt_Ts =
  let
    val base = Thm.prop_of thm;
    val tvars = Term.add_tvars base [];
    val (Ts, (ctxt', Ts')) = reuse_TFrees (length tvars) (inner_sort_of af_inst) ctxt_Ts;
    val base' = base
      |> Term_Subst.instantiate (TVars.make (tvars ~~ Ts), Vars.empty);
    val (lhs, rhs) = Logic.dest_equals base';
    val (_, lhs_args) = strip_comb lhs;
    val lift_var = Var o apsnd (mk_type af_inst) o dest_Var;
    val (lhs_args', subst) = fold_index (fn (i, v) =>
      if member (op =) lift_pos i then apfst (cons v)
      else map_prod (cons (lift_var v)) (cons (v, lift_var v))) lhs_args ([], []);
    val (lhs', rhs') = apply2 (subst_lift_term af_inst subst) (lhs, rhs);
    val lifted = (lhs', rhs') |> HOLogic.mk_eq |> HOLogic.mk_Trueprop;
  in (fold Logic.all lhs_args' lifted, (ctxt', Ts')) end;
fun mk_homomorphism_prop af_inst ctxt_Ts =
  let
    val ([T1, T2], (ctxt', Ts')) = reuse_TFrees 2 (inner_sort_of af_inst) ctxt_Ts;
    val ((f, x), _) = ctxt'
      |> mk_Free "f" (T1 --> T2)
      ||>> mk_Free "x" T1;
    val lhs = mk_comb af_inst (T1 --> T2) (lift_term af_inst f, lift_term af_inst x);
    val rhs = lift_term af_inst (f $ x);
    val prop = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs));
  in (Logic.all f (Logic.all x prop), (ctxt', Ts')) end;
fun mk_interchange_prop af_inst ctxt_Ts =
  let
    val ([T1, T2], (ctxt', Ts')) = reuse_TFrees 2 (inner_sort_of af_inst) ctxt_Ts;
    val ((f, x), _) = ctxt'
      |> mk_Free "f" (mk_type af_inst (T1 --> T2))
      ||>> mk_Free "x" T1;
    val lhs = mk_comb af_inst (T1 --> T2) (f, lift_term af_inst x);
    val T_x = Abs ("f", T1 --> T2, Bound 0 $ x);
    val rhs = mk_comb af_inst ((T1 --> T2) --> T2) (lift_term af_inst T_x, f);
    val prop = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs));
  in (Logic.all f (Logic.all x prop), (ctxt', Ts')) end;
fun mk_rel_props (af_inst, rel_inst) ctxt_Ts =
  let
    fun mk_af_rel tm =
      let val (T1, T2) = BNF_Util.dest_pred2T (fastype_of tm);
      in betapply (instantiate_poly_term rel_inst [T1, T2], tm) end;
    val ([T1, T2, T3], (ctxt', Ts')) = reuse_TFrees 3 (inner_sort_of af_inst) ctxt_Ts;
    val (pure_R, _) = mk_Free "R" (T1 --> T2 --> @{typ bool}) ctxt';
    val rel_pure = BNF_Util.mk_rel_fun pure_R (mk_af_rel pure_R) $ mk_pure af_inst T1 $
      mk_pure af_inst T2;
    val pure_prop = Logic.all pure_R (HOLogic.mk_Trueprop rel_pure);
    val ((((f, g), x), ap_R), _) = ctxt'
      |> mk_Free "f" (mk_type af_inst (T1 --> T2))
      ||>> mk_Free "g" (mk_type af_inst (T1 --> T3))
      ||>> mk_Free "x" (mk_type af_inst T1)
      ||>> mk_Free "R" (T2 --> T3 --> @{typ bool});
    val fun_rel = BNF_Util.mk_rel_fun (mk_eq_on (mk_set af_inst T1 $ x)) ap_R;
    val rel_ap = Logic.mk_implies (HOLogic.mk_Trueprop (mk_af_rel fun_rel $ f $ g),
      HOLogic.mk_Trueprop (mk_af_rel ap_R $ mk_comb af_inst (T1 --> T2) (f, x) $
        mk_comb af_inst (T1 --> T3) (g, x)));
    val ap_prop = fold_rev Logic.all [ap_R, f, g, x] rel_ap;
  in ([pure_prop, ap_prop], (ctxt', Ts')) end;
fun mk_interchange ctxt ((comb_defs, _), comb_unfolds) (af_inst, merge_thm) reds =
  let
    val T_def = the (Symtab.lookup comb_defs "T");
    val T_red = the (Symtab.lookup reds "T");
    val (weak_prop, (ctxt', _)) = mk_comb_prop [0] T_def af_inst (ctxt, []);
    fun tac goal_ctxt =
      HEADGOAL (Simplifier.rewrite_goal_tac goal_ctxt [Thm.symmetric merge_thm] THEN'
      resolve_tac goal_ctxt [T_red]);
    val weak_red =
      singleton (Variable.export ctxt' ctxt) (Goal.prove ctxt' [] [] weak_prop (tac o #context));
  in Simplifier.rewrite_rule ctxt (comb_unfolds) weak_red RS sym end;
fun mk_weak_reds ctxt ((comb_defs, _), comb_unfolds) af_inst (hom_thm, ichng_thm, reds) =
  let
    val unfolded_reds =
      Symtab.map (K (Simplifier.rewrite_rule ctxt comb_unfolds)) reds;
    val af_thms = mk_afun_thms ctxt af_inst (hom_thm, ichng_thm, unfolded_reds, NONE);
    val af = mk_afun Binding.empty (pack_afun_inst af_inst) NONE af_thms;
    fun tac goal_ctxt =
      HEADGOAL (normalize_wrapper_tac goal_ctxt (SOME af) THEN'
      Simplifier.rewrite_goal_tac goal_ctxt comb_unfolds THEN'
      resolve_tac goal_ctxt [refl]);
    fun mk comb lift_pos =
      let
        val def = the (Symtab.lookup comb_defs comb);
        val (prop, (ctxt', _)) = mk_comb_prop lift_pos def af_inst (ctxt, []);
        val hol_thm =
          singleton (Variable.export ctxt' ctxt) (Goal.prove ctxt' [] [] prop (tac o #context));
      in mk_meta_eq hol_thm end;
    val uncurry_thm = mk_meta_eq (forward_lift_rule ctxt af @{thm uncurry_pair});
  in
    [mk "C" [1], mk "C" [2], uncurry_thm]
  end;
fun mk_comb_reds ctxt combss af_inst user_combs (hom_thm, user_thms, ichng_thms) =
  let
    val ((comb_defs, comb_rules), comb_unfolds) = combss;
    val merge_thm = mk_meta_eq hom_thm;
    val user_reds = Symtab.make (user_combs ~~ user_thms);
    val reds0 = combinator_red_closure ctxt (comb_defs, comb_rules) (af_inst, merge_thm) [] user_reds;
    val ichng_thm = case ichng_thms of
        [] => singleton (Variable.export ctxt ctxt) (mk_interchange ctxt combss (af_inst, merge_thm) reds0)
      | [thm] => thm;
    val weak_reds = mk_weak_reds ctxt combss af_inst (hom_thm, ichng_thm, reds0);
    val reds1 = combinator_red_closure ctxt (comb_defs, comb_rules) (af_inst, merge_thm) weak_reds reds0;
    val unfold = Simplifier.rewrite_rule ctxt comb_unfolds;
  in (Symtab.map (K unfold) reds1, ichng_thm) end;
fun note_afun_thms af =
  let
    val thms = thms_of_afun af;
    val named_thms =
      [("homomorphism", [#hom thms]),
        ("interchange", [#ichng thms]),
        ("afun_rel_intros", #rel_intros thms)] @
      map (fn (name, thm) => ("pure_" ^ name ^ "_conv", [thm])) (Symtab.dest (#reds thms)) @
      (case #rel_thms thms of
        NONE => []
      | SOME rel_thms' =>
          [("pure_transfer", [#pure_transfer rel_thms']),
            ("ap_rel_fun_cong", [#ap_rel_fun rel_thms'])]);
    val base_name = Binding.name_of (name_of_afun af);
    fun mk_note (name, thms) =
      ((Binding.qualify true base_name (Binding.name name), []), [(thms, [])]);
  in Local_Theory.notes (map mk_note named_thms) #> #2 end;
fun register_afun af =
  let fun decl phi context = Data.map (fn {combinators, afuns, patterns} =>
    let
      val af' = morph_afun phi af;
      val (name, afuns') = Name_Space.define context true (name_of_afun af', af') afuns;
      val patterns' = Item_Net.update (name, patterns_of_afun af') patterns;
    in {combinators = combinators, afuns = afuns', patterns = patterns'} end) context;
  in Local_Theory.declaration {syntax = false, pervasive = false, pos = ⌂} decl end;
fun applicative_cmd (((((name, flags), raw_pure), raw_ap), raw_rel), raw_set) lthy =
  let
    val comb_unfolds = Named_Theorems.get lthy @{named_theorems combinator_unfold};
    val comb_reprs = Named_Theorems.get lthy @{named_theorems combinator_repr};
    val (comb_defs, comb_rules) = get_combinators (Context.Proof lthy);
    val _ = fold (fn name =>
      if Symtab.defined comb_defs name then I else error ("Unknown combinator " ^ quote name))
      flags ();
    val _ = if has_duplicates op = flags
      then warning "Ignoring duplicate combinators"
      else ();
    val user_combs0 = Ord_List.make fast_string_ord flags;
    val raw_pure' = Syntax.read_term lthy raw_pure;
    val raw_ap' = Syntax.read_term lthy raw_ap;
    val raw_rel' = Option.map (Syntax.read_term lthy) raw_rel;
    val raw_set' = Option.map (Syntax.read_term lthy) raw_set;
    val (terms, rel) = mk_terms lthy (raw_pure', raw_ap', raw_rel', raw_set');
    val derived_combs0 = combinator_closure comb_rules false user_combs0;
    val required_combs = Ord_List.make fast_string_ord ["B", "I"];
    val user_combs = Ord_List.union fast_string_ord user_combs0
      (Ord_List.subtract fast_string_ord derived_combs0 required_combs);
    val derived_combs1 = combinator_closure comb_rules false user_combs;
    val derived_combs2 = combinator_closure comb_rules true derived_combs1;
    fun is_redundant comb = eq_list (op =) (derived_combs2,
      (combinator_closure comb_rules true (Ord_List.remove fast_string_ord comb user_combs)));
    val redundant_combs = filter is_redundant user_combs;
    val _ = if null redundant_combs then () else
      warning ("Redundant combinators: " ^ commas redundant_combs);
    val prove_interchange = not (Ord_List.member fast_string_ord derived_combs1 "T");
    val (af_inst, ctxt_af) = import_afun_inst_raw terms lthy;
    
    val (rel_insts, ctxt_inst) = (case rel of
        NONE => (NONE, ctxt_af)
      | SOME r =>
          let
            val (rel_inst, ctxt') = import_poly_terms r ctxt_af |>> the_single;
            val T = fastype_of (#2 rel_inst) |> range_type |> domain_type;
            val af_inst = match_poly_terms_type ctxt' (terms, 0) (T, ~1) |> mk_afun_inst;
          in (SOME (af_inst, rel_inst), ctxt') end);
    val mk_propss = [apfst single o mk_homomorphism_prop af_inst,
      fold_map (fn comb => mk_comb_prop [] (the (Symtab.lookup comb_defs comb)) af_inst) user_combs,
      if prove_interchange then apfst single o mk_interchange_prop af_inst else pair [],
      if is_some rel then mk_rel_props (the rel_insts) else pair []];
    val (propss, (ctxt_Ts, _)) = fold_map I mk_propss (ctxt_inst, []);
    fun repr_tac ctxt = Simplifier.rewrite_goals_tac ctxt comb_reprs;
    fun after_qed thmss lthy' =
      let
        val [[hom_thm], user_thms, ichng_thms, rel_thms] = map (Variable.export lthy' ctxt_inst) thmss;
        val (reds, ichng_thm) = mk_comb_reds ctxt_inst ((comb_defs, comb_rules), comb_unfolds)
          af_inst user_combs (hom_thm, user_thms, ichng_thms);
        val rel_axioms = case rel_thms of
            [] => NONE
          | [thm1, thm2] => SOME {pure_transfer = thm1, ap_rel_fun = thm2};
        val af_thms = mk_afun_thms ctxt_inst af_inst (hom_thm, ichng_thm, reds, rel_axioms);
        val af_thms = map_afun_thms (singleton (Variable.export ctxt_inst lthy)) af_thms;
        val af = mk_afun name terms rel af_thms;
      in lthy
        |> register_afun af
        |> note_afun_thms af
      end;
  in
    Proof.theorem NONE after_qed ((map o map) (rpair []) propss) ctxt_Ts
    |> Proof.refine (Method.Basic (SIMPLE_METHOD o repr_tac))
    |> Seq.the_result ""
  end;
fun print_afuns ctxt =
  let
    fun pretty_afun (name, af) =
      let
        val [pT, (_, pure), (_, ap), (_, set)] = unpack_poly_terms (terms_of_afun af);
        val ([tvar], T) = poly_type_of_term pT;
        val rel = Option.map (#2 o the_single o unpack_poly_terms) (rel_of_afun af);
        val combinators = Symtab.keys (#reds (thms_of_afun af));
      in Pretty.block (Pretty.fbreaks ([Pretty.block [Pretty.str name, Pretty.str ":", Pretty.brk 1,
          Pretty.quote (Syntax.pretty_typ ctxt T), Pretty.brk 1, Pretty.str "of", Pretty.brk 1,
          Syntax.pretty_typ ctxt tvar],
        Pretty.block [Pretty.str "pure:", Pretty.brk 1, Pretty.quote (Syntax.pretty_term ctxt pure)],
        Pretty.block [Pretty.str "ap:", Pretty.brk 1, Pretty.quote (Syntax.pretty_term ctxt ap)],
        Pretty.block [Pretty.str "set:", Pretty.brk 1, Pretty.quote (Syntax.pretty_term ctxt set)]] @
        (case rel of
          NONE => []
        | SOME rel' => [Pretty.block [Pretty.str "rel:", Pretty.brk 1,
            Pretty.quote (Syntax.pretty_term ctxt rel')]]) @
        [Pretty.block ([Pretty.str "combinators:", Pretty.brk 1] @
          Pretty.commas (map Pretty.str combinators))])) end;
    val afuns = sort_by #1 (Name_Space.fold_table cons (get_afun_table (Context.Proof ctxt)) []);
  in Pretty.writeln (Pretty.big_list "Registered applicative functors:" (map pretty_afun afuns)) end;
fun add_unfold_thm name thm context =
  let
    val (lhs, _) = Thm.prop_of thm |> HOLogic.dest_Trueprop |> HOLogic.dest_eq
      handle TERM _ => error "Not an equation";
    val names = case name of
        SOME n => [intern context n]
      | NONE => case match_typ context (Term.fastype_of lhs) of
          ns as (_::_) => ns
        | [] => error "Unable to determine applicative functor instance";
    val _ = map (afun_of_generic context) names;
    
    val thm' = mk_meta_eq thm;
  in fold (fn n => update_afun n (add_unfolds_afun [thm'])) names context end;
fun add_unfold_attrib name = Thm.declaration_attribute (add_unfold_thm name);
end;