Theory Monad_Normalisation_Test
theory Monad_Normalisation_Test
imports Monad_Normalisation
begin
section ‹Tests and examples›
context includes monad_normalisation
begin
lemma
assumes "f = id"
shows
"do {x ← B; z ← C x; d ← E z x; a ← D z x; y ← A; return_pmf (x,y)} =
do {y ← A; x ← B; z ← C x; a ← D z x; d ← E z x; return_pmf (f (x,y))}"
apply (simp)
apply (simp add: assms)
done
lemma "(do {a ← E; b ← E; w ← B b a; z ← B a b; return_pmf (w,z)}) =
(do {a ← E; b ← E; z ← B a b; w ← B b a; return_pmf (w,z)})"
by (simp)
lemma "(do {a ← E; b ← E; w ← B b a; z ← B a b; return_pmf (w,z)}) =
(do {a ← E; b ← E; z ← B a b; w ← B b a; return_pmf (w,z)})"
by (simp)
lemma "do {y ← A; x ← A; z ← B x y y; w ← B x x y; Some (x,y)} =
do {x ← A; y ← A; z ← B x x y; w ← B x y y; Some (x,y)}"
by (simp)
lemma "do {y ← A; x ← A; z ← B x y y; w ← B x x y; {x,y}} =
do {x ← A; y ← A; z ← B x x y; w ← B x y y; {x,y}}"
by (simp)
lemma "do {y ← A; x ← A; z ← B x y y; w ← B x x y; return_pmf (x,y)} =
do {x ← A; y ← A; z ← B x x y; w ← B x y y; return_pmf (x,y)}"
by (simp)
lemma "do {x ← A 0; y ← A x; w ← B y y; z ← B x y; a ← C; Predicate.single (a,a)} =
do {x ← A 0; y ← A x; z ← B x y; w ← B y y; a ← C; Predicate.single (a,a)}"
by (simp)
lemma "do {x ← A 0; y ← A x; z ← B x y; w ← B y y; a ← C; return_pmf (a,a)} =
do {x ← A 0; y ← A x; z ← B y y; w ← B x y; a ← C; return_pmf (a,a)}"
by (simp)
lemma "do {x ← B; z ← C x; d ← E z x; a ← D z x; y ← A; return_pmf (x,y)} =
do {y ← A; x ← B; z ← C x; a ← D z x; d ← E z x; return_pmf (x,y)}"
by (simp)
no_adhoc_overloading Monad_Syntax.bind bind_pmf
context
fixes 𝒜1 :: "'a ⇒ (('a × 'a) × 'b) spmf"
and 𝒜2 :: "'a × 'a ⇒ 'b ⇒ bool spmf"
and sample_uniform :: "nat ⇒ nat spmf"
and order :: "'a ⇒ nat"
begin
lemma
"do {
x ← sample_uniform (order 𝒢);
y ← sample_uniform (order 𝒢);
z ← sample_uniform (order 𝒢);
b ← coin_spmf;
((msg1, msg2), σ) ← 𝒜1 (f x);
_ :: unit ← assert_spmf (valid_plain msg1 ∧ valid_plain msg2);
guess ← 𝒜2 (f y, xor (f z) (if b then msg1 else msg2)) σ;
return_spmf (guess ⟷ b)
} = do {
x ← sample_uniform (order 𝒢);
y ← sample_uniform (order 𝒢);
((msg1, msg2), σ) ← 𝒜1 (f x);
_ :: unit ← assert_spmf (valid_plain msg1 ∧ valid_plain msg2);
b ← coin_spmf;
x ← sample_uniform (order 𝒢);
guess ← 𝒜2 (f y, xor (f x) (if b then msg1 else msg2)) σ;
return_spmf (guess ⟷ b)
}" for xor
by (simp add: split_def)
lemma
"do {
x ← sample_uniform (order 𝒢);
xa ← sample_uniform (order 𝒢);
x ← 𝒜1 (f x);
case x of
(x, xb) ⇒
(case x of
(msg1, msg2) ⇒
λσ. do {
a ← assert_spmf (valid_plain msg1 ∧ valid_plain msg2);
x ← coin_spmf;
xaa ← map_spmf f (sample_uniform (order 𝒢));
guess ← 𝒜2 (f xa, xaa) σ;
return_spmf (guess ⟷ x)
})
xb
} = do {
x ← sample_uniform (order 𝒢);
xa ← sample_uniform (order 𝒢);
x ← 𝒜1 (f x);
case x of
(x, xb) ⇒
(case x of
(msg1, msg2) ⇒
λσ. do {
a ← assert_spmf (valid_plain msg1 ∧ valid_plain msg2);
z ← map_spmf f (sample_uniform (order 𝒢));
guess ← 𝒜2 (f xa, z) σ;
map_spmf ((⟷) guess) coin_spmf
})
xb
}"
by (simp add: map_spmf_conv_bind_spmf)
lemma elgamal_step3:
"do {
x ← sample_uniform (order 𝒢);
y ← sample_uniform (order 𝒢);
b ← coin_spmf;
p ← 𝒜1 (f x);
_ ← assert_spmf (valid_plain (fst (fst p)) ∧ valid_plain (snd (fst p)));
guess ←
𝒜2 (f y, xor (f (x * y)) (if b then fst (fst p) else snd (fst p)))
(snd p);
return_spmf (guess ⟷ b)
} = do {
y ← sample_uniform (order 𝒢);
b ← coin_spmf;
p ← 𝒜1 (f y);
_ ← assert_spmf (valid_plain (fst (fst p)) ∧ valid_plain (snd (fst p)));
ya ← sample_uniform (order 𝒢);
b' ← 𝒜2 (f ya,
xor (f (y * ya)) (if b then fst (fst p) else snd (fst p)))
(snd p);
return_spmf (b' ⟷ b)
}" for xor
by (simp)
end
text ‹Distributivity›
lemma
"do {
x ← A :: nat spmf;
a ← B;
b ← B;
if a = b then do {
return_spmf x
} else do {
y ← C;
return_spmf (x + y)
}
} = do {
a ← B;
b ← B;
if b = a then A else do {
y ← C;
x ← A;
return_spmf (y + x)
}
}"
by (simp add: add.commute cong: if_cong)
lemma
"do {
x ← A :: nat spmf;
p ← do {
a ← B;
b ← B;
return_spmf (a, b)
};
q ← coin_spmf;
if q then do {
return_spmf (x + fst p)
} else do {
y ← C;
return_spmf (y + snd p)
}
} = do {
q ← coin_spmf;
if q then do {
x ← A;
a ← B;
_ ← B;
return_spmf (x + a)
} else do {
y ← C;
a ← B;
_ ← B;
_ ← A;
return_spmf (y + a)
}
}"
by (simp cong: if_cong)
lemma
fixes f :: "nat ⇒ nat ⇒ nat + nat"
shows
"do {
x ← (A::nat set);
a ← B;
b ← B;
case f a b of
Inl c ⇒ {x}
| Inr c ⇒ do {
y ← C x;
{(x + y + c)}
}
} = do {
a ← B;
b ← B;
case f b a of
Inl c ⇒ A
| Inr c ⇒ do {
x ← A;
y ← C x;
{(y + c + x)}
}
}"
by (simp add: add.commute add.left_commute cong: sum.case_cong)
section ‹Limits›
text ‹
The following example shows that the combination of monad normalisation and regular ordered
rewriting is not necessarily confluent.
›
lemma "do {a ← A; b ← A; Some (a ∧ b, b)} =
do {a ← A; b ← A; Some (a ∧ b, a)}"
apply (simp add: conj_comms)?
apply (rewrite option_bind_commute)
apply (simp only: conj_comms)
done
text ‹
The next example shows that even monad normalisation alone is not confluent because
the term ordering prevents the reordering of ‹f A› with ‹f B›.
But if we change ‹A› to ‹E›, then the reordering works as expected.
›
lemma
"do {a ← f A; b ← f B; c ← D b; d ← f C; F a c d} =
do {b ← f B; c ← D b; a ← f A; d ← f C; F a c d}"
for f :: "'b ⇒ 'a option" and D :: "'a ⇒ 'a option"
apply(simp)?
apply(subst option_bind_commute, subst (2) option_bind_commute, rule refl)
done
lemma
"do {a ← f E; b ← f B; c ← D b; d ← f C; F a c d} =
do {b ← f B; c ← D b; a ← f E; d ← f C; F a c d}"
for f :: "'b ⇒ 'a option" and D :: "'a ⇒ 'a option"
by simp
end
end