Theory RP_RF

(* Title: RP_RF.thy
  Author: Andreas Lochbihler, ETH Zurich *)

subsection ‹The random-permutation random-function switching lemma›

theory RP_RF imports
  Pseudo_Random_Function
  Pseudo_Random_Permutation
  CryptHOL.GPV_Bisim
begin

lemma rp_resample:
  assumes "B  A  C" "A  C = {}" "C  B" and finB: "finite B"
  shows "bind_spmf (spmf_of_set B) (λx. if x  A then spmf_of_set C else return_spmf x) = spmf_of_set C"
proof(cases "C = {}  A  B = {}")
  case False
  define A' where "A'  A  B"
  from False have C: "C  {}" and A': "A'  {}" by(auto simp add: A'_def)
  have B: "B = A'  C" using assms by(auto simp add: A'_def)
  with finB have finA: "finite A'" and finC: "finite C" by simp_all
  from assms have A'C: "A'  C = {}" by(auto simp add: A'_def)
  have "bind_spmf (spmf_of_set B) (λx. if x  A then spmf_of_set C else return_spmf x) = 
        bind_spmf (spmf_of_set B) (λx. if x  A' then spmf_of_set C else return_spmf x)"
    by(rule bind_spmf_cong[OF refl])(simp add: set_spmf_of_set finB A'_def)
  also have " = spmf_of_set C" (is "?lhs = ?rhs")
  proof(rule spmf_eqI)
    fix i
    have "(xC. spmf (if x  A' then spmf_of_set C else return_spmf x) i) = indicator C i" using finA finC 
      by(simp add: disjoint_notin1[OF A'C] indicator_single_Some sum_mult_indicator[of C "λ_. 1 :: real" "λ_. _" "λx. x", simplified] split: split_indicator cong: conj_cong sum.cong)
    then show "spmf ?lhs i = spmf ?rhs i" using B finA finC A'C C A'
      by(simp add: spmf_bind integral_spmf_of_set sum_Un spmf_of_set field_simps)(simp add: field_simps card_Un_disjoint)
  qed
  finally show ?thesis .
qed(use assms in auto 4 3 cong: bind_spmf_cong_simp simp add: subsetD bind_spmf_const spmf_of_set_empty disjoint_notin1 intro!: arg_cong[where f=spmf_of_set])

locale rp_rf =
  rp: random_permutation A +
  rf: random_function "spmf_of_set A"
  for A :: "'a set"
  +
  assumes finite_A: "finite A"
  and nonempty_A: "A  {}"
begin

type_synonym 'a' adversary = "(bool, 'a', 'a') gpv"
  
definition game :: "bool  'a adversary  bool spmf" where
  "game b 𝒜 = run_gpv (if b then rp.random_permutation else rf.random_oracle) 𝒜 Map.empty"
  
abbreviation prp_game :: "'a adversary  bool spmf" where "prp_game  game True"
abbreviation prf_game :: "'a adversary  bool spmf" where "prf_game  game False"
  
definition advantage :: "'a adversary  real" where
  "advantage 𝒜 = ¦spmf (prp_game 𝒜) True - spmf (prf_game 𝒜) True¦"
  
lemma advantage_nonneg: "0  advantage 𝒜" by(simp add: advantage_def)

lemma advantage_le_1: "advantage 𝒜  1"
  by(auto simp add: advantage_def intro!: abs_leI)(metis diff_0_right diff_left_mono order_trans pmf_le_1 pmf_nonneg) +

context includes ℐ.lifting begin
lift_definition  :: "('a, 'a) " is "(λx. if x  A then A else {})" .
lemma outs_ℐ_ℐ [simp]: "outs_ℐ  = A" by transfer auto
lemma responses_ℐ_ℐ [simp]: "responses_ℐ  x = (if x  A then A else {})" by transfer simp
lifting_update ℐ.lifting
lifting_forget ℐ.lifting
end

lemma rp_rf:
  assumes bound: "interaction_any_bounded_by 𝒜 q"
    and lossless: "lossless_gpv  𝒜"
    and WT: " ⊢g 𝒜 "
  shows "advantage 𝒜  q * q / card A"
  including lifting_syntax  
proof -
  let ?run = "λb. exec_gpv (if b then rp.random_permutation else rf.random_oracle) 𝒜 Map.empty"
  define rp_bad :: "bool × ('a  'a)  'a  ('a × (bool × ('a  'a))) spmf"
    where "rp_bad = (λ(bad, σ) x. case σ x of Some y  return_spmf (y, (bad, σ))
      | None  bind_spmf (spmf_of_set A) (λy. if y  ran σ then map_spmf (λy'. (y', (True, σ(x  y')))) (spmf_of_set (A - ran σ)) else return_spmf (y, (bad, (σ(x  y))))))"
  have rp_bad_simps: "rp_bad (bad, σ) x = (case σ x of Some y  return_spmf (y, (bad, σ))
      | None  bind_spmf (spmf_of_set A) (λy. if y  ran σ then map_spmf (λy'. (y', (True, σ(x  y')))) (spmf_of_set (A - ran σ)) else return_spmf (y, (bad, (σ(x  y))))))"
    for bad σ x by(simp add: rp_bad_def)

  let ?S = "rel_prod2 (=)"
  define init :: "bool × ('a  'a)" where "init = (False, Map.empty)"
  have rp: "rp.random_permutation = (λσ x. case σ x of Some y  return_spmf (y, σ) 
    | None  bind_spmf (bind_spmf (spmf_of_set A) (λy. if y  ran σ then spmf_of_set (A - ran σ) else return_spmf y)) (λy. return_spmf (y, σ(x  y))))"
    by(subst rp_resample)(auto simp add: finite_A rp.random_permutation_def[abs_def])
  have [transfer_rule]: "(?S ===> (=) ===> rel_spmf (rel_prod (=) ?S)) rp.random_permutation rp_bad"
    unfolding rp rp_bad_def
    by(auto simp add: rel_fun_def map_spmf_conv_bind_spmf split: option.split intro!: rel_spmf_bind_reflI)
  have [transfer_rule]: "?S Map.empty init" by(simp add: init_def)
  have "spmf (prp_game 𝒜) True = spmf (run_gpv rp_bad 𝒜 init) True"
    unfolding vimage_def game_def if_True by transfer_prover
  moreover {
    define collision :: "('a  'a)  bool" where "collision m  ¬ inj_on m (dom m)" for m
    have [simp]: "¬ collision Map.empty" by(simp add: collision_def)
    have [simp]: " collision m; m x = None   collision (m(x := y))" for m x y
      by(auto simp add: collision_def fun_upd_idem dom_minus fun_upd_image dest: inj_on_fun_updD)
    have collision_map_updI: " m x = None; y  ran m   collision (m(x  y))" for m x y
      by(auto simp add: collision_def ran_def intro: rev_image_eqI)
    have collision_map_upd_iff: "¬ collision m  collision (m(x  y))  y  ran m  m x  Some y" for m x y
      by(auto simp add: collision_def ran_def fun_upd_idem intro: inj_on_fun_updI rev_image_eqI dest: inj_on_eq_iff)
  
    let ?bad1 = "collision" and ?bad2 = "fst"
      and ?X = "λσ1 (bad, σ2). σ1 = σ2  ¬ collision σ1  ¬ bad"
      and ?I1 = "λσ1. dom σ1  A  ran σ1  A"
      and ?I2 = "λ(bad, σ2). dom σ2  A  ran σ2  A"
    let ?X_bad = "λσ1 s2. ?I1 σ1  ?I2 s2"
    have [simp]: " ⊢c rf.random_oracle s1 " if "ran s1  A" for s1 using that
      by(intro WT_calleeI)(auto simp add: rf.random_oracle_def[abs_def] finite_A nonempty_A ran_def split: option.split_asm)
    have [simp]: "callee_invariant_on rf.random_oracle ?I1 "
      by(unfold_locales)(auto simp add: rf.random_oracle_def finite_A split: option.split_asm)
    then interpret rf: callee_invariant_on rf.random_oracle ?I1  .
    have [simp]: " ⊢c rp_bad s2  " if "ran (snd s2)  A" for s2 using that
      by(intro WT_calleeI)(auto simp add: rp_bad_def finite_A split: prod.split_asm option.split_asm if_split_asm intro: ranI)
    have [simp]: "callee_invariant_on rf.random_oracle (λσ1. ?bad1 σ1  ?I1 σ1) "
      by(unfold_locales)(clarsimp simp add: rf.random_oracle_def finite_A  split: option.split_asm)+
    have [simp]: "callee_invariant_on rp_bad (λs2. ?I2 s2) "
      by(unfold_locales)(auto 4 3 simp add: rp_bad_simps finite_A split: option.splits if_split_asm iff del: domIff)
    have [simp]: "callee_invariant_on rp_bad (λs2. ?bad2 s2  ?I2 s2) "
      by(unfold_locales)(auto 4 3 simp add: rp_bad_simps finite_A split: option.splits if_split_asm iff del: domIff)
    have [simp]: " ⊢c rp_bad (bad, σ2) " if "ran σ2  A" for bad σ2 using that
      by(intro WT_calleeI)(auto simp add: rp_bad_def finite_A nonempty_A ran_def split: option.split_asm if_split_asm)
    have [simp]: "lossless_spmf (rp_bad (b, σ2) x)" if "x  A" "dom σ2  A" "ran σ2  A" for b σ2 x
      using finite_A that unfolding rp_bad_def
      by(clarsimp simp add: nonempty_A dom_subset_ran_iff eq_None_iff_not_dom split: option.split)
    have "rel_spmf (λ(b1, σ1) (b2, state2). (?bad1 σ1  ?bad2 state2)  (if ?bad2 state2 then ?X_bad σ1 state2 else b1 = b2  ?X σ1 state2))
            ((if False then rp.random_permutation else rf.random_oracle) s1 x) (rp_bad s2 x)"
      if "?X s1 s2" "x  outs_ℐ " "?I1 s1" "?I2 s2" for s1 s2 x using that finite_A
      by(auto split!: option.split simp add: rf.random_oracle_def rp_bad_def rel_spmf_return_spmf1 collision_map_updI dom_subset_ran_iff eq_None_iff_not_dom collision_map_upd_iff intro!: rel_spmf_bind_reflI)
    with _ _ have "rel_spmf
       (λ(b1, σ1) (b2, state2). (?bad1 σ1  ?bad2 state2)  (if ?bad2 state2 then ?X_bad σ1 state2 else b1 = b2  ?X σ1 state2))
       (?run False) (exec_gpv rp_bad 𝒜 init)"
      by(rule exec_gpv_oracle_bisim_bad_invariant[where=  and ?I1.0 = "?I1" and ?I2.0="?I2"])(auto simp add: init_def WT lossless finite_A nonempty_A)
   then have "¦spmf (map_spmf fst (?run False)) True - spmf (run_gpv rp_bad 𝒜 init) True¦  spmf (map_spmf (?bad1  snd) (?run False)) True"
      unfolding spmf_conv_measure_spmf measure_map_spmf vimage_def
      by(intro fundamental_lemma[where ?bad2.0="λ(_, s2). ?bad2 s2"])(auto simp add: split_def elim: rel_spmf_mono)
    also have "ennreal   ennreal (q / card A) * (enat q)" unfolding if_False using bound _ _ _ _ _ _ _ WT
      by(rule rf.interaction_bounded_by_exec_gpv_bad_count[where count="λs. card (dom s)"])
        (auto simp add: rf.random_oracle_def finite_A nonempty_A card_insert_if finite_subset[OF _ finite_A] map_spmf_conv_bind_spmf[symmetric] spmf.map_comp o_def collision_map_upd_iff map_mem_spmf_of_set card_gt_0_iff card_mono field_simps Int_absorb2 intro: card_ran_le_dom[OF finite_subset, OF _ finite_A, THEN order_trans] split: option.splits)
    hence "spmf (map_spmf (?bad1  snd) (?run False)) True  q * q / card A"
      by(simp add: ennreal_of_nat_eq_real_of_nat ennreal_times_divide ennreal_mult''[symmetric])
    finally have "¦spmf (run_gpv rp_bad 𝒜 init) True - spmf (run_gpv rf.random_oracle 𝒜 Map.empty) True¦  q * q / card A"
      by simp }
  ultimately show ?thesis by(simp add: advantage_def game_def)
qed

end

end