# Theory Refine_Imperative_HOL.Sepref_Constraints

```theory Sepref_Constraints
imports Main Automatic_Refinement.Refine_Lib Sepref_Basic
begin

definition "CONSTRAINT_SLOT (x::prop) ≡ x"

(* TODO: Find something better than True to put in empty slot! Perhaps "A⟹A" *)
lemma insert_slot_rl1:
assumes "PROP P ⟹ PROP (CONSTRAINT_SLOT (Trueprop True)) ⟹ PROP Q"
shows "PROP (CONSTRAINT_SLOT (PROP P)) ⟹ PROP Q"
using assms unfolding CONSTRAINT_SLOT_def by simp

lemma insert_slot_rl2:
assumes "PROP P ⟹ PROP (CONSTRAINT_SLOT S) ⟹ PROP Q"
shows "PROP (CONSTRAINT_SLOT (PROP S &&& PROP P)) ⟹ PROP Q"
using assms unfolding CONSTRAINT_SLOT_def conjunction_def .

lemma remove_slot: "PROP (CONSTRAINT_SLOT (Trueprop True))"
unfolding CONSTRAINT_SLOT_def by (rule TrueI)

definition CONSTRAINT where [simp]: "CONSTRAINT P x ≡ P x"

lemma CONSTRAINT_D:
assumes "CONSTRAINT (P::'a => bool) x"
shows "P x"
using assms unfolding CONSTRAINT_def by simp

lemma CONSTRAINT_I:
assumes "P x"
shows "CONSTRAINT (P::'a => bool) x"
using assms unfolding CONSTRAINT_def by simp

text ‹Special predicate to indicate unsolvable constraint.
The constraint solver refuses to put those into slot.
Thus, adding safe rules introducing this can be used to indicate
unsolvable constraints early.
›
definition CN_FALSE :: "('a⇒bool) ⇒ 'a ⇒ bool" where [simp]: "CN_FALSE P x ≡ False"
lemma CN_FALSEI: "CN_FALSE P x ⟹ P x" by simp

named_theorems constraint_simps ‹Simplification of constraints›

named_theorems constraint_abbrevs ‹Constraint Solver: Abbreviations›
lemmas split_constraint_rls
= atomize_conj[symmetric] imp_conjunction all_conjunction conjunction_imp

ML ‹
signature SEPREF_CONSTRAINTS = sig
(******** Constraint Slot *)
(* Tactic with slot subgoal *)
val WITH_SLOT: tactic' -> tactic
(* Process all goals in slot *)
val ON_SLOT: tactic -> tactic
(* Create slot as last subgoal. Fail if slot already present. *)
val create_slot_tac: tactic
(* Create slot if there isn't one already *)
val ensure_slot_tac: tactic
(* Remove empty slot *)
val remove_slot_tac: tactic
(* Move slot to first subgoal *)
val prefer_slot_tac: tactic
(* Destruct slot *)
val dest_slot_tac: tactic'
(* Check if goal state has slot *)
val has_slot: thm -> bool
(* Defer subgoal to slot *)
val to_slot_tac: tactic'
(* Print slot constraints *)
val print_slot_tac: Proof.context -> tactic

(* Focus on goals in slot *)
val focus: tactic
(* Unfocus goals in slot *)
val unfocus: tactic
(* Unfocus goals, and insert them as first subgoals *)
val unfocus_ins:tactic

(* Focus on some goals in slot *)
val cond_focus: (term -> bool) -> tactic
(* Move some goals to slot *)
val some_to_slot_tac: (term -> bool) -> tactic

(******** Constraints *)
(* Check if subgoal is a constraint. To be used with COND' *)
val is_constraint_goal: term -> bool
(* Identity on constraint subgoal, no_tac otherwise *)
val is_constraint_tac: tactic'
(* Defer constraint to slot *)
val slot_constraint_tac: int -> tactic

(******** Constraint solving *)

val add_constraint_rule: thm -> Context.generic -> Context.generic
val del_constraint_rule: thm -> Context.generic -> Context.generic
val get_constraint_rules: Proof.context -> thm list

val add_safe_constraint_rule: thm -> Context.generic -> Context.generic
val del_safe_constraint_rule: thm -> Context.generic -> Context.generic
val get_safe_constraint_rules: Proof.context -> thm list

(* Solve constraint subgoal *)
val solve_constraint_tac: Proof.context -> tactic'
(* Solve constraint subgoal if solvable, fail if definitely unsolvable,
apply simplification and unique rules otherwise. *)
val safe_constraint_tac: Proof.context -> tactic'

(* CONSTRAINT tag on goal is optional *)
val solve_constraint'_tac: Proof.context -> tactic'
(* CONSTRAINT tag on goal is optional *)
val safe_constraint'_tac: Proof.context -> tactic'

(* Solve, or apply safe-rules and defer to constraint slot *)
val constraint_tac: Proof.context -> tactic'

(* Apply safe rules to all constraint goals in slot *)
val process_constraint_slot: Proof.context -> tactic

(* Solve all constraint goals in slot, insert unsolved ones as first subgoals *)
val solve_constraint_slot: Proof.context -> tactic

val setup: theory -> theory

end

structure Sepref_Constraints: SEPREF_CONSTRAINTS  = struct
fun is_slot_goal @{mpat "CONSTRAINT_SLOT _"} = true | is_slot_goal _ = false

fun slot_goal_num st = let
val i = find_index is_slot_goal (Thm.prems_of st) + 1
in
i
end

fun has_slot st = slot_goal_num st > 0

fun WITH_SLOT tac st = let
val si = slot_goal_num st
in
if si>0 then tac si st else (warning "Constraints: No slot"; Seq.empty)
end

val to_slot_tac = IF_EXGOAL (fn i => WITH_SLOT (fn si =>
if i<si then
prefer_tac si THEN prefer_tac (i+1)
THEN (
PRIMITIVE (fn st => Drule.comp_no_flatten (st, 0) 1 @{thm insert_slot_rl1})
ORELSE PRIMITIVE (fn st => Drule.comp_no_flatten (st, 0) 1 @{thm insert_slot_rl2})
)
THEN defer_tac 1
else no_tac))

val create_slot_tac =
COND (has_slot) no_tac
(PRIMITIVE (Thm.implies_intr @{cterm "CONSTRAINT_SLOT (Trueprop True)"})
THEN defer_tac 1)

val ensure_slot_tac = TRY create_slot_tac

val prefer_slot_tac = WITH_SLOT prefer_tac

val dest_slot_tac = SELECT_GOAL (
ALLGOALS (
CONVERSION (Conv.rewr_conv @{thm CONSTRAINT_SLOT_def})
THEN' Goal.conjunction_tac
THEN' TRY o resolve0_tac @{thms TrueI})
THEN distinct_subgoals_tac
)

val remove_slot_tac = WITH_SLOT (resolve0_tac @{thms remove_slot})

val focus = WITH_SLOT (fn i =>
PRIMITIVE (Goal.restrict i 1)
THEN ALLGOALS dest_slot_tac
THEN create_slot_tac)

val unfocus_ins =
PRIMITIVE (Goal.unrestrict 1)
THEN WITH_SLOT defer_tac

fun some_to_slot_tac cond = (ALLGOALS (COND' (fn t => is_slot_goal t orelse not (cond t)) ORELSE' to_slot_tac))

val unfocus =
some_to_slot_tac (K true)
THEN unfocus_ins

fun cond_focus cond =
focus
THEN some_to_slot_tac (not o cond)

fun ON_SLOT tac = focus THEN tac THEN unfocus

fun print_slot_tac ctxt = ON_SLOT (print_tac ctxt "SLOT:")

local
(*fun prepare_constraint_conv ctxt = let
open Conv
fun CONSTRAINT_conv ct = case Thm.term_of ct of
@{mpat "Trueprop (_ _)"} =>
HOLogic.Trueprop_conv
(rewr_conv @{thm CONSTRAINT_def[symmetric]}) ct
| _ => raise CTERM ("CONSTRAINT_conv", [ct])

fun rec_conv ctxt ct = (
CONSTRAINT_conv
else_conv
implies_conv (rec_conv ctxt) (rec_conv ctxt)
else_conv
forall_conv (rec_conv o #2) ctxt
) ct
in
rec_conv ctxt
end*)

fun unfold_abbrevs ctxt =
Local_Defs.unfold0 ctxt (
@{thms split_constraint_rls CONSTRAINT_def}
@ Named_Theorems.get ctxt @{named_theorems constraint_abbrevs}
@ Named_Theorems.get ctxt @{named_theorems constraint_simps})
#> Conjunction.elim_conjunctions

fun check_constraint_rl thm = let
fun ck (t as @{mpat "Trueprop (?C _)"}) =
if is_Var (Term.head_of C) then
raise TERM ("Schematic head in constraint rule",[t,Thm.prop_of thm])
else ()
| ck @{mpat "⋀_. PROP ?t"} = ck t
| ck @{mpat "PROP ?s ⟹ PROP ?t"} = (ck s; ck t)
| ck t = raise TERM ("Invalid part of constraint rule",[t,Thm.prop_of thm])

in
ck (Thm.prop_of thm); thm
end

fun check_unsafe_constraint_rl thm = let
val _ = Thm.nprems_of thm = 0
andalso raise TERM("Unconditional constraint rule must be safe (register this as safe rule)",[Thm.prop_of thm])
in
thm
end

in
structure constraint_rules = Named_Sorted_Thms (
val name = @{binding constraint_rules}
val description = "Constraint rules"
val sort = K I
fun transform context = let
open Conv
val ctxt = Context.proof_of context
in
unfold_abbrevs ctxt #> map (check_constraint_rl o check_unsafe_constraint_rl)
end
)

structure safe_constraint_rules = Named_Sorted_Thms (
val name = @{binding safe_constraint_rules}
val description = "Safe Constraint rules"
val sort = K I
fun transform context = let
open Conv
val ctxt = Context.proof_of context
in
unfold_abbrevs ctxt #> map check_constraint_rl
end
)

end

val del_constraint_rule = constraint_rules.del_thm
val get_constraint_rules = constraint_rules.get

val del_safe_constraint_rule = safe_constraint_rules.del_thm
val get_safe_constraint_rules = safe_constraint_rules.get

fun is_constraint_goal t = case Logic.strip_assums_concl t of
@{mpat "Trueprop (CONSTRAINT _ _)"} => true
| _ => false

val is_constraint_tac = COND' is_constraint_goal

fun is_slottable_constraint_goal t = case Logic.strip_assums_concl t of
@{mpat "Trueprop (CONSTRAINT (CN_FALSE _) _)"} => false
| @{mpat "Trueprop (CONSTRAINT _ _)"} => true
| _ => false

val slot_constraint_tac = COND' is_slottable_constraint_goal THEN' to_slot_tac

datatype 'a seq_cases = SC_NONE | SC_SINGLE of 'a Seq.seq | SC_MULTIPLE of 'a Seq.seq

fun seq_cases seq =
case Seq.pull seq of
NONE => SC_NONE
| SOME (st1,seq) => case Seq.pull seq of
NONE => SC_SINGLE (Seq.single st1)
| SOME (st2,seq) => SC_MULTIPLE (Seq.cons st1 (Seq.cons st2 seq))

fun SEQ_CASES tac (single_tac, multiple_tac) st = let
val res = tac st
in
case seq_cases res of
SC_NONE => Seq.empty
| SC_SINGLE res => Seq.maps single_tac res
| SC_MULTIPLE res => Seq.maps multiple_tac res
end

fun SAFE tac = SEQ_CASES tac (all_tac, no_tac)
fun SAFE' tac = SAFE o tac

local
fun simp_constraints_tac ctxt = let
val ctxt = put_simpset HOL_basic_ss ctxt
addsimps (Named_Theorems.get ctxt @{named_theorems constraint_simps})
in
simp_tac ctxt
end

fun unfold_abbrevs_tac ctxt =  let
val ctxt = put_simpset HOL_basic_ss ctxt
addsimps (Named_Theorems.get ctxt @{named_theorems constraint_abbrevs})
val ethms = @{thms conjE}
val ithms = @{thms conjI}
in
full_simp_tac ctxt
THEN_ALL_NEW TRY o REPEAT_ALL_NEW (ematch_tac ctxt ethms)
THEN_ALL_NEW TRY o REPEAT_ALL_NEW (match_tac ctxt ithms)
end

fun WITH_RULE_NETS tac ctxt = let
val scn_net = safe_constraint_rules.get ctxt |> Tactic.build_net
val cn_net = constraint_rules.get ctxt |> Tactic.build_net
in
tac (scn_net,cn_net) ctxt
end

fun wrap_tac step_tac ctxt = REPEAT_ALL_NEW (
simp_constraints_tac ctxt
THEN_ALL_NEW unfold_abbrevs_tac ctxt
THEN_ALL_NEW step_tac ctxt
)

fun solve_step_tac (scn_net,cn_net) ctxt = REPEAT_ALL_NEW (
DETERM o resolve_from_net_tac ctxt scn_net
ORELSE' resolve_from_net_tac ctxt cn_net
)

fun safe_step_tac (scn_net,cn_net) ctxt = REPEAT_ALL_NEW (
DETERM o resolve_from_net_tac ctxt scn_net
ORELSE' SAFE' (resolve_from_net_tac ctxt cn_net)
)

fun solve_tac cn_nets ctxt = SOLVED' (wrap_tac (solve_step_tac cn_nets) ctxt)
fun safe_tac cn_nets ctxt =
simp_constraints_tac ctxt
THEN_ALL_NEW unfold_abbrevs_tac ctxt
THEN_ALL_NEW (solve_tac cn_nets ctxt ORELSE' TRY o wrap_tac (safe_step_tac cn_nets) ctxt)

in
val solve_constraint_tac = TRADE (fn ctxt =>
is_constraint_tac
THEN' resolve_tac ctxt @{thms CONSTRAINT_I}
THEN' WITH_RULE_NETS solve_tac ctxt)

val safe_constraint_tac = TRADE (fn ctxt =>
is_constraint_tac
THEN' resolve_tac ctxt @{thms CONSTRAINT_I}
THEN' WITH_RULE_NETS safe_tac ctxt
THEN_ALL_NEW fo_resolve_tac @{thms CONSTRAINT_D} ctxt) (* TODO/FIXME: fo_resolve_tac has non-canonical parameter order *)

val solve_constraint'_tac = TRADE (fn ctxt =>
TRY o resolve_tac ctxt @{thms CONSTRAINT_I}
THEN' WITH_RULE_NETS solve_tac ctxt)

val safe_constraint'_tac = TRADE (fn ctxt =>
TRY o resolve_tac ctxt @{thms CONSTRAINT_I}
THEN' WITH_RULE_NETS safe_tac ctxt)

end

fun constraint_tac ctxt =
safe_constraint_tac ctxt THEN_ALL_NEW slot_constraint_tac

fun process_constraint_slot ctxt = ON_SLOT (ALLGOALS (TRY o safe_constraint_tac ctxt))

fun solve_constraint_slot ctxt =
cond_focus is_constraint_goal
THEN ALLGOALS (
COND' is_slot_goal
ORELSE' (
solve_constraint_tac ctxt
ORELSE' TRY o safe_constraint_tac ctxt
)
)
THEN unfocus_ins

val setup = I
#> constraint_rules.setup
#> safe_constraint_rules.setup

end
›

setup Sepref_Constraints.setup

method_setup print_slot = ‹Scan.succeed (fn ctxt => SIMPLE_METHOD (Sepref_Constraints.print_slot_tac ctxt))›

method_setup solve_constraint = ‹Scan.succeed (fn ctxt => SIMPLE_METHOD' (Sepref_Constraints.solve_constraint'_tac ctxt))›
method_setup safe_constraint = ‹Scan.succeed (fn ctxt => SIMPLE_METHOD' (Sepref_Constraints.safe_constraint'_tac ctxt))›

end
```