Theory IND_CCA2

(* Title: IND_CCA2.thy
  Author: Andreas Lochbihler, ETH Zurich *)

theory IND_CCA2 imports
  CryptHOL.Computational_Model
  CryptHOL.Negligible
  CryptHOL.Environment_Functor
begin

locale pk_enc = 
  fixes key_gen :: "security  ('ekey × 'dkey) spmf" ― ‹probabilistic›
  and encrypt :: "security  'ekey  'plain  'cipher spmf"  ― ‹probabilistic›
  and decrypt :: "security  'dkey  'cipher  'plain option" ― ‹deterministic, but not used›
  and valid_plain :: "security  'plain  bool" ― ‹checks whether a plain text is valid, i.e., has the right format›

subsection ‹The IND-CCA2 game for public-key encryption›

text ‹
  We model an IND-CCA2 security game in the multi-user setting as described in
  cite"BellareBoldyrevaMicali2000EUROCRYPT".
›

locale ind_cca2 = pk_enc +
  constrains key_gen :: "security  ('ekey × 'dkey) spmf"
  and encrypt :: "security  'ekey  'plain  'cipher spmf"
  and decrypt :: "security  'dkey  'cipher  'plain option"
  and valid_plain :: "security  'plain  bool"
begin

type_synonym ('ekey', 'dkey', 'cipher') state_oracle = "('ekey' × 'dkey' × 'cipher' list) option"

fun decrypt_oracle
  :: "security  ('ekey, 'dkey, 'cipher) state_oracle  'cipher
   ('plain option × ('ekey, 'dkey, 'cipher) state_oracle) spmf"
where
  "decrypt_oracle η None cipher = return_spmf (None, None)"
| "decrypt_oracle η (Some (ekey, dkey, cstars)) cipher = return_spmf
   (if cipher  set cstars then None else decrypt η dkey cipher, Some (ekey, dkey, cstars))"

fun ekey_oracle
  :: "security  ('ekey, 'dkey, 'cipher) state_oracle  unit  ('ekey × ('ekey, 'dkey, 'cipher) state_oracle) spmf"
where
  "ekey_oracle η None _ = do {
      (ekey, dkey)  key_gen η;
      return_spmf (ekey, Some (ekey, dkey, []))
    }"
| "ekey_oracle η (Some (ekey, rest)) _ = return_spmf (ekey, Some (ekey, rest))"

lemma ekey_oracle_conv:
  "ekey_oracle η σ x =
  (case σ of None  map_spmf (λ(ekey, dkey). (ekey, Some (ekey, dkey, []))) (key_gen η) 
   | Some (ekey, rest)  return_spmf (ekey, Some (ekey, rest)))"
by(cases σ)(auto simp add: map_spmf_conv_bind_spmf split_def)

context notes bind_spmf_cong[fundef_cong] begin
function encrypt_oracle
  :: "bool  security  ('ekey, 'dkey, 'cipher) state_oracle  'plain × 'plain
   ('cipher × ('ekey, 'dkey, 'cipher) state_oracle) spmf"
where
  "encrypt_oracle b η None m01 = do { (_, σ)  ekey_oracle η None (); encrypt_oracle b η σ m01 }"
| "encrypt_oracle b η (Some (ekey, dkey, cstars)) (m0, m1) =
  (if valid_plain η m0  valid_plain η m1 then do {  
     let pb = (if b then m0 else m1);
     cstar  encrypt η ekey pb;
     return_spmf (cstar, Some (ekey, dkey, cstar # cstars))
   } else return_pmf None)"
by pat_completeness auto
termination by(relation "Wellfounded.measure (λ(b, η, σ, m01). case σ of None  1 | _  0)") auto
end

subsubsection ‹Single-user setting›

type_synonym ('plain', 'cipher') call1 = "unit + 'cipher' + 'plain' × 'plain'"
type_synonym ('ekey', 'plain', 'cipher') ret1 = "'ekey' + 'plain' option + 'cipher'"

definition oracle1 :: "bool  security 
   (('ekey, 'dkey, 'cipher) state_oracle, ('plain, 'cipher) call1, ('ekey, 'plain, 'cipher) ret1) oracle'"
where "oracle1 b η = ekey_oracle η O (decrypt_oracle η O encrypt_oracle b η)"

lemma oracle1_simps [simp]:
  "oracle1 b η s (Inl x) = map_spmf (apfst Inl) (ekey_oracle η s x)"
  "oracle1 b η s (Inr (Inl y)) = map_spmf (apfst (Inr  Inl)) (decrypt_oracle η s y)"
  "oracle1 b η s (Inr (Inr z)) = map_spmf (apfst (Inr  Inr)) (encrypt_oracle b η s z)"
by(simp_all add: oracle1_def spmf.map_comp apfst_compose o_def)

type_synonym ('ekey', 'plain', 'cipher') adversary1' = 
  "(bool, ('plain', 'cipher') call1, ('ekey', 'plain', 'cipher') ret1) gpv"
type_synonym ('ekey', 'plain', 'cipher') adversary1 =
  "security  ('ekey', 'plain', 'cipher') adversary1'"

definition ind_cca21 :: "('ekey, 'plain, 'cipher) adversary1  security  bool spmf"
where
  "ind_cca21 𝒜 η = TRY do {
    b  coin_spmf;
    (guess, s)  exec_gpv (oracle1 b η) (𝒜 η) None;
    return_spmf (guess = b)
  } ELSE coin_spmf"

definition advantage1 :: "('ekey, 'plain, 'cipher) adversary1  advantage"
where "advantage1 𝒜 η = ¦spmf (ind_cca21 𝒜 η) True - 1/2¦"

lemma advantage1_nonneg: "advantage1 𝒜 η  0" by(simp add: advantage1_def)

abbreviation secure_for1 :: "('ekey, 'plain, 'cipher) adversary1  bool"
where "secure_for1 𝒜  negligible (advantage1 𝒜)"

definition ibounded_by1' :: "('ekey, 'plain, 'cipher) adversary1'  nat  bool"
where "ibounded_by1' 𝒜 q = interaction_any_bounded_by 𝒜 q"

abbreviation ibounded_by1 :: "('ekey, 'plain, 'cipher) adversary1  (security  nat)  bool"
where "ibounded_by1  rel_envir ibounded_by1'"

definition lossless1' :: "('ekey, 'plain, 'cipher) adversary1'  bool"
where "lossless1' 𝒜 = lossless_gpv ℐ_full 𝒜"

abbreviation lossless1 :: "('ekey, 'plain, 'cipher) adversary1  bool"
where "lossless1  pred_envir lossless1'"

lemma lossless_decrypt_oracle [simp]: "lossless_spmf (decrypt_oracle η σ cipher)"
by(cases "(η, σ, cipher)" rule: decrypt_oracle.cases) simp_all

lemma lossless_ekey_oracle [simp]:
  "lossless_spmf (ekey_oracle η σ x)  (σ = None  lossless_spmf (key_gen η))"
by(cases "(η, σ, x)" rule: ekey_oracle.cases)(auto)

lemma lossless_encrypt_oracle [simp]:
  " σ = None  lossless_spmf (key_gen η);
    ekey m. valid_plain η m  lossless_spmf (encrypt η ekey m) 
   lossless_spmf (encrypt_oracle b η σ (m0, m1))  valid_plain η m0  valid_plain η m1"
apply(cases "(b, η, σ, (m0, m1))" rule: encrypt_oracle.cases)
apply(auto simp add: split_beta dest: lossless_spmfD_set_spmf_nonempty split: if_split_asm)
done

subsubsection ‹Multi-user setting›

definition oraclen :: "bool  security
    ('i  ('ekey, 'dkey, 'cipher) state_oracle, 'i × ('plain, 'cipher) call1, ('ekey, 'plain, 'cipher) ret1) oracle'"
where "oraclen b η = family_oracle (λ_. oracle1 b η)"

lemma oraclen_apply [simp]:
  "oraclen b η s (i, x) = map_spmf (apsnd (fun_upd s i)) (oracle1 b η (s i) x)"
by(simp add: oraclen_def)

type_synonym ('i, 'ekey', 'plain', 'cipher') adversaryn' = 
  "(bool, 'i × ('plain', 'cipher') call1, ('ekey', 'plain', 'cipher') ret1) gpv"
type_synonym ('i, 'ekey', 'plain', 'cipher') adversaryn =
  "security  ('i, 'ekey', 'plain', 'cipher') adversaryn'"

definition ind_cca2n :: "('i, 'ekey, 'plain, 'cipher) adversaryn  security  bool spmf"
where
  "ind_cca2n 𝒜 η = TRY do {
    b  coin_spmf;
    (guess, σ)  exec_gpv (oraclen b η) (𝒜 η) (λ_. None);
    return_spmf (guess = b)
  } ELSE coin_spmf"

definition advantagen :: "('i, 'ekey, 'plain, 'cipher) adversaryn  advantage"
where "advantagen 𝒜 η = ¦spmf (ind_cca2n 𝒜 η) True - 1/2¦"

lemma advantagen_nonneg: "advantagen 𝒜 η  0" by(simp add: advantagen_def)

abbreviation secure_forn :: "('i, 'ekey, 'plain, 'cipher) adversaryn  bool"
where "secure_forn 𝒜  negligible (advantagen 𝒜)"

definition ibounded_byn' :: "('i, 'ekey, 'plain, 'cipher) adversaryn'  nat  bool"
where "ibounded_byn' 𝒜 q = interaction_any_bounded_by 𝒜 q"

abbreviation ibounded_byn :: "('i, 'ekey, 'plain, 'cipher) adversaryn  (security  nat)  bool"
where "ibounded_byn  rel_envir ibounded_byn'"

definition losslessn' :: "('i, 'ekey, 'plain, 'cipher) adversaryn'  bool"
where "losslessn' 𝒜 = lossless_gpv ℐ_full 𝒜"

abbreviation losslessn :: "('i, 'ekey, 'plain, 'cipher) adversaryn  bool"
where "losslessn  pred_envir losslessn'"


definition cipher_queries :: "('i  ('ekey, 'dkey, 'cipher) state_oracle)  'cipher set"
where "cipher_queries ose = ((_, _, ciphers)ran ose. set ciphers)"

lemma cipher_queriesI:
  " ose n = Some (ek, dk, ciphers); x  set ciphers   x  cipher_queries ose"
by(auto simp add: cipher_queries_def ran_def)

lemma cipher_queriesE:
  assumes "x  cipher_queries ose"
  obtains (cipher_queries) n ek dk ciphers where "ose n = Some (ek, dk, ciphers)" "x  set ciphers"
using assms by(auto simp add: cipher_queries_def ran_def)

lemma cipher_queries_updE:
  assumes "x  cipher_queries (ose(n  (ek, dk, ciphers)))"
  obtains (old) "x  cipher_queries ose" "x  set ciphers" | (new) "x  set ciphers"
using assms by(cases "x  set ciphers")(fastforce elim!: cipher_queriesE split: if_split_asm intro: cipher_queriesI)+

lemma cipher_queries_empty [simp]: "cipher_queries Map.empty = {}"
by(simp add: cipher_queries_def)

end

end