Theory Refine_Imperative_HOL.Sepref_Constraints
theory Sepref_Constraints
imports Main Automatic_Refinement.Refine_Lib Sepref_Basic
begin
definition "CONSTRAINT_SLOT (x::prop) ≡ x"
lemma insert_slot_rl1:
  assumes "PROP P ⟹ PROP (CONSTRAINT_SLOT (Trueprop True)) ⟹ PROP Q"
  shows "PROP (CONSTRAINT_SLOT (PROP P)) ⟹ PROP Q"
  using assms unfolding CONSTRAINT_SLOT_def by simp
lemma insert_slot_rl2:
  assumes "PROP P ⟹ PROP (CONSTRAINT_SLOT S) ⟹ PROP Q"
  shows "PROP (CONSTRAINT_SLOT (PROP S &&& PROP P)) ⟹ PROP Q"
  using assms unfolding CONSTRAINT_SLOT_def conjunction_def .
lemma remove_slot: "PROP (CONSTRAINT_SLOT (Trueprop True))"
  unfolding CONSTRAINT_SLOT_def by (rule TrueI)
definition CONSTRAINT where [simp]: "CONSTRAINT P x ≡ P x"
lemma CONSTRAINT_D:
  assumes "CONSTRAINT (P::'a => bool) x"
  shows "P x"
  using assms unfolding CONSTRAINT_def by simp
lemma CONSTRAINT_I:
  assumes "P x"
  shows "CONSTRAINT (P::'a => bool) x"
  using assms unfolding CONSTRAINT_def by simp
text ‹Special predicate to indicate unsolvable constraint.
  The constraint solver refuses to put those into slot.
  Thus, adding safe rules introducing this can be used to indicate 
  unsolvable constraints early.
›
definition CN_FALSE :: "('a⇒bool) ⇒ 'a ⇒ bool" where [simp]: "CN_FALSE P x ≡ False"  
lemma CN_FALSEI: "CN_FALSE P x ⟹ P x" by simp
named_theorems constraint_simps ‹Simplification of constraints›
named_theorems constraint_abbrevs ‹Constraint Solver: Abbreviations›
lemmas split_constraint_rls 
    = atomize_conj[symmetric] imp_conjunction all_conjunction conjunction_imp
ML ‹
  signature SEPREF_CONSTRAINTS = sig
    
    
    val WITH_SLOT: tactic' -> tactic
    
    val ON_SLOT: tactic -> tactic
    
    val create_slot_tac: tactic
    
    val ensure_slot_tac: tactic
    
    val remove_slot_tac: tactic
    
    val prefer_slot_tac: tactic
    
    val dest_slot_tac: tactic'
    
    val has_slot: thm -> bool
    
    val to_slot_tac: tactic'
    
    val print_slot_tac: Proof.context -> tactic
    
    val focus: tactic
    
    val unfocus: tactic
    
    val unfocus_ins:tactic
    
    val cond_focus: (term -> bool) -> tactic
    
    val some_to_slot_tac: (term -> bool) -> tactic
    
    
    val is_constraint_goal: term -> bool
    
    val is_constraint_tac: tactic'
    
    val slot_constraint_tac: int -> tactic
    
    val add_constraint_rule: thm -> Context.generic -> Context.generic
    val del_constraint_rule: thm -> Context.generic -> Context.generic
    val get_constraint_rules: Proof.context -> thm list
    val add_safe_constraint_rule: thm -> Context.generic -> Context.generic
    val del_safe_constraint_rule: thm -> Context.generic -> Context.generic
    val get_safe_constraint_rules: Proof.context -> thm list
    
    val solve_constraint_tac: Proof.context -> tactic'
    
    val safe_constraint_tac: Proof.context -> tactic'
    
    val solve_constraint'_tac: Proof.context -> tactic'
    
    val safe_constraint'_tac: Proof.context -> tactic'
    
    
    val constraint_tac: Proof.context -> tactic'
    
    val process_constraint_slot: Proof.context -> tactic
    
    val solve_constraint_slot: Proof.context -> tactic
    val setup: theory -> theory
  end
  structure Sepref_Constraints: SEPREF_CONSTRAINTS  = struct
    fun is_slot_goal @{mpat "CONSTRAINT_SLOT _"} = true | is_slot_goal _ = false
    fun slot_goal_num st = let
      val i = find_index is_slot_goal (Thm.prems_of st) + 1
    in
      i
    end
    fun has_slot st = slot_goal_num st > 0
    fun WITH_SLOT tac st = let
      val si = slot_goal_num st
    in
      if si>0 then tac si st else (warning "Constraints: No slot"; Seq.empty)
    end
    val to_slot_tac = IF_EXGOAL (fn i => WITH_SLOT (fn si => 
      if i<si then
        prefer_tac si THEN prefer_tac (i+1)
        THEN (
          PRIMITIVE (fn st => Drule.comp_no_flatten (st, 0) 1 @{thm insert_slot_rl1}) 
          ORELSE PRIMITIVE (fn st => Drule.comp_no_flatten (st, 0) 1 @{thm insert_slot_rl2})
        )
        THEN defer_tac 1
      else no_tac))
    val create_slot_tac = 
      COND (has_slot) no_tac
        (PRIMITIVE (Thm.implies_intr @{cterm "CONSTRAINT_SLOT (Trueprop True)"}) 
        THEN defer_tac 1)
        
    val ensure_slot_tac = TRY create_slot_tac
          
      
    val prefer_slot_tac = WITH_SLOT prefer_tac
    val dest_slot_tac = SELECT_GOAL (
      ALLGOALS (
        CONVERSION (Conv.rewr_conv @{thm CONSTRAINT_SLOT_def}) 
        THEN' Goal.conjunction_tac
        THEN' TRY o resolve0_tac @{thms TrueI})
      THEN distinct_subgoals_tac
    )
    val remove_slot_tac = WITH_SLOT (resolve0_tac @{thms remove_slot})
    val focus = WITH_SLOT (fn i => 
      PRIMITIVE (Goal.restrict i 1) 
      THEN ALLGOALS dest_slot_tac
      THEN create_slot_tac)
    val unfocus_ins = 
      PRIMITIVE (Goal.unrestrict 1)
      THEN WITH_SLOT defer_tac
    fun some_to_slot_tac cond = (ALLGOALS (COND' (fn t => is_slot_goal t orelse not (cond t)) ORELSE' to_slot_tac))
    val unfocus = 
      some_to_slot_tac (K true)
      THEN unfocus_ins
    fun cond_focus cond =
      focus 
      THEN some_to_slot_tac (not o cond)
    fun ON_SLOT tac = focus THEN tac THEN unfocus
    fun print_slot_tac ctxt = ON_SLOT (print_tac ctxt "SLOT:")
    local
      
      fun unfold_abbrevs ctxt = 
        Local_Defs.unfold0 ctxt (
          @{thms split_constraint_rls CONSTRAINT_def} 
          @ Named_Theorems.get ctxt @{named_theorems constraint_abbrevs}
          @ Named_Theorems.get ctxt @{named_theorems constraint_simps})
        #> Conjunction.elim_conjunctions
  
      fun check_constraint_rl thm = let
        fun ck (t as @{mpat "Trueprop (?C _)"}) = 
              if is_Var (Term.head_of C) then
                raise TERM ("Schematic head in constraint rule",[t,Thm.prop_of thm])
              else ()
          | ck @{mpat "⋀_. PROP ?t"} = ck t
          | ck @{mpat "PROP ?s ⟹ PROP ?t"} = (ck s; ck t)
          | ck t = raise TERM ("Invalid part of constraint rule",[t,Thm.prop_of thm])
  
      in
        ck (Thm.prop_of thm); thm
      end
      fun check_unsafe_constraint_rl thm = let
        val _ = Thm.nprems_of thm = 0 
          andalso raise TERM("Unconditional constraint rule must be safe (register this as safe rule)",[Thm.prop_of thm])
      in
        thm
      end
    in
      structure constraint_rules = Named_Sorted_Thms (
        val name = @{binding constraint_rules}
        val description = "Constraint rules"
        val sort = K I
        fun transform context = let
          open Conv
          val ctxt = Context.proof_of context
        in
          unfold_abbrevs ctxt #> map (check_constraint_rl o check_unsafe_constraint_rl)
        end
      )
      structure safe_constraint_rules = Named_Sorted_Thms (
        val name = @{binding safe_constraint_rules}
        val description = "Safe Constraint rules"
        val sort = K I
        fun transform context = let
          open Conv
          val ctxt = Context.proof_of context
        in
          unfold_abbrevs ctxt #> map check_constraint_rl
        end
      )
    end  
    val add_constraint_rule = constraint_rules.add_thm
    val del_constraint_rule = constraint_rules.del_thm
    val get_constraint_rules = constraint_rules.get
    val add_safe_constraint_rule = safe_constraint_rules.add_thm
    val del_safe_constraint_rule = safe_constraint_rules.del_thm
    val get_safe_constraint_rules = safe_constraint_rules.get
    fun is_constraint_goal t = case Logic.strip_assums_concl t of
      @{mpat "Trueprop (CONSTRAINT _ _)"} => true
    | _ => false
    val is_constraint_tac = COND' is_constraint_goal
    fun is_slottable_constraint_goal t = case Logic.strip_assums_concl t of
      @{mpat "Trueprop (CONSTRAINT (CN_FALSE _) _)"} => false
    | @{mpat "Trueprop (CONSTRAINT _ _)"} => true
    | _ => false
    val slot_constraint_tac = COND' is_slottable_constraint_goal THEN' to_slot_tac
    datatype 'a seq_cases = SC_NONE | SC_SINGLE of 'a Seq.seq | SC_MULTIPLE of 'a Seq.seq
    fun seq_cases seq = 
      case Seq.pull seq of
        NONE => SC_NONE
      | SOME (st1,seq) => case Seq.pull seq of
          NONE => SC_SINGLE (Seq.single st1)
        | SOME (st2,seq) => SC_MULTIPLE (Seq.cons st1 (Seq.cons st2 seq))  
    fun SEQ_CASES tac (single_tac, multiple_tac) st = let
      val res = tac st
    in
      case seq_cases res of
        SC_NONE => Seq.empty
      | SC_SINGLE res => Seq.maps single_tac res
      | SC_MULTIPLE res => Seq.maps multiple_tac res
    end
    fun SAFE tac = SEQ_CASES tac (all_tac, no_tac)
    fun SAFE' tac = SAFE o tac
    local
      fun simp_constraints_tac ctxt = let
        val ctxt = put_simpset HOL_basic_ss ctxt 
          addsimps (Named_Theorems.get ctxt @{named_theorems constraint_simps})
      in
        simp_tac ctxt
      end
      fun unfold_abbrevs_tac ctxt =  let
        val ctxt = put_simpset HOL_basic_ss ctxt 
          addsimps (Named_Theorems.get ctxt @{named_theorems constraint_abbrevs})
        val ethms = @{thms conjE}  
        val ithms = @{thms conjI}  
      in
        full_simp_tac ctxt 
        THEN_ALL_NEW TRY o REPEAT_ALL_NEW (ematch_tac ctxt ethms)
        THEN_ALL_NEW TRY o REPEAT_ALL_NEW (match_tac ctxt ithms)
      end
  
      fun WITH_RULE_NETS tac ctxt = let
        val scn_net = safe_constraint_rules.get ctxt |> Bires.build_net
        val cn_net = constraint_rules.get ctxt |> Bires.build_net
      in
        tac (scn_net,cn_net) ctxt
      end
      fun wrap_tac step_tac ctxt = REPEAT_ALL_NEW (
        simp_constraints_tac ctxt 
        THEN_ALL_NEW unfold_abbrevs_tac ctxt
        THEN_ALL_NEW step_tac ctxt
      )
      fun solve_step_tac (scn_net,cn_net) ctxt = REPEAT_ALL_NEW (
        DETERM o Bires.resolve_from_net_tac ctxt scn_net
        ORELSE' Bires.resolve_from_net_tac ctxt cn_net
      )
      fun safe_step_tac (scn_net,cn_net) ctxt = REPEAT_ALL_NEW (
        DETERM o Bires.resolve_from_net_tac ctxt scn_net
        ORELSE' SAFE' (Bires.resolve_from_net_tac ctxt cn_net)
      )
      fun solve_tac cn_nets ctxt = SOLVED' (wrap_tac (solve_step_tac cn_nets) ctxt)
      fun safe_tac cn_nets ctxt =  
        simp_constraints_tac ctxt
        THEN_ALL_NEW unfold_abbrevs_tac ctxt
        THEN_ALL_NEW (solve_tac cn_nets ctxt ORELSE' TRY o wrap_tac (safe_step_tac cn_nets) ctxt)
    in
      val solve_constraint_tac = TRADE (fn ctxt =>
        is_constraint_tac
        THEN' resolve_tac ctxt @{thms CONSTRAINT_I}
        THEN' WITH_RULE_NETS solve_tac ctxt)
      val safe_constraint_tac = TRADE (fn ctxt =>
        is_constraint_tac
        THEN' resolve_tac ctxt @{thms CONSTRAINT_I}
        THEN' WITH_RULE_NETS safe_tac ctxt
        THEN_ALL_NEW fo_resolve_tac @{thms CONSTRAINT_D} ctxt) 
      val solve_constraint'_tac = TRADE (fn ctxt =>
        TRY o resolve_tac ctxt @{thms CONSTRAINT_I}
        THEN' WITH_RULE_NETS solve_tac ctxt)
      val safe_constraint'_tac = TRADE (fn ctxt =>
        TRY o resolve_tac ctxt @{thms CONSTRAINT_I}
        THEN' WITH_RULE_NETS safe_tac ctxt)
    end  
    fun constraint_tac ctxt = 
      safe_constraint_tac ctxt THEN_ALL_NEW slot_constraint_tac
    fun process_constraint_slot ctxt = ON_SLOT (ALLGOALS (TRY o safe_constraint_tac ctxt))
    fun solve_constraint_slot ctxt = 
      cond_focus is_constraint_goal 
        THEN ALLGOALS (
          COND' is_slot_goal
          ORELSE' (
            solve_constraint_tac ctxt
            ORELSE' TRY o safe_constraint_tac ctxt
          )
        )
      THEN unfocus_ins
    val setup = I
      #> constraint_rules.setup
      #> safe_constraint_rules.setup
  end
›
setup Sepref_Constraints.setup
method_setup print_slot = ‹Scan.succeed (fn ctxt => SIMPLE_METHOD (Sepref_Constraints.print_slot_tac ctxt))›
method_setup solve_constraint = ‹Scan.succeed (fn ctxt => SIMPLE_METHOD' (Sepref_Constraints.solve_constraint'_tac ctxt))›
method_setup safe_constraint = ‹Scan.succeed (fn ctxt => SIMPLE_METHOD' (Sepref_Constraints.safe_constraint'_tac ctxt))›
end