Theory RefineG_Transfer

section ‹Transfer between Domains›
theory RefineG_Transfer
imports "../Refine_Misc"
begin
  text ‹Currently, this theory is specialized to 
    transfers that include no data refinement.
›


definition "REFINEG_TRANSFER_POST_SIMP x y  x=y"
definition [simp]: "REFINEG_TRANSFER_ALIGN x y == True"
lemma REFINEG_TRANSFER_ALIGNI: "REFINEG_TRANSFER_ALIGN x y" by simp

lemma START_REFINEG_TRANSFER: 
  assumes "REFINEG_TRANSFER_ALIGN d c"
  assumes "ca"
  assumes "REFINEG_TRANSFER_POST_SIMP c d"
  shows "da"
  using assms
  by (simp add: REFINEG_TRANSFER_POST_SIMP_def)

lemma STOP_REFINEG_TRANSFER: "REFINEG_TRANSFER_POST_SIMP c c" 
  unfolding REFINEG_TRANSFER_POST_SIMP_def ..

ML structure RefineG_Transfer = struct

  structure Post_Processors = Theory_Data
  (
    type T = (Proof.context -> tactic') Symtab.table
    val empty = Symtab.empty
    val merge = Symtab.join (K snd)
  )

  fun add_post_processor name tac =
    Post_Processors.map (Symtab.update_new (name,tac))
  fun delete_post_processor name =
    Post_Processors.map (Symtab.delete name)
  val get_post_processors = Post_Processors.get #> Symtab.dest

  fun post_process_tac ctxt = let
    val tacs = get_post_processors (Proof_Context.theory_of ctxt)
      |> map (fn (_,tac) => tac ctxt)

    val tac = REPEAT_DETERM' (CHANGED o EVERY' (map (fn t => TRY o t) tacs))
  in
    tac
  end

  structure Post_Simp = Generic_Data
  (
    type T = simpset
    val empty = HOL_basic_ss
    val merge = Raw_Simplifier.merge_ss
  )

  fun post_simps_op f a context = let
    val ctxt = Context.proof_of context
    fun do_it ss = simpset_of (f (put_simpset ss ctxt, a))
  in
    Post_Simp.map do_it context
  end
    
  val add_post_simps = post_simps_op (op addsimps)
  val del_post_simps = post_simps_op (op delsimps)

  fun get_post_ss ctxt = let
    val ss = Post_Simp.get (Context.Proof ctxt)
    val ctxt = put_simpset ss ctxt
  in
    ctxt
  end

  structure post_subst = Named_Thms
    ( val name = @{binding refine_transfer_post_subst}
      val description = "Refinement Framework: " ^ 
        "Transfer postprocessing substitutions" );

  fun post_subst_tac ctxt = let
    val s_thms = post_subst.get ctxt
    fun dis_tac goal_ctxt = ALLGOALS (Tagged_Solver.solve_tac goal_ctxt)
    val cnv = Cond_Rewr_Conv.cond_rewrs_conv dis_tac s_thms
    val ts_conv = Conv.top_sweep_conv cnv ctxt
    val ss = get_post_ss ctxt
  in
    REPEAT o CHANGED o 
    (Simplifier.simp_tac ss THEN' CONVERSION ts_conv)
  end


  structure transfer = Named_Thms
    ( val name = @{binding refine_transfer}
      val description = "Refinement Framework: " ^ 
        "Transfer rules" );

  fun transfer_tac thms ctxt i st = let 
    val thms = thms @ transfer.get ctxt;
    val ss = put_simpset HOL_basic_ss ctxt addsimps @{thms nested_case_prod_simp}
  in
    REPEAT_DETERM1 (
      COND (has_fewer_prems (Thm.nprems_of st)) no_tac (
        FIRST [
          Method.assm_tac ctxt i,
          resolve_tac ctxt thms i,
          Tagged_Solver.solve_tac ctxt i,
          CHANGED_PROP (simp_tac ss i)]
      )) st
  end

  (* Adjust right term to have same structure as left one *)
  fun align_tac ctxt = IF_EXGOAL (fn i => fn st =>
    case Logic.concl_of_goal (Thm.prop_of st) i of
      @{mpat "Trueprop (REFINEG_TRANSFER_ALIGN ?c _)"} => let
        val c = Thm.cterm_of ctxt c
        val cT = Thm.ctyp_of_cterm c
        
        val rl = @{thm REFINEG_TRANSFER_ALIGNI}
          |> Thm.incr_indexes (Thm.maxidx_of st + 1)
          |> Thm.instantiate' [NONE,SOME cT] [NONE,SOME c]
        (*val _ = tracing (@{make_string} rl)*)
      in
        resolve_tac ctxt [rl] i st
      end
    | _ => Seq.empty
  )

  fun post_transfer_tac thms ctxt = let open Autoref_Tacticals in
    resolve_tac ctxt @{thms START_REFINEG_TRANSFER} 
    THEN' align_tac ctxt 
    THEN' IF_SOLVED (transfer_tac thms ctxt)
      (post_process_tac ctxt THEN' resolve_tac ctxt @{thms STOP_REFINEG_TRANSFER})
      (K all_tac)

  end

  fun get_post_simp_rules context = Context.proof_of context
      |> get_post_ss
      |> simpset_of 
      |> Raw_Simplifier.dest_ss
      |> #simps |> map snd


  local
    val add_ps = Thm.declaration_attribute (add_post_simps o single)
    val del_ps = Thm.declaration_attribute (del_post_simps o single)
  in
    val setup = I
      #> add_post_processor "RefineG_Transfer.post_subst" post_subst_tac
      #> post_subst.setup
      #> transfer.setup
      #> Attrib.setup @{binding refine_transfer_post_simp} 
          (Attrib.add_del add_ps del_ps) 
          ("declaration of transfer post simplification rules")
      #> Global_Theory.add_thms_dynamic (
           @{binding refine_transfer_post_simps}, get_post_simp_rules)

  end
end

setup RefineG_Transfer.setup
method_setup refine_transfer = 
  Scan.lift (Args.mode "post") -- Attrib.thms 
  >> (fn (post,thms) => fn ctxt => SIMPLE_METHOD'
    ( if post then RefineG_Transfer.post_transfer_tac thms ctxt
      else RefineG_Transfer.transfer_tac thms ctxt)) "Invoke transfer rules"


locale transfer = fixes α :: "'c  'a::complete_lattice"
begin

text ‹
  In the following, we define some transfer lemmas for general
  HOL - constructs.
›

lemma transfer_if[refine_transfer]:
  assumes "b  α s1  S1"
  assumes "¬b  α s2  S2"
  shows "α (if b then s1 else s2)  (if b then S1 else S2)"
  using assms by auto

lemma transfer_prod[refine_transfer]:
  assumes "a b. α (f a b)  F a b"
  shows "α (case_prod f x)  (case_prod F x)"
  using assms by (auto split: prod.split)

lemma transfer_Let[refine_transfer]:
  assumes "x. α (f x)  F x"
  shows "α (Let x f)  Let x F"
  using assms by auto

lemma transfer_option[refine_transfer]:
  assumes "α fa  Fa"
  assumes "x. α (fb x)  Fb x"
  shows "α (case_option fa fb x)  case_option Fa Fb x"
  using assms by (auto split: option.split)

lemma transfer_sum[refine_transfer]:
  assumes "l. α (fl l)  Fl l"
  assumes "r. α (fr r)  Fr r"
  shows "α (case_sum fl fr x)  (case_sum Fl Fr x)"
  using assms by (auto split: sum.split)

lemma transfer_list[refine_transfer]:
  assumes "α fn  Fn"
  assumes "x xs. α (fc x xs)  Fc x xs"
  shows "α (case_list fn fc l)  case_list Fn Fc l"
  using assms by (auto split: list.split)


lemma transfer_rec_list[refine_transfer]:
  assumes FN: "s. α (fn s)  fn' s"
  assumes FC: "x l rec rec' s.  s. α (rec s)  (rec' s)  
     α (fc x l rec s)  fc' x l rec' s"
  shows "α (rec_list fn fc l s)  rec_list fn' fc' l s"
  apply (induct l arbitrary: s)
  apply (simp add: FN)
  apply (simp add: FC)
  done

lemma transfer_rec_nat[refine_transfer]:
  assumes FN: "s. α (fn s)  fn' s"
  assumes FC: "n rec rec' s.  s. α (rec s)  rec' s  
     α (fs n rec s)  fs' n rec' s"
  shows "α (rec_nat fn fs n s)  rec_nat fn' fs' n s"
  apply (induct n arbitrary: s)
  apply (simp add: FN)
  apply (simp add: FC)
  done

end

text ‹Transfer into complete lattice structure›
locale ordered_transfer = transfer + 
  constrains α :: "'c::complete_lattice  'a::complete_lattice"

text ‹Transfer into complete lattice structure with distributive
  transfer function.›
locale dist_transfer = ordered_transfer + 
  constrains α :: "'c::complete_lattice  'a::complete_lattice"
  assumes α_dist: "A. is_chain A  α (Sup A) = Sup (α`A)"
begin
  lemma α_mono[simp, intro!]: "mono α"
    apply rule
    apply (subgoal_tac "is_chain {x,y}")
    apply (drule α_dist)
    apply (auto simp: le_iff_sup) []
    apply (rule chainI)
    apply auto
    done

  lemma α_strict[simp]: "α bot = bot"
    using α_dist[of "{}"] by simp
end


text ‹Transfer into ccpo›
locale ccpo_transfer = transfer α for
  α :: "'c::ccpo  'a::complete_lattice" 

text ‹Transfer into ccpo with distributive
  transfer function.›
locale dist_ccpo_transfer = ccpo_transfer α
  for α :: "'c::ccpo  'a::complete_lattice" + 
  assumes α_dist: "A. is_chain A  α (Sup A) = Sup (α`A)"
begin

  lemma α_mono[simp, intro!]: "mono α"
  proof
    fix x y :: 'c
    assume LE: "xy"
    hence C[simp, intro!]: "is_chain {x,y}" by (auto intro: chainI)
    from LE have "α x  sup (α x) (α y)" by simp
    also have " = Sup (α`{x,y})" by simp
    also have " = α (Sup {x,y})"
      by (rule α_dist[symmetric]) simp
    also have "Sup {x,y} = y"
      apply (rule antisym)
      apply (rule ccpo_Sup_least[OF C]) using LE apply auto []
      apply (rule ccpo_Sup_upper[OF C]) by auto
    finally show "α x  α y" .
  qed

  lemma α_strict[simp]: "α (Sup {}) = bot"
    using α_dist[of "{}"] by simp
end

end