Theory Guessing_Many_One

(* Title: Guessing_Many_One.thy
  Author: Andreas Lochbihler, ETH Zurich 
  Author: S. Reza Sefidgar, ETH Zurich *)

subsection ‹Reducing games with many adversary guesses to games with single guesses›

theory Guessing_Many_One imports
  CryptHOL.Computational_Model
  CryptHOL.GPV_Bisim
begin

locale guessing_many_one =
  fixes init :: "('c_o × 'c_a × 's) spmf"
  and "oracle" :: "'c_o  's  'call  ('ret × 's) spmf"
  and "eval" :: "'c_o  'c_a  's  'guess  bool spmf"
begin

type_synonym ('c_a', 'guess', 'call', 'ret') adversary_single = "'c_a'  ('guess', 'call', 'ret') gpv"

definition game_single :: "('c_a, 'guess, 'call, 'ret) adversary_single  bool spmf"
where
  "game_single 𝒜 = do {
    (c_o, c_a, s)  init;
    (guess, s')  exec_gpv (oracle c_o) (𝒜 c_a) s;
    eval c_o c_a s' guess
  }"

definition advantage_single :: "('c_a, 'guess, 'call, 'ret) adversary_single  real"
where "advantage_single 𝒜 = spmf (game_single 𝒜) True"


type_synonym ('c_a', 'guess', 'call', 'ret') adversary_many = "'c_a'  (unit, 'call' + 'guess', 'ret' + unit) gpv"

definition eval_oracle :: "'c_o  'c_a  bool × 's  'guess  (unit × (bool × 's)) spmf"
where
  "eval_oracle c_o c_a = (λ(b, s') guess. map_spmf (λb'. ((), (b  b', s'))) (eval c_o c_a s' guess))"

definition game_multi :: "('c_a, 'guess, 'call, 'ret) adversary_many  bool spmf"
where
  "game_multi 𝒜 = do {
     (c_o, c_a, s)  init;
     (_, (b, _))  exec_gpv
       ((oracle c_o) O eval_oracle c_o c_a)
       (𝒜 c_a)
       (False, s);
     return_spmf b
  }"

definition advantage_multi :: "('c_a, 'guess, 'call, 'ret) adversary_many  real"
where "advantage_multi 𝒜 = spmf (game_multi 𝒜) True"


type_synonym 'guess' reduction_state = "'guess' + nat"

primrec process_call :: "'guess reduction_state  'call  ('ret option × 'guess reduction_state, 'call, 'ret) gpv"
where
  "process_call (Inr j) x = do {
    ret  Pause x Done;
    Done (Some ret, Inr j)
  }"
| "process_call (Inl guess) x = Done (None, Inl guess)"

primrec process_guess :: "'guess reduction_state  'guess  (unit option × 'guess reduction_state, 'call, 'ret) gpv"
where
  "process_guess (Inr j) guess = Done (if j > 0 then (Some (), Inr (j - 1)) else (None, Inl guess))"
| "process_guess (Inl guess) _ = Done (None, Inl guess)"

abbreviation reduction_oracle :: "'guess + nat  'call + 'guess  (('ret + unit) option × ('guess + nat), 'call, 'ret) gpv"
where "reduction_oracle  plus_intercept_stop process_call process_guess"

definition reduction :: "nat  ('c_a, 'guess, 'call, 'ret) adversary_many  ('c_a, 'guess, 'call, 'ret) adversary_single"
where
  "reduction q 𝒜 c_a = do {
    j_star  lift_spmf (spmf_of_set {..<q});
    (_, s)  inline_stop reduction_oracle (𝒜 c_a) (Inr j_star);
    Done (projl s)
  }"

lemma many_single_reduction:
  assumes bound: "c_a c_o s. (c_o, c_a, s)  set_spmf init  interaction_bounded_by (Not  isl) (𝒜 c_a) q"
  and lossless_oracle: "c_a c_o s s' x. (c_o, c_a, s)  set_spmf init  lossless_spmf (oracle c_o s' x)"
  and lossless_eval: "c_a c_o s s' guess. (c_o, c_a, s)  set_spmf init  lossless_spmf (eval c_o c_a s' guess)"
  shows "advantage_multi 𝒜  advantage_single (reduction q 𝒜) * q"
  including lifting_syntax
proof -
  define eval_oracle'
    where "eval_oracle' = (λc_o c_a ((id, occ :: nat option), s') guess. 
    map_spmf (λb'. case occ of Some j0  ((), (Suc id, Some j0), s')
                                | None  ((), (Suc id, (if b' then Some id else None)), s'))
      (eval c_o c_a s' guess))"
  let ?multi'_body = "λc_o c_a s. exec_gpv ((oracle c_o) O eval_oracle' c_o c_a) (𝒜 c_a) ((0, None), s)"
  define game_multi' where "game_multi' = (λc_o c_a s. do {
    (_, ((id, j0), s' :: 's))  ?multi'_body c_o c_a s;
    return_spmf (j0  None) })"

  define initialize :: "('c_o  'c_a  's  nat  bool spmf)  bool spmf" where
    "initialize body = do {
      (c_o, c_a, s)  init;
      js  spmf_of_set {..<q};
      body c_o c_a s js }" for body
  define body2 where "body2 c_o c_a s js = do {
    (_, (id, j0), s')  ?multi'_body c_o c_a s;
    return_spmf (j0 = Some js) }" for c_o c_a s js
  let ?game2 = "initialize body2"

  define stop_oracle where "stop_oracle = (λc_o. 
     (λ(idgs, s) x. case idgs of Inr _  map_spmf (λ(y, s). (Some y, (idgs, s))) (oracle c_o s x) | Inl _  return_spmf (None, (idgs, s)))
     OS
     (λ(idgs, s) guess :: 'guess. return_spmf (case idgs of Inr 0  (None, Inl (guess, s), s) | Inr (Suc i)  (Some (), Inr i, s) | Inl _  (None, idgs, s))))"
  define body3 where "body3 c_o c_a s js = do {
    (_ :: unit option, idgs, _)  exec_gpv_stop (stop_oracle c_o) (𝒜 c_a) (Inr js, s);
    (b' :: bool)  case idgs of Inr _  return_spmf False | Inl (g, s')  eval c_o c_a s' g;
    return_spmf b' }" for c_o c_a s js
  let ?game3 = "initialize body3"

  { define S :: "bool  nat × nat option  bool" where "S  λb' (id, occ). b'  (j0. occ = Some j0)"
    let ?S = "rel_prod S (=)"

    define initial :: "nat × nat option" where "initial = (0, None)"
    define result :: "nat × nat option  bool" where "result p = (snd p  None)" for p
    have [transfer_rule]: "(S ===> (=)) (λb. b) result" by(simp add: rel_fun_def result_def S_def)
    have [transfer_rule]: "S False initial" by (simp add: S_def initial_def)

    have eval_oracle'[transfer_rule]: 
      "((=) ===> (=) ===> ?S ===> (=) ===> rel_spmf (rel_prod (=) ?S))
       eval_oracle eval_oracle'"
      unfolding eval_oracle_def[abs_def] eval_oracle'_def[abs_def]
      by (auto simp add: rel_fun_def S_def map_spmf_conv_bind_spmf intro!: rel_spmf_bind_reflI split: option.split)
    
    have game_multi': "game_multi 𝒜 = bind_spmf init (λ(c_o, c_a, s). game_multi' c_o c_a s)"
      unfolding game_multi_def game_multi'_def initial_def[symmetric]
      by (rewrite in "case_prod " in "bind_spmf _ (case_prod )" in "_ = bind_spmf _ " split_def)
         (fold result_def; transfer_prover) }
  moreover
  have "spmf (game_multi' c_o c_a s) True = spmf (bind_spmf (spmf_of_set {..<q}) (body2 c_o c_a s)) True * q"
    if "(c_o, c_a, s)  set_spmf init" for c_o c_a s
  proof -
    have bnd: "interaction_bounded_by (Not  isl) (𝒜 c_a) q" using bound that by blast

    have bound_occ: "js < q" if that: "((), (id, Some js), s')  set_spmf (?multi'_body c_o c_a s)" 
      for s' id js
    proof -
      have "id  q" 
        by(rule oi_True.interaction_bounded_by'_exec_gpv_count[OF bnd that, where count="fst  fst", simplified])
          (auto simp add: eval_oracle'_def split: plus_oracle_split_asm option.split_asm)
      moreover let ?I = "λ((id, occ), s'). case occ of None  True | Some js  js < id"
      have "callee_invariant ((oracle c_o) O eval_oracle' c_o c_a) ?I"
        by(clarsimp simp add: split_def intro!: conjI[OF callee_invariant_extend_state_oracle_const'])
          (unfold_locales; auto simp add: eval_oracle'_def split: option.split_asm)
      from callee_invariant_on.exec_gpv_invariant[OF this that] have "js < id" by simp
      ultimately show ?thesis by simp
    qed

    let ?M = "measure (measure_spmf (?multi'_body c_o c_a s))"
    have "spmf (game_multi' c_o c_a s) True = ?M {(u, (id, j0), s'). j0  None}"
      by(auto simp add: game_multi'_def map_spmf_conv_bind_spmf[symmetric] split_def spmf_conv_measure_spmf measure_map_spmf vimage_def)
    also have "{(u, (id, j0), s'). j0  None} =
      {((), (id, Some js), s') |js s' id. js < q}  {((), (id, Some js), s') |js s' id. js  q}"
      (is "_ = ?A  _") by auto
    also have "?M  = ?M ?A"
      by (rule measure_spmf.measure_zero_union)(auto simp add: measure_spmf_zero_iff dest: bound_occ)
    also have " = measure (measure_spmf (pair_spmf (spmf_of_set {..< q}) (?multi'_body c_o c_a s)))
         {(js, (), (id, j0), s') |js j0 s' id. j0 = Some js } * q"
      (is "_ = measure ?M' ?B * _")
    proof - 
      have "?B = {(js, (), (id, j0), s') |js j0 s' id. j0 = Some js  js < q} 
        {(js, (), (id, j0), s') |js j0 s' id. j0 = Some js  js  q}" (is "_ = ?Set1  ?Set2")
        by auto
      then have "measure ?M' ?B = measure ?M' (?Set1  ?Set2)" by simp
      also have " = measure ?M' ?Set1"
        by (rule measure_spmf.measure_zero_union) (auto simp add: measure_spmf_zero_iff)
      also have " = (j{0..<q}. measure ?M' ({j} × {((), (id, Some j), s')|s' id. True}))"
        by(subst measure_spmf.finite_measure_finite_Union[symmetric])
          (auto intro!: arg_cong2[where f=measure] simp add: disjoint_family_on_def)
      also have " = (j{0..<q}. 1 / q * measure (measure_spmf (?multi'_body c_o c_a s)) {((), (id, Some j), s')|s' id. True})"
        by(simp add: measure_pair_spmf_times spmf_conv_measure_spmf[symmetric] spmf_of_set)
      also have " = 1 / q * measure (measure_spmf (?multi'_body c_o c_a s)) {((), (id, Some js), s')|js s' id. js < q}"
        unfolding sum_distrib_left[symmetric]
        by(subst measure_spmf.finite_measure_finite_Union[symmetric])
          (auto intro!: arg_cong2[where f=measure] simp add: disjoint_family_on_def)
      finally show ?thesis by simp
    qed
    also have "?B = (λ(js, _, (_, j0), _). j0 = Some js) -` {True}"
      by (auto simp add: vimage_def)
    also have rw2: "measure ?M'  = spmf (bind_spmf (spmf_of_set {..<q}) (body2 c_o c_a s)) True"
      by (simp add: body2_def[abs_def] measure_map_spmf[symmetric] map_spmf_conv_bind_spmf
        split_def pair_spmf_alt_def spmf_conv_measure_spmf[symmetric])
    finally show ?thesis .
  qed
  hence "spmf (bind_spmf init (λ(c_a, c_o, s). game_multi' c_a c_o s)) True = spmf ?game2 True * q"
    unfolding initialize_def spmf_bind[where p=init]
    by (auto intro!: integral_cong_AE simp del: integral_mult_left_zero simp add: integral_mult_left_zero[symmetric])

  moreover
  have "ord_spmf (⟶) (body2 c_o c_a s js) (body3 c_o c_a s js)"
    if init: "(c_o, c_a, s)  set_spmf init" and js: "js < Suc q" for c_o c_a s js
  proof -
    define oracle2' where "oracle2'  λ(b, (id, gs), s) guess. if id = js then do {
        b' :: bool  eval c_o c_a s guess;
        return_spmf ((), (Some b', (Suc id, Some (guess, s)), s))
      } else return_spmf ((), (b, (Suc id, gs), s))"

    let ?R = "λ((id1, j0), s1) (b', (id2, gs), s2). s1 = s2  id1 = id2  (j0 = Some js  b' = Some True)  (id2  js  b' = None)"
    from init have "rel_spmf (rel_prod (=) ?R)
      (exec_gpv (extend_state_oracle (oracle c_o) O eval_oracle' c_o c_a) (𝒜 c_a) ((0, None), s))
      (exec_gpv (extend_state_oracle (extend_state_oracle (oracle c_o)) O oracle2') (𝒜 c_a) (None, (0, None), s))"
      by(intro exec_gpv_oracle_bisim[where X="?R"])(auto simp add: oracle2'_def eval_oracle'_def spmf_rel_map map_spmf_conv_bind_spmf[symmetric] rel_spmf_return_spmf2 lossless_eval o_def intro!: rel_spmf_reflI split: option.split_asm plus_oracle_split if_split_asm)
    then have "rel_spmf (⟶) (body2 c_o c_a s js) 
      (do {
        (_, b', _, _)  exec_gpv ((oracle c_o) O oracle2') (𝒜 c_a) (None, (0, None), s);
        return_spmf (b' = Some True) })"
      (is "rel_spmf _ _ ?body2'")
      ― ‹We do not get equality here because the right hand side may return @{const True} even
        when the bad event has happened before the @{text js}-th iteration.›
      unfolding body2_def by(rule rel_spmf_bindI) clarsimp
    also
    let ?guess_oracle = "λ((id, gs), s) guess. return_spmf ((), (Suc id, if id = js then Some (guess, s) else gs), s)"
    let ?I = "λ(idgs, s). case idgs of (_, None)  False | (i, Some _)  js < i"
    interpret I: callee_invariant_on "(oracle c_o) O ?guess_oracle" "?I" ℐ_full
      by(simp)(unfold_locales; auto split: option.split)

    let ?f = "λs. case snd (fst s) of None  return_spmf False | Some a  eval c_o c_a (snd a) (fst a)"
    let ?X = "λjs (b1, (id1, gs1), s1) (b2, (id2, gs2), s2). b1 = b2  id1 = id2  gs1 = gs2  s1 = s2  (b2 = None  gs2 = None)  (id2  js  b2 = None)"
    have "?body2' = do {
      (a, r, s)  exec_gpv (λ(r, s) x. do {
               (y, s')  ((oracle c_o) O ?guess_oracle) s x;
               if ?I s'  r = None then map_spmf (λr. (y, Some r, s')) (?f s') else return_spmf (y, r, s')
             })
         (𝒜 c_a) (None, (0, None), s);
      case r of None  ?f s  return_spmf | Some r'  return_spmf r' }"
      unfolding oracle2'_def spmf_rel_eq[symmetric]
      by(rule rel_spmf_bindI[OF exec_gpv_oracle_bisim'[where X="?X js"]])
        (auto simp add: bind_map_spmf o_def spmf.map_comp split_beta conj_comms map_spmf_conv_bind_spmf[symmetric] spmf_rel_map rel_spmf_reflI cong: conj_cong split: plus_oracle_split)
    also have " = do {
        us'  exec_gpv ((oracle c_o) O ?guess_oracle) (𝒜 c_a) ((0, None), s);
        (b' :: bool)  ?f (snd us');
        return_spmf b' }"
      (is "_ = ?body2''")
      by(rule I.exec_gpv_bind_materialize[symmetric])(auto split: plus_oracle_split_asm option.split_asm)
    also have " = do {
        us'  exec_gpv_stop (lift_stop_oracle ((oracle c_o) O ?guess_oracle)) (𝒜 c_a) ((0, None), s);
        (b' :: bool)  ?f (snd us');
        return_spmf b' }"
      supply lift_stop_oracle_transfer[transfer_rule] gpv_stop_transfer[transfer_rule] exec_gpv_parametric'[transfer_rule]
      by transfer simp
    also let ?S = "λ((id1, gs1), s1) ((id2, gs2), s2). gs1 = gs2  (gs2 = None  s1 = s2  id1 = id2)  (gs1 = None  id1  js)"
    have "ord_spmf (⟶)  (exec_gpv_stop ((λ((id, gs), s) x. case gs of None  lift_stop_oracle ((oracle c_o)) ((id, gs), s) x | Some _  return_spmf (None, ((id, gs), s))) OS
            (λ((id, gs), s) guess. return_spmf (if id  js then None else Some (), (Suc id, if id = js then Some (guess, s) else gs), s)))
           (𝒜 c_a) ((0, None), s) 
          (λus'. case snd (fst (snd us')) of None  return_spmf False | Some a  eval c_o c_a (snd a) (fst a)))"
      unfolding body3_def stop_oracle_def
      by(rule ord_spmf_exec_gpv_stop[where stop = "λ((id, guess), _). guess  None" and S="?S", THEN ord_spmf_bindI])
        (auto split: prod.split_asm plus_oracle_split_asm split!: plus_oracle_stop_split simp del: not_None_eq simp add: spmf.map_comp o_def apfst_compose ord_spmf_map_spmf1 ord_spmf_map_spmf2 split_beta ord_spmf_return_spmf2 intro!: ord_spmf_reflI)
    also let ?X = "λ((id, gs), s1) (idgs, s2). s1 = s2  (case (gs, idgs) of (None, Inr id')  id' = js - id  id  js | (Some gs, Inl gs')  gs = gs'  id > js | _  False)"
    have " = body3 c_o c_a s js" unfolding body3_def spmf_rel_eq[symmetric] stop_oracle_def
      by(rule exec_gpv_oracle_bisim'[where X="?X", THEN rel_spmf_bindI])
        (auto split: option.split_asm plus_oracle_stop_split nat.splits split!: sum.split simp add: spmf_rel_map intro!: rel_spmf_reflI)
    finally show ?thesis by(rule pmf.rel_mono_strong)(auto elim!: option.rel_cases ord_option.cases)
  qed
  { then have "ord_spmf (⟶) ?game2 ?game3"
      by(clarsimp simp add: initialize_def intro!: ord_spmf_bind_reflI)
    also
    let ?X = "λ(gsid, s) (gid, s'). s = s'  rel_sum (λ(g, s1) g'. g = g'  s1 = s') (=) gsid gid"
    have "rel_spmf (⟶) ?game3 (game_single (reduction q 𝒜))"
      unfolding body3_def stop_oracle_def game_single_def reduction_def split_def initialize_def
      apply(clarsimp simp add: bind_map_spmf exec_gpv_bind exec_gpv_inline intro!: rel_spmf_bind_reflI)
      apply(rule rel_spmf_bindI[OF exec_gpv_oracle_bisim'[where X="?X"]])
      apply(auto split: plus_oracle_stop_split elim!: rel_sum.cases simp add: map_spmf_conv_bind_spmf[symmetric] split_def spmf_rel_map rel_spmf_reflI rel_spmf_return_spmf1 lossless_eval split: nat.split)
      done
    finally have "ord_spmf (⟶) ?game2 (game_single (reduction q 𝒜))"
      by(rule pmf.rel_mono_strong)(auto elim!: option.rel_cases ord_option.cases)
    from this[THEN ord_spmf_measureD, of "{True}"]
    have "spmf ?game2 True  spmf (game_single (reduction q 𝒜)) True" unfolding spmf_conv_measure_spmf
      by(rule ord_le_eq_trans)(auto intro: arg_cong2[where f=measure]) }
  ultimately show ?thesis unfolding advantage_multi_def advantage_single_def 
    by(simp add: mult_right_mono)
qed

end

end