Theory Uniform_Sampling
theory Uniform_Sampling imports
CryptHOL.CryptHOL
"HOL-Number_Theory.Cong"
begin
definition sample_uniform_units :: "nat ⇒ nat spmf"
where "sample_uniform_units q = spmf_of_set ({..< q} - {0})"
lemma set_spmf_sample_uniform_units [simp]:
"set_spmf (sample_uniform_units q) = {..< q} - {0}"
by(simp add: sample_uniform_units_def)
lemma lossless_sample_uniform_units:
assumes "(p::nat) > 1"
shows "lossless_spmf (sample_uniform_units p)"
unfolding sample_uniform_units_def
using assms by auto
lemma weight_sample_uniform_units:
assumes "(p::nat) > 1"
shows "weight_spmf (sample_uniform_units p) = 1"
using assms lossless_sample_uniform_units
by (simp add: lossless_weight_spmfD)
lemma one_time_pad':
assumes inj_on: "inj_on f ({..<q} - {0})"
and sur: "f ` ({..<q} - {0}) = ({..<q} - {0})"
shows "map_spmf f (sample_uniform_units q) = (sample_uniform_units q)"
(is "?lhs = ?rhs")
proof-
have rhs: "?rhs = spmf_of_set (({..<q} - {0}))"
by(auto simp add: sample_uniform_units_def)
also have "map_spmf(λs. f s) (spmf_of_set ({..<q} - {0})) = spmf_of_set ((λs. f s) ` ({..<q} - {0}))"
by(simp add: inj_on)
also have "f ` ({..<q} - {0}) = ({..<q} - {0})"
apply(rule endo_inj_surj) by(simp, simp add: sur, simp add: inj_on)
ultimately show ?thesis using rhs by simp
qed
lemma one_time_pad:
assumes inj_on: "inj_on f {..<q}"
and sur: "f ` {..<q} = {..<q}"
shows "map_spmf f (sample_uniform q) = (sample_uniform q)"
(is "?lhs = ?rhs")
proof-
have rhs: "?rhs = spmf_of_set ({..< q})"
by(auto simp add: sample_uniform_def)
also have "map_spmf(λs. f s) (spmf_of_set {..<q}) = spmf_of_set ((λs. f s) ` {..<q})"
by(simp add: inj_on)
also have "f ` {..<q} = {..<q}"
apply(rule endo_inj_surj) by(simp, simp add: sur, simp add: inj_on)
ultimately show ?thesis using rhs by simp
qed
lemma plus_inj_eq:
assumes x: "x < q"
and x': "x' < q"
and map: "((y :: nat) + x) mod q = (y + x') mod q"
shows "x = x'"
proof-
have "((y :: nat) + x) mod q = (y + x') mod q ⟹ x mod q = x' mod q"
proof-
have "((y:: nat) + x) mod q = (y + x') mod q ⟹ [((y:: nat) + x) = (y + x')] (mod q)"
by(simp add: cong_def)
moreover have "[((y:: nat) + x) = (y + x')] (mod q) ⟹ [x = x'] (mod q)"
by (simp add: cong_add_lcancel_nat)
moreover have "[x = x'] (mod q) ⟹ x mod q = x' mod q"
by(simp add: cong_def)
ultimately show ?thesis by(simp add: map)
qed
moreover have "x mod q = x' mod q ⟹ x = x'"
by(simp add: x x')
ultimately show ?thesis by(simp add: map)
qed
lemma inj_uni_samp_plus: "inj_on (λ(b :: nat). (y + b) mod q ) {..<q}"
by(simp add: inj_on_def)(auto simp only: plus_inj_eq)
lemma surj_uni_samp_plus:
assumes inj: "inj_on (λ(b :: nat). (y + b) mod q ) {..<q}"
shows "(λ(b :: nat). (y + b) mod q) ` {..< q} = {..< q}"
apply(rule endo_inj_surj) using inj by auto
lemma samp_uni_plus_one_time_pad:
shows "map_spmf (λb. (y + b) mod q) (sample_uniform q) = sample_uniform q"
using inj_uni_samp_plus surj_uni_samp_plus one_time_pad by simp
lemma mult_inj_eq:
assumes coprime: "coprime x (q::nat)"
and y: "y < q"
and y': "y' < q"
and map: "x * y mod q = x * y' mod q"
shows "y = y'"
proof-
have "x*y mod q = x*y' mod q ⟹ y mod q = y' mod q"
proof-
have "x*y mod q = x*y' mod q ⟹ [x*y = x*y'] (mod q)"
by(simp add: cong_def)
moreover have "[x*y = x*y'] (mod q) = [y = y'] (mod q)"
by(simp add: cong_mult_lcancel_nat coprime)
moreover have "[y = y'] (mod q) ⟹ y mod q = y' mod q"
by(simp add: cong_def)
ultimately show ?thesis by(simp add: map)
qed
moreover have "y mod q = y' mod q ⟹ y = y'"
by(simp add: y y')
ultimately show ?thesis by(simp add: map)
qed
lemma inj_on_mult:
assumes coprime: "coprime x (q::nat)"
shows "inj_on (λ b. x*b mod q) {..<q}"
apply(auto simp add: inj_on_def)
using coprime by(simp only: mult_inj_eq)
lemma surj_on_mult:
assumes coprime: "coprime x (q::nat)"
and inj: "inj_on (λ b. x*b mod q) {..<q}"
shows "(λ b. x*b mod q) ` {..< q} = {..< q}"
apply(rule endo_inj_surj) using coprime inj by auto
lemma mult_one_time_pad:
assumes coprime: "coprime x q"
shows "map_spmf (λ b. x*b mod q) (sample_uniform q) = sample_uniform q"
using inj_on_mult surj_on_mult one_time_pad coprime by simp
lemma inj_on_mult':
assumes coprime: "coprime x (q::nat)"
shows "inj_on (λ b. x*b mod q) ({..<q} - {0})"
apply(auto simp add: inj_on_def)
using coprime by(simp only: mult_inj_eq)
lemma surj_on_mult':
assumes coprime: "coprime x (q::nat)"
and inj: "inj_on (λ b. x*b mod q) ({..<q} - {0})"
shows "(λ b. x*b mod q) ` ({..<q} - {0}) = ({..<q} - {0})"
proof(rule endo_inj_surj)
show " finite ({..<q} - {0})" by auto
show "(λb. x * b mod q) ` ({..<q} - {0}) ⊆ {..<q} - {0}"
proof-
obtain nn :: "nat set ⇒ (nat ⇒ nat) ⇒ nat set ⇒ nat" where
"∀x0 x1 x2. (∃v3. v3 ∈ x2 ∧ x1 v3 ∉ x0) = (nn x0 x1 x2 ∈ x2 ∧ x1 (nn x0 x1 x2) ∉ x0)"
by moura
hence 1: "∀N f Na. nn Na f N ∈ N ∧ f (nn Na f N) ∉ Na ∨ f ` N ⊆ Na"
by (meson image_subsetI)
have 2: "x * nn ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q ∉ {..<q} ∨ x * nn ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q ∈ insert 0 {..<q}"
by force
have 3: "(x * nn ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q ∈ insert 0 {..<q} - {0}) = (x * nn ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q ∈ {..<q} - {0})"
by simp
{ assume "x * nn ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q = x * 0 mod q"
hence "(0 ≤ q) = (0 = q) ∨ (nn ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) ∉ {..<q} ∨ nn ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) ∈ {0}) ∨ nn ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) ∉ {..<q} - {0} ∨ x * nn ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q ∈ {..<q} - {0}"
by (metis antisym_conv1 insertCI lessThan_iff local.coprime mult_inj_eq) }
moreover
{ assume "0 ≠ x * nn ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q"
moreover
{ assume "x * nn ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q ∈ insert 0 {..<q} ∧ x * nn ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q ∉ {0}"
hence "(λn. x * n mod q) ` ({..<q} - {0}) ⊆ {..<q} - {0}"
using 3 1 by (meson Diff_iff) }
ultimately have "(λn. x * n mod q) ` ({..<q} - {0}) ⊆ {..<q} - {0} ∨ (0 ≤ q) = (0 = q)"
using 2 by (metis antisym_conv1 lessThan_iff mod_less_divisor singletonD) }
ultimately have "(λn. x * n mod q) ` ({..<q} - {0}) ⊆ {..<q} - {0} ∨ nn ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) ∉ {..<q} - {0} ∨ x * nn ({..<q} - {0}) (λn. x * n mod q) ({..<q} - {0}) mod q ∈ {..<q} - {0}"
by force
thus "(λn. x * n mod q) ` ({..<q} - {0}) ⊆ {..<q} - {0}"
using 1 by meson
qed
show "inj_on (λb. x * b mod q) ({..<q} - {0})"
using inj by blast
qed
lemma mult_one_time_pad':
assumes coprime: "coprime x q"
shows "map_spmf (λ b. x*b mod q) (sample_uniform_units q) = sample_uniform_units q"
using inj_on_mult' surj_on_mult' one_time_pad' coprime by simp
lemma samp_uni_add_mult:
assumes coprime: "coprime x (q::nat)"
and x': "x' < q"
and y': "y' < q"
and map: "(y + x * x') mod q = (y + x * y') mod q"
shows "x' = y'"
proof-
have "(y + x * x') mod q = (y + x * y') mod q ⟹ x' mod q = y' mod q"
proof-
have "(y + x * x') mod q = (y + x * y') mod q ⟹ [y + x*x' = y + x *y'] (mod q)"
using cong_def by blast
moreover have "[y + x*x' = y + x *y'] (mod q) ⟹ [x' = y'] (mod q)"
by(simp add: cong_add_lcancel_nat)(simp add: coprime cong_mult_lcancel_nat)
ultimately show ?thesis by(simp add: cong_def map)
qed
moreover have "x' mod q = y' mod q ⟹ x' = y'"
by(simp add: x' y')
ultimately show ?thesis by(simp add: map)
qed
lemma inj_on_add_mult:
assumes coprime: "coprime x (q::nat)"
shows "inj_on (λ b. (y + x*b) mod q) {..<q}"
apply(auto simp add: inj_on_def)
using coprime by(simp only: samp_uni_add_mult)
lemma surj_on_add_mult:
assumes coprime: "coprime x (q::nat)"
and inj: "inj_on (λ b. (y + x*b) mod q) {..<q}"
shows "(λ b. (y + x*b) mod q) ` {..< q} = {..< q}"
apply(rule endo_inj_surj) using coprime inj by auto
lemma add_mult_one_time_pad:
assumes coprime: "coprime x q"
shows "map_spmf (λ b. (y + x*b) mod q) (sample_uniform q) = (sample_uniform q)"
using inj_on_add_mult surj_on_add_mult one_time_pad coprime by simp
lemma inj_on_minus: "inj_on (λ(b :: nat). (y + (q - b)) mod q ) {..<q}"
proof(unfold inj_on_def; auto)
fix x :: nat and y' :: nat
assume x: "x < q"
assume y': "y' < q"
assume map: "(y + q - x) mod q = (y + q - y') mod q"
have "∀n na p. ∃nb. ∀nc nd pa. (¬ (nc::nat) < nd ∨ ¬ pa (nc - nd) ∨ pa 0) ∧ (¬ p (0::nat) ∨ p (n - na) ∨ na + nb = n)"
by (metis (no_types) nat_diff_split)
hence "¬ y < y' - q ∧ ¬ y < x - q"
using y' x by (metis add.commute less_diff_conv not_add_less2)
hence "∃n. (y' + n) mod q = (n + x) mod q"
using map by (metis add.commute add_diff_inverse_nat less_diff_conv mod_add_left_eq)
thus "x = y'"
by (metis plus_inj_eq x y' add.commute)
qed
lemma surj_on_minus:
assumes inj: "inj_on (λ(b :: nat). (y + (q - b)) mod q ) {..<q}"
shows "(λ(b :: nat). (y + (q - b)) mod q) ` {..< q} = {..< q}"
apply(rule endo_inj_surj) using inj by auto
lemma samp_uni_minus_one_time_pad:
shows "map_spmf(λ b. (y + (q - b)) mod q) (sample_uniform q) = sample_uniform q"
using inj_on_minus surj_on_minus one_time_pad by simp
lemma not_coin_spmf: "map_spmf (λ a. ¬ a) coin_spmf = coin_spmf"
proof-
have "inj_on Not {True, False}"
by simp
moreover have "Not ` {True, False} = {True, False}"
by auto
ultimately show ?thesis using one_time_pad
by (simp add: UNIV_bool)
qed
lemma xor_uni_samp: "map_spmf(λ b. y ⊕ b) (coin_spmf) = map_spmf(λ b. b) (coin_spmf)"
(is "?lhs = ?rhs")
proof-
have rhs: "?rhs = spmf_of_set {True, False}"
by (simp add: UNIV_bool insert_commute)
also have "map_spmf(λ b. y ⊕ b) (spmf_of_set {True, False}) = spmf_of_set((λ b. y ⊕ b) ` {True, False})"
by (simp add: xor_def)
also have "(λ b. xor y b) ` {True, False} = {True, False}"
using xor_def by auto
finally show ?thesis using rhs by(simp)
qed
lemma ped_inv_mapping:
assumes "(a::nat) < q"
and "[m ≠ 0] (mod q)"
shows "map_spmf (λ d. (d + a * (m::nat)) mod q) (sample_uniform q) = map_spmf (λ d. (d + q * m - a * m) mod q) (sample_uniform q)"
(is "?lhs = ?rhs")
proof-
have ineq: "q * m - a * m > 0"
using assms gr0I by force
have "?lhs = map_spmf (λ d. (a * m + d) mod q) (sample_uniform q)"
using add.commute by metis
also have "... = sample_uniform q"
using samp_uni_plus_one_time_pad by simp
also have "... = map_spmf (λ d. ((q * m - a * m) + d) mod q) (sample_uniform q)"
using ineq samp_uni_plus_one_time_pad by metis
ultimately show ?thesis
using add.commute ineq
by (simp add: Groups.add_ac(2))
qed
end