Skip to content

Commit 42010e8

Browse files
Fix label persistence during the training phase, other minor changes
1 parent eb4be60 commit 42010e8

8 files changed

Lines changed: 40 additions & 30 deletions

File tree

nebula/addons/attacks/dataset/labelflipping.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
"""
99

1010
import copy
11+
import logging
1112
import random
12-
13-
import torch
13+
import numpy as np
1414

1515
from nebula.addons.attacks.dataset.datasetattack import DatasetAttack
1616

@@ -87,30 +87,41 @@ def labelFlipping(
8787
- In targeted mode, labels that match `target_label` are directly changed to `target_changed_label`.
8888
"""
8989
new_dataset = copy.deepcopy(dataset)
90+
if not isinstance(new_dataset.targets, np.ndarray):
91+
new_dataset.targets = np.array(new_dataset.targets)
92+
else:
93+
new_dataset.targets = new_dataset.targets.copy()
9094

91-
targets = torch.tensor(new_dataset.targets) if isinstance(new_dataset.targets, list) else new_dataset.targets
95+
# logging.info(f"[{self.__class__.__name__}] First 20 labels before flipping: {new_dataset.targets[:20]}")
96+
# logging.info(f"[{self.__class__.__name__}] First 20 indices before flipping: {indices[:20]}")
9297

93-
num_indices = len(indices)
94-
class_list = list(set(targets.tolist()))
9598
if not targeted:
99+
num_indices = len(indices)
96100
num_flipped = int(poisoned_percent * num_indices)
97-
if num_indices == 0:
98-
return new_dataset
99-
if num_flipped > num_indices:
100-
return new_dataset
101-
flipped_indice = random.sample(indices, num_flipped)
102-
103-
for i in flipped_indice:
104-
t = targets[i]
105-
flipped = torch.tensor(random.sample(class_list, 1)[0])
106-
while t == flipped:
107-
flipped = torch.tensor(random.sample(class_list, 1)[0])
108-
targets[i] = flipped
101+
if num_indices == 0 or num_flipped > num_indices:
102+
return
103+
flipped_indices = random.sample(indices, num_flipped)
104+
class_list = list(set(new_dataset.targets.tolist()))
105+
for i in flipped_indices:
106+
current_label = new_dataset.targets[i]
107+
new_label = random.choice(class_list)
108+
while new_label == current_label:
109+
new_label = random.choice(class_list)
110+
new_dataset.targets[i] = new_label
109111
else:
110112
for i in indices:
111-
if int(targets[i]) == int(target_label):
112-
targets[i] = torch.tensor(target_changed_label)
113-
new_dataset.targets = targets
113+
if int(new_dataset.targets[i]) == target_label:
114+
new_dataset.targets[i] = target_changed_label
115+
116+
if target_label in new_dataset.targets:
117+
logging.info(f"[{self.__class__.__name__}] Target label {target_label} still present after flipping.")
118+
else:
119+
logging.info(
120+
f"[{self.__class__.__name__}] Target label {target_label} successfully flipped to {target_changed_label}."
121+
)
122+
123+
# logging.info(f"[{self.__class__.__name__}] First 20 labels after flipping: {new_dataset.targets[:20]}")
124+
114125
return new_dataset
115126

116127
def get_malicious_dataset(self):

nebula/core/datasets/cifar10/cifar10.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __getitem__(self, idx):
2626

2727
# CIFAR10 from torchvision returns a tuple (image, target)
2828
if isinstance(data, tuple):
29-
img, target = data
29+
img, _ = data
3030
else:
3131
img = data
3232

nebula/core/datasets/cifar100/cifar100.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __getitem__(self, idx):
2626

2727
# CIFAR100 from torchvision returns a tuple (image, target)
2828
if isinstance(data, tuple):
29-
img, target = data
29+
img, _ = data
3030
else:
3131
img = data
3232

nebula/core/datasets/emnist/emnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __getitem__(self, idx):
2626

2727
# EMNIST from torchvision returns a tuple (image, target)
2828
if isinstance(data, tuple):
29-
img, target = data
29+
img, _ = data
3030
else:
3131
img = data
3232

nebula/core/datasets/fashionmnist/fashionmnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __getitem__(self, idx):
2222

2323
# FashionMNIST from torchvision returns a tuple (image, target)
2424
if isinstance(data, tuple):
25-
img, target = data
25+
img, _ = data
2626
else:
2727
img = data
2828

nebula/core/datasets/mnist/mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __getitem__(self, idx):
2222

2323
# MNIST from torchvision returns a tuple (image, target)
2424
if isinstance(data, tuple):
25-
img, target = data
25+
img, _ = data
2626
else:
2727
img = data
2828

nebula/core/datasets/nebuladataset.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def __len__(self):
9090

9191
def __getitem__(self, idx):
9292
data = self.data[idx]
93-
target = self.targets[idx]
93+
# Persist the modified targets (if any) during the training process
94+
target = self.targets[idx] if hasattr(self, "targets") and self.targets is not None else None
9495
return data, target
9596

9697
def set_data(self, data, targets, data_opt=None, targets_opt=None):
@@ -269,9 +270,7 @@ def load_partition(self):
269270
self.test_set = self.handler(test_partition_file, "test", config=self.config)
270271
self.test_indices = list(range(len(self.test_set)))
271272

272-
self.local_test_set = self.handler(
273-
test_partition_file, "local_test", config=self.config, empty=True
274-
)
273+
self.local_test_set = self.handler(test_partition_file, "local_test", config=self.config, empty=True)
275274
self.local_test_set.set_data(self.test_set.data, self.test_set.targets)
276275
self.local_test_indices = self.set_local_test_indices()
277276

nebula/frontend/templates/deployment.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ <h5 id="selection-interval-title" class="step-title" style="display: none;">Targ
588588
<h5 id="start-attack-title" class="step-title" style="display: none;">Starting round</h5>
589589
<div class="form-check form-check-inline" id="start-attack-container" style="display: none;">
590590
<input type="number" class="form-control" id="start-attack"
591-
placeholder="Starting round" min="1" value="1"
591+
placeholder="Starting round" min="0" value="1"
592592
style="display: inline; width: 80%">
593593
</div>
594594
<h5 id="stop-attack-title" class="step-title" style="display: none;">Stopping round</h5>

0 commit comments

Comments
 (0)