Theory Wasm_Interpreter

section ‹WebAssembly Interpreter›

theory Wasm_Interpreter imports Wasm begin

datatype res_crash =
  CError
| CExhaustion

datatype res =
  RCrash res_crash
| RTrap
| RValue "v list"  

datatype res_step =
  RSCrash res_crash
| RSBreak nat "v list"
| RSReturn "v list"
| RSNormal "e list"

abbreviation crash_error where "crash_error  RSCrash CError"

type_synonym depth = nat
type_synonym fuel = nat

type_synonym config_tuple = "s × v list × e list"

type_synonym config_one_tuple = " s × v list × v list × e"

type_synonym res_tuple = "s × v list × res_step"

fun split_vals :: "b_e list  v list × b_e list" where
  "split_vals ((C v)#es) = (let (vs', es') = split_vals es in (v#vs', es'))"
| "split_vals es = ([], es)"

fun split_vals_e :: "e list  v list × e list" where
  "split_vals_e (($ C v)#es) = (let (vs', es') = split_vals_e es in (v#vs', es'))"
| "split_vals_e es = ([], es)"

fun split_n :: "v list  nat  v list × v list" where
  "split_n [] n = ([], [])"
| "split_n es 0 = ([], es)"
| "split_n (e#es) (Suc n) = (let (es', es'') = split_n es n in (e#es', es''))"

lemma split_n_conv_take_drop: "split_n es n = (take n es, drop n es)"
  by (induction es n rule: split_n.induct, simp_all)

lemma split_n_length:
  assumes "split_n es n = (es1, es2)" "length es  n"
  shows "length es1 = n"
  using assms
  unfolding split_n_conv_take_drop
  by fastforce

lemma split_n_conv_app:
  assumes "split_n es n = (es1, es2)"
  shows "es = es1@es2"
  using assms
  unfolding split_n_conv_take_drop
  by auto

lemma app_conv_split_n:
  assumes "es = es1@es2"
  shows "split_n es (length es1) = (es1, es2)"
  using assms
  unfolding split_n_conv_take_drop
  by auto

lemma split_vals_const_list: "split_vals (map EConst vs) = (vs, [])"
  by (induction vs, simp_all)

lemma split_vals_e_const_list: "split_vals_e ($$* vs) = (vs, [])"
  by (induction vs, simp_all)

lemma split_vals_e_conv_app:
  assumes "split_vals_e xs = (as, bs)"
  shows "xs = ($$* as)@bs"
  using assms
proof (induction xs arbitrary: as rule: split_vals_e.induct)
  case (1 v es)
  obtain as' bs' where "split_vals_e es = (as', bs')"
    by (meson surj_pair)
  thus ?case
    using 1
    by fastforce
qed simp_all

abbreviation expect :: "'a option  ('a  'b)  'b  'b" where
  "expect a f b  (case a of
                     Some a'  f a'
                   | None  b)"

abbreviation vs_to_es :: " v list  e list"
  where "vs_to_es v  $$* (rev v)"

definition e_is_trap :: "e  bool" where
  "e_is_trap e = (case e of Trap  True | _  False)"

definition es_is_trap :: "e list  bool" where
  "es_is_trap es = (case es of [e]  e_is_trap e | _  False)"

lemma[simp]: "e_is_trap e = (e = Trap)"
  using e_is_trap_def
  by (cases "e") auto

lemma[simp]: "es_is_trap es = (es = [Trap])"
proof (cases es)
  case Nil
  thus ?thesis
    using es_is_trap_def
    by auto
next
  case outer_Cons:(Cons a list)
  thus ?thesis
  proof (cases list)
    case Nil
    thus ?thesis
      using outer_Cons es_is_trap_def
      by auto
  next
    case (Cons a' list')
    thus ?thesis
      using es_is_trap_def outer_Cons
      by auto
  qed
qed

axiomatization 
  mem_grow_impl:: "mem  nat  mem option" where
  mem_grow_impl_correct:"(mem_grow_impl m n = Some m')  (mem_grow m n = m')"

(*
definition mem_grow_impl:: "mem ⇒ nat ⇒ mem option" where
  "mem_grow_impl m n = Some (mem_grow m n)"

lemma mem_grow_impl_correct:
  "(mem_grow_impl m n = Some m') ⟹ (mem_grow m n = m')"
  unfolding mem_grow_impl_def
*)
axiomatization 
  host_apply_impl:: "s  tf  host  v list  (s × v list) option" where
  host_apply_impl_correct:"(host_apply_impl s tf h vs = Some m')  (hs. host_apply s tf h vs hs = Some m')"

function (sequential)                                                                               
    run_step :: "depth  nat  config_tuple  res_tuple"
and run_one_step :: "depth  nat  config_one_tuple  res_tuple" where
  "run_step d i (s,vs,es) = (let (ves, es') = split_vals_e es in
                             case es' of
                               []  (s,vs, crash_error)
                             | e#es'' 
                               if e_is_trap e
                                 then
                                   if (es''  []  ves  [])
                                     then
                                       (s, vs, RSNormal [Trap])
                                     else
                                       (s, vs, crash_error)
                                 else
                                   (let (s',vs',r) = run_one_step d i (s,vs,(rev ves),e) in
                                    case r of
                                      RSNormal res  (s', vs', RSNormal (res@es''))
                                  | _  (s', vs', r)))"
| "run_one_step d i (s, vs, ves, e) =
     (case e of
    ― ‹B_E›
      ― ‹UNOPS›
        $(Unop_i T_i32 iop) 
         (case ves of
            (ConstInt32 c)#ves' 
              (s, vs, RSNormal (vs_to_es ((ConstInt32 (app_unop_i iop c))#ves')))
          | _  (s, vs, crash_error))
      | $(Unop_i T_i64 iop) 
          (case ves of
             (ConstInt64 c)#ves' 
               (s, vs, RSNormal (vs_to_es ((ConstInt64 (app_unop_i iop c))#ves')))
           | _  (s, vs, crash_error))
      | $(Unop_i _ iop)  (s, vs, crash_error)
      | $(Unop_f T_f32 fop) 
          (case ves of
             (ConstFloat32 c)#ves' 
               (s, vs, RSNormal (vs_to_es ((ConstFloat32 (app_unop_f fop c))#ves')))
           | _  (s, vs, crash_error))
      | $(Unop_f T_f64 fop) 
          (case ves of
             (ConstFloat64 c)#ves' 
               (s, vs, RSNormal (vs_to_es ((ConstFloat64 (app_unop_f fop c))#ves')))
           | _  (s, vs, crash_error))
      | $(Unop_f _ fop)  (s, vs, crash_error)
      ― ‹BINOPS›
      | $(Binop_i T_i32 iop) 
          (case ves of
             (ConstInt32 c2)#(ConstInt32 c1)#ves' 
                expect (app_binop_i iop c1 c2) (λc. (s, vs, RSNormal (vs_to_es ((ConstInt32 c)#ves')))) (s, vs, RSNormal ((vs_to_es ves')@[Trap]))
           | _  (s, vs, crash_error))
      | $(Binop_i T_i64 iop) 
          (case ves of
             (ConstInt64 c2)#(ConstInt64 c1)#ves' 
                expect (app_binop_i iop c1 c2) (λc. (s, vs, RSNormal (vs_to_es ((ConstInt64 c)#ves')))) (s, vs, RSNormal ((vs_to_es ves')@[Trap]))
           | _  (s, vs, crash_error))
      | $(Binop_i _ iop)  (s, vs, crash_error)
      | $(Binop_f T_f32 fop) 
          (case ves of
             (ConstFloat32 c2)#(ConstFloat32 c1)#ves' 
                expect (app_binop_f fop c1 c2) (λc. (s, vs, RSNormal (vs_to_es ((ConstFloat32 c)#ves')))) (s, vs, RSNormal ((vs_to_es ves')@[Trap]))
           | _  (s, vs, crash_error))
      | $(Binop_f T_f64 fop) 
        (case ves of
           (ConstFloat64 c2)#(ConstFloat64 c1)#ves' 
              expect (app_binop_f fop c1 c2) (λc. (s, vs, RSNormal (vs_to_es ((ConstFloat64 c)#ves')))) (s, vs, RSNormal ((vs_to_es ves')@[Trap]))
         | _  (s, vs, crash_error))
      | $(Binop_f _ fop)  (s, vs, crash_error)
      ― ‹TESTOPS›
      | $(Testop T_i32 testop) 
          (case ves of
             (ConstInt32 c)#ves' 
               (s, vs, RSNormal (vs_to_es ((ConstInt32 (wasm_bool (app_testop_i testop c)))#ves')))
           | _  (s, vs, crash_error))
      | $(Testop T_i64 testop) 
          (case ves of
             (ConstInt64 c)#ves' 
               (s, vs, RSNormal (vs_to_es ((ConstInt32 (wasm_bool (app_testop_i testop c)))#ves')))
           | _  (s, vs, crash_error))
      | $(Testop _ testop)  (s, vs, crash_error)
      ― ‹RELOPS›
      | $(Relop_i T_i32 iop) 
          (case ves of
             (ConstInt32 c2)#(ConstInt32 c1)#ves' 
               (s, vs, RSNormal (vs_to_es ((ConstInt32 (wasm_bool (app_relop_i iop c1 c2)))#ves')))
           | _  (s, vs, crash_error))
      | $(Relop_i T_i64 iop) 
          (case ves of
             (ConstInt64 c2)#(ConstInt64 c1)#ves' 
               (s, vs, RSNormal (vs_to_es ((ConstInt32 (wasm_bool (app_relop_i iop c1 c2)))#ves')))
           | _  (s, vs, crash_error))
      | $(Relop_i _ iop)  (s, vs, crash_error)
      | $(Relop_f T_f32 fop) 
          (case ves of
             (ConstFloat32 c2)#(ConstFloat32 c1)#ves' 
               (s, vs, RSNormal (vs_to_es ((ConstInt32 (wasm_bool (app_relop_f fop c1 c2)))#ves')))
           | _  (s, vs, crash_error))
      | $(Relop_f T_f64 fop) 
          (case ves of
             (ConstFloat64 c2)#(ConstFloat64 c1)#ves' 
               (s, vs, RSNormal (vs_to_es ((ConstInt32 (wasm_bool (app_relop_f fop c1 c2)))#ves')))
           | _  (s, vs, crash_error))
      | $(Relop_f _ fop)  (s, vs, crash_error)
      ― ‹CONVERT›
      | $(Cvtop t2 Convert t1 sx) 
          (case ves of
             v#ves' 
               (if (types_agree t1 v)
                  then
                    expect (cvt t2 sx v) (λv'. (s, vs, RSNormal (vs_to_es (v'#ves')))) (s, vs, RSNormal ((vs_to_es ves')@[Trap]))
                  else
                    (s, vs, crash_error))
           | _  (s, vs, crash_error))
      | $(Cvtop t2 Reinterpret t1 sx) 
          (case ves of
             v#ves' 
               (if (types_agree t1 v  sx = None)
                  then
                    (s, vs, RSNormal (vs_to_es ((wasm_deserialise (bits v) t2)#ves')))
                  else
                    (s, vs, crash_error))
           | _  (s, vs, crash_error))
      ― ‹UNREACHABLE›
      | $Unreachable 
          (s, vs, RSNormal ((vs_to_es ves)@[Trap]))
      ― ‹NOP›
      | $Nop 
          (s, vs, RSNormal (vs_to_es ves))
      ― ‹DROP›
      | $Drop 
          (case ves of
             v#ves' 
               (s, vs, RSNormal (vs_to_es ves'))
           | _  (s, vs, crash_error))
      ― ‹SELECT›
      | $Select 
          (case ves of
             (ConstInt32 c)#v2#v1#ves' 
               (if int_eq c 0 then (s, vs, RSNormal (vs_to_es (v2#ves'))) else (s, vs, RSNormal (vs_to_es (v1#ves'))))
           | _  (s, vs, crash_error))
      ― ‹BLOCK›
      | $(Block (t1s _> t2s) es) 
          (if length ves  length t1s
             then
               let (ves', ves'') = split_n ves (length t1s) in
               (s, vs, RSNormal ((vs_to_es ves'') @ [Label (length t2s) [] ((vs_to_es ves')@($* es))]))
             else
               (s, vs, crash_error))
      ― ‹LOOP›
      | $(Loop (t1s _> t2s) es) 
          (if length ves  length t1s
             then
               let (ves', ves'') = split_n ves (length t1s) in
               (s, vs, RSNormal ((vs_to_es ves'') @ [Label (length t1s) [$(Loop (t1s _> t2s) es)] ((vs_to_es ves')@($* es))]))
             else
               (s, vs, crash_error))
      ― ‹IF›
      | $(If tf es1 es2) 
          (case ves of
             (ConstInt32 c)#ves' 
                if int_eq c 0
                  then
                    (s, vs, RSNormal ((vs_to_es ves')@[$(Block tf es2)]))
                  else
                    (s, vs, RSNormal ((vs_to_es ves')@[$(Block tf es1)]))
           | _  (s, vs, crash_error))
      ― ‹BR›
      | $Br j 
          (s, vs, RSBreak j ves)
      ― ‹BR_IF›
      | $Br_if j 
          (case ves of
             (ConstInt32 c)#ves' 
                if int_eq c 0
                  then
                    (s, vs, RSNormal (vs_to_es ves'))
                  else
                    (s, vs, RSNormal ((vs_to_es ves') @ [$Br j]))
           | _  (s, vs, crash_error))
      ― ‹BR_TABLE›
      | $Br_table js j 
          (case ves of
             (ConstInt32 c)#ves' 
             let k = nat_of_int c in
                if k < length js
                  then
                    (s, vs, RSNormal ((vs_to_es ves') @ [$Br (js!k)]))
                  else
                    (s, vs, RSNormal ((vs_to_es ves') @ [$Br j]))
           | _  (s, vs, crash_error))
      ― ‹CALL›
      | $Call j 
          (s, vs, RSNormal ((vs_to_es ves) @ [Callcl (sfunc s i j)]))
      ― ‹CALL_INDIRECT›
      | $Call_indirect j 
          (case ves of
             (ConstInt32 c)#ves' 
               (case (stab s i (nat_of_int c)) of
                  Some cl 
                    if (stypes s i j = cl_type cl)
                      then
                        (s, vs, RSNormal ((vs_to_es ves') @ [Callcl cl]))
                      else
                        (s, vs, RSNormal ((vs_to_es ves')@[Trap]))
                | _  (s, vs, RSNormal ((vs_to_es ves')@[Trap])))
           | _  (s, vs, crash_error))
      ― ‹RETURN›
      | $Return 
          (s, vs, RSReturn ves)
      ― ‹GET_LOCAL›
      | $Get_local j 
          (if j < length vs
             then (s, vs, RSNormal (vs_to_es ((vs!j)#ves)))
             else (s, vs, crash_error))
      ― ‹SET_LOCAL›
      | $Set_local j 
          (case ves of
             v#ves' 
               if j < length vs
                 then (s, vs[j := v], RSNormal (vs_to_es ves'))
                 else (s, vs, crash_error)
           | _  (s, vs, crash_error))
      ― ‹TEE_LOCAL›
      | $Tee_local j 
          (case ves of
             v#ves' 
               (s, vs, RSNormal ((vs_to_es (v#ves)) @ [$(Set_local j)]))
           | _  (s, vs, crash_error))
      ― ‹GET_GLOBAL›
      | $Get_global j 
          (s, vs, RSNormal (vs_to_es ((sglob_val s i j)#ves)))
      ― ‹SET_GLOBAL›
      | $Set_global j 
          (case ves of
             v#ves'  ((supdate_glob s i j v), vs, RSNormal (vs_to_es ves'))
           | _  (s, vs, crash_error))
      ― ‹LOAD›
      | $(Load t None a off) 
          (case ves of
             (ConstInt32 k)#ves' 
               expect (smem_ind s i)
                  (λj.
                    expect (load ((mem s)!j) (nat_of_int k) off (t_length t))
                      (λbs. (s, vs, RSNormal (vs_to_es ((wasm_deserialise bs t)#ves'))))
                      (s, vs, RSNormal ((vs_to_es ves')@[Trap])))
                  (s, vs, crash_error)
           | _  (s, vs, crash_error))
      ― ‹LOAD PACKED›
      | $(Load t (Some (tp, sx)) a off) 
          (case ves of
             (ConstInt32 k)#ves' 
               expect (smem_ind s i)
                  (λj.
                    expect (load_packed sx ((mem s)!j) (nat_of_int k) off (tp_length tp) (t_length t))
                      (λbs. (s, vs, RSNormal (vs_to_es ((wasm_deserialise bs t)#ves'))))
                      (s, vs, RSNormal ((vs_to_es ves')@[Trap])))
                  (s, vs, crash_error)
           | _  (s, vs, crash_error))
      ― ‹STORE›
      | $(Store t None a off) 
          (case ves of
             v#(ConstInt32 k)#ves' 
               (if (types_agree t v)
                 then
                   expect (smem_ind s i)
                      (λj.
                         expect (store ((mem s)!j) (nat_of_int k) off (bits v) (t_length t))
                           (λmem'. (smem:= ((mem s)[j := mem']), vs, RSNormal (vs_to_es ves')))
                           (s, vs, RSNormal ((vs_to_es ves')@[Trap])))
                      (s, vs, crash_error)
                 else
                   (s, vs, crash_error))
           | _  (s, vs, crash_error))
      ― ‹STORE_PACKED›
      | $(Store t (Some tp) a off) 
          (case ves of
                  v#(ConstInt32 k)#ves' 
                    (if (types_agree t v)
                      then
                        expect (smem_ind s i)
                           (λj.
                              expect (store_packed ((mem s)!j) (nat_of_int k) off (bits v) (tp_length tp))
                                (λmem'. (smem:= ((mem s)[j := mem']), vs, RSNormal (vs_to_es ves')))
                                (s, vs, RSNormal ((vs_to_es ves')@[Trap])))
                           (s, vs, crash_error)
                      else
                        (s, vs, crash_error))
                | _  (s, vs, crash_error))
      ― ‹CURRENT_MEMORY›
      | $Current_memory 
          expect (smem_ind s i)
            (λj. (s, vs, RSNormal (vs_to_es ((ConstInt32 (int_of_nat (mem_size ((s.mem s)!j))))#ves))))
            (s, vs, crash_error)
      ― ‹GROW_MEMORY›
      | $Grow_memory 
          (case ves of
             (ConstInt32 c)#ves' 
                expect (smem_ind s i)
                  (λj.
                     let l = (mem_size ((s.mem s)!j)) in
                     (expect (mem_grow_impl ((mem s)!j) (nat_of_int c))
                        (λmem'. (smem:= ((mem s)[j := mem']), vs, RSNormal (vs_to_es ((ConstInt32 (int_of_nat l))#ves'))))
                        (s, vs, RSNormal (vs_to_es ((ConstInt32 int32_minus_one)#ves')))))
                  (s, vs, crash_error)
           | _  (s, vs, crash_error))
      ― ‹VAL› - should not be executed›
      | $C v  (s, vs, crash_error)
    ― ‹E›
      ― ‹CALLCL›
      | Callcl cl 
          (case cl of
             Func_native i' (t1s _> t2s) ts es 
               let n = length t1s in
               let m = length t2s in
               if length ves  n
                 then
                   let (ves', ves'') = split_n ves n in
                   let zs = n_zeros ts in
                     (s, vs, RSNormal ((vs_to_es ves'') @ ([Local m i' ((rev ves')@zs) [$(Block ([] _> t2s) es)]])))
                 else
                   (s, vs, crash_error)
           | Func_host (t1s _> t2s) f 
               let n = length t1s in
               let m = length t2s in
               if length ves  n
                 then
                   let (ves', ves'') = split_n ves n in
                   case host_apply_impl s (t1s _> t2s) f (rev ves') of
                     Some (s',rves)  
                       if list_all2 types_agree t2s rves
                         then
                           (s', vs, RSNormal ((vs_to_es ves'') @ ($$* rves)))
                         else
                           (s', vs, crash_error)
                   | None  (s, vs, RSNormal ((vs_to_es ves'')@[Trap]))
                 else
                   (s, vs, crash_error))
      ― ‹LABEL›
      | Label ln les es 
          if es_is_trap es
            then
              (s, vs, RSNormal ((vs_to_es ves)@[Trap]))
             else
               (if (const_list es)
                  then
                    (s, vs, RSNormal ((vs_to_es ves)@es))
                  else
                    let (s', vs', res) = run_step d i (s, vs, es) in
                    (case res of
                       RSBreak 0 bvs 
                         if (length bvs  ln)
                           then (s', vs', RSNormal ((vs_to_es ((take ln bvs)@ves))@les))
                           else (s', vs', crash_error)
                     | RSBreak (Suc n) bvs 
                         (s', vs', RSBreak n bvs)
                     | RSReturn rvs 
                         (s', vs', RSReturn rvs)
                     | RSNormal es' 
                         (s', vs', RSNormal ((vs_to_es ves)@[Label ln les es']))
                     | _  (s', vs', crash_error)))
     ― ‹LOCAL›
     | Local ln j vls es 
          if es_is_trap es
            then
              (s, vs, RSNormal ((vs_to_es ves)@[Trap]))
             else
               (if (const_list es)
                  then
                    if (length es = ln)
                      then (s, vs, RSNormal ((vs_to_es ves)@es))
                      else (s, vs, crash_error)
                  else
                    case d of
                      0  (s, vs, crash_error)
                    | Suc d' 
                        let (s', vls', res) = run_step d' j (s, vls, es) in
                        (case res of
                           RSReturn rvs 
                             if (length rvs  ln)
                               then (s', vs, RSNormal (vs_to_es ((take ln rvs)@ves)))
                               else (s', vs, crash_error)
                         | RSNormal es' 
                             (s', vs, RSNormal ((vs_to_es ves)@[Local ln j vls' es']))
                         | _  (s', vs, RSCrash CExhaustion)))
     ― ‹TRAP› - should not be executed›
     | Trap  (s, vs, crash_error))"
  by pat_completeness auto
termination
proof -
  {
    fix xs::"e list" and as b bs
    assume local_assms:"(as, b#bs) = split_vals_e xs"
    have "2*(size b) < 2*(size_list size xs) + 1"
      using local_assms[symmetric] split_vals_e_conv_app
            size_list_estimation'[of b xs "size b" size]
      unfolding size_list_def
      by fastforce
  }
  thus ?thesis
    by (relation "measure (case_sum
                               (λp. 2 * (size_list size (snd (snd (snd (snd p))))) + 1)
                               (λp. 2 * size (snd (snd (snd (snd (snd p)))))))") auto
qed

fun run_v :: "fuel  depth  nat  config_tuple  (s × res)" where
  "run_v (Suc n) d i (s,vs,es) = (if (es_is_trap es)
                                    then (s, RTrap)
                                    else if (const_list es)
                                           then (s, RValue (fst (split_vals_e es)))
                                           else (let (s',vs',res) = (run_step d i (s,vs,es)) in
                                                 case res of
                                                   RSNormal es'  run_v n d i (s',vs',es')
                                                 | RSCrash error  (s, RCrash error)
                                                 | _  (s, RCrash CError)))"
| "run_v 0 d i (s,vs,es) = (s, RCrash CExhaustion)"

end