Theory Sepref_Definition

section ‹Sepref-Definition Command›
theory Sepref_Definition
imports Sepref_Rules "Lib/Pf_Mono_Prover" "Lib/Term_Synth"
keywords "sepref_definition" :: thy_goal
      and "sepref_thm" :: thy_goal
begin
subsection ‹Setup of Extraction-Tools›
  declare [[cd_patterns "hn_refine _ ?f _ _ _"]]

  lemma heap_fixp_codegen:
    assumes DEF: "f  heap.fixp_fun cB"
    assumes M: "(x. mono_Heap (λf. cB f x))"
    shows "f x = cB f x"
    unfolding DEF
    apply (rule fun_cong[of _ _ x])
    apply (rule heap.mono_body_fixp)
    apply fact
    done


  ML structure Sepref_Extraction = struct
      val heap_extraction: Refine_Automation.extraction = {
          pattern = Logic.varify_global @{term "heap.fixp_fun x"},
          gen_thm = @{thm heap_fixp_codegen},
          gen_tac = (fn ctxt => 
            Pf_Mono_Prover.mono_tac ctxt
          )
        }

      val setup = I 
        (*#> Refine_Automation.add_extraction "trivial" triv_extraction*)
        #> Refine_Automation.add_extraction "heap" heap_extraction

    end

  setup Sepref_Extraction.setup 


  subsection ‹Synthesis setup for sepref-definition goals›
  (* TODO: The UNSPEC are an ad-hoc hack to specify the synthesis goal *)
  consts UNSPEC::'a  

  abbreviation hfunspec 
    :: "('a  'b  assn)  ('a  'b  assn)×('a  'b  assn)" 
    ("(_?)" [1000] 999)
    where "R?  hf_pres R UNSPEC"

  definition SYNTH :: "('a  'r nres)  (('ai 'ri Heap) × ('a  'r nres)) set  bool"
    where "SYNTH f R  True"

  definition [simp]: "CP_UNCURRY _ _  True"
  definition [simp]: "INTRO_KD _ _  True"
  definition [simp]: "SPEC_RES_ASSN _ _  True"

  lemma [synth_rules]: "CP_UNCURRY f g" by simp
  lemma [synth_rules]: "CP_UNCURRY (uncurry0 f) (uncurry0 g)" by simp
  lemma [synth_rules]: "CP_UNCURRY f g  CP_UNCURRY (uncurry f) (uncurry g)" by simp

  lemma [synth_rules]: "INTRO_KD R1 R1'; INTRO_KD R2 R2'  INTRO_KD (R1*aR2) (R1'*aR2')" by simp
  lemma [synth_rules]: "INTRO_KD (R?) (hf_pres R k)" by simp
  lemma [synth_rules]: "INTRO_KD (Rk) (Rk)" by simp
  lemma [synth_rules]: "INTRO_KD (Rd) (Rd)" by simp

  lemma [synth_rules]: "SPEC_RES_ASSN R R" by simp
  lemma [synth_rules]: "SPEC_RES_ASSN UNSPEC R" by simp
  
  lemma synth_hnrI:
    "CP_UNCURRY fi f; INTRO_KD R R'; SPEC_RES_ASSN S S'  SYNTH_TERM (SYNTH f ([P]a RS)) ((fi,SDUMMY)SDUMMY,(fi,f)([P]a R'S'))" 
    by (simp add: SYNTH_def)

term starts_with

ML structure Sepref_Definition = struct
    fun make_hnr_goal t ctxt = let
      val ctxt = Variable.declare_term t ctxt
      val (pat,goal) = case Term_Synth.synth_term @{thms synth_hnrI} ctxt t of
        @{mpat "(?pat,?goal)"} => (pat,goal) | t => raise TERM("Synthesized term does not match",[t])
      val pat = Thm.cterm_of ctxt pat |> Refine_Automation.prepare_cd_pattern ctxt
      val goal = HOLogic.mk_Trueprop goal
    in
      ((pat,goal),ctxt)
    end

    val cfg_prep_code = Attrib.setup_config_bool @{binding sepref_definition_prep_code} (K true)

    local 
      open Refine_Util
      val flags = parse_bool_config' "prep_code" cfg_prep_code
      val parse_flags = parse_paren_list' flags  

    in       
      val sd_parser = parse_flags -- Parse.binding -- Parse.opt_attribs --| @{keyword "is"} 
        -- Parse.term --| @{keyword "::"} -- Parse.term
    end  

    fun mk_synth_term ctxt t_raw r_raw = let
        val t = Syntax.parse_term ctxt t_raw
        val r = Syntax.parse_term ctxt r_raw
        val t = Const (@{const_name SYNTH},dummyT)$t$r
      in
        Syntax.check_term ctxt t
      end  


    fun sd_cmd ((((flags,name),attribs),t_raw),r_raw) lthy = let
      local
        val ctxt = Refine_Util.apply_configs flags lthy
      in
        val flag_prep_code = Config.get ctxt cfg_prep_code
      end

      val t = mk_synth_term lthy t_raw r_raw

      val ((pat,goal),ctxt) = make_hnr_goal t lthy
      
      fun 
        after_qed [[thm]] ctxt = let
            val thm = singleton (Variable.export ctxt lthy) thm

            val (_,lthy) 
              = Local_Theory.note 
                 ((Refine_Automation.mk_qualified (Binding.name_of name) "refine_raw",[]),[thm]) 
                 lthy;

            val ((dthm,rthm),lthy) = Refine_Automation.define_concrete_fun NONE name attribs [] thm [pat] lthy

            val lthy = lthy 
              |> flag_prep_code ? Refine_Automation.extract_recursion_eqs 
                [Sepref_Extraction.heap_extraction] (Binding.name_of name) dthm

            val _ = Thm.pretty_thm lthy dthm |> Pretty.string_of |> writeln
            val _ = Thm.pretty_thm lthy rthm |> Pretty.string_of |> writeln
          in
            lthy
          end
        | after_qed thmss _ = raise THM ("After-qed: Wrong thmss structure",~1,flat thmss)

    in
      Proof.theorem NONE after_qed [[ (goal,[]) ]] ctxt
    end



    val _ = Outer_Syntax.local_theory_to_proof @{command_keyword "sepref_definition"}
      "Synthesis of imperative program"
      (sd_parser >> sd_cmd)

    val st_parser = Parse.binding --| @{keyword "is"} -- Parse.term --| @{keyword "::"} -- Parse.term

    fun st_cmd ((name,t_raw),r_raw) lthy = let
      val t = mk_synth_term lthy t_raw r_raw
      val ((_,goal),ctxt) = make_hnr_goal t lthy
      
      fun 
        after_qed [[thm]] ctxt = let
            val thm = singleton (Variable.export ctxt lthy) thm

            val _ = Thm.pretty_thm lthy thm |> Pretty.string_of |> tracing
  
            val (_,lthy) 
              = Local_Theory.note 
                 ((Refine_Automation.mk_qualified (Binding.name_of name) "refine_raw",[]),[thm]) 
                 lthy;

          in
            lthy
          end
        | after_qed thmss _ = raise THM ("After-qed: Wrong thmss structure",~1,flat thmss)

    in
      Proof.theorem NONE after_qed [[ (goal,[]) ]] ctxt
    end

    val _ = Outer_Syntax.local_theory_to_proof @{command_keyword "sepref_thm"}
      "Synthesis of imperative program: Only generate raw refinement theorem"
      (st_parser >> st_cmd)

  end

end