Theory PRF_UPF_IND_CCA

(* Title: PRF_UPF_IND_CCA.thy
  Author: Andreas Lochbihler, ETH Zurich 
  Author: S. Reza Sefidgar, ETH Zurich *)

subsection ‹IND-CCA from a PRF and an unpredictable function›

theory PRF_UPF_IND_CCA
imports
  Pseudo_Random_Function
  CryptHOL.List_Bits
  Unpredictable_Function
  IND_CCA2_sym
  CryptHOL.Negligible
begin

text ‹Formalisation of Shoup's construction of an IND-CCA secure cipher from a PRF and an unpredictable function cite‹\S 7› in "Shoup2004IACR".›

type_synonym bitstring = "bool list"

locale simple_cipher = 
  PRF: "prf" prf_key_gen prf_fun "spmf_of_set (nlists UNIV prf_clen)" +
  UPF: upf upf_key_gen upf_fun
  for prf_key_gen :: "'prf_key spmf"
  and prf_fun :: "'prf_key  bitstring  bitstring"
  and prf_domain :: "bitstring set"
  and prf_range :: "bitstring set"
  and prf_dlen :: nat
  and prf_clen :: nat
  and upf_key_gen :: "'upf_key spmf"
  and upf_fun :: "'upf_key  bitstring  'hash"
  +
  assumes prf_domain_finite: "finite prf_domain"
  assumes prf_domain_nonempty: "prf_domain  {}"
  assumes prf_domain_length:  "x  prf_domain  length x = prf_dlen"
  assumes prf_codomain_length: 
    " key_prf  set_spmf prf_key_gen; m  prf_domain   length (prf_fun key_prf m) = prf_clen"
  assumes prf_key_gen_lossless: "lossless_spmf prf_key_gen"
  assumes upf_key_gen_lossless: "lossless_spmf upf_key_gen"
begin

type_synonym 'hash' cipher_text = "bitstring × bitstring × 'hash'"

definition key_gen :: "('prf_key × 'upf_key) spmf" where
 "key_gen = do {
   k_prf  prf_key_gen;
   k_upf :: 'upf_key  upf_key_gen;
   return_spmf (k_prf, k_upf)
 }"

lemma lossless_key_gen [simp]: "lossless_spmf key_gen"
  by(simp add: key_gen_def prf_key_gen_lossless upf_key_gen_lossless)

fun encrypt :: "('prf_key × 'upf_key)  bitstring  'hash cipher_text spmf"
where
  "encrypt (k_prf, k_upf) m = do {
    x  spmf_of_set prf_domain;
    let c = prf_fun k_prf x [⊕] m;
    let t = upf_fun k_upf (x @ c);
    return_spmf ((x, c, t))
  }"

lemma lossless_encrypt [simp]: "lossless_spmf (encrypt k m)"
  by (cases k) (simp add: Let_def prf_domain_nonempty prf_domain_finite split: bool.split)

fun decrypt :: "('prf_key × 'upf_key)  'hash cipher_text  bitstring option"
where
  "decrypt (k_prf, k_upf) (x, c, t) = (
    if upf_fun k_upf (x @ c) = t  length x = prf_dlen then
      Some (prf_fun k_prf x [⊕] c)
    else
      None
  )"

lemma cipher_correct:
  " k  set_spmf key_gen; length m = prf_clen 
   encrypt k m  (λc. return_spmf (decrypt k c)) = return_spmf (Some m)"
by (cases k) (simp add: prf_domain_nonempty prf_domain_finite prf_domain_length
  prf_codomain_length key_gen_def bind_eq_return_spmf Let_def)

declare encrypt.simps[simp del]

sublocale ind_cca: ind_cca key_gen encrypt decrypt "λm. length m = prf_clen" .
interpretation ind_cca': ind_cca key_gen encrypt "λ _ _. None" "λm. length m = prf_clen" .

definition intercept_upf_enc
  :: "'prf_key  bool  'hash cipher_text set × 'hash cipher_text set  bitstring × bitstring
   ('hash cipher_text option × ('hash cipher_text set × 'hash cipher_text set),
     bitstring + (bitstring × 'hash), 'hash + unit) gpv" 
where 
  "intercept_upf_enc k b = (λ(L, D) (m1, m0).
    (case (length m1 = prf_clen  length m0 = prf_clen) of
      False  Done (None, L, D)
    | True  do {
        x  lift_spmf (spmf_of_set prf_domain);
        let c = prf_fun k x [⊕] (if b then m1 else m0);
        t  Pause (Inl (x @ c)) Done;
        Done ((Some (x, c, projl t)), (insert (x, c, projl t) L, D))
      }))"

definition intercept_upf_dec
  :: "'hash cipher_text set × 'hash cipher_text set  'hash cipher_text
   (bitstring option × ('hash cipher_text set × 'hash cipher_text set),
     bitstring + (bitstring × 'hash), 'hash + unit) gpv" 
where
  "intercept_upf_dec = (λ(L, D) (x, c, t).
    if (x, c, t)  L  length x  prf_dlen then Done (None, (L, D)) else do {
      Pause (Inr (x @ c, t)) Done;
      Done (None, (L, insert (x, c, t) D))
    })"

definition intercept_upf :: 
  "'prf_key  bool  'hash cipher_text set × 'hash cipher_text set  bitstring × bitstring + 'hash cipher_text
   (('hash cipher_text option + bitstring option) × ('hash cipher_text set × 'hash cipher_text set),
     bitstring + (bitstring × 'hash), 'hash + unit) gpv" 
where
  "intercept_upf k b = plus_intercept (intercept_upf_enc k b) intercept_upf_dec"

lemma intercept_upf_simps [simp]:
  "intercept_upf k b (L, D) (Inr (x, c, t)) =
    (if (x, c, t)  L  length x  prf_dlen then Done (Inr None, (L, D)) else do {
      Pause (Inr (x @ c, t)) Done;
      Done (Inr None, (L, insert (x, c, t) D))
    })"
  "intercept_upf k b (L, D) (Inl (m1, m0)) = 
    (case (length m1 = prf_clen  length m0 = prf_clen) of
      False  Done (Inl None, L, D)
    | True  do {
        x  lift_spmf (spmf_of_set prf_domain);
        let c = prf_fun k x [⊕] (if b then m1 else m0);
        t  Pause (Inl (x @ c)) Done;
        Done (Inl (Some (x, c, projl t)), (insert (x, c, projl t) L, D))
      })"
   by(simp_all add: intercept_upf_def intercept_upf_dec_def intercept_upf_enc_def o_def map_gpv_bind_gpv gpv.map_id Let_def split!: bool.split)


lemma interaction_bounded_by_upf_enc_Inr [interaction_bound]:
  "interaction_bounded_by (Not  isl) (intercept_upf_enc k b LD mm) 0"
unfolding intercept_upf_enc_def case_prod_app
by(interaction_bound, clarsimp simp add: SUP_constant bot_enat_def split: prod.split)

lemma interaction_bounded_by_upf_dec_Inr [interaction_bound]:
  "interaction_bounded_by (Not  isl) (intercept_upf_dec LD c) 1"
unfolding intercept_upf_dec_def case_prod_app
by(interaction_bound, clarsimp simp add: SUP_constant split: prod.split)

lemma interaction_bounded_by_intercept_upf_Inr [interaction_bound]:
  "interaction_bounded_by (Not  isl) (intercept_upf k b LD x) 1"
unfolding intercept_upf_def 
by interaction_bound(simp add: split_def one_enat_def SUP_le_iff split: sum.split)

lemma interaction_bounded_by_intercept_upf_Inl [interaction_bound]:
  "isl x  interaction_bounded_by (Not  isl) (intercept_upf k b LD x) 0"
unfolding intercept_upf_def case_prod_app
by interaction_bound(auto split: sum.split)

lemma lossless_intercept_upf_enc [simp]: "lossless_gpv (ℐ_full  ℐ_full) (intercept_upf_enc k b LD mm)"
by(simp add: intercept_upf_enc_def split_beta prf_domain_finite prf_domain_nonempty Let_def split: bool.split)

lemma lossless_intercept_upf_dec [simp]: "lossless_gpv (ℐ_full  ℐ_full) (intercept_upf_dec LD mm)"
by(simp add: intercept_upf_dec_def split_beta)

lemma lossless_intercept_upf [simp]: "lossless_gpv (ℐ_full  ℐ_full) (intercept_upf k b LD x)"
by(cases x)(simp_all add: intercept_upf_def)

lemma results_gpv_intercept_upf [simp]: "results_gpv (ℐ_full  ℐ_full) (intercept_upf k b LD x)  responses_ℐ (ℐ_full  ℐ_full) x × UNIV"
by(cases x)(auto simp add: intercept_upf_def)

definition reduction_upf :: "(bitstring, 'hash cipher_text) ind_cca.adversary
   (bitstring, 'hash) UPF.adversary"
where "reduction_upf 𝒜 = do {
    k  lift_spmf prf_key_gen;
    b  lift_spmf coin_spmf;
    (_, (L, D))  inline (intercept_upf k b) 𝒜 ({}, {});
    Done () }"

lemma lossless_reduction_upf [simp]: 
  "lossless_gpv (ℐ_full  ℐ_full) 𝒜  lossless_gpv (ℐ_full  ℐ_full) (reduction_upf 𝒜)"
by(auto simp add: reduction_upf_def prf_key_gen_lossless intro: lossless_inline del: subsetI)

context includes lifting_syntax begin

lemma round_1:
  assumes "lossless_gpv (ℐ_full  ℐ_full) 𝒜"
  shows "¦spmf (ind_cca.game 𝒜) True - spmf (ind_cca'.game 𝒜) True¦  UPF.advantage (reduction_upf 𝒜)" 
proof -
  define oracle_decrypt0' where "oracle_decrypt0'  (λkey (bad, L) (x', c', t'). return_spmf (
      if (x', c', t')  L  length x'  prf_dlen then (None, (bad, L))
      else (decrypt key (x', c', t'), (bad  upf_fun (snd key) (x' @ c') = t', L))))"
  have oracle_decrypt0'_simps:
    "oracle_decrypt0' key (bad, L) (x', c', t') = return_spmf (
       if (x', c', t')  L  length x'  prf_dlen then (None, (bad, L))
       else (decrypt key (x', c', t'), (bad  upf_fun (snd key) (x' @ c') = t', L)))"
    for key L bad x' c' t' by(simp add: oracle_decrypt0'_def)
  have lossless_oracle_decrypt0' [simp]: "lossless_spmf (oracle_decrypt0' k Lbad c)" for k Lbad c
    by(simp add: oracle_decrypt0'_def split_def)
  have callee_invariant_oracle_decrypt0' [simp]: "callee_invariant (oracle_decrypt0' k) fst" for k
    by (unfold_locales) (auto simp add: oracle_decrypt0'_def split: if_split_asm)

  define oracle_decrypt1'
    where "oracle_decrypt1' = (λ(key :: 'prf_key × 'upf_key) (bad, L) (x', c', t'). 
      return_spmf (None :: bitstring option,
        (bad  upf_fun (snd key) (x' @ c') = t'  (x', c', t')  L  length x' = prf_dlen), L))"
  have oracle_decrypt1'_simps:
    "oracle_decrypt1' key (bad, L) (x', c', t') = 
    return_spmf (None, 
      (bad  upf_fun (snd key) (x' @ c') = t'  (x', c', t')  L  length x' = prf_dlen, L))"
    for key L bad x' c' t' by(simp add: oracle_decrypt1'_def)
  have lossless_oracle_decrypt1' [simp]: "lossless_spmf (oracle_decrypt1' k Lbad c)" for k Lbad c
    by(simp add: oracle_decrypt1'_def split_def)
  have callee_invariant_oracle_decrypt1' [simp]: "callee_invariant (oracle_decrypt1' k) fst" for k
    by (unfold_locales) (auto simp add: oracle_decrypt1'_def)

  define game01'
    where "game01' = (λ(decrypt :: 'prf_key × 'upf_key  (bitstring × bitstring × 'hash, bitstring option, bool × (bitstring × bitstring × 'hash) set) callee) 𝒜. do {
    key  key_gen;
    b  coin_spmf;
    (b', (bad', L'))  exec_gpv ((ind_cca.oracle_encrypt key b) O decrypt key) 𝒜 (False, {});
    return_spmf (b = b', bad') })"
  let ?game0' = "game01' oracle_decrypt0'"
  let ?game1' = "game01' oracle_decrypt1'"

  have game0'_eq: "ind_cca.game 𝒜 = map_spmf fst (?game0' 𝒜)" (is ?game0)
    and game1'_eq: "ind_cca'.game 𝒜 = map_spmf fst (?game1' 𝒜)" (is ?game1)
  proof -
    let ?S = "rel_prod2 (=)"
    define initial where "initial = (False, {} :: 'hash cipher_text set)"
    have [transfer_rule]: "?S {} initial" by(simp add: initial_def)

    have [transfer_rule]: 
      "((=) ===> ?S ===> (=) ===> rel_spmf (rel_prod (=) ?S))
       ind_cca.oracle_decrypt oracle_decrypt0'"
      unfolding ind_cca.oracle_decrypt_def[abs_def] oracle_decrypt0'_def[abs_def]
      by(simp add: rel_spmf_return_spmf1 rel_fun_def)

    have [transfer_rule]: 
      "((=) ===> ?S ===> (=) ===> rel_spmf (rel_prod (=) ?S))
       ind_cca'.oracle_decrypt oracle_decrypt1'"
      unfolding ind_cca'.oracle_decrypt_def[abs_def] oracle_decrypt1'_def[abs_def]
      by (simp add: rel_spmf_return_spmf1 rel_fun_def)

    note [transfer_rule] = extend_state_oracle_transfer
    show ?game0 ?game1 unfolding game01'_def ind_cca.game_def ind_cca'.game_def initial_def[symmetric]
      by (simp_all add: map_spmf_bind_spmf o_def split_def) transfer_prover+
  qed

  have *: "rel_spmf (λ(b'1, (bad1, L1)) (b'2, (bad2, L2)). bad1 = bad2  (¬ bad2  b'1 = b'2))
         (exec_gpv ((ind_cca.oracle_encrypt k b) O oracle_decrypt1' k) 𝒜 (False, {}))
         (exec_gpv ((ind_cca.oracle_encrypt k b) O oracle_decrypt0' k) 𝒜 (False, {}))"
    for k b
    by (cases k; rule exec_gpv_oracle_bisim_bad[where X="(=)" and ?bad1.0=fst and ?bad2.0=fst and= "ℐ_full  ℐ_full"])
       (auto intro: rel_spmf_reflI callee_invariant_extend_state_oracle_const' simp add: spmf_rel_map1 spmf_rel_map2 oracle_decrypt0'_simps oracle_decrypt1'_simps assms split: plus_oracle_split)
    ― ‹We cannot get rid of the losslessness assumption on @{term 𝒜} in this step, because if it 
      were not, then the bad event might still occur, but the adversary does not terminate in
      the case of @{term "?game1'"}. Thus, the reduction does not terminate either, but it cannot
      detect whether the bad event has happened. So the advantage in the UPF game could be lower than
      the probability of the bad event, if the adversary is not lossless.›
  have "¦measure (measure_spmf (?game1' 𝒜)) {(b, bad). b} - measure (measure_spmf (?game0' 𝒜)) {(b, bad). b}¦ 
      measure (measure_spmf (?game1' 𝒜)) {(b, bad). bad}"
    by (rule fundamental_lemma[where ?bad2.0=snd])(auto intro!: rel_spmf_bind_reflI rel_spmf_bindI[OF *] simp add: game01'_def)
  also have " = spmf (map_spmf snd (?game1' 𝒜)) True"
    by (simp add: spmf_conv_measure_spmf measure_map_spmf split_def vimage_def)
  also have "map_spmf snd (?game1' 𝒜) = UPF.game (reduction_upf 𝒜)"
  proof -
    note [split del] = if_split
    have "map_spmf (λx. fst (snd x)) (exec_gpv ((ind_cca.oracle_encrypt (k_prf, k_upf) b) O oracle_decrypt1' (k_prf, k_upf)) 𝒜 (False, {})) = 
        map_spmf (λx. fst (snd x)) (exec_gpv (UPF.oracle k_upf) (inline (intercept_upf k_prf b) 𝒜 ({}, {})) (False, {}))"
      (is "map_spmf ?fl ?lhs = map_spmf ?fr ?rhs" is "map_spmf _ (exec_gpv ?oracle_normal _ ?init_normal) = _")
      for k_prf k_upf b
    proof(rule map_spmf_eq_map_spmfI)
      define oracle_intercept
        where [simp]: "oracle_intercept = (λ(s', s) y. map_spmf (λ((x, s'), s). (x, s', s))
         (exec_gpv (UPF.oracle k_upf) (intercept_upf k_prf b s' y) s))"
      let ?I = "(λ((L, D), (flg, Li)).
          ((x, c, t)  L. upf_fun k_upf (x @ c) = t  length x = prf_dlen) 
          (eLi. (x,c,_)  L. e = x @ c) 
          (((x, c, t)  D. upf_fun k_upf (x @ c) = t)  flg))"
      interpret callee_invariant_on oracle_intercept "?I" ℐ_full
        apply(unfold_locales)
        subgoal for s x y s'
          apply(cases s; cases s'; cases x)
           apply(clarsimp simp add: set_spmf_of_set_finite[OF prf_domain_finite]
                UPF.oracle_hash_def prf_domain_length exec_gpv_bind Let_def split: bool.splits)
          apply(force simp add: exec_gpv_bind UPF.oracle_flag_def split: if_split_asm)
          done
        subgoal by simp
        done
      
      define S :: "bool × 'hash cipher_text set  ('hash cipher_text set × 'hash cipher_text set) ×  bool × bitstring set  bool"
        where "S = (λ(bad, L1) ((L2, D), _). bad = ((x, c, t)D. upf_fun k_upf (x @ c) = t)  L1 = L2)  (λ_. True)  ?I"
      define initial :: "('hash cipher_text set × 'hash cipher_text set) ×  bool × bitstring set"
        where "initial = (({}, {}), (False, {}))"
      have [transfer_rule]: "S ?init_normal initial" by(simp add: S_def initial_def)
      have [transfer_rule]: "(S ===> (=) ===> rel_spmf (rel_prod (=) S)) ?oracle_normal oracle_intercept"
        unfolding S_def
        by(rule callee_invariant_restrict_relp, unfold_locales)
          (auto simp add: rel_fun_def bind_spmf_of_set prf_domain_finite prf_domain_nonempty bind_spmf_pmf_assoc bind_assoc_pmf bind_return_pmf spmf_rel_map exec_gpv_bind Let_def ind_cca.oracle_encrypt_def oracle_decrypt1'_def encrypt.simps UPF.oracle_hash_def UPF.oracle_flag_def bind_map_spmf o_def split: plus_oracle_split bool.split if_split intro!: rel_spmf_bind_reflI rel_pmf_bind_reflI)
      have "rel_spmf (rel_prod (=) S) ?lhs (exec_gpv oracle_intercept 𝒜 initial)"
        by(transfer_prover)
      then show "rel_spmf (λx y. ?fl x = ?fr y) ?lhs ?rhs"
        by(auto simp add: S_def exec_gpv_inline spmf_rel_map initial_def elim: rel_spmf_mono)
    qed
    then show ?thesis including monad_normalisation
      by(auto simp add: reduction_upf_def UPF.game_def game01'_def key_gen_def map_spmf_conv_bind_spmf split_def exec_gpv_bind intro!: bind_spmf_cong[OF refl])
  qed
  finally show ?thesis using game0'_eq game1'_eq 
    by (auto simp add: spmf_conv_measure_spmf measure_map_spmf vimage_def fst_def UPF.advantage_def)
qed


definition oracle_encrypt2 :: 
  "('prf_key × 'upf_key)  bool  (bitstring, bitstring) PRF.dict  bitstring × bitstring 
     ('hash cipher_text option × (bitstring, bitstring) PRF.dict) spmf"
where
  "oracle_encrypt2 = (λ(k_prf, k_upf) b D (msg1, msg0). (case (length msg1 = prf_clen  length msg0 = prf_clen) of
      False  return_spmf (None, D)
    | True  do {
        x  spmf_of_set prf_domain;
        P  spmf_of_set (nlists UNIV prf_clen);
        let p = (case D x of Some r  r | None  P);
        let c = p [⊕] (if b then msg1 else msg0);
        let t = upf_fun k_upf (x @ c);
        return_spmf (Some (x, c, t), D(x  p)) 
      }))"

definition oracle_decrypt2:: "('prf_key × 'upf_key)  ('hash cipher_text, bitstring option, 'state) callee"
where "oracle_decrypt2 = (λkey D cipher. return_spmf (None, D))"

lemma lossless_oracle_decrypt2 [simp]: "lossless_spmf (oracle_decrypt2 k Dbad c)"
  by(simp add: oracle_decrypt2_def split_def)

lemma callee_invariant_oracle_decrypt2 [simp]: "callee_invariant (oracle_decrypt2 key) fst"
  by (unfold_locales) (auto simp add: oracle_decrypt2_def split: if_split_asm)

lemma oracle_decrypt2_parametric [transfer_rule]:
  "(rel_prod P U ===> S ===> rel_prod (=) (rel_prod (=) H) ===> rel_spmf (rel_prod (=) S))
   oracle_decrypt2 oracle_decrypt2"
  unfolding oracle_decrypt2_def split_def relator_eq[symmetric] by transfer_prover

definition game2 :: "(bitstring, 'hash cipher_text) ind_cca.adversary  bool spmf" 
where 
  "game2 𝒜  do {
    key  key_gen;
    b  coin_spmf;
    (b', D)  exec_gpv 
      (oracle_encrypt2 key b O oracle_decrypt2 key) 𝒜 Map_empty;
    return_spmf (b = b')
  }"

fun intercept_prf :: 
  "'upf_key  bool  unit  (bitstring × bitstring) + 'hash cipher_text
   (('hash cipher_text option + bitstring option) × unit, bitstring, bitstring) gpv" 
where
  "intercept_prf _ _ _ (Inr _) = Done (Inr None, ())"
| "intercept_prf k b _ (Inl (m1, m0)) = (case (length m1) = prf_clen  (length m0) = prf_clen of
      False  Done (Inl None, ())
    | True  do {
        x  lift_spmf (spmf_of_set prf_domain);
        p  Pause x Done;
        let c = p [⊕] (if b then m1 else m0);
        let t = upf_fun k (x @ c);
        Done (Inl (Some (x, c, t)), ())
      })"

definition reduction_prf
  :: "(bitstring, 'hash cipher_text) ind_cca.adversary  (bitstring, bitstring) PRF.adversary"
where 
 "reduction_prf 𝒜 = do {
   k  lift_spmf upf_key_gen;
   b  lift_spmf coin_spmf;
   (b', _)  inline (intercept_prf k b) 𝒜 ();
   Done (b' = b)
 }"

lemma round_2: "¦spmf (ind_cca'.game 𝒜) True - spmf (game2 𝒜) True¦ = PRF.advantage (reduction_prf 𝒜)" 
proof -
  define oracle_encrypt1''
    where "oracle_encrypt1'' = (λ(k_prf, k_upf) b (_ :: unit) (msg1, msg0). 
      case length msg1 = prf_clen  length msg0 = prf_clen of
        False  return_spmf (None, ())
      | True  do {
          x  spmf_of_set prf_domain;
          let p = prf_fun k_prf x;
          let c = p [⊕] (if b then msg1 else msg0);
          let t = upf_fun k_upf (x @ c);
          return_spmf (Some (x, c, t), ())})"
  define game1'' where "game1'' = do {
    key  key_gen;
    b  coin_spmf;
    (b', D)  exec_gpv (oracle_encrypt1'' key b O oracle_decrypt2 key) 𝒜 ();
    return_spmf (b = b')}"

  have "ind_cca'.game 𝒜 = game1''"
  proof -
    define S where "S = (λ(L :: 'hash cipher_text set) (D :: unit). True)"
    have [transfer_rule]: "S {} ()" by (simp add: S_def)
    have [transfer_rule]: 
      "((=) ===> (=) ===> S ===> (=) ===> rel_spmf (rel_prod (=) S))
       ind_cca'.oracle_encrypt oracle_encrypt1''"
      unfolding ind_cca'.oracle_encrypt_def[abs_def] oracle_encrypt1''_def[abs_def]
      by (auto simp add: rel_fun_def Let_def S_def encrypt.simps prf_domain_finite prf_domain_nonempty intro: rel_spmf_bind_reflI rel_pmf_bind_reflI split: bool.split)
    have [transfer_rule]:
      "((=) ===> S ===> (=) ===> rel_spmf (rel_prod (=) S)) 
       ind_cca'.oracle_decrypt oracle_decrypt2"
      unfolding ind_cca'.oracle_decrypt_def[abs_def] oracle_decrypt2_def[abs_def]
      by(auto simp add: rel_fun_def)
    show ?thesis unfolding ind_cca'.game_def game1''_def by transfer_prover
  qed

  also have " = PRF.game_0 (reduction_prf 𝒜)"
  proof -
    { fix k_prf k_upf b
      define oracle_normal
        where "oracle_normal = oracle_encrypt1'' (k_prf, k_upf) b O oracle_decrypt2 (k_prf, k_upf)"
      define oracle_intercept
        where "oracle_intercept = (λ(s', s :: unit) y. map_spmf (λ((x, s'), s). (x, s', s)) (exec_gpv (PRF.prf_oracle k_prf) (intercept_prf k_upf b s' y) ()))"
      define initial where "initial = ()"
      define S where "S = (λ(s2 :: unit, _ :: unit) (s1 :: unit). True)"
      have [transfer_rule]: "S ((), ()) initial" by(simp add: S_def initial_def)
      have [transfer_rule]: "(S ===> (=) ===> rel_spmf (rel_prod (=) S)) oracle_intercept oracle_normal"
        unfolding oracle_normal_def oracle_intercept_def
        by(auto split: bool.split plus_oracle_split simp add: S_def rel_fun_def exec_gpv_bind PRF.prf_oracle_def oracle_encrypt1''_def Let_def map_spmf_conv_bind_spmf oracle_decrypt2_def intro!: rel_spmf_bind_reflI rel_spmf_reflI)
      have "map_spmf (λx. b = fst x) (exec_gpv oracle_normal 𝒜 initial) =
        map_spmf (λx. b = fst (fst x)) (exec_gpv (PRF.prf_oracle k_prf) (inline (intercept_prf k_upf b) 𝒜 ()) ())"
        by(transfer fixing: b 𝒜 prf_fun k_prf prf_domain prf_clen upf_fun k_upf)
          (auto simp add: map_spmf_eq_map_spmf_iff exec_gpv_inline spmf_rel_map oracle_intercept_def split_def intro: rel_spmf_reflI) }
    then show ?thesis unfolding game1''_def PRF.game_0_def key_gen_def reduction_prf_def
      by (auto simp add: exec_gpv_bind_lift_spmf exec_gpv_bind map_spmf_conv_bind_spmf split_def eq_commute intro!: bind_spmf_cong[OF refl])
  qed
  also have "game2 𝒜 = PRF.game_1 (reduction_prf 𝒜)"
  proof -
    note [split del] = if_split
    { fix k_upf b k_prf
      define oracle2
        where "oracle2 = oracle_encrypt2 (k_prf, k_upf) b O oracle_decrypt2 (k_prf, k_upf)"
      define oracle_intercept
        where "oracle_intercept = (λ(s', s) y. map_spmf (λ((x, s'), s). (x, s', s)) (exec_gpv PRF.random_oracle (intercept_prf k_upf b s' y) s))"
      define S
        where "S = (λ(s2 :: unit, s2') (s1 :: (bitstring, bitstring) PRF.dict). s2' = s1)"

      have [transfer_rule]: "S ((), Map_empty) Map_empty" by(simp add: S_def)
      have [transfer_rule]: "(S ===> (=) ===> rel_spmf (rel_prod (=) S)) oracle_intercept oracle2"
        unfolding oracle2_def oracle_intercept_def
        by(auto split: bool.split plus_oracle_split option.split simp add: S_def rel_fun_def exec_gpv_bind PRF.random_oracle_def oracle_encrypt2_def Let_def map_spmf_conv_bind_spmf oracle_decrypt2_def rel_spmf_return_spmf1 fun_upd_idem intro!: rel_spmf_bind_reflI rel_spmf_reflI)

      have [symmetric]: "map_spmf (λx. b = fst (fst x)) (exec_gpv (PRF.random_oracle) (inline (intercept_prf k_upf b) 𝒜 ()) Map.empty) = 
        map_spmf (λx. b = fst x) (exec_gpv oracle2 𝒜 Map_empty)"
        by(transfer fixing: b prf_clen prf_domain upf_fun k_upf 𝒜 k_prf)
          (simp add: exec_gpv_inline map_spmf_conv_bind_spmf[symmetric] spmf.map_comp o_def split_def oracle_intercept_def) }
    then show ?thesis
      unfolding game2_def PRF.game_1_def key_gen_def reduction_prf_def
      by (clarsimp simp add: exec_gpv_bind_lift_spmf exec_gpv_bind map_spmf_conv_bind_spmf split_def bind_spmf_const prf_key_gen_lossless lossless_weight_spmfD eq_commute)
  qed
  ultimately show ?thesis by(simp add: PRF.advantage_def)
qed


definition oracle_encrypt3 :: 
   "('prf_key × 'upf_key)  bool  (bool × (bitstring, bitstring) PRF.dict) 
    bitstring × bitstring  ('hash cipher_text option × (bool × (bitstring, bitstring) PRF.dict)) spmf"
where
  "oracle_encrypt3 = (λ(k_prf, k_upf) b (bad, D) (msg1, msg0). 
    (case (length msg1 = prf_clen  length msg0 = prf_clen) of
      False  return_spmf (None, (bad, D))
    | True  do {
        x  spmf_of_set prf_domain;
        P  spmf_of_set (nlists UNIV prf_clen);
        let (p, F) = (case D x of Some r  (P, True) | None  (P, False));
        let c = p [⊕] (if b then msg1 else msg0);
        let t = upf_fun k_upf (x @ c);
        return_spmf (Some (x, c, t), (bad  F, D(x  p))) 
      }))"

lemma lossless_oracle_encrypt3 [simp]:
  "lossless_spmf (oracle_encrypt3 k b D m10) "
  by (cases m10) (simp add: oracle_encrypt3_def prf_domain_nonempty prf_domain_finite
    split_def Let_def split: bool.splits)

lemma callee_invariant_oracle_encrypt3 [simp]: "callee_invariant (oracle_encrypt3 key b) fst"
  by (unfold_locales) (auto simp add: oracle_encrypt3_def split_def Let_def split: bool.splits)

definition game3 :: "(bitstring, 'hash cipher_text) ind_cca.adversary  (bool × bool) spmf" 
where 
  "game3 𝒜  do {
    key  key_gen;
    b  coin_spmf;
    (b', (bad, D))  exec_gpv (oracle_encrypt3 key b O oracle_decrypt2 key) 𝒜 (False, Map_empty);
    return_spmf (b = b', bad)
  }"

lemma round_3:
  assumes "lossless_gpv (ℐ_full  ℐ_full) 𝒜"
  shows "¦measure (measure_spmf (game3 𝒜)) {(b, bad). b} - spmf (game2 𝒜) True¦ 
           measure (measure_spmf (game3 𝒜)) {(b, bad). bad}" 
proof -
  define oracle_encrypt2'
    where "oracle_encrypt2' = (λ(k_prf :: 'prf_key, k_upf) b (bad, D) (msg1, msg0). 
      case length msg1 = prf_clen  length msg0 = prf_clen of
        False  return_spmf (None, (bad, D))
      | True  do {
          x  spmf_of_set prf_domain;
          P  spmf_of_set (nlists UNIV prf_clen);
          let (p, F) = (case D x of Some r  (r, True) | None  (P, False));
          let c = p [⊕] (if b then msg1 else msg0);
          let t = upf_fun k_upf (x @ c);
          return_spmf (Some (x, c, t), (bad  F, D(x  p))) 
        })"

  have [simp]: "lossless_spmf (oracle_encrypt2' key b D msg10) " for key b D msg10
    by (cases msg10) (simp add: oracle_encrypt2'_def prf_domain_nonempty prf_domain_finite
      split_def Let_def split: bool.split)
  have [simp]: "callee_invariant (oracle_encrypt2' key b) fst" for key b
    by (unfold_locales) (auto simp add: oracle_encrypt2'_def split_def Let_def split: bool.splits)

  define game2'
    where "game2' = (λ𝒜. do {
      key  key_gen;
      b  coin_spmf;
      (b', (bad, D))  exec_gpv (oracle_encrypt2' key b O oracle_decrypt2 key) 𝒜 (False, Map_empty);
      return_spmf (b = b', bad)})"

  have game2'_eq: "game2 𝒜 = map_spmf fst (game2' 𝒜)"
  proof -
    define S where "S = (λ(D1 :: (bitstring, bitstring) PRF.dict) (bad :: bool, D2). D1 = D2)"
    have [transfer_rule, simp]: "S Map_empty (b, Map_empty)" for b by (simp add: S_def)
  
    have [transfer_rule]: "((=) ===> (=) ===> S ===> (=) ===> rel_spmf (rel_prod (=) S))
      oracle_encrypt2 oracle_encrypt2'"
      unfolding oracle_encrypt2_def[abs_def] oracle_encrypt2'_def[abs_def]
      by (auto simp add: rel_fun_def Let_def split_def S_def
           intro!: rel_spmf_bind_reflI split: bool.split option.split)
    have [transfer_rule]: "((=) ===> S ===> (=) ===> rel_spmf (rel_prod (=) S)) 
      oracle_decrypt2 oracle_decrypt2"
      by(auto simp add: rel_fun_def oracle_decrypt2_def)

    show ?thesis unfolding game2_def game2'_def
      by (simp add: map_spmf_bind_spmf o_def split_def Map_empty_def[symmetric] del: Map_empty_def)
         transfer_prover
  qed
  moreover have *: "rel_spmf (λ(b'1, bad1, L1) (b'2, bad2, L2). (bad1  bad2)  (¬ bad2  b'1  b'2))
    (exec_gpv (oracle_encrypt3 key b O oracle_decrypt2 key) 𝒜 (False, Map_empty))
    (exec_gpv (oracle_encrypt2' key b O oracle_decrypt2 key) 𝒜 (False, Map_empty))"
    for key b
    apply(rule exec_gpv_oracle_bisim_bad[where X="(=)" and X_bad = "λ_ _. True" and ?bad1.0=fst and ?bad2.0=fst and="ℐ_full  ℐ_full"])
    apply(simp_all add: assms)
    apply(auto simp add: assms spmf_rel_map Let_def oracle_encrypt2'_def oracle_encrypt3_def split: plus_oracle_split prod.split bool.split option.split intro!: rel_spmf_bind_reflI rel_spmf_reflI)
    done
  have "¦measure (measure_spmf (game3 𝒜)) {(b, bad). b} - measure (measure_spmf (game2' 𝒜)) {(b, bad). b}¦ 
    measure (measure_spmf (game3 𝒜)) {(b, bad). bad}"
    unfolding game2'_def game3_def
    by(rule fundamental_lemma[where ?bad2.0=snd])(intro rel_spmf_bind_reflI rel_spmf_bindI[OF *]; clarsimp)
  ultimately show ?thesis by(simp add: spmf_conv_measure_spmf measure_map_spmf vimage_def fst_def)
qed

lemma round_4:
  assumes "lossless_gpv (ℐ_full  ℐ_full) 𝒜"
  shows "map_spmf fst (game3 𝒜) = coin_spmf" 
proof -
  define oracle_encrypt4
    where "oracle_encrypt4 = (λ(k_prf :: 'prf_key, k_upf) (s :: unit) (msg1 :: bitstring, msg0 :: bitstring).
      case length msg1 = prf_clen  length msg0 = prf_clen of
        False  return_spmf (None, s)
      | True  do {
          x  spmf_of_set prf_domain;
          P  spmf_of_set (nlists UNIV prf_clen);
          let c = P;
          let t = upf_fun k_upf (x @ c);
          return_spmf (Some (x, c, t), s) })"

  have [simp]: "lossless_spmf (oracle_encrypt4 k s msg10)" for k s msg10 
    by (cases msg10) (simp add: oracle_encrypt4_def prf_domain_finite prf_domain_nonempty
      split_def Let_def split: bool.splits)

  define game4 where "game4 = (λ𝒜. do {
    key  key_gen;
    (b', _)  exec_gpv (oracle_encrypt4 key O oracle_decrypt2 key) 𝒜 ();
    map_spmf ((=) b') coin_spmf})"

  have "map_spmf fst (game3 𝒜) = game4 𝒜"
  proof -
    note [split del] = if_split
    define S where "S = (λ(_ :: unit) (_ :: bool × (bitstring, bitstring) PRF.dict). True)"
    define initial3 where "initial3 = (False, Map.empty :: (bitstring, bitstring) PRF.dict)"
    have [transfer_rule]: "S () initial3" by(simp add: S_def)
    have [transfer_rule]: "((=) ===> (=) ===> S ===> (=) ===> rel_spmf (rel_prod (=) S))
       (λkey b. oracle_encrypt4 key) oracle_encrypt3"
    proof(intro rel_funI; hypsubst)
      fix key unit msg10 b Dbad
      have "map_spmf fst (oracle_encrypt4 key () msg10) = map_spmf fst (oracle_encrypt3 key b Dbad msg10)"
        unfolding oracle_encrypt3_def oracle_encrypt4_def
        apply (clarsimp simp add: map_spmf_conv_bind_spmf Let_def split: bool.split prod.split; rule conjI; clarsimp)
        apply (rewrite in " = _" one_time_pad[symmetric, where xs="if b then fst msg10 else snd msg10"])
         apply(simp split: if_split)
        apply(simp add: bind_map_spmf o_def option.case_distrib case_option_collapse xor_list_commute split_def cong del: option.case_cong_weak if_weak_cong)
        done
      then show "rel_spmf (rel_prod (=) S) (oracle_encrypt4 key unit msg10) (oracle_encrypt3 key b Dbad msg10)"
        by(auto simp add: spmf_rel_eq[symmetric] spmf_rel_map S_def elim: rel_spmf_mono)
    qed

    show ?thesis
      unfolding game3_def game4_def including monad_normalisation
      by (simp add: map_spmf_bind_spmf o_def split_def map_spmf_conv_bind_spmf initial3_def[symmetric] eq_commute)
         transfer_prover
  qed
  also have " = coin_spmf"
    by(simp add: map_eq_const_coin_spmf game4_def bind_spmf_const split_def lossless_exec_gpv[OF assms] lossless_weight_spmfD)
  finally  show ?thesis .
qed

lemma game3_bad:
  assumes "interaction_bounded_by isl 𝒜 q"
  shows "measure (measure_spmf (game3 𝒜)) {(b, bad). bad}  q / card prf_domain * q"
proof -
  have "measure (measure_spmf (game3 𝒜)) {(b, bad). bad} = spmf (map_spmf snd (game3 𝒜)) True"
    by (simp add: spmf_conv_measure_spmf measure_map_spmf vimage_def snd_def)
  also
  have "spmf (map_spmf (fst  snd) (exec_gpv (oracle_encrypt3 k b O oracle_decrypt2 k) 𝒜 (False, Map.empty))) True  q / card prf_domain * q"
    (is "spmf (map_spmf _ (exec_gpv ?oracle _ _)) _   _")
    if k: "k  set_spmf key_gen" for k b
  proof(rule callee_invariant_on.interaction_bounded_by'_exec_gpv_bad_count)
    obtain k_prf k_upf where k: "k = (k_prf, k_upf)" by(cases k)
    let ?I = "λ(bad, D). finite (dom D)  dom D  prf_domain"
    have "callee_invariant (oracle_encrypt3 k b) ?I"
      by unfold_locales(clarsimp simp add: prf_domain_finite oracle_encrypt3_def Let_def split_def split: bool.splits)+
    moreover have "callee_invariant (oracle_decrypt2 k) ?I"
      by unfold_locales (clarsimp simp add: prf_domain_finite oracle_decrypt2_def)+
    ultimately show "callee_invariant ?oracle ?I" by simp

    let ?count = "λ(bad, D). card (dom D)"
    show "s x y s'.  (y, s')  set_spmf (?oracle s x); ?I s; isl x   ?count s'  Suc (?count s)"
      by(clarsimp simp add: isl_def oracle_encrypt3_def split_def Let_def card_insert_if split: bool.splits)
    show " (y, s')  set_spmf (?oracle s x); ?I s; ¬ isl x   ?count s'  ?count s" for s x y s'
      by(cases x)(simp_all add: oracle_decrypt2_def)
    show "spmf (map_spmf (fst  snd) (?oracle s' x)) True  q / card prf_domain"
      if I: "?I s'" and bad: "¬ fst s'" and count: "?count s' < q + ?count (False, Map.empty)" 
      and x: "isl x"
      for s' x
    proof -
      obtain bad D where s' [simp]: "s' = (bad, D)" by(cases s')
      from x obtain m1 m0 where x [simp]: "x = Inl (m1, m0)" by(auto elim: islE)
      have *: "(case D x of None  False | Some x  True)  x  dom D" for x
        by(auto split: option.split)
      show ?thesis
      proof(cases "length m1 = prf_clen  length m0 = prf_clen")
        case True
        with bad
        have "spmf (map_spmf (fst  snd) (?oracle s' x)) True = pmf (bernoulli_pmf (card (dom D  prf_domain) / card prf_domain)) True"
          by(simp add: spmf.map_comp o_def oracle_encrypt3_def k * bool.case_distrib[where h="λp. spmf (map_spmf _ p) _"] option.case_distrib[where h=snd] map_spmf_bind_spmf Let_def split_beta bind_spmf_const cong: bool.case_cong option.case_cong split del: if_split split: bool.split)
            (simp add: map_spmf_conv_bind_spmf[symmetric] map_mem_spmf_of_set prf_domain_finite prf_domain_nonempty)
        also have " = card (dom D  prf_domain) / card prf_domain"
          by(rule pmf_bernoulli_True)(auto simp add: field_simps prf_domain_finite prf_domain_nonempty card_gt_0_iff card_mono)
        also have "dom D  prf_domain = dom D" using I by auto
        also have "card (dom D)  q" using count by simp
        finally show ?thesis by(simp add: divide_right_mono o_def)
      next
        case False
        thus ?thesis using bad 
          by(auto simp add: spmf.map_comp o_def oracle_encrypt3_def k split: bool.split)
      qed
    qed
  qed(auto split: plus_oracle_split_asm simp add: oracle_decrypt2_def assms)
  then have "spmf (map_spmf snd (game3 𝒜)) True  q / card prf_domain * q"
    by(auto 4 3 simp add: game3_def map_spmf_bind_spmf o_def split_def map_spmf_conv_bind_spmf intro: spmf_bind_leI)
  finally show ?thesis .
qed


theorem security:
  assumes lossless: "lossless_gpv (ℐ_full  ℐ_full) 𝒜"
  and bound: "interaction_bounded_by isl 𝒜 q"
  shows "ind_cca.advantage 𝒜  
    PRF.advantage (reduction_prf 𝒜) + UPF.advantage (reduction_upf 𝒜) +
    real q / real (card prf_domain) * real q" (is "?LHS  _")
proof -
  have "?LHS  ¦spmf (ind_cca.game 𝒜) True - spmf (ind_cca'.game 𝒜) True¦ + ¦spmf (ind_cca'.game 𝒜) True - 1 / 2¦"
    (is "_  ?round1 + ?rest") using abs_triangle_ineq by(simp add: ind_cca.advantage_def)
  also have "?round1  UPF.advantage (reduction_upf 𝒜)"
    using lossless by(rule round_1)
  also have "?rest  ¦spmf (ind_cca'.game 𝒜) True - spmf (game2 𝒜) True¦ + ¦spmf (game2 𝒜) True - 1 / 2¦"
    (is "_  ?round2 + ?rest") using abs_triangle_ineq by simp
  also have "?round2 = PRF.advantage (reduction_prf 𝒜)" by(rule round_2)
  also have "?rest  ¦measure (measure_spmf (game3 𝒜)) {(b, bad). b} - spmf (game2 𝒜) True¦ +
       ¦measure (measure_spmf (game3 𝒜)) {(b, bad). b} - 1 / 2¦" 
    (is "_  ?round3 + _") using abs_triangle_ineq by simp
  also have "?round3  measure (measure_spmf (game3 𝒜)) {(b, bad). bad}"
    using round_3[OF lossless] .
  also have "  q / card prf_domain * q" using bound by(rule game3_bad)
  also have "measure (measure_spmf (game3 𝒜)) {(b, bad). b} = spmf coin_spmf True"
    using round_4[OF lossless, symmetric]
    by(simp add: spmf_conv_measure_spmf measure_map_spmf vimage_def fst_def)
  also have "¦ - 1 / 2¦ = 0" by(simp add: spmf_of_set)
  finally show ?thesis by(simp)
qed

theorem security1:
  assumes lossless: "lossless_gpv (ℐ_full  ℐ_full) 𝒜"
  assumes q: "interaction_bounded_by isl 𝒜 q"
  and q': "interaction_bounded_by (Not  isl) 𝒜 q'"
  shows "ind_cca.advantage 𝒜  
    PRF.advantage (reduction_prf 𝒜) +
    UPF.advantage1 (guessing_many_one.reduction q' (λ_. reduction_upf 𝒜) ()) * q' +
    real q * real q / real (card prf_domain)"
proof -
  have "ind_cca.advantage 𝒜  
    PRF.advantage (reduction_prf 𝒜) + UPF.advantage (reduction_upf 𝒜) +
    real q / real (card prf_domain) * real q"
    using lossless q by(rule security)
  also note q'[interaction_bound]
  have "interaction_bounded_by (Not  isl) (reduction_upf 𝒜) q'"
    unfolding reduction_upf_def by(interaction_bound)(simp_all add: SUP_le_iff)
  then have "UPF.advantage (reduction_upf 𝒜)  UPF.advantage1 (guessing_many_one.reduction q' (λ_. reduction_upf 𝒜) ()) * q'"
    by(rule UPF.advantage_advantage1)
  finally show ?thesis by(simp)
qed

end

end

locale simple_cipher' = 
  fixes prf_key_gen :: "security  'prf_key spmf"
  and prf_fun :: "security  'prf_key  bitstring  bitstring"
  and prf_domain :: "security  bitstring set"
  and prf_range :: "security  bitstring set"
  and prf_dlen :: "security  nat"
  and prf_clen :: "security  nat"
  and upf_key_gen :: "security  'upf_key spmf"
  and upf_fun :: "security  'upf_key  bitstring  'hash"
  assumes simple_cipher: "η. simple_cipher (prf_key_gen η) (prf_fun η) (prf_domain η) (prf_dlen η) (prf_clen η) (upf_key_gen η)"
begin

sublocale simple_cipher 
  "prf_key_gen η" "prf_fun η" "prf_domain η" "prf_range η" "prf_dlen η" "prf_clen η" "upf_key_gen η" "upf_fun η"
  for η
by(rule simple_cipher)

theorem security_asymptotic:
  fixes q q' :: "security  nat"
  assumes lossless: "η. lossless_gpv (ℐ_full  ℐ_full) (𝒜 η)"
  and bound: "η. interaction_bounded_by isl (𝒜 η) (q η)"
  and bound': "η. interaction_bounded_by (Not  isl) (𝒜 η) (q' η)"
  and [negligible_intros]:
    "polynomial q'" "polynomial q"
    "negligible (λη. PRF.advantage η (reduction_prf η (𝒜 η)))"
    "negligible (λη. UPF.advantage1 η (guessing_many_one.reduction (q' η) (λ_. reduction_upf η (𝒜 η)) ()))"
    "negligible (λη. 1 / card (prf_domain η))"
  shows "negligible (λη. ind_cca.advantage η (𝒜 η))"
proof -
  have "negligible (λη. PRF.advantage η (reduction_prf η (𝒜 η)) +
    UPF.advantage1 η (guessing_many_one.reduction (q' η) (λ_. reduction_upf η (𝒜 η)) ()) * q' η +
    real (q η) / real (card (prf_domain η)) * real (q η))"
    by(rule negligible_intros)+
  thus ?thesis by(rule negligible_le)(simp add: security1[OF lossless bound bound'] ind_cca.advantage_nonneg)
qed

end

end