Skip to content

Commit b396397

Browse files
Sparsity Preserving DP-SGD in TF Privacy [5 of 5]
Integrate sparsity preserving noise into DP Keras Model with fast gradient clipping. See https://research.google/blog/sparsity-preserving-differentially-private-training/ for more details on the algorithm. PiperOrigin-RevId: 666849100
1 parent 93c7e54 commit b396397

6 files changed

Lines changed: 86 additions & 16 deletions

File tree

tensorflow_privacy/privacy/fast_gradient_clipping/type_aliases.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
# Tensorflow aliases.
2222
Tensor = Union[tf.Tensor, tf.IndexedSlices, tf.SparseTensor, tf.RaggedTensor]
2323

24-
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Mapping[str, tf.Tensor]]
24+
PackedTensors = Union[Tensor, Iterable[Tensor], Mapping[str, Tensor]]
2525

2626
InputTensors = PackedTensors
2727

28-
OutputTensors = Union[tf.Tensor, Iterable[tf.Tensor]]
28+
OutputTensors = Union[Tensor, Iterable[Tensor]]
2929

3030
BatchSize = Union[int, tf.Tensor]
3131

tensorflow_privacy/privacy/keras_models/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ py_library(
1818
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_manip_utils",
1919
"//tensorflow_privacy/privacy/fast_gradient_clipping:gradient_clipping_utils",
2020
"//tensorflow_privacy/privacy/fast_gradient_clipping:noise_utils",
21+
"//tensorflow_privacy/privacy/sparsity_preserving_noise:layer_registry",
22+
"//tensorflow_privacy/privacy/sparsity_preserving_noise:sparse_noise_utils",
2123
],
2224
)
2325

tensorflow_privacy/privacy/keras_models/dp_keras_model.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,37 @@
1313
# limitations under the License.
1414
"""Keras Model for vectorized dpsgd with XLA acceleration."""
1515

16+
import dataclasses
17+
1618
from absl import logging
1719
import tensorflow as tf
1820
from tensorflow_privacy.privacy.fast_gradient_clipping import clip_grads
1921
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
2022
from tensorflow_privacy.privacy.fast_gradient_clipping import gradient_clipping_utils
2123
from tensorflow_privacy.privacy.fast_gradient_clipping import noise_utils
24+
from tensorflow_privacy.privacy.sparsity_preserving_noise import layer_registry as snlr
25+
from tensorflow_privacy.privacy.sparsity_preserving_noise import sparse_noise_utils
26+
2227

2328
_PRIVATIZED_LOSS_NAME = 'privatized_loss'
2429

2530

31+
@dataclasses.dataclass
32+
class SparsityPreservingDPSGDConfig:
33+
"""Config for adding sparsity preserving noise to the gradients."""
34+
35+
# The ratio of how the noise is split between partition selection and gradient
36+
# noise.
37+
sparse_selection_ratio: float = 0.0
38+
# The threshold to use for private partition selection.
39+
sparse_selection_threshold: int = 100
40+
# A `LayerRegistry` instance containing functions that help compute
41+
# contribution counts for sparse layers. See
42+
# `tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for
43+
# more details.
44+
sparse_selection_layer_registry: snlr.LayerRegistry | None = None
45+
46+
2647
def make_dp_model_class(cls):
2748
"""Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
2849

@@ -104,6 +125,9 @@ def __init__(
104125
num_microbatches=None,
105126
use_xla=True,
106127
layer_registry=None,
128+
sparsity_preserving_dpsgd_config: (
129+
SparsityPreservingDPSGDConfig | None
130+
) = None,
107131
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
108132
**kwargs,
109133
):
@@ -118,6 +142,9 @@ def __init__(
118142
help compute gradient norms quickly. See
119143
`tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
120144
more details.
145+
sparsity_preserving_dpsgd_config: If provided, uses partition selection
146+
and sparse noise for privatizing sparse gradients for layers in
147+
`sparsity_preserving_dpsgd_config.sparse_selection_layer_registry`.
121148
*args: These will be passed on to the base class `__init__` method.
122149
**kwargs: These will be passed on to the base class `__init__` method.
123150
"""
@@ -127,6 +154,8 @@ def __init__(
127154
self._layer_registry = layer_registry
128155
self._clipping_loss = None
129156

157+
self._sparsity_preserving_dpsgd_config = sparsity_preserving_dpsgd_config
158+
130159
# Given that `num_microbatches` was added as an argument after the fact,
131160
# this check helps detect unintended calls to the earlier API.
132161
# In particular, boolean values supplied to `use_xla` in the earlier API
@@ -276,11 +305,16 @@ def train_step(self, data):
276305
# microbatches is done here.
277306
tape = tf.GradientTape(persistent=True, watch_accessed_variables=False)
278307

308+
sparse_noise_layer_registry = None
309+
if self._sparsity_preserving_dpsgd_config is not None:
310+
sparse_noise_layer_registry = (
311+
self._sparsity_preserving_dpsgd_config.sparse_selection_layer_registry
312+
)
279313
registry_generator_fn = (
280314
gradient_clipping_utils.get_registry_generator_fn(
281315
tape=tape,
282316
layer_registry=self._layer_registry,
283-
sparse_noise_layer_registry=None,
317+
sparse_noise_layer_registry=sparse_noise_layer_registry,
284318
num_microbatches=num_microbatches,
285319
)
286320
)
@@ -310,14 +344,53 @@ def train_step(self, data):
310344
)
311345
)
312346
output_metrics[_PRIVATIZED_LOSS_NAME] = clipping_loss
313-
if self._noise_multiplier > 0:
347+
noise_multiplier, noise_multiplier_sparse = self._noise_multiplier, None
348+
contribution_counts = None
349+
if self._sparsity_preserving_dpsgd_config is not None:
350+
logging.info('Using sparse noise.')
351+
352+
varname_to_contribution_counts_fns = (
353+
sparse_noise_utils.extract_varname_to_contribution_counts_fns(
354+
registry_fn_outputs_list,
355+
self.trainable_variables,
356+
)
357+
)
358+
contribution_counts = sparse_noise_utils.get_contribution_counts(
359+
self.trainable_variables,
360+
clipped_grads,
361+
varname_to_contribution_counts_fns,
362+
)
363+
364+
noise_multiplier_sparse, noise_multiplier = (
365+
sparse_noise_utils.split_noise_multiplier(
366+
noise_multiplier,
367+
self._sparsity_preserving_dpsgd_config.sparse_selection_ratio,
368+
contribution_counts,
369+
)
370+
)
371+
logging.info(
372+
'Split noise multiplier for gradient noise: %s and partition'
373+
' selection: %s',
374+
noise_multiplier,
375+
noise_multiplier_sparse,
376+
)
377+
378+
if noise_multiplier > 0:
379+
sparse_noise_config = None
380+
if self._sparsity_preserving_dpsgd_config is not None:
381+
sparse_noise_config = noise_utils.SparsityPreservingNoiseConfig(
382+
sparse_noise_multiplier=noise_multiplier_sparse,
383+
sparse_selection_threshold=self._sparsity_preserving_dpsgd_config.sparse_selection_threshold,
384+
sparse_contribution_counts=contribution_counts,
385+
)
314386
grads = noise_utils.add_aggregate_noise(
315387
clipped_grads,
316388
num_microbatches,
317389
self._l2_norm_clip,
318-
self._noise_multiplier,
390+
noise_multiplier,
319391
loss_reduction=None,
320392
loss_model=self,
393+
sparse_noise_config=sparse_noise_config,
321394
)
322395
else:
323396
grads = clipped_grads

tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def sample_true_positive_indices(
166166
tf.shape(contribution_count_values),
167167
mean=0.0,
168168
stddev=noise_multiplier,
169-
dtype=tf.float32,
169+
dtype=contribution_count_values.dtype,
170170
)
171171
)
172172
noised_contribution_counts_indices = contribution_counts.indices[
@@ -281,7 +281,7 @@ def add_sparse_gradient_noise(
281281
"""
282282
filtered_grad_values = tf.gather(grad, indices)
283283
sparse_noise_values = tf.random.normal(
284-
filtered_grad_values.shape, mean=0.0, stddev=noise_stddev
284+
tf.shape(filtered_grad_values), mean=0.0, stddev=noise_stddev
285285
)
286286
filtered_noised_grad_values = filtered_grad_values + sparse_noise_values
287287
return tf.IndexedSlices(
@@ -362,15 +362,10 @@ def get_contribution_counts(
362362
if var.name not in varname_to_contribution_counts_fns:
363363
contribution_counts_list.append(None)
364364
continue
365-
contribution_counts_fns = varname_to_contribution_counts_fns[var.name]
366-
if not contribution_counts_fns or not contribution_counts_fns[0]:
365+
contribution_counts_fn = varname_to_contribution_counts_fns[var.name]
366+
if not contribution_counts_fn:
367367
contribution_counts_list.append(None)
368368
continue
369-
if len(contribution_counts_fns) > 1:
370-
raise NotImplementedError(
371-
'Sparse noise is not supported for shared weight variables.'
372-
)
373-
contribution_counts_fn = contribution_counts_fns[0]
374369
contribution_counts = contribution_counts_fn(grad)
375370
contribution_counts_list.append(contribution_counts)
376371

tensorflow_privacy/privacy/sparsity_preserving_noise/sparse_noise_utils_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def test_get_contribution_counts(self):
369369
tf.ones((1, 2)),
370370
]
371371
varname_to_contribution_counts_fns = {
372-
'var1:0': [lambda grad: 1.0],
372+
'var1:0': lambda grad: 1.0,
373373
'var2:0': None,
374374
}
375375
contribution_counts = sparse_noise_utils.get_contribution_counts(

tensorflow_privacy/privacy/sparsity_preserving_noise/type_aliases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
InputArgs = Sequence[Any]
2121
InputKwargs = Mapping[str, Any]
22-
SparseGradient = tf.IndexedSlices
22+
SparseGradient = tf.IndexedSlices | tf.SparseTensor
2323
ContributionCountHistogram = tf.SparseTensor
2424
ContributionCountHistogramFn = Callable[
2525
[SparseGradient], ContributionCountHistogram

0 commit comments

Comments
 (0)