Theory Inca_to_Ubx_compiler

theory Inca_to_Ubx_compiler
  imports Inca_to_Ubx_simulation Result
    Inca_Verification
    "VeriComp.Compiler"
    "HOL-Library.Monad_Syntax"
begin

section ‹Generic program rewriting›

primrec monadic_fold_map where
  "monadic_fold_map f acc [] = Some (acc, [])" |
  "monadic_fold_map f acc (x # xs) = do {
    (acc', x')  f acc x;
    (acc'', xs')  monadic_fold_map f acc' xs;
    Some (acc'', x' # xs')
  }"

lemma monadic_fold_map_length:
  "monadic_fold_map f acc xs = Some (acc', xs')  length xs = length xs'"
  by (induction xs arbitrary: acc xs') (auto simp: bind_eq_Some_conv)

lemma monadic_fold_map_ConsD[dest]:
  assumes "monadic_fold_map f a (x # xs) = Some (c, ys)"
  shows "y ys' b. ys = y # ys'  f a x = Some (b, y)  monadic_fold_map f b xs = Some (c, ys')"
  using assms
  by (auto simp add: bind_eq_Some_conv)

lemma monadic_fold_map_eq_Some_conv:
  "monadic_fold_map f a (x # xs) = Some (c, ys) 
    (y ys' b. f a x = Some (b, y)  monadic_fold_map f b xs = Some (c, ys')  ys = y # ys')"
  by (auto simp add: bind_eq_Some_conv)

lemma monadic_fold_map_eq_Some_conv':
  "monadic_fold_map f a (x # xs) = Some p 
    (y ys' b. f a x = Some (b, y)  monadic_fold_map f b xs = Some (fst p, ys')  snd p = y # ys')"
  by (cases p) (auto simp add: bind_eq_Some_conv)

lemma monadic_fold_map_list_all2:
  assumes "monadic_fold_map f acc xs = Some (acc', ys)" and
    "acc acc' x y. f acc x = Some (acc', y)  P x y"
  shows "list_all2 P xs ys"
  using assms(1)
proof (induction xs arbitrary: acc ys)
  case Nil
  then show ?case by simp
next
  case (Cons x xs)
  show ?case
    using Cons.prems
    by (auto simp: bind_eq_Some_conv intro: assms(2) Cons.IH)
qed

lemma monadic_fold_map_list_all:
  assumes "monadic_fold_map f acc xs = Some (acc', ys)" and
    "acc acc' x y. f acc x = Some (acc', y)  P y"
  shows "list_all P ys"
proof -
  have "list_all2 (λ_. P) xs ys"
    using assms
    by (auto elim: monadic_fold_map_list_all2)
  thus ?thesis
    by (auto elim: list_rel_imp_pred2)
qed

fun gen_pop_push where
  "gen_pop_push instr (domain, codomain) Σ = (
    let ar = length domain in
    if ar  length Σ  take ar Σ = domain then
      Some (instr, codomain @ drop ar Σ)
    else
      None
  )"

context inca_to_ubx_simulation begin

section ‹Lifting›

fun lift_instr :: "_  _  _  _  _  _ 
  ((_, _, _, _, _, _, 'opubx, 'ubx1, 'ubx2) Ubx.instr × _) option" where
  "lift_instr F L ret N (Inca.IPush d) Σ = Some (IPush d, None # Σ)" |
  "lift_instr F L ret N Inca.IPop (_ # Σ) = Some (IPop, Σ)" |
  "lift_instr F L ret N (Inca.IGet n) Σ = (if n < N then Some (IGet n, None # Σ) else None)" |
  "lift_instr F L ret N (Inca.ISet n) (None # Σ) = (if n < N then Some (ISet n, Σ) else None)" |
  "lift_instr F L ret N (Inca.ILoad x) (None # Σ) = Some (ILoad x, None # Σ)" |
  "lift_instr F L ret N (Inca.IStore x) (None # None # Σ) = Some (IStore x, Σ)" |
  "lift_instr F L ret N (Inca.IOp op) Σ =
    gen_pop_push (IOp op) (replicate (𝔄𝔯𝔦𝔱𝔶 op) None, [None]) Σ" |
  "lift_instr F L ret N (Inca.IOpInl opinl) Σ =
    gen_pop_push (IOpInl opinl) (replicate (𝔄𝔯𝔦𝔱𝔶 (𝔇𝔢ℑ𝔫𝔩 opinl)) None, [None]) Σ" |
  "lift_instr F L ret N (Inca.ICJump lt lf) [None] =
    (if List.member L lt  List.member L lf then Some (ICJump lt lf, []) else None)" |
  "lift_instr F L ret N (Inca.ICall f) Σ = do {
    (ar, ret)  F f;
    gen_pop_push (ICall f) (replicate ar None, replicate ret None) Σ
  }" |
  "lift_instr F L ret N Inca.IReturn Σ =
    (if Σ = replicate ret None then Some (IReturn, []) else None)" |
  "lift_instr _ _ _ _ _ _ = None"

definition lift_instrs where
  "lift_instrs F L ret N 
    monadic_fold_map (λΣ instr. map_option prod.swap (lift_instr F L ret N instr Σ))"

lemma lift_instrs_length:
  assumes "lift_instrs F L ret N Σi xs = Some (Σo, ys)"
  shows "length xs = length ys"
  using assms unfolding lift_instrs_def
  by (auto intro: monadic_fold_map_length)

lemma lift_instrs_not_Nil: "lift_instrs F L ret N Σi xs = Some (Σo, ys)  xs  []  ys  []"
  using lift_instrs_length by fastforce

lemma lift_instrs_NilD[dest]:
  assumes "lift_instrs F L ret N Σi [] = Some (Σo, ys)"
  shows "Σo = Σi  ys = []"
  using assms
  by (simp_all add: lift_instrs_def)

lemmas Some_eq_bind_conv =
  bind_eq_Some_conv[unfolded eq_commute[of "Option.bind f g" "Some x" for f g x]]

lemma lift_instr_is_jump:
  assumes "lift_instr F L ret N x Σi = Some (y, Σo)"
  shows "Inca.is_jump x  Ubx.is_jump y"
  using assms
  by (rule lift_instr.elims)
    (auto simp add: if_split_eq2 Let_def Some_eq_bind_conv)

lemma lift_instr_is_return:
  assumes "lift_instr F L ret N x Σi = Some (y, Σo)"
  shows "Inca.is_return x  Ubx.is_return y"
  using assms
  by (rule lift_instr.elims)
    (auto simp add: if_split_eq2 Let_def Some_eq_bind_conv)

lemma lift_instrs_all_not_jump_not_return:
  assumes "lift_instrs F L ret N Σi xs = Some (Σo, ys)"
  shows
    "list_all (λi. ¬ Inca.is_jump i  ¬ Inca.is_return i) xs 
     list_all (λi. ¬ Ubx.is_jump i  ¬ Ubx.is_return i) ys"
  using assms
proof (induction xs arbitrary: Σi Σo ys)
  case Nil
  then show ?case by (simp add: lift_instrs_def)
next
  case (Cons x xs)
  from Cons.prems show ?case
    apply (simp add: lift_instrs_def bind_eq_Some_conv)
    apply (fold lift_instrs_def)
    by (auto simp add: Cons.IH lift_instr_is_jump lift_instr_is_return)
qed

lemma lift_instrs_all_butlast_not_jump_not_return:
  assumes "lift_instrs F L ret N Σi xs = Some (Σo, ys)"
  shows
    "list_all (λi. ¬ Inca.is_jump i  ¬ Inca.is_return i) (butlast xs) 
     list_all (λi. ¬ Ubx.is_jump i  ¬ Ubx.is_return i) (butlast ys)"
  using lift_instrs_length[OF assms(1)] assms unfolding lift_instrs_def
proof (induction xs ys arbitrary: Σi Σo rule: list_induct2)
  case Nil
  then show ?case by simp
next
  case (Cons x xs y ys)
  thus ?case
    by (auto simp add: bind_eq_Some_conv lift_instr_is_jump lift_instr_is_return)
qed

lemma lift_instr_sp:
  assumes "lift_instr F L ret N x Σi = Some (y, Σo)"
  shows "Subx.sp_instr F ret y Σi Σo"
  using assms
  apply (induction F L ret N x Σi rule: lift_instr.induct;
      auto simp: Let_def intro: Subx.sp_instr.intros)
    apply (rule Subx.sp_instr.Op, metis append_take_drop_id)
   apply (rule Subx.sp_instr.OpInl, metis append_take_drop_id)
  apply (auto simp add: bind_eq_Some_conv intro!: Subx.sp_instr.Call, metis append_take_drop_id)
  done

lemma lift_instrs_sp:
  assumes "lift_instrs F L ret N Σi xs = Some (Σo, ys)"
  shows "Subx.sp_instrs F ret ys Σi Σo"
  using assms unfolding lift_instrs_def
proof (induction xs arbitrary: Σi Σo ys)
  case Nil
  thus ?case by (auto intro: Subx.sp_instrs.Nil)
next
  case (Cons x xs)
  from Cons.prems show ?case
    by (auto simp add: bind_eq_Some_conv intro: Subx.sp_instrs.Cons lift_instr_sp Cons.IH)
qed

lemma lift_instr_fun_call_in_range:
  assumes "lift_instr F L ret N x Σi = Some (y, Σo)"
  shows "Subx.fun_call_in_range F y"
  using assms
  by (induction F L ret N x Σi rule: lift_instr.induct) (auto simp: Let_def bind_eq_Some_conv)

lemma lift_instrs_all_fun_call_in_range:
  assumes "lift_instrs F L ret N Σi xs = Some (Σo, ys)"
  shows "list_all (Subx.fun_call_in_range F) ys"
  using assms unfolding lift_instrs_def
  by (auto intro!: monadic_fold_map_list_all intro: lift_instr_fun_call_in_range)

lemma lift_instr_local_var_in_range:
  assumes "lift_instr F L ret N x Σi = Some (y, Σo)"
  shows "Subx.local_var_in_range N y"
  using assms
  by (induction F L ret N x Σi rule: lift_instr.induct) (auto simp: Let_def bind_eq_Some_conv)

lemma lift_instrs_all_local_var_in_range:
  assumes "lift_instrs F L ret N Σi xs = Some (Σo, ys)"
  shows "list_all (Subx.local_var_in_range N) ys"
  using assms unfolding lift_instrs_def
  by (auto intro!: monadic_fold_map_list_all intro: lift_instr_local_var_in_range)

lemma lift_instr_jump_in_range:
  assumes "lift_instr F L ret N x Σi = Some (y, Σo)"
  shows "Subx.jump_in_range (set L) y"
  using assms
  by (induction F L ret N x Σi rule: lift_instr.induct)
    (auto simp: Let_def bind_eq_Some_conv in_set_member)

lemma lift_instrs_all_jump_in_range:
  assumes "lift_instrs F L ret N Σi xs = Some (Σo, ys)"
  shows "list_all (Subx.jump_in_range (set L)) ys"
  using assms unfolding lift_instrs_def
  by (auto intro!: monadic_fold_map_list_all intro: lift_instr_jump_in_range)

lemma lift_instr_norm:
  "lift_instr F L ret N instr1 Σ1 = Some (instr2, Σ2)  norm_eq instr1 instr2"
  by (induction instr1 Σ1 rule: lift_instr.induct) (auto simp: Let_def bind_eq_Some_conv)

lemma lift_instrs_all_norm:
  assumes "lift_instrs F L ret N Σ1 instrs1 = Some (Σ2, instrs2)"
  shows "list_all2 norm_eq instrs1 instrs2"
  using assms unfolding lift_instrs_def
  by (auto simp: lift_instr_norm elim!: monadic_fold_map_list_all2)


section ‹Optimization›

context
  fixes load_oracle :: "nat  type option"
begin

definition orelse :: "'a option  'a option  'a option"  (infixr orelse 55) where
  "x orelse y = (case x of Some x'  Some x' | None  y)"

lemma None_orelse[simp]: "None orelse y = y"
  by (simp add: orelse_def)

lemma orelse_None[simp]: "x orelse None = x"
  by (cases x) (simp_all add: orelse_def)

lemma Some_orelse[simp]: "Some x orelse y = Some x"
  by (simp add: orelse_def)

lemma orelse_eq_Some_conv:
  "x orelse y = Some z  (x = Some z  x = None  y = Some z)"
  by (cases x) simp_all

lemma orelse_eq_SomeE:
  assumes
    "x orelse y = Some z" and
    "x = Some z  P" and
    "x = None  y = Some z  P"
  shows "P"
  using assms(1)
  unfolding orelse_def
  by (cases x; auto intro: assms(2,3))

fun drop_prefix where
  "drop_prefix [] ys = Some ys" |
  "drop_prefix (x # xs) (y # ys) = (if x = y then drop_prefix xs ys else None)" |
  "drop_prefix _ _ = None "

lemma drop_prefix_append_prefix[simp]: "drop_prefix xs (xs @ ys) = Some ys"
  by (induction xs) simp_all

lemma drop_prefix_eq_Some_conv: "drop_prefix xs ys = Some zs  ys = xs @ zs"
  by (induction xs ys arbitrary: zs rule: drop_prefix.induct)
    (auto simp: if_split_eq1)

fun optim_instr where
  "optim_instr _ _ _ (IPush d) Σ =
    Some Pair  (Some IPushUbx1  (unbox_ubx1 d))  Some (Some Ubx1 # Σ) orelse
    Some Pair  (Some IPushUbx2  (unbox_ubx2 d))  Some (Some Ubx2 # Σ) orelse
    Some (IPush d, None # Σ)
  " |
  "optim_instr _ _ _ (IPushUbx1 n) Σ = Some (IPushUbx1 n, Some Ubx1 # Σ)" |
  "optim_instr _ _ _ (IPushUbx2 b) Σ = Some (IPushUbx2 b, Some Ubx2 # Σ)" |
  "optim_instr _ _ _ IPop (_ # Σ) = Some (IPop, Σ)" |
  "optim_instr _ _ pc (IGet n) Σ =
    map_option (λτ. (IGetUbx τ n, Some τ # Σ)) (load_oracle pc) orelse
    Some (IGet n, None # Σ)" |
  "optim_instr _ _ pc (IGetUbx τ n) Σ = Some (IGetUbx τ n, Some τ # Σ)" |
  "optim_instr _ _ _ (ISet n) (None # Σ) = Some (ISet n, Σ)" |
  "optim_instr _ _ _ (ISet n) (Some τ # Σ) = Some (ISetUbx τ n, Σ)" |
  "optim_instr _ _ _ (ISetUbx _ n) (None # Σ) = Some (ISet n, Σ)" |
  "optim_instr _ _ _ (ISetUbx _ n) (Some τ # Σ) = Some (ISetUbx τ n, Σ)" |
  "optim_instr _ _ pc (ILoad x) (None # Σ) =
    map_option (λτ. (ILoadUbx τ x, Some τ # Σ)) (load_oracle pc) orelse
    Some (ILoad x, None # Σ)" |
  "optim_instr _ _ _ (ILoadUbx τ x) (None # Σ) = Some (ILoadUbx τ x, Some τ # Σ)" |
  "optim_instr _ _ _ (IStore x) (None # None # Σ) = Some (IStore x, Σ)" |
  "optim_instr _ _ _ (IStore x) (None # Some τ # Σ) = Some (IStoreUbx τ x, Σ)" |
  "optim_instr _ _ _ (IStoreUbx _ x) (None # None # Σ) = Some (IStore x, Σ)" |
  "optim_instr _ _ _ (IStoreUbx _ x) (None # Some τ # Σ) = Some (IStoreUbx τ x, Σ)" |
  "optim_instr _ _ _ (IOp op) Σ =
    map_option (λΣo. (IOp op, None # Σo)) (drop_prefix (replicate (𝔄𝔯𝔦𝔱𝔶 op) None) Σ)" |
  "optim_instr _ _ _ (IOpInl opinl) Σ = (
    let ar = 𝔄𝔯𝔦𝔱𝔶 (𝔇𝔢ℑ𝔫𝔩 opinl) in
    if ar  length Σ then
      case 𝔘𝔟𝔵 opinl (take ar Σ) of
        None  map_option (λΣo. (IOpInl opinl, None # Σo)) (drop_prefix (replicate ar None) Σ) |
        Some opubx  map_option (λΣo. (IOpUbx opubx, snd (𝔗𝔶𝔭𝔢𝔒𝔣𝔒𝔭 opubx) # Σo))
          (drop_prefix (fst (𝔗𝔶𝔭𝔢𝔒𝔣𝔒𝔭 opubx)) Σ)
    else
      None
  )" |
  "optim_instr _ _ _ (IOpUbx opubx) Σ =
    (let p = 𝔗𝔶𝔭𝔢𝔒𝔣𝔒𝔭 opubx in
     map_option (λΣo. (IOpUbx opubx, snd p # Σo)) (drop_prefix (fst p) Σ))" |
  "optim_instr _ _ _ (ICJump lt lf) [None] = Some (ICJump lt lf, []) " |
  "optim_instr F _ _ (ICall f) Σ = do {
    (ar, ret)  F f;
    Σo  drop_prefix (replicate ar None) Σ;
    Some (ICall f, replicate ret None @ Σo)
  }" |
  "optim_instr _ ret _ IReturn Σ = (if Σ = replicate ret None then Some (IReturn, []) else None)" |
  "optim_instr _ _ _ _ _ = None"

definition optim_instrs where
  "optim_instrs F ret  λpc Σi instrs.
    map_option (λ((_, Σo), instrs'). (Σo, instrs'))
      (monadic_fold_map (λ(pc, Σ) instr.
        map_option (λ(instr', Σo). ((Suc pc, Σo), instr')) (optim_instr F ret pc instr Σ))
      (pc, Σi) instrs)"

lemma optim_instrs_Cons_eq_Some_conv:
  "optim_instrs F ret pc Σi (instr # instrs) = Some (Σo, ys)  (y ys' Σ.
    ys = y # ys' 
    optim_instr F ret pc instr Σi = Some (y, Σ) 
    optim_instrs F ret (Suc pc) Σ instrs = Some (Σo, ys'))"
  unfolding optim_instrs_def
  by (auto simp: bind_eq_Some_conv)

lemma optim_instrs_length:
  assumes "optim_instrs F ret pc Σi xs = Some (Σo, ys)"
  shows "length xs = length ys"
  using assms unfolding optim_instrs_def
  by (auto intro: monadic_fold_map_length)

lemma optim_instrs_not_Nil: "optim_instrs F ret pc Σi xs = Some (Σo, ys)  xs  []  ys  []"
  using optim_instrs_length by fastforce

lemma optim_instrs_NilD[dest]:
  assumes "optim_instrs F ret pc Σi [] = Some (Σo, ys)"
  shows "Σo = Σi  ys = []"
  using assms
  by (simp_all add: optim_instrs_def)

lemma optim_instrs_ConsD[dest]:
  assumes "optim_instrs F ret pc Σi (x # xs) = Some (Σo, ys)"
  shows "y ys' Σ. ys = y # ys' 
    optim_instr F ret pc x Σi = Some (y, Σ) 
    optim_instrs F ret (Suc pc) Σ xs = Some (Σo, ys')"
  using assms
  unfolding optim_instrs_def
  by (auto simp: bind_eq_Some_conv)

lemma optim_instr_norm:
  assumes "optim_instr F ret pc instr1 Σ1 = Some (instr2, Σ2)"
  shows "norm_instr instr1 = norm_instr instr2"
  using assms
  by (cases "(F, ret, pc, instr1, Σ1)" rule: optim_instr.cases)
    (auto simp: ap_option_eq_Some_conv Let_def if_split_eq1 bind_eq_Some_conv option.case_eq_if
      orelse_eq_Some_conv
      dest!: Subx.box_unbox_inverse dest: Subx.𝔘𝔟𝔵_invertible)

lemma optim_instrs_all_norm:
  assumes "optim_instrs F ret pc Σ1 instrs1 = Some (Σ2, instrs2)"
  shows "list_all2 (λi1 i2. norm_instr i1 = norm_instr i2) instrs1 instrs2"
  using assms unfolding optim_instrs_def
  by (auto simp: optim_instr_norm elim!: monadic_fold_map_list_all2)

lemma optim_instr_is_jump:
  assumes "optim_instr F ret pc x Σi = Some (y, Σo)"
  shows "is_jump x  is_jump y"
  using assms
  by (cases "(F, ret, pc, x, Σi)" rule: optim_instr.cases;
      simp add: orelse_eq_Some_conv ap_option_eq_Some_conv bind_eq_Some_conv
        Let_def if_split_eq1 option.case_eq_if;
      safe; simp)

lemma optim_instr_is_return:
  assumes "optim_instr F ret pc x Σi = Some (y, Σo)"
  shows "is_return x  is_return y"
  using assms
  by (cases "(F, ret, pc, x, Σi)" rule: optim_instr.cases;
      simp add: orelse_eq_Some_conv ap_option_eq_Some_conv bind_eq_Some_conv
        Let_def if_split_eq1 option.case_eq_if;
      safe; simp)

lemma optim_instrs_all_butlast_not_jump_not_return:
  assumes "optim_instrs F ret pc Σi xs = Some (Σo, ys)"
  shows
    "list_all (λi. ¬ is_jump i  ¬ is_return i) (butlast xs) 
     list_all (λi. ¬ is_jump i  ¬ is_return i) (butlast ys)"
  using optim_instrs_length[OF assms(1)] assms
proof (induction xs ys arbitrary: pc Σi Σo rule: list_induct2)
  case Nil
  thus ?case by simp
next
  case (Cons x xs y ys)
  from Cons.prems obtain Σ where
    optim_x: "optim_instr F ret pc x Σi = Some (y, Σ)" and
    optim_xs: "optim_instrs F ret (Suc pc) Σ xs = Some (Σo, ys)"
    by auto
  show ?case
    using Cons.hyps
    using optim_x optim_xs
    apply (simp add: Cons.IH optim_instr_is_jump optim_instr_is_return)
    by fastforce
qed

lemma optim_instr_jump_in_range:
  assumes "optim_instr F ret pc x Σi = Some (y, Σo)"
  shows "Subx.jump_in_range L x  Subx.jump_in_range L y"
  using assms
  by (cases "(F, ret, pc, x, Σi)" rule: optim_instr.cases)
    (auto simp: ap_option_eq_Some_conv Let_def if_split_eq1 option.case_eq_if
      bind_eq_Some_conv orelse_eq_Some_conv)

lemma optim_instrs_all_jump_in_range:
  assumes "optim_instrs F ret pc Σi xs = Some (Σo, ys)"
  shows "list_all (Subx.jump_in_range L) xs  list_all (Subx.jump_in_range L) ys"
  using assms
  by (induction xs arbitrary: pc Σi Σo ys) (auto simp: optim_instr_jump_in_range)

lemma optim_instr_fun_call_in_range:
  assumes "optim_instr F ret pc x Σi = Some (y, Σo)"
  shows "Subx.fun_call_in_range F x  Subx.fun_call_in_range F y"
  using assms
  by (cases "(F, ret, pc, x, Σi)" rule: optim_instr.cases)
    (auto simp: ap_option_eq_Some_conv Let_def if_split_eq1 option.case_eq_if
      bind_eq_Some_conv orelse_eq_Some_conv)

lemma optim_instrs_all_fun_call_in_range:
  assumes "optim_instrs F ret pc Σi xs = Some (Σo, ys)"
  shows "list_all (Subx.fun_call_in_range F) xs  list_all (Subx.fun_call_in_range F) ys"
  using assms
  by (induction xs arbitrary: pc Σi Σo ys) (auto simp: optim_instr_fun_call_in_range)

lemma optim_instr_local_var_in_range:
  assumes "optim_instr F ret pc x Σi = Some (y, Σo)"
  shows "Subx.local_var_in_range N x  Subx.local_var_in_range N y"
  using assms
  by (cases "(F, ret, pc, x, Σi)" rule: optim_instr.cases)
    (auto simp: ap_option_eq_Some_conv Let_def if_split_eq1 option.case_eq_if
      bind_eq_Some_conv orelse_eq_Some_conv)

lemma optim_instrs_all_local_var_in_range:
  assumes "optim_instrs F ret pc Σi xs = Some (Σo, ys)"
  shows "list_all (Subx.local_var_in_range N) xs  list_all (Subx.local_var_in_range N) ys"
  using assms
  by (induction xs arbitrary: pc Σi Σo ys) (auto simp: optim_instr_local_var_in_range)

lemma optim_instr_sp:
  assumes "optim_instr F ret pc x Σi = Some (y, Σo)"
  shows "Subx.sp_instr F ret y Σi Σo"
  using assms
  by (cases "(F, ret, pc, x, Σi)" rule: optim_instr.cases)
    (auto simp add: Let_def if_split_eq1 option.case_eq_if
      simp: ap_option_eq_Some_conv orelse_eq_Some_conv drop_prefix_eq_Some_conv bind_eq_Some_conv
      intro: Subx.sp_instr.intros)

lemma optim_instrs_sp:
  assumes "optim_instrs F ret pc Σi xs = Some (Σo, ys)"
  shows "Subx.sp_instrs F ret ys Σi Σo"
  using assms
  by (induction xs arbitrary: pc Σi Σo ys)
    (auto intro!: Subx.sp_instrs.intros optim_instr_sp)


section ‹Compilation of function definition›

definition compile_basic_block where
  "compile_basic_block F L ret N 
    ap_map_prod Some (λi1. do {
      _  if i1  [] then Some () else None;
      _  if list_all (λi. ¬ Inca.is_jump i  ¬ Inca.is_return i) (butlast i1) then Some () else None;
      (Σo, i2)  lift_instrs F L ret N ([] :: type option list) i1;
      if Σo = [] then
        case optim_instrs F ret 0 ([] :: type option list) i2 of
          Some (Σo', i2')  Some (if Σo' = [] then i2' else i2) |
          None  Some i2
      else
        None
    })"

lemma compile_basic_block_rel_prod_all_norm_eq:
  assumes "compile_basic_block F L ret N bblock1 = Some bblock2"
  shows "rel_prod (=) (list_all2 norm_eq) bblock1 bblock2"
  using assms
  unfolding compile_basic_block_def
  apply (auto simp add: ap_map_prod_eq_Some_conv bind_eq_Some_conv
      simp: if_split_eq1
      intro: lift_instrs_all_norm
      dest!: optim_instrs_all_norm lift_instrs_all_norm)
  subgoal premises prems for _ xs zs ys
    using list_all2 norm_eq xs ys
  proof (rule list_all2_trans[of norm_eq "λi. norm_eq (norm_instr i)" norm_eq xs ys zs, simplified])
    show "list_all2 (λi. norm_eq (norm_instr i)) ys zs"
    proof (cases "optim_instrs F ret 0 [] ys")
      case None
      with prems show ?thesis by (simp add: list.rel_refl)
    next
      case (Some p)
      with prems show ?thesis
        by (cases p) (auto simp: list.rel_refl intro: optim_instrs_all_norm)
    qed
  qed
  done

lemma list_all_iff_butlast_last:
  assumes "xs  []"
  shows "list_all P xs  list_all P (butlast xs)  P (last xs)"
  using assms
  by (induction xs) auto

lemma compile_basic_block_wf:
  assumes "compile_basic_block F L ret N x = Some y"
  shows "Subx.wf_basic_block F (set L) ret N y"
proof -
  obtain f instrs1 instrs2 instrs3 where
    x_def: "x = (f, instrs1)" and
    y_def: "y = (f, instrs3)" and
    "instrs1  []" and
    all_not_jump_not_return_instrs1:
      "list_all (λi. ¬ Inca.is_jump i  ¬ Inca.is_return i) (butlast instrs1)" and
    lift_instrs1: "lift_instrs F L ret N ([] :: type option list) instrs1 = Some ([], instrs2)" and
    instr4_defs: "instrs3 = instrs2 
      optim_instrs F ret 0 ([] :: type option list) instrs2 = Some ([], instrs3)"
    using assms
    unfolding compile_basic_block_def
    apply (auto simp: ap_map_prod_eq_Some_conv bind_eq_Some_conv if_split_eq1 option.case_eq_if)
    by blast

  have "instrs3  []"
    using instr4_defs instrs1  []
    using lift_instrs_not_Nil[OF lift_instrs1]
    by (auto simp: optim_instrs_not_Nil)
  moreover have "list_all (Subx.local_var_in_range N) instrs3"
    using instr4_defs lift_instrs1
    by (auto dest: lift_instrs_all_local_var_in_range simp: optim_instrs_all_local_var_in_range)
  moreover have "list_all (Subx.fun_call_in_range F) instrs3"
    using instr4_defs lift_instrs1
    by (auto dest: lift_instrs_all_fun_call_in_range simp: optim_instrs_all_fun_call_in_range)
  moreover have "list_all (Subx.jump_in_range (set L)) instrs3"
    using instr4_defs lift_instrs1
    by (auto dest: lift_instrs_all_jump_in_range simp: optim_instrs_all_jump_in_range)
  moreover have "list_all (λi. ¬ Ubx.instr.is_jump i  ¬ Ubx.instr.is_return i) (butlast instrs3)"
    using instr4_defs lift_instrs1 all_not_jump_not_return_instrs1
    by (auto simp:
        lift_instrs_all_butlast_not_jump_not_return
        optim_instrs_all_butlast_not_jump_not_return)
  moreover have "Subx.sp_instrs F ret instrs3 [] []"
    using instr4_defs lift_instrs1
    by (auto intro: lift_instrs_sp optim_instrs_sp)
  ultimately show ?thesis
    by (auto simp: y_def intro!: Subx.wf_basic_blockI)
qed

fun compile_fundef where
  "compile_fundef F (Fundef bblocks1 ar ret locals) = do {
    _  if bblocks1 = [] then None else Some ();
    bblocks2  ap_map_list (compile_basic_block F (map fst bblocks1) ret (ar + locals)) bblocks1;
    Some (Fundef bblocks2 ar ret locals)
  }"

lemma compile_fundef_arities: "compile_fundef F fd1 = Some fd2  arity fd1 = arity fd2"
  by (cases fd1) (auto simp: bind_eq_Some_conv)

lemma compile_fundef_returns: "compile_fundef F fd1 = Some fd2  return fd1 = return fd2"
  by (cases fd1) (auto simp: bind_eq_Some_conv)

lemma compile_fundef_locals:
  "compile_fundef F fd1 = Some fd2  fundef_locals fd1 = fundef_locals fd2"
  by (cases fd1) (auto simp: bind_eq_Some_conv)

lemma if_then_None_else_Some_eq[simp]:
  "(if a then None else Some b) = Some c  ¬ a  b = c"
  "(if a then None else Some b) = None  a"
  by (cases a) simp_all

lemma
  assumes "compile_fundef F fd1 = Some fd2"
  shows
    rel_compile_fundef: "rel_fundef (=) norm_eq fd1 fd2" (is ?REL) and
    wf_compile_fundef: "Subx.wf_fundef F fd2" (is ?WF)
  unfolding atomize_conj
proof (cases fd1)
  case (Fundef bblocks1 ar ret locals)
  with assms obtain bblocks2 where
    "bblocks1  []" and
    lift_bblocks1:
      "ap_map_list (compile_basic_block F (map fst bblocks1) ret (ar + locals)) bblocks1 = Some bblocks2" and
    fd2_def: "fd2 = Fundef bblocks2 ar ret locals"
    by (auto simp add: bind_eq_Some_conv)

  show "?REL  ?WF"
  proof (rule conjI)
    show ?REL
      unfolding Fundef fd2_def
    proof (rule fundef.rel_intros)
      show "list_all2 (rel_prod (=) (list_all2 norm_eq)) bblocks1 bblocks2"
        using lift_bblocks1
        unfolding ap_map_list_iff_list_all2
        by (auto elim: list.rel_mono_strong intro: compile_basic_block_rel_prod_all_norm_eq)
    qed simp_all
  next
    have "bblocks2  []"
      using bblocks1  [] length_ap_map_list[OF lift_bblocks1] by force
    moreover have "list_all (Subx.wf_basic_block F (fst ` set bblocks1) ret (ar + locals)) bblocks2"
      using lift_bblocks1
      unfolding ap_map_list_iff_list_all2
      by (auto elim!: list_rel_imp_pred2 dest: compile_basic_block_wf)
    moreover have "fst ` set bblocks1 = fst ` set bblocks2"
      using lift_bblocks1
      unfolding ap_map_list_iff_list_all2
      by (induction bblocks1 bblocks2 rule: list.rel_induct)
        (auto simp add: compile_basic_block_def ap_map_prod_eq_Some_conv)
    ultimately show ?WF
      unfolding fd2_def
      by (auto intro: Subx.wf_fundefI)
  qed
qed

end

end

locale inca_ubx_compiler =
  inca_to_ubx_simulation Finca_empty Finca_get
  for
    Finca_empty and
    Finca_get :: "_  'fun  _ option" +
  fixes
    load_oracle :: "'fun  nat  type option"
begin


section ‹Compilation of function environment›

definition compile_env_entry where
  "compile_env_entry F  λp. ap_map_prod Some (compile_fundef (load_oracle (fst p)) F) p"

lemma rel_compile_env_entry:
  assumes "compile_env_entry F (f, fd1) = Some (f, fd2)"
  shows "rel_fundef (=) norm_eq fd1 fd2"
  using assms unfolding compile_env_entry_def
  by (auto simp: ap_map_prod_eq_Some_conv intro!: rel_compile_fundef)

definition compile_env where
  "compile_env e  do {
    let fundefs1 = Finca_to_list e;
    fundefs2  ap_map_list (compile_env_entry (map_option funtype  Finca_get e)) fundefs1;
    Some (Subx.Fenv.from_list fundefs2)
  }"

lemma rel_ap_map_list_ap_map_list_compile_env_entries:
  assumes "ap_map_list (compile_env_entry F) xs = Some ys"
  shows "rel_fundefs (Finca_get (Sinca.Fenv.from_list xs)) (Fubx_get (Subx.Fenv.from_list ys))"
  using assms
proof (induction xs arbitrary: ys)
  case Nil
  thus ?case
    using rel_fundefs_empty by simp
next
  case (Cons x xs)
  from Cons.prems obtain y ys' where
    ys_def: "ys = y # ys'" and
    compile_env_x: "compile_env_entry F x = Some y" and
    compile_env_xs: "ap_map_list (compile_env_entry F) xs = Some ys'"
    by (auto simp add: ap_option_eq_Some_conv)

  obtain f fd1 fd2 where
    prods: "x = (f, fd1)" "y = (f, fd2)" and "compile_fundef (load_oracle f) F fd1 = Some fd2"
    using compile_env_x
    by (cases x) (auto simp: compile_env_entry_def eq_fst_iff ap_map_prod_eq_Some_conv)

  have "rel_fundef (=) norm_eq fd1 fd2"
    using compile_env_x[unfolded prods]
    by (auto intro: rel_compile_env_entry)
  thus ?case
    using Cons.IH[OF compile_env_xs, THEN rel_fundefsD]
    unfolding prods ys_def
    unfolding Sinca.Fenv.from_list_correct Subx.Fenv.from_list_correct
    by (auto intro: rel_fundefsI)
qed

lemma rel_fundefs_compile_env:
  assumes "compile_env F1 = Some F2"
  shows "rel_fundefs (Finca_get F1) (Fubx_get F2)"
proof -
  from assms obtain xs where
    ap_map_list_F1: "ap_map_list (compile_env_entry (map_option funtype  Finca_get F1)) (Finca_to_list F1) = Some xs" and
    F2_def: "F2 = Subx.Fenv.from_list xs"
    by (auto simp: compile_env_def bind_eq_Some_conv)

  show ?thesis
    using rel_ap_map_list_ap_map_list_compile_env_entries[OF ap_map_list_F1]
    unfolding F2_def Sinca.Fenv.get_from_list_to_list
    by assumption
qed


section ‹Compilation of program›

fun compile where
  "compile (Prog F1 H f) = Some Prog  compile_env F1  Some H  Some f"

lemma ap_map_list_cong:
  assumes "x. x  set ys  f x = g x" and "xs = ys"
  shows "ap_map_list f xs = ap_map_list g ys"
  using assms
  by (induction xs arbitrary: ys) auto

lemma compile_env_wf_fundefs:
  assumes "compile_env F1 = Some F2"
  shows "Subx.wf_fundefs (Fubx_get F2)"
proof (intro Subx.wf_fundefsI allI)
  fix f
  obtain xs where
    ap_map_list_F1: "ap_map_list (compile_env_entry (map_option funtype  Finca_get F1)) (Finca_to_list F1) = Some xs" and
    F2_def: "F2 = Subx.Fenv.from_list xs"
    using assms by (auto simp: compile_env_def bind_eq_Some_conv)

  have rel_map_of_F1_xs:
    "f. rel_option (λx y. compile_fundef (load_oracle f) (map_option funtype  Finca_get F1) x = Some y)
      (map_of (Finca_to_list F1) f) (map_of xs f)"
    using ap_map_list_F1
    by (auto simp: compile_env_entry_def ap_map_prod_eq_Some_conv
        dest: ap_map_list_imp_rel_option_map_of)

  have funtype_F1_eq_funtype_F2:
    "map_option funtype  Finca_get F1 = map_option funtype  Fubx_get F2"
  proof (rule ext, simp)
    fix x
    show "map_option funtype (Finca_get F1 x) = map_option funtype (Fubx_get F2 x)"
      unfolding F2_def Subx.Fenv.from_list_correct Sinca.Fenv.to_list_correct[symmetric]
      using rel_map_of_F1_xs[of x]
      by (cases rule: option.rel_cases)
        (simp_all add: funtype_def compile_fundef_arities compile_fundef_returns)
  qed

  show "pred_option (Subx.wf_fundef (map_option funtype  Fubx_get F2)) (Fubx_get F2 f) "
  proof (cases "map_of (Finca_to_list F1) f")
    case None
    thus ?thesis
      using rel_map_of_F1_xs[of f, unfolded None]
      by (simp add: F2_def Subx.Fenv.from_list_correct)
  next
    case (Some fd1)
    show ?thesis
      using rel_map_of_F1_xs[of f, unfolded Some option_rel_Some1]
      unfolding funtype_F1_eq_funtype_F2 F2_def Subx.Fenv.from_list_correct
      by (auto intro: wf_compile_fundef)
  qed
qed

lemma compile_load:
  assumes
    compile_p1: "compile p1 = Some p2" and
    load: "Subx.load p2 s2"
  shows "s1. Sinca.load p1 s1  match s1 s2"
proof -
  obtain F1 H main where p1_def: "p1 = Prog F1 H main"
    by (cases p1) simp
  then obtain F2 where
    compile_F1: "compile_env F1 = Some F2" and
    p2_def: "p2 = Prog F2 H main"
    using compile_p1
    by (auto simp: ap_option_eq_Some_conv)

  note rel_F1_F2 = rel_fundefs_compile_env[OF compile_F1]

  show ?thesis
    using assms(2) unfolding p2_def Subx.load_def
  proof (cases _ _ _ s2 rule: Global.load.cases)
    case (1 fd2)
    then obtain fd1 where
      F1_main: "Finca_get F1 main = Some fd1" and rel_fd1_fd2: "rel_fundef (=) norm_eq fd1 fd2"
      using rel_fundefs_Some2[OF rel_F1_F2]
      by auto
      
    let ?s1 = "State F1 H [allocate_frame main fd1 [] uninitialized]"

    show ?thesis
    proof (intro exI conjI)
      show "Sinca.load p1 ?s1"
        unfolding Sinca.load_def p1_def
        using 1 F1_main rel_fd1_fd2
        by (auto simp: rel_fundef_arities intro!: Global.load.intros dest: rel_fundef_body_length)
    next
      have "Subx.wf_state s2"
        unfolding 1
        using compile_F1
        by (auto intro!: Subx.wf_stateI intro: compile_env_wf_fundefs)
      then show "match ?s1 s2"
        using 1 rel_F1_F2 rel_fd1_fd2
        by (auto simp: allocate_frame_def rel_fundef_locals
            simp: rel_fundef_rel_fst_hd_bodies[OF rel_fd1_fd2 disjI2]
            intro!: match.intros rel_stacktraces.intros intro: Subx.sp_instrs.Nil)
    qed
  qed
qed

interpretation std_to_inca_compiler: compiler where
  step1 = Sinca.step and final1 = "final Finca_get Inca.IReturn" and load1 = Sinca.load and
  step2 = Subx.step and final2 = "final Fubx_get Ubx.IReturn" and load2 = Subx.load and
  match = "λ_. match" and order = "λ_ _. False" and
  compile = compile
using compile_load
  by unfold_locales auto


subsection ‹Completeness of compilation›

lemma lift_instr_None_preservation:
  assumes "lift_instr F L ret N instr Σ = Some (instr', Σ')" and "list_all ((=) None) Σ"
  shows "list_all ((=) None) Σ'"
  using assms
  by (cases "(F, L, ret, N, instr, Σ)" rule: lift_instr.cases)
    (auto simp: Let_def bind_eq_Some_conv)

lemma lift_instr_complete:
  assumes
    "Sinca.local_var_in_range N instr" and
    "Sinca.jump_in_range (set L) instr" and
    "Sinca.fun_call_in_range F instr" and
    "Sinca.sp_instr F ret instr (length Σ) k" and
    "list_all ((=) None) Σ"
  shows "instr' Σ'. lift_instr F L ret N instr Σ = Some (instr', Σ')  length Σ' = k"
  using assms
  by (cases "(F, L, ret, N, instr, Σ)" rule: lift_instr.cases)
    (auto simp add: in_set_member Let_def
      dest: Map.domD dest!: list_all_eq_const_imp_replicate' elim: Sinca.sp_instr.cases)

lemma lift_instrs_complete:
  fixes Σ :: "type option list"
  assumes
    "list_all (Sinca.local_var_in_range N) instrs" and
    "list_all (Sinca.jump_in_range (set L)) instrs" and
    "list_all (Sinca.fun_call_in_range F) instrs" and
    "Sinca.sp_instrs F ret instrs (length Σ) k" and
    "list_all ((=) None) Σ"
  shows "Σ' instrs'. lift_instrs F L ret N Σ instrs = Some (Σ', instrs')  length Σ' = k"
  using assms
proof (induction instrs arbitrary: Σ)
  case Nil
  thus ?case
    unfolding lift_instrs_def
    by (auto elim: Sinca.sp_instrs.cases)
next
  case (Cons instr instrs')
  from Cons.prems(4) obtain k' where
    sp_head: "Sinca.sp_instr F ret instr (length Σ) k'" and
    sp_tail: "Sinca.sp_instrs F ret instrs' k' k"
    by (cases rule: Sinca.sp_instrs.cases) simp

  have inv_instrs':
    "list_all (Sinca.local_var_in_range N) instrs'"
    "list_all (Sinca.jump_in_range (set L)) instrs'"
    "list_all (Sinca.fun_call_in_range F) instrs'"
    using Cons.prems(1-3) by simp_all

  from Cons.prems(1-3,5) obtain instr2 Σtmp where
    lift_head: "lift_instr F L ret N instr Σ = Some (instr2, Σtmp)" and
    "length Σtmp = k'"
    using lift_instr_complete[OF _ _ _ sp_head, of N L] by auto
  hence "list_all ((=) None) Σtmp"
    by (meson Cons.prems(5) lift_instr_None_preservation)
  then obtain instrs2 and Σ' :: "type option list" where
    lift_tail: "lift_instrs F L ret N Σtmp instrs' = Some (Σ', instrs2)" and
    "length Σ' = k"
    using Cons.IH[OF inv_instrs', of Σtmp] sp_tail
    unfolding length Σtmp = k'
    by auto
  show ?case
  proof (intro exI conjI)
    show "lift_instrs F L ret N Σ (instr # instrs') = Some (Σ', instr2 # instrs2)"
      using lift_head lift_tail
      by (simp add: lift_instrs_def)
  next
    show "length Σ' = k"
      by (rule length Σ' = k)
  qed
qed

lemma optim_instr_complete:
  assumes sp: "Subx.sp_instr F ret instr Σ Σ'"
  shows "Σ'' instr'. optim_instr 𝒪 F ret pc instr Σ = Some (instr', Σ'')  length Σ' = length Σ''"
  using sp
proof (cases F ret instr Σ Σ' rule: Subx.sp_instr.cases)
  case (Push d)
  thus ?thesis
    by (cases "unbox_ubx1 d"; cases "unbox_ubx2 d") simp_all
next
  case (Get n)
  thus ?thesis
    by (cases "𝒪 pc") simp_all
next
  case (Load x Σ)
  then show ?thesis
    by (cases "𝒪 pc") simp_all
next
  case (OpInl opinl Σ)
  then show ?thesis
    by (cases "𝔘𝔟𝔵 opinl (replicate (𝔄𝔯𝔦𝔱𝔶 (𝔇𝔢ℑ𝔫𝔩 opinl)) None)")
      (simp_all add: Let_def Subx.𝔘𝔟𝔵_opubx_type)
qed simp_all

lemma compile_basic_block_complete:
  assumes wf_bblock1: "Sinca.wf_basic_block F (set L) ret n bblock1"
  shows "bblock2. compile_basic_block 𝒪 F L ret n bblock1 = Some bblock2"
proof (cases bblock1)
  case (Pair label instrs1)
  moreover obtain instrs2 where
    "lift_instrs F L ret n ([] :: type option list) instrs1 = Some ([], instrs2)"
    using wf_bblock1[unfolded Pair, simplified]
    using lift_instrs_complete[of n instrs1 L F ret "[]" 0]
    by (auto simp: Sinca.wf_basic_block_def)
  ultimately show ?thesis
    using wf_bblock1[unfolded Pair, simplified]
    apply (simp add: compile_basic_block_def ap_map_prod_eq_Some_conv)
    by (cases "optim_instrs 𝒪 F ret 0 [] instrs2") (auto simp: Sinca.wf_basic_block_def)
qed

lemma bind_eq_map_option[simp]: "x  (λy. Some (f y)) = map_option f x"
  by (cases x) simp_all

lemma compile_fundef_complete:
  assumes wf_fd1: "Sinca.wf_fundef F fd1"
  shows "fd2. compile_fundef 𝒪 F fd1 = Some fd2"
proof (cases fd1)
  case (Fundef bblocks ar ret locals)
  then obtain bblock bblocks' where bblocks_def: "bblocks = bblock # bblocks'"
    using wf_fd1 by (cases bblocks; auto simp: Sinca.wf_fundef_def)
  obtain label instrs where "bblock = (label, instrs)"
    by (cases bblock) simp
  show ?thesis
    using wf_fd1
    by (auto simp add: Fundef Sinca.wf_fundef_def
        intro!: ex_ap_map_list_eq_SomeI intro: compile_basic_block_complete
        elim!: list.pred_mono_strong)
qed

lemma compile_env_entry_complete:
  assumes wf_fd1: "Sinca.wf_fundef F fd1"
  shows "fd2. compile_env_entry F (f, fd1) = Some fd2"
    using compile_fundef_complete[OF wf_fd1]
    by (simp add: compile_env_entry_def ap_map_prod_eq_Some_conv)

lemma compile_env_complete:
  assumes wf_F1: "pred_map (Sinca.wf_fundef (map_option funtype  Finca_get F1)) (Finca_get F1)"
  shows "F2. compile_env F1 = Some F2"
proof -
  show ?thesis
    using wf_F1
    by (auto simp add: compile_env_def
        intro: ex_ap_map_list_eq_SomeI Sinca.Fenv.to_list_list_allI compile_env_entry_complete
          pred_map_get)
qed

theorem compile_complete:
  assumes wf_p1: "Sinca.wf_prog p1"
  shows "p2. compile p1 = Some p2"
proof (cases p1)
  case (Prog F1 H main)
  then show ?thesis
    using wf_p1 unfolding Sinca.wf_prog_def
    by (auto simp: Let_def dest: compile_env_complete)
qed

end

end