Skip to content

Commit e334633

Browse files
Bugfix.
PiperOrigin-RevId: 486344068
1 parent f7e1e61 commit e334633

2 files changed

Lines changed: 10 additions & 7 deletions

File tree

tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,19 @@ def _slice_data_by_indices(data: AttackInputData, idx_train,
4242
result.labels_train = _slice_if_not_none(data.labels_train, idx_train)
4343
result.loss_train = _slice_if_not_none(data.loss_train, idx_train)
4444
result.entropy_train = _slice_if_not_none(data.entropy_train, idx_train)
45-
# Copy over sample weights if provided.
46-
result.sample_weight_train = data.sample_weight_train
47-
result.sample_weight_test = data.sample_weight_test
45+
# Slice sample weights if provided.
46+
result.sample_weight_train = _slice_if_not_none(data.sample_weight_train,
47+
idx_train)
4848

4949
# Slice test data.
5050
result.logits_test = _slice_if_not_none(data.logits_test, idx_test)
5151
result.probs_test = _slice_if_not_none(data.probs_test, idx_test)
5252
result.labels_test = _slice_if_not_none(data.labels_test, idx_test)
5353
result.loss_test = _slice_if_not_none(data.loss_test, idx_test)
5454
result.entropy_test = _slice_if_not_none(data.entropy_test, idx_test)
55+
# Slice sample weights if provided.
56+
result.sample_weight_test = _slice_if_not_none(data.sample_weight_test,
57+
idx_test)
5558

5659
# A slice has the same multilabel status as the original data. This is because
5760
# of the way multilabel status is computed. A dataset is multilabel if at

tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/dataset_slicing_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def __init__(self, methodname):
115115
loss_test = np.array([0.5, 3.5, 7, 4.5])
116116
entropy_train = np.array([0.4, 8, 0.6, 10])
117117
entropy_test = np.array([15, 10.5, 4.5, 0.3])
118-
sample_weight_train = np.array([1.0, 0.5])
119-
sample_weight_test = np.array([0.5, 1.0])
118+
sample_weight_train = np.array([1.0, 0.2, 0.5, 0.8])
119+
sample_weight_test = np.array([0.5, 1.0, 0.1, 0.8])
120120

121121
self.input_data = AttackInputData(
122122
logits_train=logits_train,
@@ -175,8 +175,8 @@ def test_slice_by_class(self):
175175
# Check sample weights
176176
self.assertLen(output.sample_weight_train, 2)
177177
np.testing.assert_array_equal(output.sample_weight_train, [1.0, 0.5])
178-
self.assertLen(output.sample_weight_test, 2)
179-
np.testing.assert_array_equal(output.sample_weight_test, [0.5, 1.0])
178+
self.assertLen(output.sample_weight_test, 1)
179+
np.testing.assert_array_equal(output.sample_weight_test, [0.5])
180180

181181
def test_slice_by_percentile(self):
182182
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))

0 commit comments

Comments
 (0)