Theory Imperative

theory Imperative
imports UpDown_Scheme Sep_Main
section ‹ Imperative Version ›

theory Imperative
imports UpDown_Scheme Separation_Logic_Imperative_HOL.Sep_Main
begin

type_synonym pointmap = "grid_point ⇒ nat"
type_synonym impgrid = "rat array"

instance rat :: heap ..

primrec rat_pair where "rat_pair (a, b) = (of_rat a, of_rat b)"

declare rat_pair.simps [simp del]

definition
   zipWithA :: "('a::heap ⇒ 'b::heap ⇒ 'a::heap) ⇒ 'a array ⇒ 'b array ⇒ 'a array Heap"
where
  "zipWithA f a b = do {
     n ← Array.len a;
     Heap_Monad.fold_map (λn. do {
       x ← Array.nth a n ;
       y ← Array.nth b n ;
       Array.upd n (f x y) a
     }) [0..<n];
     return a
   }"

theorem zipWithA [sep_heap_rules]:
  fixes xs ys :: "'a::heap list"
  assumes "length xs = length ys"
  shows "< a ↦a xs * b ↦a ys > zipWithA f a b < λr. (a ↦a map (case_prod f) (zip xs ys)) * b ↦a ys * ↑(a = r) >"
proof -
  { fix n and xs :: "'a list"
    let ?part_res = "λn xs. (map (case_prod f) (zip (take n xs) (take n ys)) @ drop n xs)"
    assume "n ≤ length xs" "length xs = length ys"
    then have "< a ↦a xs * b ↦a ys > Heap_Monad.fold_map (λn. do {
         x ← Array.nth a n ;
         y ← Array.nth b n ;
         Array.upd n (f x y) a
       }) [0..<n] < λr. a ↦a ?part_res n xs * b ↦a ys >"
    proof (induct n arbitrary: xs)
      case 0 then show ?case by sep_auto
    next
      case (Suc n)
      note Suc.hyps [sep_heap_rules]
      have *: "(?part_res n xs)[n := f (?part_res n xs ! n) (ys ! n)] =  ?part_res (Suc n) xs"
        using Suc.prems by (simp add: nth_append take_Suc_conv_app_nth upd_conv_take_nth_drop)
      from Suc.prems show ?case
        by (sep_auto simp add: fold_map_append *)
    qed }
  note this[sep_heap_rules]
  show ?thesis
    unfolding zipWithA_def
    by (sep_auto simp add: assms)
qed

definition copy_array :: "'a::heap array ⇒ ('a::heap array) Heap" where
  "copy_array a = Array.freeze a ⤜ Array.of_list"

theorem copy_array [sep_heap_rules]:
  "< a ↦a xs > copy_array a < λr. a ↦a xs * r ↦a xs >"
  unfolding copy_array_def
  by sep_auto

definition sum_array :: "rat array ⇒ rat array ⇒ unit Heap" where
  "sum_array a b = zipWithA (+) a b ⪢ return ()"

theorem sum_array [sep_heap_rules]:
  fixes xs ys :: "rat list"
  shows "length xs = length ys ⟹ < a ↦a xs * b ↦a ys > sum_array a b < λr. (a ↦a map (λ(a, b). a + b) (zip xs ys)) * b ↦a ys >"
  unfolding sum_array_def by sep_auto

locale linearization =
  fixes dm lm :: nat
  fixes pm :: pointmap
  assumes pm: "bij_betw pm (sparsegrid dm lm) {..< card (sparsegrid dm lm)}"
begin

lemma linearizationD:
  "p ∈ sparsegrid dm lm ⟹ pm p < card (sparsegrid dm lm)"
  using pm by (auto simp: bij_betw_def)

definition gridI :: "impgrid ⇒ (grid_point ⇒ real) ⇒ assn" where
  "gridI a v =
    (∃A xs. a ↦a xs * ↑((∀p∈sparsegrid dm lm. v p = of_rat (xs ! pm p)) ∧ length xs = card (sparsegrid dm lm)))"

lemma gridI_nth_rule [sep_heap_rules]:
  "g ∈ sparsegrid dm lm ⟹ < gridI a v > Array.nth a (pm g) <λr. gridI a v * ↑ (of_rat r = v g)>"
  using pm by (sep_auto simp: bij_betw_def gridI_def)

lemma gridI_upd_rule [sep_heap_rules]:
  "g ∈ sparsegrid dm lm ⟹
    < gridI a v > Array.upd (pm g) x a <λa'. gridI a (fun_upd v g (of_rat x)) * ↑(a' = a)>"
  unfolding gridI_def using pm
  by (sep_auto simp: bij_betw_def inj_onD intro!: nth_list_update_eq[symmetric] nth_list_update_neq[symmetric])

primrec upI' :: "nat ⇒ nat ⇒ grid_point ⇒ impgrid ⇒ (rat * rat) Heap" where
  "upI' d       0 p a = return (0, 0)" |
  "upI' d (Suc l) p a = do {
       (fl, fml) ← upI' d l (child p left d) a ;
       (fmr, fr) ← upI' d l (child p right d) a ;
       val ← Array.nth a (pm p) ;
       Array.upd (pm p) (fml + fmr) a ;
       let result = ((fml + fmr + val / 2 ^ (lv p d) / 2) / 2) ;
       return (fl + result, fr + result)
     }"

lemma upI' [sep_heap_rules]:
  assumes lin[simp]: "d < dm"
    and l: "level p + l = lm" "l = 0 ∨ p ∈ sparsegrid dm lm"
  shows "< gridI a v > upI' d l p a <λr. let (r', v') = up' d l p v in gridI a v' * ↑(rat_pair r = r') >"
  using l
proof (induct l arbitrary: p v)
  note rat_pair.simps [simp]
  case 0 then show ?case by sep_auto
next
  case (Suc l)
  from Suc.prems ‹d < dm›
  have [simp]: "level (child p left d) + l = lm" "level (child p right d) + l = lm" "p ∈ sparsegrid dm lm"
    by (auto simp: sparsegrid_length)

  have [simp]: "child p left d ∉ sparsegrid dm lm ⟹ l = 0" "child p right d ∉ sparsegrid dm lm ⟹ l = 0"
    using Suc.prems by (auto simp: sparsegrid_def lgrid_def)

  note Suc(1)[sep_heap_rules]
  show ?case
    by (sep_auto split: prod.split simp: of_rat_add of_rat_divide of_rat_power of_rat_mult rat_pair_def Let_def)
qed

primrec downI' :: "nat ⇒ nat ⇒ grid_point ⇒ impgrid ⇒ rat ⇒ rat ⇒ unit Heap" where
  "downI' d       0 p a fl fr = return ()" |
  "downI' d (Suc l) p a fl fr = do {
      val ← Array.nth a (pm p) ;
      let fm = ((fl + fr) / 2 + val) ;
      Array.upd (pm p) (((fl + fr) / 4 + (1 / 3) * val) / 2 ^ (lv p d)) a ;
      downI' d l (child p left d) a fl fm ;
      downI' d l (child p right d) a fm fr
    }"

lemma downI' [sep_heap_rules]:
  assumes lin[simp]: "d < dm"
    and l: "level p + l = lm" "l = 0 ∨ p ∈ sparsegrid dm lm"
  shows "< gridI a v > downI' d l p a fl fr <λr. gridI a (down' d l p (of_rat fl) (of_rat fr) v) >"
  using l
proof (induct l arbitrary: p v fl fr)
  note rat_pair.simps [simp]
  case 0 then show ?case by sep_auto
next
  case (Suc l)
  from Suc.prems ‹d < dm›
  have [simp]: "level (child p left d) + l = lm" "level (child p right d) + l = lm" "p ∈ sparsegrid dm lm"
    by (auto simp: sparsegrid_length)

  have [simp]: "child p left d ∉ sparsegrid dm lm ⟹ l = 0" "child p right d ∉ sparsegrid dm lm ⟹ l = 0"
    using Suc.prems by (auto simp: sparsegrid_def lgrid_def)

  note Suc(1)[sep_heap_rules]
  show ?case
    by (sep_auto split: prod.split simp: of_rat_add of_rat_divide of_rat_power of_rat_mult rat_pair_def Let_def fun_upd_def)
qed

definition liftI :: "(nat ⇒ nat ⇒ grid_point ⇒ impgrid ⇒ unit Heap) ⇒ nat ⇒ impgrid ⇒ unit Heap" where
  "liftI f d a = 
    foldr (λ p n. n ⪢ f d (lm - level p) p a) (gridgen (start dm) ({ 0 ..< dm } - { d }) lm) (return ())"

theorem liftI [sep_heap_rules]:
  assumes "d < dm"
  and f[sep_heap_rules]: "⋀v p. p ∈ lgrid (start dm) ({0..<dm} - {d}) lm ⟹
    < gridI a v > f d (lm - level p) p a <λr. gridI a (f' d (lm - level p) p v) >"
  shows "< gridI a v > liftI f d a <λr. gridI a (Grid.lift f' dm lm d v) >"
proof -
  let ?ds = "{0..<dm} - {d}" and ?g = "gridI a"
  { fix ps assume "set ps ⊆ set (gridgen (start dm) ?ds lm)" and "distinct ps"
    then have "< ?g v >
        foldr (λp n. (n :: unit Heap) ⪢ f d (lm - level p) p a) ps (return ())
      <λr. ?g (foldr (λp α. f' d (lm - level p) p α) ps v) >"
      by (induct ps arbitrary: v) (sep_auto simp: gridgen_lgrid_eq)+ }
  from this[OF subset_refl gridgen_distinct]
  show ?thesis
    by (simp add: liftI_def Grid.lift_def)
qed

definition upI where "upI = liftI (λd l p a. upI' d l p a ⪢ return ())"

theorem upI [sep_heap_rules]:
  assumes [simp]: "d < dm"
  shows "< gridI a v > upI d a <λr. gridI a (up dm lm d v) > "
  unfolding up_def upI_def
  by (sep_auto simp: lgrid_def sparsegrid_def lgrid_def split: prod.split
               intro: grid_union_dims[of "{0..<dm} - {d}" "{0..<dm}"])

definition downI where "downI = liftI (λd l p a. downI' d l p a 0 0)"

theorem downI [sep_heap_rules]:
  assumes [simp]: "d < dm"
  shows "< gridI a v > downI d a <λr. gridI a (down dm lm d v) > "
  unfolding down_def downI_def
  by (sep_auto simp: lgrid_def sparsegrid_def lgrid_def split: prod.split
               intro: grid_union_dims[of "{0..<dm} - {d}" "{0..<dm}"])

theorem copy_array_gridI [sep_heap_rules]:
  "< gridI a v > copy_array a < λr. gridI a v * gridI r v >"
  unfolding gridI_def
  by sep_auto

theorem sum_array_gridI [sep_heap_rules]:
  "< gridI a v * gridI b w > sum_array a b < λr. gridI a (sum_vector v w) * gridI b w >"
  unfolding gridI_def
  by (sep_auto simp: sum_vector_def nth_map linearizationD of_rat_add)

primrec updownI' :: "nat ⇒ impgrid ⇒ unit Heap" where
  "updownI' 0 a = return ()" |
  "updownI' (Suc d) a = do {
      b ← copy_array a ;
      upI d a ;
      updownI' d a ;
      updownI' d b ;
      downI d b ;
      sum_array a b
    }"

theorem updownI' [sep_heap_rules]:
  "d ≤ dm ⟹ < gridI a v > updownI' d a <λr. gridI a (updown' dm lm d v) >t"
proof (induct d arbitrary: a v)
  case (Suc d)
  note Suc.hyps [sep_heap_rules]
  from Suc.prems show ?case
    by sep_auto
qed sep_auto

definition updownI where "updownI a = updownI' dm a"

theorem updownI [sep_heap_rules]:
  "< gridI a v > updownI a <λr. gridI a (updown dm lm v) >t"
  unfolding updown_def updownI_def by sep_auto

end

end