Theory Certification_Monads.Error_Monad

(* Title:     Error_Monad
   Author:    Christian Sternagel
   Author:    René Thiemann
*)

section ‹The Sum Type as Error Monad›

theory Error_Monad
imports
  "HOL-Library.Monad_Syntax"
  Error_Syntax
begin

text ‹Make monad syntax (including do-notation) available for the sum type.›
definition bind :: "'e + 'a  ('a  'e + 'b)  'e + 'b"
where
  "bind m f = (case m of Inr x  f x | Inl e  Inl e)"

adhoc_overloading
  Monad_Syntax.bind bind

abbreviation (input) "return  Inr"
abbreviation (input) "error  Inl"
abbreviation (input) "run  projr"


subsection ‹Monad Laws›

lemma return_bind [simp]:
  "(return x  f) = f x"
  by (simp add: bind_def)

lemma bind_return [simp]:
  "(m  return) = m"
  by (cases m) (simp_all add: bind_def)

lemma error_bind [simp]:
  "(error e  f) = error e"
  by (simp add: bind_def)

lemma bind_assoc [simp]:
  fixes m :: "'a + 'b"
  shows "((m  f)  g) = (m  (λx. f x  g))"
  by (cases m) (simp_all add: bind_def)

lemma bind_cong [fundef_cong]:
  fixes m1 m2 :: "'e + 'a"
    and f1 f2 :: "'a  'e + 'b"
  assumes "m1 = m2"
    and "y. m2 = Inr y  f1 y = f2 y"
  shows "(m1  f1) = (m2  f2)"
  using assms by (cases "m1") (auto simp: bind_def)

definition catch_error :: "'e + 'a  ('e  'f + 'a)  'f + 'a"
where
  catch_def: "catch_error m f = (case m of Inl e  f e | Inr x  Inr x)"

adhoc_overloading
  Error_Syntax.catch catch_error

lemma catch_splits:
  "P (try m catch f)  (e. m = Inl e  P (f e))  (x. m = Inr x  P (Inr x))"
  "P (try m catch f)  (¬ ((e. m = Inl e  ¬ P (f e))  (x. m = Inr x  ¬ P (Inr x))))"
  by (case_tac [!] m) (simp_all add: catch_def)

abbreviation update_error :: "'e + 'a  ('e  'f)  'f + 'a"
where
  "update_error m f  try m catch (λx. error (f x))"

adhoc_overloading
  Error_Syntax.update_error update_error

lemma catch_return [simp]:
  "(try return x catch f) = return x" by (simp add: catch_def)

lemma catch_error [simp]:
  "(try error e catch f) = f e" by (simp add: catch_def)

lemma update_error_return [simp]:
  "(m <+? c = return x)  (m = return x)"
  by (cases m) simp_all

definition "isOK m  (case m of Inl e  False | Inr x  True)"

lemma isOK_E [elim]:
  assumes "isOK m"
  obtains x where "m = return x"
  using assms by (cases m) (simp_all add: isOK_def)

lemma isOK_I [simp, intro]:
  "m = return x  isOK m"
  by (cases m) (simp_all add: isOK_def)

lemma isOK_iff:
  "isOK m  (x. m = return x)"
  by blast

lemma isOK_error [simp]:
  "isOK (error x) = False"
  by blast

lemma isOK_bind [simp]:
  "isOK (m  f)  isOK m  isOK (f (run m))"
  by (cases m) simp_all

lemma isOK_update_error [simp]:
  "isOK (m <+? f)  isOK m"
  by (cases m) simp_all

lemma isOK_case_prod [simp]:
  "isOK (case lr of (l, r)  P l r) = (case lr of (l, r)  isOK (P l r))"
  by (rule prod.case_distrib)

lemma isOK_case_option [simp]:
  "isOK (case x of None  P | Some v  Q v) = (case x of None  isOK P | Some v  isOK (Q v))"
  by (cases x) (auto)        

lemma isOK_Let [simp]:
  "isOK (Let s f) = isOK (f s)"
  by (simp add: Let_def)

lemma run_bind [simp]:
  "isOK m  run (m  f) = run (f (run m))"
  by auto

lemma run_catch [simp]:
  "isOK m  run (try m catch f) = run m"
  by auto

fun foldM :: "('a  'b  'e + 'a)  'a  'b list  'e + 'a"
where 
  "foldM f d [] = return d" |
  "foldM f d (x # xs) = do { y  f d x; foldM f y xs }"

fun forallM_index_aux :: "('a  nat  'e + unit)  nat  'a list  (('a × nat) × 'e) + unit"
where
  "forallM_index_aux P i [] = return ()" |
  "forallM_index_aux P i (x # xs) = do {
    P x i <+? Pair (x, i);
    forallM_index_aux P (Suc i) xs
  }"

lemma isOK_forallM_index_aux [simp]:
  "isOK (forallM_index_aux P n xs) = (i < length xs. isOK (P (xs ! i) (i + n)))"
proof (induct xs arbitrary: n)
  case (Cons x xs)
  have "(i < length (x # xs). isOK (P ((x # xs) ! i) (i + n))) 
    (isOK (P x n)  (i < length xs. isOK (P (xs ! i) (i + Suc n))))"
    by (auto, case_tac i) (simp_all)
  then show ?case
    unfolding Cons [of "Suc n", symmetric] by simp
qed auto

definition forallM_index :: "('a  nat  'e + unit)  'a list  (('a × nat) × 'e) + unit"
where
  "forallM_index P xs = forallM_index_aux P 0 xs"

lemma isOK_forallM_index [simp]:
  "isOK (forallM_index P xs)  (i < length xs. isOK (P (xs ! i) i))"
  unfolding forallM_index_def isOK_forallM_index_aux by simp

lemma forallM_index [fundef_cong]:
  fixes c :: "'a  nat  'e + unit"
  assumes "x i. x  set xs  c x i = d x i"
  shows "forallM_index c xs = forallM_index d xs"
proof -
  { fix n
    have "forallM_index_aux c n xs = forallM_index_aux d n xs"
      using assms by (induct xs arbitrary: n) simp_all }
  then show ?thesis by (simp add: forallM_index_def)
qed

hide_const forallM_index_aux

text ‹
  Check whether @{term f} succeeds for all elements of a given list. In case it doesn't,
  return the first offending element together with the produced error.
›
fun forallM :: "('a  'e + unit)  'a list  ('a * 'e) + unit"
where
  "forallM f [] = return ()" |
  "forallM f (x # xs) = f x <+? Pair x  forallM f xs"

lemma forallM_fundef_cong [fundef_cong]:
  assumes "xs = ys" "x. x  set ys  f x = g x"
  shows "forallM f xs = forallM g ys"
  unfolding assms(1) using assms(2)
proof (induct ys)
  case (Cons x xs)
  thus ?case by (cases "g x", auto)
qed auto

lemma isOK_forallM [simp]:
  "isOK (forallM f xs)  (x  set xs. isOK (f x))"
  by (induct xs) (simp_all)

text ‹
  Check whether @{term f} succeeds for at least one element of a given list.
  In case it doesn't, return the list of produced errors.
›
fun existsM :: "('a  'e + unit)  'a list  'e list + unit"
where
  "existsM f [] = error []" |
  "existsM f (x # xs) = (try f x catch (λe. existsM f xs <+? Cons e))"

lemma existsM_cong [fundef_cong]:
  assumes "xs = ys"
  and "x. x  set ys  f x = g x"
  shows "existsM f xs = existsM g ys"
  using assms
  by (induct ys arbitrary:xs) (auto split:catch_splits)

lemma isOK_existsM [simp]:
  "isOK (existsM f xs)  (xset xs. isOK (f x))"
proof (induct xs)
  case (Cons x xs)
  show ?case
  proof (cases "f x")
    case (Inl e)
    with Cons show ?thesis by simp
  qed (auto simp add: catch_def)
qed simp

lemma is_OK_if_return [simp]:
  "isOK (if b then return x else m)  b  isOK m"
  "isOK (if b then m else return x)  ¬ b  isOK m"
  by simp_all

lemma isOK_if_error [simp]:
  "isOK (if b then error e else m)  ¬ b  isOK m"
  "isOK (if b then m else error e)  b  isOK m"
  by simp_all

lemma isOK_if:
  "isOK (if b then x else y)  b  isOK x  ¬ b  isOK y"
  by simp

fun sequence :: "('e + 'a) list  'e + 'a list"
where
  "sequence [] = Inr []" |
  "sequence (m # ms) = do {
    x  m;
    xs  sequence ms;
    return (x # xs)
  }"


subsection ‹Monadic Map for Error Monad›

fun mapM :: "('a  'e + 'b)  'a list  'e + 'b list"
where
  "mapM f [] = return []" |
  "mapM f (x#xs) = do {
    y  f x;
    ys  mapM f xs;
    Inr (y # ys)
  }"

lemma mapM_error:
  "(e. mapM f xs = error e)  (xset xs. e. f x = error e)"
proof (induct xs)
  case (Cons x xs)
  then show ?case
    by (cases "f x", simp_all, cases "mapM f xs", simp_all)
qed simp

lemma mapM_return:
  assumes "mapM f xs = return ys"
  shows "ys = map (run  f) xs  (xset xs. e. f x  error e)"
using assms
proof (induct xs arbitrary: ys)
  case (Cons x xs ys)   
  then show ?case
    by (cases "f x", simp, cases "mapM f xs", simp_all)
qed simp

lemma mapM_return_idx:
  assumes *: "mapM f xs = Inr ys" and "i < length xs" 
  shows "y. f (xs ! i) = Inr y  ys ! i = y"
proof -
  note ** = mapM_return [OF *, unfolded set_conv_nth]
  with assms have "e. f (xs ! i)  Inl e" by auto
  then obtain y where "f (xs ! i) = Inr y" by (cases "f (xs ! i)") auto
  then have "f (xs ! i) = Inr y  ys ! i = y" unfolding ** [THEN conjunct1] using assms by auto
  then show ?thesis ..
qed

lemma mapM_cong [fundef_cong]:
  assumes "xs = ys" and "x. x  set ys  f x = g x"
  shows "mapM f xs = mapM g ys"
  unfolding assms(1) using assms(2) by (induct ys) auto

lemma bindE [elim]:
  assumes "(p  f) = return x"
  obtains y where "p = return y" and "f y = return x"
  using assms by (cases p) simp_all

lemma then_return_eq [simp]:
  "(p  q) = return f  isOK p  q = return f"
  by (cases p) simp_all

fun choice :: "('e + 'a) list  'e list + 'a"
where
  "choice [] = error []" |
  "choice (x # xs) = (try x catch (λe. choice xs <+? Cons e))"

declare choice.simps [simp del]

lemma isOK_mapM:
  assumes "isOK (mapM f xs)"
  shows "(x. x  set xs  isOK (f x))  run (mapM f xs) = map (λx. run (f x)) xs"
  using assms mapM_return[of f xs] by (force simp: isOK_def split: sum.splits)+

fun firstM
  where
    "firstM f [] = error []"
  | "firstM f (x # xs) = (try f x  return x catch (λe. firstM f xs <+? Cons e))"

lemma firstM:
  "isOK (firstM f xs)  (xset xs. isOK (f x))"
  by (induct xs) (auto simp: catch_def split: sum.splits)

lemma firstM_return:
  assumes "firstM f xs = return y"
  shows "isOK (f y)  y  set xs"
  using assms by (induct xs) (auto simp: catch_def split: sum.splits)


end