Skip to content

Commit be5d9be

Browse files
avoid issues with tensors in some datasets
1 parent a09d3c2 commit be5d9be

1 file changed

Lines changed: 33 additions & 12 deletions

File tree

nebula/addons/attacks/dataset/datapoison.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,16 @@ def apply_noise(self, t, noise_type, poisoned_ratio):
8989
- Noise for types "salt", "gaussian", and "s&p" is generated using `random_noise` from
9090
the `skimage.util` package, and returned as a `torch.Tensor`.
9191
"""
92+
arr = t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else np.array(t)
93+
9294
if noise_type == "salt":
93-
return torch.tensor(random_noise(t, mode=noise_type, amount=poisoned_ratio))
95+
return torch.tensor(random_noise(arr, mode=noise_type, amount=poisoned_ratio))
9496
elif noise_type == "gaussian":
95-
return torch.tensor(random_noise(t, mode=noise_type, mean=0, var=poisoned_ratio, clip=True))
97+
return torch.tensor(random_noise(arr, mode=noise_type, mean=0, var=poisoned_ratio, clip=True))
9698
elif noise_type == "s&p":
97-
return torch.tensor(random_noise(t, mode=noise_type, amount=poisoned_ratio))
99+
return torch.tensor(random_noise(arr, mode=noise_type, amount=poisoned_ratio))
98100
elif noise_type == "nlp_rawdata":
99-
return self.poison_to_nlp_rawdata(t, poisoned_ratio)
101+
return self.poison_to_nlp_rawdata(arr, poisoned_ratio)
100102
else:
101103
logging.info(f"ERROR: noise_type '{noise_type}' not supported in data poison attack.")
102104
return t
@@ -144,8 +146,11 @@ def datapoison(
144146
- Targeted poisoning modifies only samples with `target_label` by adding an 'X' pattern, regardless of `poisoned_ratio`.
145147
"""
146148
new_dataset = copy.deepcopy(dataset)
147-
train_data = new_dataset.data
148-
targets = new_dataset.targets
149+
if not isinstance(new_dataset.targets, np.ndarray):
150+
new_dataset.targets = np.array(new_dataset.targets)
151+
else:
152+
new_dataset.targets = new_dataset.targets.copy()
153+
149154
num_indices = len(indices)
150155
if not isinstance(noise_type, str):
151156
noise_type = noise_type[0]
@@ -157,18 +162,34 @@ def datapoison(
157162
if num_poisoned > num_indices:
158163
return new_dataset
159164
poisoned_indice = random.sample(indices, num_poisoned)
165+
logging.info(f"Number of poisoned samples: {num_poisoned}")
160166

161167
for i in poisoned_indice:
162-
t = train_data[i]
168+
t = new_dataset.data[i]
169+
if isinstance(t, tuple):
170+
t = t[0]
163171
poisoned = self.apply_noise(t, noise_type, poisoned_ratio)
164-
train_data[i] = poisoned
172+
if isinstance(t, tuple):
173+
poisoned = (poisoned, t[1])
174+
if isinstance(poisoned, torch.Tensor):
175+
poisoned = poisoned.detach().clone()
176+
if len(poisoned.shape) == 0:
177+
poisoned = poisoned.view(-1)
178+
new_dataset.data[i] = poisoned
165179
else:
166180
for i in indices:
167-
if int(targets[i]) == int(target_label):
168-
t = train_data[i]
181+
if int(new_dataset.targets[i]) == int(target_label):
182+
t = new_dataset.data[i]
183+
if isinstance(t, tuple):
184+
t = t[0]
185+
if isinstance(t, torch.Tensor):
186+
t = t.detach().clone()
187+
if len(t.shape) == 0:
188+
t = t.view(-1)
169189
poisoned = self.add_x_to_image(t)
170-
train_data[i] = poisoned
171-
new_dataset.data = train_data
190+
if isinstance(t, tuple):
191+
poisoned = (poisoned, t[1])
192+
new_dataset.data[i] = poisoned
172193
return new_dataset
173194

174195
def add_x_to_image(self, img):

0 commit comments

Comments
 (0)