Theory DiffArray_ST

(* Author: Peter Lammich *)

section ‹Single Threaded Arrays›
theory DiffArray_ST
imports DiffArray_Base
begin



subsection ‹Primitive Operations›

typedef 'a starray = "UNIV :: 'a array set" 
  morphisms Rep_starray STArray
  by blast
setup_lifting type_definition_starray

lift_definition starray_new :: "nat  'a  'a starray" is "array_new" .

lift_definition starray_tabulate :: "nat  (nat  'a)  'a starray" is "array_tabulate" .
  
lift_definition starray_length :: "'a starray  nat" is array_length .
  
lift_definition starray_get :: "'a starray  nat  'a" is array_get .

lift_definition starray_set :: "'a starray  nat  'a  'a starray" is array_set .
  
lift_definition starray_of_list :: "'a list  'a starray" is array_of_list .

lift_definition starray_grow :: "'a starray  nat  'a  'a starray" is "array_grow" .

lift_definition starray_take :: "'a starray  nat  'a starray" is array_take .

lift_definition starray_get_oo :: "'a  'a starray  nat  'a" is array_get_oo .

lift_definition starray_set_oo :: "(unit  'a starray)  'a starray  nat  'a  'a starray" is array_set_oo .

lift_definition starray_map :: "('a  'b)  'a starray  'b starray" is array_map .

lift_definition starray_fold :: "('a  'b  'b)  'a starray  'b  'b" is array_fold .

lift_definition starray_foldr :: "('a  'b  'b)  'a starray  'b  'b" is array_foldr .


definition "starray_α = array_α o Rep_starray"  
 
subsubsection ‹Refinement Lemmas›

context
  notes [simp] = STArray_inverse array_eq_iff starray_α_def
begin

  lemma starray_α_inj: "starray_α a = starray_α b  a=b" unfolding starray_α_def by transfer auto

  lemma starray_eq_iff: "a=b  starray_α a = starray_α b" unfolding starray_α_def by transfer auto
  
  lemma starray_new_refine[simp,array_refine]: "starray_α (starray_new n a) = replicate n a" unfolding starray_α_def by transfer auto

  lemma starray_tabulate_refine[simp,array_refine]: "starray_α (starray_tabulate n f) = tabulate n f" unfolding starray_α_def by transfer auto
  
  lemma starray_length_refine[simp,array_refine]: "starray_length a = length (starray_α a)" unfolding starray_α_def by transfer auto
  
  lemma starray_get_refine[simp,array_refine]: "starray_get a i = starray_α a ! i" unfolding starray_α_def by transfer auto

  lemma starray_set_refine[simp,array_refine]: "starray_α (starray_set a i x) = (starray_α a)[i := x]" unfolding starray_α_def by transfer auto

  lemma starray_of_list_refine[simp,array_refine]: "starray_α (starray_of_list xs) = xs" unfolding starray_α_def by transfer auto

  lemma starray_grow_refine[simp,array_refine]: 
    "starray_α (starray_grow a n d) = take n (starray_α a) @ replicate (n-length (starray_α a)) d"
    unfolding starray_α_def by transfer auto

  lemma starray_take_refine[simp,array_refine]: "starray_α (starray_take a n) = take n (starray_α a)"
    unfolding starray_α_def by transfer auto
  
  lemma starray_get_oo_refine[simp,array_refine]: "starray_get_oo x a i = (if i<length (starray_α a) then starray_α a!i else x)" unfolding starray_α_def by transfer auto

  lemma starray_set_oo_refine[simp,array_refine]: "starray_α (starray_set_oo f a i x) 
    = (if i<length (starray_α a) then (starray_α a)[i:=x] else starray_α (f ()))" 
    unfolding starray_α_def by transfer auto

  lemma starray_map_refine[simp,array_refine]: "starray_α (starray_map f a) = map f (starray_α a)"
    unfolding starray_α_def by transfer auto

  lemma starray_fold_refine[simp, array_refine]: "starray_fold f a s = fold f (starray_α a) s"  
    unfolding starray_α_def by transfer auto
    
  lemma starray_foldr_refine[simp, array_refine]: "starray_foldr f a s = foldr f (starray_α a) s"  
    unfolding starray_α_def by transfer auto
    
              
end  

lifting_update starray.lifting
lifting_forget starray.lifting

subsection ‹Code Generator Setup›

subsubsection ‹Code-Numeral Preparation›


definition [code del]: "starray_new' == starray_new o nat_of_integer"
definition [code del]: "starray_tabulate' n f  starray_tabulate (nat_of_integer n) (f o integer_of_nat)"

definition [code del]: "starray_length' == integer_of_nat o starray_length"
definition [code del]: "starray_get' a == starray_get a o nat_of_integer"
definition [code del]: "starray_set' a == starray_set a o nat_of_integer"
definition [code del]:
  "starray_get_oo' x a == starray_get_oo x a o nat_of_integer"
definition [code del]:
  "starray_set_oo' f a == starray_set_oo f a o nat_of_integer"


lemma [code]:
  "starray_new == starray_new' o integer_of_nat"
  "starray_tabulate n f == starray_tabulate' (integer_of_nat n) (f o nat_of_integer)"
  "starray_length == nat_of_integer o starray_length'"
  "starray_get a == starray_get' a o integer_of_nat"
  "starray_set a == starray_set' a o integer_of_nat"
  "starray_get_oo x a == starray_get_oo' x a o integer_of_nat"
  "starray_set_oo g a == starray_set_oo' g a o integer_of_nat"
  by (simp_all
    del: array_refine
    add: o_def
    add: starray_new'_def starray_tabulate'_def starray_length'_def starray_get'_def starray_set'_def
      starray_get_oo'_def starray_set_oo'_def)

text ‹Fallbacks›

lemmas starray_get_oo'_fallback[code] = starray_get_oo'_def[unfolded starray_get_oo_def[abs_def]]
lemmas starray_set_oo'_fallback[code] = starray_set_oo'_def[unfolded starray_set_oo_def[abs_def]]

lemma starray_tabulate'_fallback[code]: 
  "starray_tabulate' n f = starray_of_list (map (f o integer_of_nat) [0..<nat_of_integer n])"
  unfolding starray_tabulate'_def 
  by (simp add: starray_eq_iff tabulate_def)

lemma starray_new'_fallback[code]: "starray_new' n x = starray_of_list (replicate (nat_of_integer n) x)"  
  by (simp add: starray_new'_def starray_eq_iff)
  

(*
  Primitive operations, to be implemented for target:
  
    starray_of_list
    starray_tabulate' (dflt via of_list)
    starray_new' (dflt via of_list)

    starray_length'
    starray_get'
    starray_set'
    starray_get_oo' (dflt via array_get)

*)




code_printing code_module "STArray" 
  (SML)
‹
structure STArray = struct

datatype 'a Cell = Invalid | Value of 'a array;

exception AccessedOldVersion;

type 'a starray = 'a Cell Unsynchronized.ref;

fun fromList l = Unsynchronized.ref (Value (Array.fromList l));
fun starray (size, v) = Unsynchronized.ref (Value (Array.array (size,v)));
fun tabulate (size, f) = Unsynchronized.ref (Value (Array.tabulate(size, f)));
fun sub (Unsynchronized.ref Invalid, idx) = raise AccessedOldVersion |
    sub (Unsynchronized.ref (Value a), idx) = Array.sub (a,idx);
fun update (aref,idx,v) =
  case aref of
    (Unsynchronized.ref Invalid) => raise AccessedOldVersion |
    (Unsynchronized.ref (Value a)) => (
      aref := Invalid;
      Array.update (a,idx,v);
      Unsynchronized.ref (Value a)
    );

fun length (Unsynchronized.ref Invalid) = raise AccessedOldVersion |
    length (Unsynchronized.ref (Value a)) = Array.length a


structure IsabelleMapping = struct
type 'a ArrayType = 'a starray;

fun starray_new (n:IntInf.int) (a:'a) = starray (IntInf.toInt n, a);
fun starray_of_list (xs:'a list) = fromList xs;

fun starray_tabulate (n:IntInf.int) (f:IntInf.int -> 'a) = tabulate (IntInf.toInt n, f o IntInf.fromInt)

fun starray_length (a:'a ArrayType) = IntInf.fromInt (length a);

fun starray_get (a:'a ArrayType) (i:IntInf.int) = sub (a, IntInf.toInt i);

fun starray_set (a:'a ArrayType) (i:IntInf.int) (e:'a) = update (a, IntInf.toInt i, e);

fun starray_get_oo (d:'a) (a:'a ArrayType) (i:IntInf.int) =
  sub (a,IntInf.toInt i) handle Subscript => d

fun starray_set_oo (d:(unit->'a ArrayType)) (a:'a ArrayType) (i:IntInf.int) (e:'a) =
  update (a, IntInf.toInt i, e) handle Subscript => d ()


end;

end;
›


code_printing
  type_constructor starray  (SML) "_/ STArray.IsabelleMapping.ArrayType"
| constant STArray  (SML) "STArray.IsabelleMapping.starray'_of'_list"
| constant starray_new'  (SML) "STArray.IsabelleMapping.starray'_new"
| constant starray_tabulate'  (SML) "STArray.IsabelleMapping.starray'_tabulate"
| constant starray_length'  (SML) "STArray.IsabelleMapping.starray'_length"
| constant starray_get'  (SML) "STArray.IsabelleMapping.starray'_get"
| constant starray_set'  (SML) "STArray.IsabelleMapping.starray'_set"
| constant starray_of_list  (SML) "STArray.IsabelleMapping.starray'_of'_list"
| constant starray_get_oo'  (SML) "STArray.IsabelleMapping.starray'_get'_oo"
| constant starray_set_oo'  (SML) "STArray.IsabelleMapping.starray'_set'_oo"


subsection ‹Tests› 
(* TODO: Add more systematic tests! *)

definition "test1  
  let a=starray_of_list [1,2,3,4,5,6];
      b=starray_tabulate 6 (Suc);
      a'=starray_set a 3 42;
      b'=starray_set b 3 42;
      c=starray_new 6 0
  in
    i{0..<6}. 
      starray_get a' i = (if i=3 then 42 else i+1)  
     starray_get b' i = (if i=3 then 42 else i+1)  
     starray_get c i = (0::nat)
          "

lemma enum_rangeE:
  assumes "i{l..<h}"
  assumes "P l"
  assumes "i{Suc l..<h}  P i"
  shows "P i"
  using assms
  by (metis atLeastLessThan_iff less_eq_Suc_le nat_less_le)          
          
          
lemma "test1"
  unfolding test1_def Let_def
  apply (intro ballI conjI)
  apply (erule enum_rangeE, (simp; fail))+ apply simp
  apply (erule enum_rangeE, (simp; fail))+ apply simp
  apply (erule enum_rangeE, (simp; fail))+ apply simp
  done  
  
ML_val if not @{code test1} then error "ERROR" else ()          

export_code test1 checking OCaml? Haskell? SML


hide_const test1
hide_fact test1_def


experiment
begin

fun allTrue :: "bool list  nat  bool list" where
"allTrue a 0 = a" |
"allTrue a (Suc i) = (allTrue a i)[i := True]"

lemma length_allTrue: "n  length a   length(allTrue a n) = length a"
by(induction n) (auto)

lemma "n  length a  i < n. (allTrue a n) ! i"
by(induction n) (auto simp: nth_list_update length_allTrue)


fun allTrue' :: "bool array  nat  bool array" where
"allTrue' a 0 = a" |
"allTrue' a (Suc i) = array_set (allTrue' a i) i True"


lemma "array_α (allTrue' xs i) = allTrue (array_α xs) i"
  apply (induction xs i rule: allTrue'.induct)
  apply auto
  done


end



end