section ‹Monadify›
imports Sepref_Basic Sepref_Id_Op
begin

text ‹
In this phase, a monadic program is converted to complete monadic form,
that is, computation of compound expressions are made visible as top-level

The monadify process is separated into 2 steps.
\begin{enumerate}
\item In a first step, eta-expansion is used to add missing operands
to operations and combinators. This way, operators and combinators
always occur with the same arity, which simplifies further processing.

\item In a second step, computation of compound operands is flattened,
introducing new bindings for the intermediate values.
\end{enumerate}
›

definition SP ― ‹Tag to protect content from further application of arity
and combinator equations›
where [simp]: "SP x ≡ x"
lemma SP_cong[cong]: "SP x ≡ SP x" by simp
lemma PR_CONST_cong[cong]: "PR_CONST x ≡ PR_CONST x" by simp

definition RCALL ― ‹Tag that marks recursive call›
where [simp]: "RCALL D ≡ D"
definition EVAL ― ‹Tag that marks evaluation of plain expression for monadify phase›
where [simp]: "EVAL x ≡ RETURN x"

text ‹
Internally, the package first applies rewriting rules from
‹sepref_monadify_arity›, which use eta-expansion to ensure that
every combinator has enough actual parameters. Moreover, this phase will
mark recursive calls by the tag @{const RCALL}.

Next, rewriting rules from ‹sepref_monadify_comb› are used to
add @{const EVAL}-tags to plain expressions that should be evaluated
in the monad. The @{const EVAL} tags are flattened using a default simproc
that generates left-to-right argument order.
›

"Refine_Basic.bind$(RETURN$x)$(λ⇩2x. f x) = f x" "EVAL$x ≡ RETURN$x" by simp_all definition [simp]: "PASS ≡ RETURN" ― ‹Pass on value, invalidating old one› lemma remove_pass_simps: "Refine_Basic.bind$(PASS$x)$(λ⇩2x. f x) ≡ f x"
"Refine_Basic.bind$m$(λ⇩2x. PASS$x) ≡ m" by simp_all definition COPY :: "'a ⇒ 'a" ― ‹Marks required copying of parameter› where [simp]: "COPY x ≡ x" lemma RET_COPY_PASS_eq: "RETURN$(COPY$p) = PASS$p" by simp

ML ‹
local
fun cr_var (i,T) = ("v"^string_of_int i, Free ("__v"^string_of_int i,T))

fun lambda2_name n t = let
val t = @{mk_term "PROTECT2 ?t DUMMY"}
in
Term.lambda_name n t
end

fun
bind_args exp0 [] = exp0
| bind_args exp0 ((x,m)::xms) = let
val lr = bind_args exp0 xms
|> incr_boundvars 1
|> lambda2_name x
in @{mk_term "Refine_Basic.bind$?m$?lr"} end

val (f,args) = Autoref_Tagging.strip_app t
val _ = not (is_Abs f) orelse

val argTs = map fastype_of args
(*val args = map monadify args*)
val args = map (fn a => @{mk_term "EVAL$?a"}) args (*val fT = fastype_of f val argTs = binder_types fT*) val argVs = tag_list 0 argTs |> map cr_var val res0 = let val x = Autoref_Tagging.list_APP (f,map #2 argVs) in @{mk_term "SP (RETURN$?x)"}
end

val res = bind_args res0 (argVs ~~ args)
in
res
end

fun monadify_conv_aux ctxt ct = case Thm.term_of ct of
@{mpat "EVAL$_"} => let fun tac goal_ctxt = simp_tac (put_simpset HOL_basic_ss goal_ctxt addsimps @{thms monadify_simps SP_def}) 1 in (*Refine_Util.monitor_conv "monadify"*) ( Refine_Util.f_tac_conv ctxt (dest_comb #> #2 #> monadify) tac) ct end | t => raise TERM ("monadify_conv",[t]) (*fun extract_comb_conv ctxt = Conv.rewrs_conv (Named_Theorems_Rev.get ctxt @{named_theorems_rev sepref_monadify_evalcomb}) *) in (* val monadify_conv = Conv.top_conv (fn ctxt => Conv.try_conv ( extract_comb_conv ctxt else_conv monadify_conv_aux ctxt ) ) *) val monadify_simproc = \<^simproc_setup>‹passive monadify ("EVAL$a") = ‹K (try o monadify_conv_aux)››;

end

local
open Sepref_Basic
fun mark_params t = let
val (P,c,Q,R,a) = dest_hn_refine t
val pps = strip_star P |> map_filter (dest_hn_ctxt_opt #> map_option #2)

fun tr env (t as @{mpat "RETURN$?x"}) = if is_Bound x orelse member (aconv) pps x then @{mk_term env: "PASS$?x"}
else t
| tr env (t1$t2) = tr env t1$ tr env t2
| tr env (Abs (x,T,t)) = Abs (x,T,tr (T::env) t)
| tr _ t = t

val a = tr [] a
in
mk_hn_refine (P,c,Q,R,a)
end

in
fun mark_params_conv ctxt = Refine_Util.f_tac_conv ctxt
(mark_params)
(fn goal_ctxt => simp_tac (put_simpset HOL_basic_ss goal_ctxt addsimps @{thms PASS_def}) 1)

end

local

open Sepref_Basic

fun dp ctxt (@{mpat "Refine_Basic.bind$(PASS$?p)$(?t' AS⇩p (λ_. PROTECT2 _ DUMMY))"}) = let val (t',ps) = let val ((t',rc),ctxt) = dest_lambda_rc ctxt t' val f = case t' of @{mpat "PROTECT2 ?f _"} => f | _ => raise Match val (f,ps) = dp ctxt f val t' = @{mk_term "PROTECT2 ?f DUMMY"} val t' = rc t' in (t',ps) end val dup = member (aconv) ps p val t = if dup then @{mk_term "Refine_Basic.bind$(RETURN$(COPY$?p))$?t'"} else @{mk_term "Refine_Basic.bind$(PASS$?p)$?t'"}
in
(t,p::ps)
end
| dp ctxt (t1$t2) = (#1 (dp ctxt t1)$ #1 (dp ctxt t2),[])
| dp ctxt (t as (Abs _)) = (apply_under_lambda (#1 oo dp) ctxt t,[])
| dp _ t = (t,[])

fun dp_conv ctxt = Refine_Util.f_tac_conv ctxt
(#1 o dp ctxt)
(fn goal_ctxt =>
ALLGOALS (simp_tac (put_simpset HOL_basic_ss goal_ctxt addsimps @{thms RET_COPY_PASS_eq})))

in
fun dup_tac ctxt = CONVERSION (Sepref_Basic.hn_refine_concl_conv_a dp_conv ctxt)
end

fun arity_tac ctxt = let
val arity1_ss = put_simpset HOL_basic_ss ctxt

val arity2_ss = put_simpset HOL_basic_ss ctxt
in
simp_tac arity1_ss THEN' simp_tac arity2_ss
end

fun comb_tac ctxt = let
val comb1_ss = put_simpset HOL_basic_ss ctxt

val comb2_ss = put_simpset HOL_basic_ss ctxt
in
simp_tac comb1_ss THEN' simp_tac comb2_ss
end

(*fun ops_tac ctxt = CONVERSION (

fun mark_params_tac ctxt = CONVERSION (
Refine_Util.HOL_concl_conv mark_params_conv ctxt)

fun contains_eval @{mpat "Trueprop (hn_refine _ _ _ _ ?a)"} =
Term.exists_subterm (fn @{mpat EVAL} => true | _ => false) a
| contains_eval t = raise TERM("contains_eval",[t]);

fun remove_pass_tac ctxt =
simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms remove_pass_simps})

fun monadify_tac dbg ctxt = let
open Sepref_Basic
in
PHASES' [
("arity", arity_tac, 0),
("comb", comb_tac, 0),
(*("ops", ops_tac, 0),*)
("check_EVAL", K (CONCL_COND' (not o contains_eval)), 0),
("mark_params", mark_params_tac, 0),
("dup", dup_tac, 0),
("remove_pass", remove_pass_tac, 0)
] (flag_phases_ctrl dbg) ctxt
end

end
›

"RETURN ≡ λ⇩2x. SP RETURN$x" "RECT ≡ λ⇩2B x. SP RECT$(λ⇩2D x. B$(λ⇩2x. RCALL$D$x)$x)$x" "case_list ≡ λ⇩2fn fc l. SP case_list$fn$(λ⇩2x xs. fc$x$xs)$l"
"case_prod ≡ λ⇩2fp p. SP case_prod$(λ⇩2a b. fp$a$b)$p"
"case_option ≡ λ⇩2fn fs ov. SP case_option$fn$(λ⇩2x. fs$x)$ov"
"If ≡ λ⇩2b t e. SP If$b$t$e" "Let ≡ λ⇩2x f. SP Let$x$(λ⇩2x. f$x)"
by (simp_all only: SP_def APP_def PROTECT2_def RCALL_def)

"⋀B x. RECT$B$x ≡ Refine_Basic.bind$(EVAL$x)$(λ⇩2x. SP (RECT$B$x))" "⋀D x. RCALL$D$x ≡ Refine_Basic.bind$(EVAL$x)$(λ⇩2x. SP (RCALL$D$x))"
"⋀fn fc l. case_list$fn$fc$l ≡ Refine_Basic.bind$(EVAL$l)$(λ⇩2l. (SP case_list$fn$fc$l))" "⋀fp p. case_prod$fp$p ≡ Refine_Basic.bind$(EVAL$p)$(λ⇩2p. (SP case_prod$fp$p))"
"⋀fn fs ov. case_option$fn$fs$ov ≡ Refine_Basic.bind$(EVAL$ov)$(λ⇩2ov. (SP case_option$fn$fs$ov))" "⋀b t e. If$b$t$e ≡ Refine_Basic.bind$(EVAL$b)$(λ⇩2b. (SP If$b$t$e))"
"⋀x. RETURN$x ≡ Refine_Basic.bind$(EVAL$x)$(λ⇩2x. SP (RETURN$x))" "⋀x f. Let$x$f ≡ Refine_Basic.bind$(EVAL$x)$(λ⇩2x. (SP Let$x$f))"
by (simp_all)

"EVAL$(If$b$t$e) ≡ Refine_Basic.bind$(EVAL$b)$(λ⇩2b. If$b$(EVAL$t)$(EVAL$e))"
"EVAL$(case_list$fn$(λ⇩2x xs. fc x xs)$l) ≡
Refine_Basic.bind$(EVAL$l)$(λ⇩2l. case_list$(EVAL$fn)$(λ⇩2x xs. EVAL$(fc x xs))$l)"
"EVAL$(case_prod$(λ⇩2a b. fp a b)$p) ≡ Refine_Basic.bind$(EVAL$p)$(λ⇩2p. case_prod$(λ⇩2a b. EVAL$(fp a b))$p)" "EVAL$(case_option$fn$(λ⇩2x. fs x)$ov) ≡ Refine_Basic.bind$(EVAL$ov)$(λ⇩2ov. case_option$(EVAL$fn)$(λ⇩2x. EVAL$(fs x))$ov)" "EVAL$ (Let $v$ (λ⇩2x. f x)) ≡ (⤜) $(EVAL$ v) $(λ⇩2x. EVAL$ (f x))"
apply (rule eq_reflection, simp split: list.split prod.split option.split)+
done

"EVAL$(PR_CONST x) ≡ SP (RETURN$(PR_CONST x))"