Theory Hashed_Elgamal

(* Title: Hashed_Elgamal.thy
  Author: Andreas Lochbihler, ETH Zurich *)

subsection ‹Hashed Elgamal in the Random Oracle Model›

theory Hashed_Elgamal imports
  CryptHOL.GPV_Bisim
  CryptHOL.Cyclic_Group_SPMF
  CryptHOL.List_Bits
  IND_CPA_PK
  Diffie_Hellman
begin

type_synonym bitstring = "bool list"

locale hash_oracle = fixes len :: "nat" begin

type_synonym 'a state = "'a  bitstring"

definition "oracle" :: "'a state  'a  (bitstring × 'a state) spmf"
where
  "oracle σ x = 
  (case σ x of None  do {
     bs  spmf_of_set (nlists UNIV len);
     return_spmf (bs, σ(x  bs))
   } | Some bs  return_spmf (bs, σ))"

abbreviation (input) initial :: "'a state" where "initial  Map.empty"

inductive invariant :: "'a state  bool"
where
  invariant: " finite (dom σ); length ` ran σ  {len}   invariant σ"

lemma invariant_initial [simp]: "invariant initial"
by(rule invariant.intros) auto

lemma invariant_update [simp]: " invariant σ; length bs = len   invariant (σ(x  bs))"
by(auto simp add: invariant.simps ran_def)
                           
lemma invariant [intro!, simp]: "callee_invariant oracle invariant"
by unfold_locales(simp_all add: oracle_def in_nlists_UNIV split: option.split_asm)

lemma invariant_in_dom [simp]: "callee_invariant oracle (λσ. x  dom σ)"
by unfold_locales(simp_all add: oracle_def split: option.split_asm)

lemma lossless_oracle [simp]: "lossless_spmf (oracle σ x)"
by(simp add: oracle_def split: option.split)

lemma card_dom_state:
  assumes σ': "(x, σ')  set_spmf (exec_gpv oracle gpv σ)"
  and ibound: "interaction_any_bounded_by gpv n"
  shows "card (dom σ')  n + card (dom σ)"
proof(cases "finite (dom σ)")
  case True
  interpret callee_invariant_on "oracle" "λσ. finite (dom σ)" ℐ_full
    by unfold_locales(auto simp add: oracle_def split: option.split_asm)
  from ibound σ' _ _ _ True show ?thesis
    by(rule interaction_bounded_by'_exec_gpv_count)(auto simp add: oracle_def card_insert_if simp del: fun_upd_apply split: option.split_asm)
next
  case False
  interpret callee_invariant_on "oracle" "λσ'. dom σ  dom σ'" ℐ_full
    by unfold_locales(auto simp add: oracle_def split: option.split_asm)
  from σ' have "dom σ  dom σ'" by(rule exec_gpv_invariant) simp_all
  with False have "infinite (dom σ')" by(auto intro: finite_subset)
  with False show ?thesis by simp
qed

end

locale elgamal_base =
  fixes 𝒢 :: "'grp cyclic_group" (structure)
  and len_plain :: "nat"
begin

sublocale hash: hash_oracle "len_plain" .
abbreviation hash :: "'grp  (bitstring, 'grp, bitstring) gpv"
where "hash x  Pause x Done"

type_synonym 'grp' pub_key = "'grp'"
type_synonym 'grp' priv_key = nat
type_synonym plain = bitstring
type_synonym 'grp' cipher = "'grp' × bitstring"

definition key_gen :: "('grp pub_key × 'grp priv_key) spmf"
where 
  "key_gen = do {
     x  sample_uniform (order 𝒢);
     return_spmf (g [^] x, x)
  }"

definition aencrypt :: "'grp pub_key  plain  ('grp cipher, 'grp, bitstring) gpv"
where
  "aencrypt α msg = do {
    y  lift_spmf (sample_uniform (order 𝒢));
    h  hash (α [^] y);
    Done (g [^] y, h [⊕] msg)
  }"

definition adecrypt :: "'grp priv_key  'grp cipher  (plain, 'grp, bitstring) gpv"
where
  "adecrypt x = (λ(β, ζ). do {
    h  hash (β [^] x);
    Done (ζ [⊕] h)
  })"

definition valid_plains :: "plain  plain  bool"
where "valid_plains msg1 msg2  length msg1 = len_plain  length msg2 = len_plain"

lemma lossless_aencrypt [simp]: "lossless_gpv  (aencrypt α msg)  0 < order 𝒢"
by(simp add: aencrypt_def Let_def)

lemma interaction_bounded_by_aencrypt [interaction_bound, simp]:
  "interaction_bounded_by (λ_. True) (aencrypt α msg) 1"
unfolding aencrypt_def by interaction_bound(simp add: one_enat_def SUP_le_iff)

sublocale ind_cpa: ind_cpa_pk "lift_spmf key_gen" aencrypt adecrypt valid_plains .
sublocale lcdh: lcdh 𝒢 .

fun elgamal_adversary
   :: "('grp pub_key, plain, 'grp cipher, 'grp, bitstring, 'state) ind_cpa.adversary
    'grp lcdh.adversary"                     
where
  "elgamal_adversary (𝒜1, 𝒜2) α β = do {
    (((msg1, msg2), σ), s)  exec_gpv hash.oracle (𝒜1 α) hash.initial;
    ― ‹have to check that the attacker actually sends an element from the group; otherwise stop early›
    TRY do {
      _ :: unit  assert_spmf (valid_plains msg1 msg2);
      h'  spmf_of_set (nlists UNIV len_plain);
      (guess, s')  exec_gpv hash.oracle (𝒜2 (β, h') σ) s;
      return_spmf (dom s')
    } ELSE return_spmf (dom s)
  }"

end

locale elgamal = elgamal_base +
  assumes cyclic_group: "cyclic_group 𝒢"
begin

interpretation cyclic_group 𝒢 by(fact cyclic_group)

lemma advantage_elgamal: 
  includes lifting_syntax
  assumes lossless: "ind_cpa.lossless 𝒜"
  shows "ind_cpa.advantage hash.oracle hash.initial 𝒜  lcdh.advantage (elgamal_adversary 𝒜)"
proof -
  note [cong del] = if_weak_cong and [split del] = if_split
    and [simp] = map_lift_spmf gpv.map_id lossless_weight_spmfD map_spmf_bind_spmf bind_spmf_const
  obtain 𝒜1 𝒜2 where 𝒜 [simp]: "𝒜 = (𝒜1, 𝒜2)" by(cases "𝒜")

  interpret cyclic_group: cyclic_group 𝒢 by(rule cyclic_group)
  from finite_carrier have [simp]: "order 𝒢 > 0" using order_gt_0_iff_finite by(simp)

  from lossless have lossless1 [simp]: "pk. lossless_gpv ℐ_full (𝒜1 pk)"
    and lossless2 [simp]: "σ cipher. lossless_gpv ℐ_full (𝒜2 σ cipher)"
    by(auto simp add: ind_cpa.lossless_def)

  text ‹We change the adversary's oracle to record the queries made by the adversary›
  define hash_oracle' where "hash_oracle' = (λσ x. do {
      h  hash x;
      Done (h, insert x σ)
    })"
  have [simp]: "lossless_gpv ℐ_full (hash_oracle' σ x)" for σ x by(simp add: hash_oracle'_def)
  have [simp]: "lossless_gpv ℐ_full (inline hash_oracle' (𝒜1 α) s)" for α s
    by(rule lossless_inline[where=ℐ_full]) simp_all
  define game0 where "game0 = TRY do {
      (pk, _)  lift_spmf key_gen;
      b  lift_spmf coin_spmf;
      (((msg1, msg2), σ), s)  inline hash_oracle' (𝒜1 pk) {};
      assert_gpv (valid_plains msg1 msg2);
      cipher  aencrypt pk (if b then msg1 else msg2);
      (guess, s')  inline hash_oracle' (𝒜2 cipher σ) s;
      Done (guess = b)
    } ELSE lift_spmf coin_spmf"
  { define cr where "cr = (λ_ :: unit. λ_ :: 'a set. True)"
    have [transfer_rule]: "cr () {}" by(simp add: cr_def)
    have [transfer_rule]: "((=) ===> cr ===> cr) (λ_ σ. σ) insert" by(simp add: rel_fun_def cr_def)
    have [transfer_rule]: "(cr ===> (=) ===> rel_gpv (rel_prod (=) cr) (=)) id_oracle hash_oracle'"
      unfolding hash_oracle'_def id_oracle_def[abs_def] bind_gpv_Pause bind_rpv_Done by transfer_prover
    have "ind_cpa.ind_cpa 𝒜 = game0" unfolding game0_def 𝒜 ind_cpa_pk.ind_cpa.simps
      by(transfer fixing: 𝒢 len_plain 𝒜1 𝒜2)(simp add: bind_map_gpv o_def ind_cpa_pk.ind_cpa.simps split_def) }
  note game0 = this
  have game0_alt_def: "game0 = do {
      x  lift_spmf (sample_uniform (order 𝒢));
      b  lift_spmf coin_spmf;
      (((msg1, msg2), σ), s)  inline hash_oracle' (𝒜1 (g [^] x)) {};
      TRY do {
        _ :: unit  assert_gpv (valid_plains msg1 msg2);
        cipher  aencrypt (g [^] x) (if b then msg1 else msg2);
        (guess, s')  inline hash_oracle' (𝒜2 cipher σ) s;
        Done (guess = b)
      } ELSE lift_spmf coin_spmf
    }"
    by(simp add: split_def game0_def key_gen_def lift_spmf_bind_spmf bind_gpv_assoc try_gpv_bind_lossless[symmetric])

  define hash_oracle'' where "hash_oracle'' = (λ(s, σ) (x :: 'a). do {
      (h, σ')  case σ x of
          None  bind_spmf (spmf_of_set (nlists UNIV len_plain)) (λbs. return_spmf (bs, σ(x  bs)))
        | Some (bs :: bitstring)  return_spmf (bs, σ);
      return_spmf (h, insert x s, σ')
    })"
  have *: "exec_gpv hash.oracle (inline hash_oracle' 𝒜 s) σ = 
    map_spmf (λ(a, b, c). ((a, b), c)) (exec_gpv hash_oracle'' 𝒜 (s, σ))" for 𝒜 σ s
    by(simp add: hash_oracle'_def hash_oracle''_def hash.oracle_def Let_def exec_gpv_inline exec_gpv_bind o_def split_def cong del: option.case_cong_weak)
  have [simp]: "lossless_spmf (hash_oracle'' s plain)" for s plain
    by(simp add: hash_oracle''_def Let_def split: prod.split option.split)
  have [simp]: "lossless_spmf (exec_gpv hash_oracle'' (𝒜1 α) s)" for s α
    by(rule lossless_exec_gpv[where=ℐ_full]) simp_all
  have [simp]: "lossless_spmf (exec_gpv hash_oracle'' (𝒜2 σ cipher) s)" for σ cipher s
    by(rule lossless_exec_gpv[where=ℐ_full]) simp_all

  let ?sample = "λf. bind_spmf (sample_uniform (order 𝒢)) (λx. bind_spmf (sample_uniform (order 𝒢)) (f x))"
  define game1 where "game1 = (λ(x :: nat) (y :: nat). do {
      b  coin_spmf;
      (((msg1, msg2), σ), (s, s_h))  exec_gpv hash_oracle'' (𝒜1 (g [^] x)) ({}, hash.initial);
      TRY do {
        _ :: unit  assert_spmf (valid_plains msg1 msg2);
        (h, s_h')  hash.oracle s_h (g [^] (x * y));
        let cipher = (g [^] y, h [⊕] (if b then msg1 else msg2));
        (guess, (s', s_h''))  exec_gpv hash_oracle'' (𝒜2 cipher σ) (s, s_h');
        return_spmf (guess = b, g [^] (x * y)  s')
      } ELSE do {
        b  coin_spmf;
        return_spmf (b, g [^] (x * y)  s)
      }
    })"
  have game01: "run_gpv hash.oracle game0 hash.initial = map_spmf fst (?sample game1)"
    apply(simp add: exec_gpv_bind split_def bind_gpv_assoc aencrypt_def game0_alt_def game1_def o_def bind_map_spmf if_distribs * try_bind_assert_gpv try_bind_assert_spmf lossless_inline[where=ℐ_full] bind_rpv_def nat_pow_pow del: bind_spmf_const)
    including monad_normalisation by(simp add: bind_rpv_def nat_pow_pow)
  
  define game2 where "game2 = (λ(x :: nat) (y :: nat). do {
    b  coin_spmf;
    (((msg1, msg2), σ), (s, s_h))  exec_gpv hash_oracle'' (𝒜1 (g [^] x)) ({}, hash.initial);
    TRY do {
      _ :: unit  assert_spmf (valid_plains msg1 msg2);
      h  spmf_of_set (nlists UNIV len_plain);
      ― ‹We do not do the lookup in s_h› here, so the rest differs only if the adversary guessed y›
      let cipher = (g [^] y, h [⊕] (if b then msg1 else msg2));
      (guess, (s', s_h'))  exec_gpv hash_oracle'' (𝒜2 cipher σ) (s, s_h);
      return_spmf (guess = b, g [^] (x * y)  s')
    } ELSE do {
      b  coin_spmf;
      return_spmf (b, g [^] (x * y)  s)
    }
  })"
  interpret inv'': callee_invariant_on "hash_oracle''" "λ(s, s_h). s = dom s_h" ℐ_full
    by unfold_locales(auto simp add: hash_oracle''_def split: option.split_asm if_split)
  have in_encrypt_oracle: "callee_invariant hash_oracle'' (λ(s, _). x  s)" for x
    by unfold_locales(auto simp add: hash_oracle''_def)

  { fix x y :: nat
    let ?bad = "λ(s, s_h). g [^] (x * y)  s"
    let ?X = "λ(s, s_h) (s', s_h'). s = dom s_h  s' = s  s_h = s_h'(g [^] (x * y) := None)"
    have bisim:
      "rel_spmf (λ(a, s1') (b, s2'). ?bad s1' = ?bad s2'  (¬ ?bad s2'  a = b  ?X s1' s2'))
             (hash_oracle'' s1 plain) (hash_oracle'' s2 plain)"
      if "?X s1 s2" for s1 s2 plain using that
      by(auto split: prod.splits intro!: rel_spmf_bind_reflI simp add: hash_oracle''_def rel_spmf_return_spmf2 fun_upd_twist split: option.split dest!: fun_upd_eqD)
    have inv: "callee_invariant hash_oracle'' ?bad"
      by(unfold_locales)(auto simp add: hash_oracle''_def split: option.split_asm)
    have "rel_spmf (λ(win, bad) (win', bad'). bad = bad'  (¬ bad'  win = win')) (game2 x y) (game1 x y)"
      unfolding game1_def game2_def
      apply(clarsimp simp add: split_def o_def hash.oracle_def rel_spmf_bind_reflI if_distribs intro!: rel_spmf_bind_reflI simp del: bind_spmf_const)
      apply(rule rel_spmf_try_spmf)
      subgoal for b msg1 msg2 σ s s_h
        apply(rule rel_spmf_bind_reflI)
        apply(drule inv''.exec_gpv_invariant; clarsimp)
        apply(cases "s_h (g [^] (x * y))")
        subgoal ― ‹case @{const None}
          apply(clarsimp intro!: rel_spmf_bind_reflI)
          apply(rule rel_spmf_bindI)
           apply(rule exec_gpv_oracle_bisim_bad_full[OF _ _ bisim inv inv, where R="λ(x, s1) (y, s2). ?bad s1 = ?bad s2  (¬ ?bad s2  x = y)"]; clarsimp simp add: fun_upd_idem; fail)
          apply clarsimp
          done
        subgoal by(auto intro!: rel_spmf_bindI1 rel_spmf_bindI2 lossless_exec_gpv[where=ℐ_full] dest!: callee_invariant_on.exec_gpv_invariant[OF in_encrypt_oracle])
        done
      subgoal by(rule rel_spmf_reflI) simp
      done }
  hence "rel_spmf (λ(win, bad) (win', bad'). (bad  bad')  (¬ bad'  win  win')) (?sample game2) (?sample game1)"
    by(intro rel_spmf_bind_reflI)
  hence "¦measure (measure_spmf (?sample game2)) {(x, _). x} - measure (measure_spmf (?sample game1)) {(y, _). y}¦
         measure (measure_spmf (?sample game2)) {(_, bad). bad}"
    unfolding split_def by(rule fundamental_lemma)
  moreover have "measure (measure_spmf (?sample game2)) {(x, _). x} = spmf (map_spmf fst (?sample game2)) True"
    and "measure (measure_spmf (?sample game1)) {(y, _). y} = spmf (map_spmf fst (?sample game1)) True"
    and "measure (measure_spmf (?sample game2)) {(_, bad). bad} = spmf (map_spmf snd (?sample game2)) True"
    unfolding spmf_conv_measure_spmf measure_map_spmf by(rule arg_cong2[where f=measure]; fastforce)+
  ultimately have hop23: "¦spmf (map_spmf fst (?sample game2)) True - spmf (map_spmf fst (?sample game1)) True¦  spmf (map_spmf snd (?sample game2)) True" by simp

  define game3
    where "game3 = (λf :: _  _  _  bitstring spmf  _ spmf. λ(x :: nat) (y :: nat). do {
      b  coin_spmf;
      (((msg1, msg2), σ), (s, s_h))  exec_gpv hash_oracle'' (𝒜1 (g [^] x)) ({}, hash.initial);
      TRY do {
        _ :: unit  assert_spmf (valid_plains msg1 msg2);
        h'  f b msg1 msg2 (spmf_of_set (nlists UNIV len_plain));
        let cipher = (g [^] y, h');
        (guess, (s', s_h'))  exec_gpv hash_oracle'' (𝒜2 cipher σ) (s, s_h);
        return_spmf (guess = b, g [^] (x * y)  s')
      } ELSE do {
        b  coin_spmf;
        return_spmf (b, g [^] (x * y)  s)
      }
    })"
  let ?f = "λb msg1 msg2. map_spmf (λh. (if b then msg1 else msg2) [⊕] h)"
  have "game2 x y = game3 ?f x y" for x y
    unfolding game2_def game3_def by(simp add: Let_def bind_map_spmf xor_list_commute o_def nat_pow_pow)
  also have "game3 ?f x y = game3 (λ_ _ _ x. x) x y" for x y (* optimistic sampling *)
    unfolding game3_def
    by(auto intro!: try_spmf_cong bind_spmf_cong[OF refl] if_cong[OF refl] simp add: split_def one_time_pad valid_plains_def simp del: map_spmf_of_set_inj_on bind_spmf_const split: if_split)
  finally have game23: "game2 x y = game3 (λ_ _ _ x. x) x y" for x y .

  define hash_oracle''' where "hash_oracle''' = (λ(σ :: 'a  _). hash.oracle σ)"
  { define bisim where "bisim = (λσ (s :: 'a set, σ' :: 'a  bitstring). s = dom σ  σ = σ')"
    have [transfer_rule]: "bisim Map_empty ({}, Map_empty)" by(simp add: bisim_def)
    have [transfer_rule]: "(bisim ===> (=) ===> rel_spmf (rel_prod (=) bisim)) hash_oracle''' hash_oracle''"
      by(auto simp add: hash_oracle''_def split_def hash_oracle'''_def spmf_rel_map hash.oracle_def rel_fun_def bisim_def split: option.split intro!: rel_spmf_bind_reflI)
    have * [transfer_rule]: "(bisim ===> (=)) dom fst" by(simp add: bisim_def rel_fun_def)
    have * [transfer_rule]: "(bisim ===> (=)) (λx. x) snd" by(simp add: rel_fun_def bisim_def)
    have "game3 (λ_ _ _ x. x) x y = do {
        b  coin_spmf;
        (((msg1, msg2), σ), s)  exec_gpv hash_oracle''' (𝒜1 (g [^] x)) hash.initial;
        TRY do {
          _ :: unit  assert_spmf (valid_plains msg1 msg2);
          h'  spmf_of_set (nlists UNIV len_plain);
          let cipher = (g [^] y, h');
          (guess, s')  exec_gpv hash_oracle''' (𝒜2 cipher σ) s;
          return_spmf (guess = b, g [^] (x * y)  dom s')
        } ELSE do {
          b  coin_spmf;
          return_spmf (b, g [^] (x * y)  dom s)
        }
      }" for x y
      unfolding game3_def Map_empty_def[symmetric] split_def fst_conv snd_conv prod.collapse
      by(transfer fixing: 𝒜1 𝒢 len_plain x y 𝒜2) simp
    moreover have "map_spmf snd ( x y) = do {
        zs  elgamal_adversary 𝒜 (g [^] x) (g [^] y);
        return_spmf (g [^] (x * y)  zs)
      }" for x y
      by(simp add: o_def split_def hash_oracle'''_def map_try_spmf map_scale_spmf)
        (simp add: o_def map_try_spmf map_scale_spmf map_spmf_conv_bind_spmf[symmetric] spmf.map_comp map_const_spmf_of_set)
    ultimately have "map_spmf snd (?sample (game3 (λ_ _ _ x. x))) = lcdh.lcdh (elgamal_adversary 𝒜)"
      by(simp add: o_def lcdh.lcdh_def Let_def nat_pow_pow) }
  then have game2_snd: "map_spmf snd (?sample game2) = lcdh.lcdh (elgamal_adversary 𝒜)"
    using game23 by(simp add: o_def)

  have "map_spmf fst (game3 (λ_ _ _ x. x) x y) = do {
      (((msg1, msg2), σ), (s, s_h))  exec_gpv hash_oracle'' (𝒜1 (g [^] x)) ({}, hash.initial);
      TRY do {
        _ :: unit  assert_spmf (valid_plains msg1 msg2);
        h'  spmf_of_set (nlists UNIV len_plain);
        (guess, (s', s_h'))  exec_gpv hash_oracle'' (𝒜2 (g [^] y, h') σ) (s, s_h);
        map_spmf ((=) guess) coin_spmf
      } ELSE coin_spmf
    }" for x y 
    including monad_normalisation
    by(simp add: game3_def o_def split_def map_spmf_conv_bind_spmf try_spmf_bind_out weight_spmf_le_1 scale_bind_spmf try_spmf_bind_out1 bind_scale_spmf)
  then have game3_fst: "map_spmf fst (game3 (λ_ _ _ x. x) x y) = coin_spmf" for x y
    by(simp add: o_def if_distribs spmf.map_comp map_eq_const_coin_spmf split_def)

  have "ind_cpa.advantage hash.oracle hash.initial 𝒜 = ¦spmf (map_spmf fst (?sample game1)) True - 1 / 2¦"
    using game0 by(simp add: ind_cpa_pk.advantage_def game01 o_def)
  also have " = ¦1 / 2 - spmf (map_spmf fst (?sample game1)) True¦"
    by(simp add: abs_minus_commute)
  also have "1 / 2 = spmf (map_spmf fst (?sample game2)) True"
    by(simp add: game23 o_def game3_fst spmf_of_set)
  also note hop23 also note game2_snd
  finally show ?thesis by(simp add: lcdh.advantage_def)
qed

end

context elgamal_base begin

lemma lossless_key_gen [simp]: "lossless_spmf key_gen  0 < order 𝒢"
by(simp add: key_gen_def Let_def)

lemma lossless_elgamal_adversary:
  " ind_cpa.lossless 𝒜; η. 0 < order 𝒢 
   lcdh.lossless (elgamal_adversary 𝒜)"
by(cases 𝒜)(auto simp add: lcdh.lossless_def ind_cpa.lossless_def split_def Let_def intro!: lossless_exec_gpv[where=ℐ_full] lossless_inline)

end

end