Theory Certification_Monads.Error_Monad
section ‹The Sum Type as Error Monad›
theory Error_Monad
imports
"HOL-Library.Monad_Syntax"
Error_Syntax
begin
text ‹Make monad syntax (including do-notation) available for the sum type.›
definition bind :: "'e + 'a ⇒ ('a ⇒ 'e + 'b) ⇒ 'e + 'b"
where
"bind m f = (case m of Inr x ⇒ f x | Inl e ⇒ Inl e)"
adhoc_overloading
Monad_Syntax.bind bind
abbreviation (input) "return ≡ Inr"
abbreviation (input) "error ≡ Inl"
abbreviation (input) "run ≡ projr"
subsection ‹Monad Laws›
lemma return_bind [simp]:
"(return x ⤜ f) = f x"
by (simp add: bind_def)
lemma bind_return [simp]:
"(m ⤜ return) = m"
by (cases m) (simp_all add: bind_def)
lemma error_bind [simp]:
"(error e ⤜ f) = error e"
by (simp add: bind_def)
lemma bind_assoc [simp]:
fixes m :: "'a + 'b"
shows "((m ⤜ f) ⤜ g) = (m ⤜ (λx. f x ⤜ g))"
by (cases m) (simp_all add: bind_def)
lemma bind_cong [fundef_cong]:
fixes m1 m2 :: "'e + 'a"
and f1 f2 :: "'a ⇒ 'e + 'b"
assumes "m1 = m2"
and "⋀y. m2 = Inr y ⟹ f1 y = f2 y"
shows "(m1 ⤜ f1) = (m2 ⤜ f2)"
using assms by (cases "m1") (auto simp: bind_def)
definition catch_error :: "'e + 'a ⇒ ('e ⇒ 'f + 'a) ⇒ 'f + 'a"
where
catch_def: "catch_error m f = (case m of Inl e ⇒ f e | Inr x ⇒ Inr x)"
adhoc_overloading
Error_Syntax.catch catch_error
lemma catch_splits:
"P (try m catch f) ⟷ (∀e. m = Inl e ⟶ P (f e)) ∧ (∀x. m = Inr x ⟶ P (Inr x))"
"P (try m catch f) ⟷ (¬ ((∃e. m = Inl e ∧ ¬ P (f e)) ∨ (∃x. m = Inr x ∧ ¬ P (Inr x))))"
by (case_tac [!] m) (simp_all add: catch_def)
abbreviation update_error :: "'e + 'a ⇒ ('e ⇒ 'f) ⇒ 'f + 'a"
where
"update_error m f ≡ try m catch (λx. error (f x))"
adhoc_overloading
Error_Syntax.update_error update_error
lemma catch_return [simp]:
"(try return x catch f) = return x" by (simp add: catch_def)
lemma catch_error [simp]:
"(try error e catch f) = f e" by (simp add: catch_def)
lemma update_error_return [simp]:
"(m <+? c = return x) ⟷ (m = return x)"
by (cases m) simp_all
definition "isOK m ⟷ (case m of Inl e ⇒ False | Inr x ⇒ True)"
lemma isOK_E [elim]:
assumes "isOK m"
obtains x where "m = return x"
using assms by (cases m) (simp_all add: isOK_def)
lemma isOK_I [simp, intro]:
"m = return x ⟹ isOK m"
by (cases m) (simp_all add: isOK_def)
lemma isOK_iff:
"isOK m ⟷ (∃x. m = return x)"
by blast
lemma isOK_error [simp]:
"isOK (error x) = False"
by blast
lemma isOK_bind [simp]:
"isOK (m ⤜ f) ⟷ isOK m ∧ isOK (f (run m))"
by (cases m) simp_all
lemma isOK_update_error [simp]:
"isOK (m <+? f) ⟷ isOK m"
by (cases m) simp_all
lemma isOK_case_prod [simp]:
"isOK (case lr of (l, r) ⇒ P l r) = (case lr of (l, r) ⇒ isOK (P l r))"
by (rule prod.case_distrib)
lemma isOK_case_option [simp]:
"isOK (case x of None ⇒ P | Some v ⇒ Q v) = (case x of None ⇒ isOK P | Some v ⇒ isOK (Q v))"
by (cases x) (auto)
lemma isOK_Let [simp]:
"isOK (Let s f) = isOK (f s)"
by (simp add: Let_def)
lemma run_bind [simp]:
"isOK m ⟹ run (m ⤜ f) = run (f (run m))"
by auto
lemma run_catch [simp]:
"isOK m ⟹ run (try m catch f) = run m"
by auto
fun foldM :: "('a ⇒ 'b ⇒ 'e + 'a) ⇒ 'a ⇒ 'b list ⇒ 'e + 'a"
where
"foldM f d [] = return d" |
"foldM f d (x # xs) = do { y ← f d x; foldM f y xs }"
fun forallM_index_aux :: "('a ⇒ nat ⇒ 'e + unit) ⇒ nat ⇒ 'a list ⇒ (('a × nat) × 'e) + unit"
where
"forallM_index_aux P i [] = return ()" |
"forallM_index_aux P i (x # xs) = do {
P x i <+? Pair (x, i);
forallM_index_aux P (Suc i) xs
}"
lemma isOK_forallM_index_aux [simp]:
"isOK (forallM_index_aux P n xs) = (∀i < length xs. isOK (P (xs ! i) (i + n)))"
proof (induct xs arbitrary: n)
case (Cons x xs)
have "(∀i < length (x # xs). isOK (P ((x # xs) ! i) (i + n))) ⟷
(isOK (P x n) ∧ (∀i < length xs. isOK (P (xs ! i) (i + Suc n))))"
by (auto, case_tac i) (simp_all)
then show ?case
unfolding Cons [of "Suc n", symmetric] by simp
qed auto
definition forallM_index :: "('a ⇒ nat ⇒ 'e + unit) ⇒ 'a list ⇒ (('a × nat) × 'e) + unit"
where
"forallM_index P xs = forallM_index_aux P 0 xs"
lemma isOK_forallM_index [simp]:
"isOK (forallM_index P xs) ⟷ (∀i < length xs. isOK (P (xs ! i) i))"
unfolding forallM_index_def isOK_forallM_index_aux by simp
lemma forallM_index [fundef_cong]:
fixes c :: "'a ⇒ nat ⇒ 'e + unit"
assumes "⋀x i. x ∈ set xs ⟹ c x i = d x i"
shows "forallM_index c xs = forallM_index d xs"
proof -
{ fix n
have "forallM_index_aux c n xs = forallM_index_aux d n xs"
using assms by (induct xs arbitrary: n) simp_all }
then show ?thesis by (simp add: forallM_index_def)
qed
hide_const forallM_index_aux
text ‹
Check whether @{term f} succeeds for all elements of a given list. In case it doesn't,
return the first offending element together with the produced error.
›
fun forallM :: "('a ⇒ 'e + unit) ⇒ 'a list ⇒ ('a * 'e) + unit"
where
"forallM f [] = return ()" |
"forallM f (x # xs) = f x <+? Pair x ⪢ forallM f xs"
lemma forallM_fundef_cong [fundef_cong]:
assumes "xs = ys" "⋀x. x ∈ set ys ⟹ f x = g x"
shows "forallM f xs = forallM g ys"
unfolding assms(1) using assms(2)
proof (induct ys)
case (Cons x xs)
thus ?case by (cases "g x", auto)
qed auto
lemma isOK_forallM [simp]:
"isOK (forallM f xs) ⟷ (∀x ∈ set xs. isOK (f x))"
by (induct xs) (simp_all)
text ‹
Check whether @{term f} succeeds for at least one element of a given list.
In case it doesn't, return the list of produced errors.
›
fun existsM :: "('a ⇒ 'e + unit) ⇒ 'a list ⇒ 'e list + unit"
where
"existsM f [] = error []" |
"existsM f (x # xs) = (try f x catch (λe. existsM f xs <+? Cons e))"
lemma existsM_cong [fundef_cong]:
assumes "xs = ys"
and "⋀x. x ∈ set ys ⟹ f x = g x"
shows "existsM f xs = existsM g ys"
using assms
by (induct ys arbitrary:xs) (auto split:catch_splits)
lemma isOK_existsM [simp]:
"isOK (existsM f xs) ⟷ (∃x∈set xs. isOK (f x))"
proof (induct xs)
case (Cons x xs)
show ?case
proof (cases "f x")
case (Inl e)
with Cons show ?thesis by simp
qed (auto simp add: catch_def)
qed simp
lemma is_OK_if_return [simp]:
"isOK (if b then return x else m) ⟷ b ∨ isOK m"
"isOK (if b then m else return x) ⟷ ¬ b ∨ isOK m"
by simp_all
lemma isOK_if_error [simp]:
"isOK (if b then error e else m) ⟷ ¬ b ∧ isOK m"
"isOK (if b then m else error e) ⟷ b ∧ isOK m"
by simp_all
lemma isOK_if:
"isOK (if b then x else y) ⟷ b ∧ isOK x ∨ ¬ b ∧ isOK y"
by simp
fun sequence :: "('e + 'a) list ⇒ 'e + 'a list"
where
"sequence [] = Inr []" |
"sequence (m # ms) = do {
x ← m;
xs ← sequence ms;
return (x # xs)
}"
subsection ‹Monadic Map for Error Monad›
fun mapM :: "('a ⇒ 'e + 'b) ⇒ 'a list ⇒ 'e + 'b list"
where
"mapM f [] = return []" |
"mapM f (x#xs) = do {
y ← f x;
ys ← mapM f xs;
Inr (y # ys)
}"
lemma mapM_error:
"(∃e. mapM f xs = error e) ⟷ (∃x∈set xs. ∃e. f x = error e)"
proof (induct xs)
case (Cons x xs)
then show ?case
by (cases "f x", simp_all, cases "mapM f xs", simp_all)
qed simp
lemma mapM_return:
assumes "mapM f xs = return ys"
shows "ys = map (run ∘ f) xs ∧ (∀x∈set xs. ∀e. f x ≠ error e)"
using assms
proof (induct xs arbitrary: ys)
case (Cons x xs ys)
then show ?case
by (cases "f x", simp, cases "mapM f xs", simp_all)
qed simp
lemma mapM_return_idx:
assumes *: "mapM f xs = Inr ys" and "i < length xs"
shows "∃y. f (xs ! i) = Inr y ∧ ys ! i = y"
proof -
note ** = mapM_return [OF *, unfolded set_conv_nth]
with assms have "⋀e. f (xs ! i) ≠ Inl e" by auto
then obtain y where "f (xs ! i) = Inr y" by (cases "f (xs ! i)") auto
then have "f (xs ! i) = Inr y ∧ ys ! i = y" unfolding ** [THEN conjunct1] using assms by auto
then show ?thesis ..
qed
lemma mapM_cong [fundef_cong]:
assumes "xs = ys" and "⋀x. x ∈ set ys ⟹ f x = g x"
shows "mapM f xs = mapM g ys"
unfolding assms(1) using assms(2) by (induct ys) auto
lemma bindE [elim]:
assumes "(p ⤜ f) = return x"
obtains y where "p = return y" and "f y = return x"
using assms by (cases p) simp_all
lemma then_return_eq [simp]:
"(p ⪢ q) = return f ⟷ isOK p ∧ q = return f"
by (cases p) simp_all
fun choice :: "('e + 'a) list ⇒ 'e list + 'a"
where
"choice [] = error []" |
"choice (x # xs) = (try x catch (λe. choice xs <+? Cons e))"
declare choice.simps [simp del]
lemma isOK_mapM:
assumes "isOK (mapM f xs)"
shows "(∀x. x ∈ set xs ⟶ isOK (f x)) ∧ run (mapM f xs) = map (λx. run (f x)) xs"
using assms mapM_return[of f xs] by (force simp: isOK_def split: sum.splits)+
fun firstM
where
"firstM f [] = error []"
| "firstM f (x # xs) = (try f x ⪢ return x catch (λe. firstM f xs <+? Cons e))"
lemma firstM:
"isOK (firstM f xs) ⟷ (∃x∈set xs. isOK (f x))"
by (induct xs) (auto simp: catch_def split: sum.splits)
lemma firstM_return:
assumes "firstM f xs = return y"
shows "isOK (f y) ∧ y ∈ set xs"
using assms by (induct xs) (auto simp: catch_def split: sum.splits)
end