-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdp_preparer.py
More file actions
55 lines (39 loc) · 1.93 KB
/
dp_preparer.py
File metadata and controls
55 lines (39 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# written by Lukas Abfalterer in 2021 (labfalterer a.t. student.ethz.ch)
import torch
class SymmetricNoise(torch.nn.Module):
"""Symmetric Differential Privacy Preparer"""
def __init__(self, element_size, range_begin, args, device):
super().__init__()
self.element_size = element_size
self.Delta = int(element_size / (2 * range_begin))
self.device = device
def forward(self, p):
A = torch.zeros(int(self.element_size + self.Delta), device=self.device)
B = torch.zeros(int(self.element_size + self.Delta), device=self.device)
A[: self.element_size] = p
B[-self.element_size :] = p
p_A_slice = p[: self.element_size - self.Delta]
p_B_slice = p[self.Delta :]
dist_events = torch.sum(p[self.element_size - self.Delta :])
dist_events_dual = torch.sum(p[: self.Delta])
return p_A_slice, p_B_slice, dist_events, dist_events_dual, A, B
class MixtureNoise(torch.nn.Module):
"""Mixture Differential Privacy Preparer"""
def __init__(self, element_size, range_begin, args, device):
super().__init__()
self.element_size = element_size
self.Delta = int(element_size / (2 * range_begin))
self.q = torch.tensor(args.mixture_q, device=device)
self.device = device
def forward(self, p):
A = torch.zeros(int(self.element_size + self.Delta), device=self.device)
B = torch.zeros(int(self.element_size + self.Delta), device=self.device)
A[: self.element_size] = p
B[: self.element_size] = (1.0 - self.q) * p
B[-self.element_size :] += self.q * p
B /= torch.sum(B)
p_A_slice = A[: self.element_size]
p_B_slice = B[: self.element_size]
dist_events = torch.tensor(0.0, device=self.device) # (A[B==0]) A/B
dist_events_dual = torch.sum(B[-self.Delta :]) # (B[A==0]) B/A
return p_A_slice, p_B_slice, dist_events, dist_events_dual, A, B