Theory Refine_Imperative_HOL.Sepref_Monadify
section ‹Monadify›
theory Sepref_Monadify
imports Sepref_Basic Sepref_Id_Op
begin
text ‹
  In this phase, a monadic program is converted to complete monadic form,
  that is, computation of compound expressions are made visible as top-level 
  operations in the monad.
  The monadify process is separated into 2 steps.
  \begin{enumerate}
    \item In a first step, eta-expansion is used to add missing operands 
      to operations and combinators. This way, operators and combinators
      always occur with the same arity, which simplifies further processing.
    \item In a second step, computation of compound operands is flattened,
      introducing new bindings for the intermediate values. 
  \end{enumerate}
›
definition SP 
  where [simp]: "SP x ≡ x"
lemma SP_cong[cong]: "SP x ≡ SP x" by simp
lemma PR_CONST_cong[cong]: "PR_CONST x ≡ PR_CONST x" by simp
definition RCALL 
  where [simp]: "RCALL D ≡ D"
definition EVAL 
  where [simp]: "EVAL x ≡ RETURN x"
text ‹
  Internally, the package first applies rewriting rules from 
  ‹sepref_monadify_arity›, which use eta-expansion to ensure that
  every combinator has enough actual parameters. Moreover, this phase will
  mark recursive calls by the tag @{const RCALL}.
  Next, rewriting rules from ‹sepref_monadify_comb› are used to
  add @{const EVAL}-tags to plain expressions that should be evaluated
  in the monad. The @{const EVAL} tags are flattened using a default simproc 
  that generates left-to-right argument order.
›
lemma monadify_simps: 
  "Refine_Basic.bind$(RETURN$x)$(λ⇩2x. f x) = f x" 
  "EVAL$x ≡ RETURN$x"
  by simp_all
definition [simp]: "PASS ≡ RETURN"
  
lemma remove_pass_simps:
  "Refine_Basic.bind$(PASS$x)$(λ⇩2x. f x) ≡ f x" 
  "Refine_Basic.bind$m$(λ⇩2x. PASS$x) ≡ m"
  by simp_all
definition COPY :: "'a ⇒ 'a" 
  
  where [simp]: "COPY x ≡ x"
lemma RET_COPY_PASS_eq: "RETURN$(COPY$p) = PASS$p" by simp
named_theorems_rev sepref_monadify_arity "Sepref.Monadify: Arity alignment equations"
named_theorems_rev sepref_monadify_comb "Sepref.Monadify: Combinator equations"
ML ‹
  structure Sepref_Monadify = struct
    local
      fun cr_var (i,T) = ("v"^string_of_int i, Free ("__v"^string_of_int i,T))
      fun lambda2_name n t = let
        val t = @{mk_term "PROTECT2 ?t DUMMY"}
      in
        Term.lambda_name n t
      end
      fun 
        bind_args exp0 [] = exp0
      | bind_args exp0 ((x,m)::xms) = let
          val lr = bind_args exp0 xms 
            |> incr_boundvars 1 
            |> lambda2_name x
        in @{mk_term "Refine_Basic.bind$?m$?lr"} end
      fun monadify t = let
        val (f,args) = Autoref_Tagging.strip_app t
        val _ = not (is_Abs f) orelse 
          raise TERM ("monadify: higher-order",[t])
        val argTs = map fastype_of args
        
        val args = map (fn a => @{mk_term "EVAL$?a"}) args
        
  
        val argVs = tag_list 0 argTs
          |> map cr_var
        val res0 = let
          val x = Autoref_Tagging.list_APP (f,map #2 argVs)
        in 
          @{mk_term "SP (RETURN$?x)"}
        end
        val res = bind_args res0 (argVs ~~ args)
      in
        res
      end
      fun monadify_conv_aux ctxt ct = case Thm.term_of ct of
        @{mpat "EVAL$_"} => let
          fun tac goal_ctxt =
            simp_tac (put_simpset HOL_basic_ss goal_ctxt addsimps @{thms monadify_simps SP_def}) 1
        in  (
          Refine_Util.f_tac_conv ctxt (dest_comb #> #2 #> monadify) tac) ct
        end
      | t => raise TERM ("monadify_conv",[t])
        
    in
        
      val monadify_simproc =
        \<^simproc_setup>‹passive monadify ("EVAL$a") = ‹K (try o monadify_conv_aux)››;
    end
    local
      open Sepref_Basic
      fun mark_params t = let
        val (P,c,Q,R,a) = dest_hn_refine t
        val pps = strip_star P |> map_filter (dest_hn_ctxt_opt #> map_option #2)
        fun tr env (t as @{mpat "RETURN$?x"}) = 
              if is_Bound x orelse member (aconv) pps x then
                @{mk_term env: "PASS$?x"}
              else t
          | tr env (t1$t2) = tr env t1 $ tr env t2
          | tr env (Abs (x,T,t)) = Abs (x,T,tr (T::env) t)
          | tr _ t = t
        val a = tr [] a
      in
        mk_hn_refine (P,c,Q,R,a)
      end
    in  
    fun mark_params_conv ctxt = Refine_Util.f_tac_conv ctxt 
      (mark_params) 
      (fn goal_ctxt => simp_tac (put_simpset HOL_basic_ss goal_ctxt addsimps @{thms PASS_def}) 1)
    end  
    local
      open Sepref_Basic
      fun dp ctxt (@{mpat "Refine_Basic.bind$(PASS$?p)$(?t' AS⇩p (λ_. PROTECT2 _ DUMMY))"}) = 
          let
            val (t',ps) = let
                val ((t',rc),ctxt) = dest_lambda_rc ctxt t'
                val f = case t' of @{mpat "PROTECT2 ?f _"} => f | _ => raise Match 
                val (f,ps) = dp ctxt f
                val t' = @{mk_term "PROTECT2 ?f DUMMY"}
                val t' = rc t'
              in
                (t',ps)
              end
  
            val dup = member (aconv) ps p
            val t = if dup then
              @{mk_term "Refine_Basic.bind$(RETURN$(COPY$?p))$?t'"}
            else
              @{mk_term "Refine_Basic.bind$(PASS$?p)$?t'"}
          in
            (t,p::ps)
          end
        | dp ctxt (t1$t2) = (#1 (dp ctxt t1) $ #1 (dp ctxt t2),[])
        | dp ctxt (t as (Abs _)) = (apply_under_lambda (#1 oo dp) ctxt t,[])
        | dp _ t = (t,[])
      fun dp_conv ctxt = Refine_Util.f_tac_conv ctxt 
        (#1 o dp ctxt) 
        (fn goal_ctxt =>
          ALLGOALS (simp_tac (put_simpset HOL_basic_ss goal_ctxt addsimps @{thms RET_COPY_PASS_eq})))
    in
      fun dup_tac ctxt = CONVERSION (Sepref_Basic.hn_refine_concl_conv_a dp_conv ctxt)
    end
    fun arity_tac ctxt = let
      val arity1_ss = put_simpset HOL_basic_ss ctxt 
        addsimps ((Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_arity}))
        |> Simplifier.add_cong @{thm SP_cong}
        |> Simplifier.add_cong @{thm PR_CONST_cong}
      val arity2_ss = put_simpset HOL_basic_ss ctxt 
        addsimps @{thms beta SP_def}
    in
      simp_tac arity1_ss THEN' simp_tac arity2_ss
    end
    fun comb_tac ctxt = let
      val comb1_ss = put_simpset HOL_basic_ss ctxt 
        |> Simplifier.add_simps (Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_comb})
        
        |> Simplifier.add_proc monadify_simproc
        |> Simplifier.add_cong @{thm SP_cong}
        |> Simplifier.add_cong @{thm PR_CONST_cong}
      val comb2_ss = put_simpset HOL_basic_ss ctxt 
        addsimps @{thms SP_def}
    in
      simp_tac comb1_ss THEN' simp_tac comb2_ss
    end
    
    fun mark_params_tac ctxt = CONVERSION (
      Refine_Util.HOL_concl_conv mark_params_conv ctxt)
    fun contains_eval @{mpat "Trueprop (hn_refine _ _ _ _ ?a)"} =   
      Term.exists_subterm (fn @{mpat EVAL} => true | _ => false) a
    | contains_eval t = raise TERM("contains_eval",[t]);  
    fun remove_pass_tac ctxt = 
      simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms remove_pass_simps})
    fun monadify_tac dbg ctxt = let
      open Sepref_Basic
    in
      PHASES' [
        ("arity", arity_tac, 0),
        ("comb", comb_tac, 0),
        
        ("check_EVAL", K (CONCL_COND' (not o contains_eval)), 0),
        ("mark_params", mark_params_tac, 0),
        ("dup", dup_tac, 0),
        ("remove_pass", remove_pass_tac, 0)
      ] (flag_phases_ctrl dbg) ctxt
    end
  end
›
lemma dflt_arity[sepref_monadify_arity]:
  "RETURN ≡ λ⇩2x. SP RETURN$x" 
  "RECT ≡ λ⇩2B x. SP RECT$(λ⇩2D x. B$(λ⇩2x. RCALL$D$x)$x)$x" 
  "case_list ≡ λ⇩2fn fc l. SP case_list$fn$(λ⇩2x xs. fc$x$xs)$l" 
  "case_prod ≡ λ⇩2fp p. SP case_prod$(λ⇩2a b. fp$a$b)$p" 
  "case_option ≡ λ⇩2fn fs ov. SP case_option$fn$(λ⇩2x. fs$x)$ov" 
  "If ≡ λ⇩2b t e. SP If$b$t$e" 
  "Let ≡ λ⇩2x f. SP Let$x$(λ⇩2x. f$x)"
  by (simp_all only: SP_def APP_def PROTECT2_def RCALL_def)
lemma dflt_comb[sepref_monadify_comb]:
  "⋀B x. RECT$B$x ≡ Refine_Basic.bind$(EVAL$x)$(λ⇩2x. SP (RECT$B$x))"
  "⋀D x. RCALL$D$x ≡ Refine_Basic.bind$(EVAL$x)$(λ⇩2x. SP (RCALL$D$x))"
  "⋀fn fc l. case_list$fn$fc$l ≡ Refine_Basic.bind$(EVAL$l)$(λ⇩2l. (SP case_list$fn$fc$l))"
  "⋀fp p. case_prod$fp$p ≡ Refine_Basic.bind$(EVAL$p)$(λ⇩2p. (SP case_prod$fp$p))"
  "⋀fn fs ov. case_option$fn$fs$ov 
    ≡ Refine_Basic.bind$(EVAL$ov)$(λ⇩2ov. (SP case_option$fn$fs$ov))"
  "⋀b t e. If$b$t$e ≡ Refine_Basic.bind$(EVAL$b)$(λ⇩2b. (SP If$b$t$e))"
  "⋀x. RETURN$x ≡ Refine_Basic.bind$(EVAL$x)$(λ⇩2x. SP (RETURN$x))"
  "⋀x f. Let$x$f ≡ Refine_Basic.bind$(EVAL$x)$(λ⇩2x. (SP Let$x$f))"
  by (simp_all)
lemma dflt_plain_comb[sepref_monadify_comb]:
  "EVAL$(If$b$t$e) ≡ Refine_Basic.bind$(EVAL$b)$(λ⇩2b. If$b$(EVAL$t)$(EVAL$e))"
  "EVAL$(case_list$fn$(λ⇩2x xs. fc x xs)$l) ≡ 
    Refine_Basic.bind$(EVAL$l)$(λ⇩2l. case_list$(EVAL$fn)$(λ⇩2x xs. EVAL$(fc x xs))$l)"
  "EVAL$(case_prod$(λ⇩2a b. fp a b)$p) ≡ 
    Refine_Basic.bind$(EVAL$p)$(λ⇩2p. case_prod$(λ⇩2a b. EVAL$(fp a b))$p)"
  "EVAL$(case_option$fn$(λ⇩2x. fs x)$ov) ≡ 
    Refine_Basic.bind$(EVAL$ov)$(λ⇩2ov. case_option$(EVAL$fn)$(λ⇩2x. EVAL$(fs x))$ov)"
  "EVAL $ (Let $ v $ (λ⇩2x. f x)) ≡ (⤜) $ (EVAL $ v) $ (λ⇩2x. EVAL $ (f x))"
  apply (rule eq_reflection, simp split: list.split prod.split option.split)+
  done
lemma evalcomb_PR_CONST[sepref_monadify_comb]:
  "EVAL$(PR_CONST x) ≡ SP (RETURN$(PR_CONST x))"
  by simp
end