Theory Monad_Normalisation_Test

(*  Title:      Monad_Normalisation_Test.thy
    Author:     Manuel Eberl, TU München
    Author:     Andreas Lochbihler, ETH Zurich
    Author:     Joshua Schneider, ETH Zurich
*)

theory Monad_Normalisation_Test
imports Monad_Normalisation
begin

section ‹Tests and examples›

context includes monad_normalisation
begin

lemma
  assumes "f = id"
  shows
    "do {x  B; z  C x; d  E z x; a  D z x; y  A; return_pmf (x,y)} =
     do {y  A; x  B; z  C x; a  D z x; d  E z x; return_pmf (f (x,y))}"
apply (simp)
apply (simp add: assms)
done

lemma "(do {a  E; b  E; w  B b a; z  B a b; return_pmf (w,z)}) =
       (do {a  E; b  E; z  B a b; w  B b a; return_pmf (w,z)})"
by (simp)

lemma "(do {a  E; b  E; w  B b a; z  B a b; return_pmf (w,z)}) =
       (do {a  E; b  E; z  B a b; w  B b a; return_pmf (w,z)})"
by (simp)

lemma "do {y  A; x  A; z  B x y y; w  B x x y; Some (x,y)} =
       do {x  A; y  A; z  B x x y; w  B x y y; Some (x,y)}"
by (simp)

lemma "do {y  A; x  A; z  B x y y; w  B x x y; {x,y}} =
       do {x  A; y  A; z  B x x y; w  B x y y; {x,y}}"
by (simp)

lemma "do {y  A; x  A; z  B x y y; w  B x x y; return_pmf (x,y)} =
       do {x  A; y  A; z  B x x y; w  B x y y; return_pmf (x,y)}"
by (simp)

lemma "do {x  A 0; y  A x; w  B y y; z  B x y; a  C; Predicate.single (a,a)} =
       do {x  A 0; y  A x; z  B x y; w  B y y; a  C; Predicate.single (a,a)}"
by (simp)

lemma "do {x  A 0; y  A x; z  B x y; w  B y y; a  C; return_pmf (a,a)} =
       do {x  A 0; y  A x; z  B y y; w  B x y; a  C; return_pmf (a,a)}"
by (simp)

lemma "do {x  B; z  C x; d  E z x; a  D z x; y  A; return_pmf (x,y)} =
       do {y  A; x  B; z  C x; a  D z x; d  E z x; return_pmf (x,y)}"
by (simp)


no_adhoc_overloading Monad_Syntax.bind bind_pmf

context
  fixes 𝒜1 :: "'a  (('a × 'a) × 'b) spmf"
  and 𝒜2 :: "'a × 'a  'b  bool spmf"
  and sample_uniform :: "nat  nat spmf"
  and order :: "'a  nat"
begin

lemma 
  "do {
      x  sample_uniform (order 𝒢);
      y  sample_uniform (order 𝒢);
      z  sample_uniform (order 𝒢);
      b  coin_spmf;
      ((msg1, msg2), σ)  𝒜1 (f x);
      _ :: unit  assert_spmf (valid_plain msg1  valid_plain msg2);
      guess  𝒜2 (f y, xor (f z) (if b then msg1 else msg2)) σ;
      return_spmf (guess  b)
    } = do {
      x  sample_uniform (order 𝒢);
      y  sample_uniform (order 𝒢);
      ((msg1, msg2), σ)  𝒜1 (f x);
      _ :: unit  assert_spmf (valid_plain msg1  valid_plain msg2);
      b  coin_spmf;
      x  sample_uniform (order 𝒢);
      guess  𝒜2 (f y, xor (f x) (if b then msg1 else msg2)) σ;
      return_spmf (guess  b)
    }" for xor
by (simp add: split_def)

lemma
  "do {
      x  sample_uniform (order 𝒢);
      xa  sample_uniform (order 𝒢);
      x  𝒜1 (f x);
      case x of
      (x, xb) 
        (case x of
         (msg1, msg2) 
           λσ. do {
                a  assert_spmf (valid_plain msg1  valid_plain msg2);
                x  coin_spmf;
                xaa  map_spmf f (sample_uniform (order 𝒢));
                guess  𝒜2 (f xa, xaa) σ;
                return_spmf (guess  x)
              })
         xb
    } = do {
      x  sample_uniform (order 𝒢);
      xa  sample_uniform (order 𝒢);
      x  𝒜1 (f x);
      case x of
      (x, xb) 
        (case x of
         (msg1, msg2) 
           λσ. do {
                a  assert_spmf (valid_plain msg1  valid_plain msg2);
                z  map_spmf f (sample_uniform (order 𝒢));
                guess  𝒜2 (f xa, z) σ;
                map_spmf ((⟷) guess) coin_spmf
              })
         xb
    }"
by (simp add: map_spmf_conv_bind_spmf)

lemma elgamal_step3:
  "do {
      x  sample_uniform (order 𝒢);
      y  sample_uniform (order 𝒢);
      b  coin_spmf;
      p  𝒜1 (f x);
      _  assert_spmf (valid_plain (fst (fst p))  valid_plain (snd (fst p)));
      guess 
        𝒜2 (f y, xor (f (x * y)) (if b then fst (fst p) else snd (fst p)))
         (snd p);
      return_spmf (guess  b)
    }  = do {
      y  sample_uniform (order 𝒢);
      b  coin_spmf;
      p  𝒜1 (f y);
      _  assert_spmf (valid_plain (fst (fst p))  valid_plain (snd (fst p)));
      ya  sample_uniform (order 𝒢);
      b'  𝒜2 (f ya,
                 xor (f (y * ya)) (if b then fst (fst p) else snd (fst p)))
             (snd p);
      return_spmf (b'  b)
    }" for xor
by (simp)

end

text ‹Distributivity›

lemma
  "do {
      x  A :: nat spmf;
      a  B;
      b  B;
      if a = b then do {
        return_spmf x
      } else do {
        y  C;
        return_spmf (x + y)
      }
   } = do {
      a  B;
      b  B;
      if b = a then A else do {
        y  C;
        x  A;
        return_spmf (y + x)
      }
   }"
by (simp add: add.commute cong: if_cong)

lemma
  "do {
      x  A :: nat spmf;
      p  do {
        a  B;
        b  B;
        return_spmf (a, b)
      };
      q  coin_spmf;
      if q then do {
        return_spmf (x + fst p)
      } else do {
        y  C;
        return_spmf (y + snd p)
      }
   } = do {
      q  coin_spmf;
      if q then do {
        x  A;
        a  B;
        _  B;
        return_spmf (x + a)
      } else do {
        y  C;
        a  B;
        _  B;
        _  A;
        return_spmf (y + a)
      }
   }"
by (simp cong: if_cong)

lemma
  fixes f :: "nat  nat  nat + nat"
  shows
  "do {
      x  (A::nat set);
      a  B;
      b  B;
      case f a b of
        Inl c  {x}
      | Inr c  do {
          y  C x;
          {(x + y + c)}
        }
   } = do {
      a  B;
      b  B;
      case f b a of
        Inl c  A
      | Inr c  do {
          x  A;
          y  C x;
          {(y + c + x)}
      }
   }"
by (simp add: add.commute add.left_commute cong: sum.case_cong)


section ‹Limits›

text ‹
  The following example shows that the combination of monad normalisation and regular ordered
  rewriting is not necessarily confluent.
›

lemma "do {a  A; b  A; Some (a  b, b)} =
       do {a  A; b  A; Some (a  b, a)}"
apply (simp add: conj_comms)?       ― ‹no progress made›
apply (rewrite option_bind_commute) ― ‹force a particular binder order›
apply (simp only: conj_comms)
done

text ‹
  The next example shows that even monad normalisation alone is not confluent because 
  the term ordering prevents the reordering of f A› with f B›.
  But if we change A› to E›, then the reordering works as expected.
›

lemma
  "do {a  f A; b  f B; c  D b; d  f C; F a c d} = 
   do {b  f B; c  D b; a  f A; d  f C; F a c d}"
  for f :: "'b  'a option" and D :: "'a  'a option"
  apply(simp)? ― ‹no progress made›
  apply(subst option_bind_commute, subst (2) option_bind_commute, rule refl)
  done

lemma
  "do {a  f E; b  f B; c  D b; d  f C; F a c d} = 
   do {b  f B; c  D b; a  f E; d  f C; F a c d}"
  for f :: "'b  'a option" and D :: "'a  'a option"
  by simp

end

end