1313# limitations under the License.
1414"""Keras Model for vectorized dpsgd with XLA acceleration."""
1515
16+ import dataclasses
17+
1618from absl import logging
1719import tensorflow as tf
1820from tensorflow_privacy .privacy .fast_gradient_clipping import clip_grads
1921from tensorflow_privacy .privacy .fast_gradient_clipping import common_manip_utils
2022from tensorflow_privacy .privacy .fast_gradient_clipping import gradient_clipping_utils
2123from tensorflow_privacy .privacy .fast_gradient_clipping import noise_utils
24+ from tensorflow_privacy .privacy .sparsity_preserving_noise import layer_registry as snlr
25+ from tensorflow_privacy .privacy .sparsity_preserving_noise import sparse_noise_utils
26+
2227
2328_PRIVATIZED_LOSS_NAME = 'privatized_loss'
2429
2530
31+ @dataclasses .dataclass
32+ class SparsityPreservingDPSGDConfig :
33+ """Config for adding sparsity preserving noise to the gradients."""
34+
35+ # The ratio of how the noise is split between partition selection and gradient
36+ # noise.
37+ sparse_selection_ratio : float = 0.0
38+ # The threshold to use for private partition selection.
39+ sparse_selection_threshold : int = 100
40+ # A `LayerRegistry` instance containing functions that help compute
41+ # contribution counts for sparse layers. See
42+ # `tensorflow_privacy.privacy.sparsity_preserving_noise.layer_registry` for
43+ # more details.
44+ sparse_selection_layer_registry : snlr .LayerRegistry | None = None
45+
46+
2647def make_dp_model_class (cls ):
2748 """Given a subclass of `tf.keras.Model`, returns a DP-SGD version of it."""
2849
@@ -104,6 +125,9 @@ def __init__(
104125 num_microbatches = None ,
105126 use_xla = True ,
106127 layer_registry = None ,
128+ sparsity_preserving_dpsgd_config : (
129+ SparsityPreservingDPSGDConfig | None
130+ ) = None ,
107131 * args , # pylint: disable=keyword-arg-before-vararg, g-doc-args
108132 ** kwargs ,
109133 ):
@@ -118,6 +142,9 @@ def __init__(
118142 help compute gradient norms quickly. See
119143 `tensorflow_privacy.privacy.fast_gradient_clipping.layer_registry` for
120144 more details.
145+ sparsity_preserving_dpsgd_config: If provided, uses partition selection
146+ and sparse noise for privatizing sparse gradients for layers in
147+ `sparsity_preserving_dpsgd_config.sparse_selection_layer_registry`.
121148 *args: These will be passed on to the base class `__init__` method.
122149 **kwargs: These will be passed on to the base class `__init__` method.
123150 """
@@ -127,6 +154,8 @@ def __init__(
127154 self ._layer_registry = layer_registry
128155 self ._clipping_loss = None
129156
157+ self ._sparsity_preserving_dpsgd_config = sparsity_preserving_dpsgd_config
158+
130159 # Given that `num_microbatches` was added as an argument after the fact,
131160 # this check helps detect unintended calls to the earlier API.
132161 # In particular, boolean values supplied to `use_xla` in the earlier API
@@ -276,11 +305,16 @@ def train_step(self, data):
276305 # microbatches is done here.
277306 tape = tf .GradientTape (persistent = True , watch_accessed_variables = False )
278307
308+ sparse_noise_layer_registry = None
309+ if self ._sparsity_preserving_dpsgd_config is not None :
310+ sparse_noise_layer_registry = (
311+ self ._sparsity_preserving_dpsgd_config .sparse_selection_layer_registry
312+ )
279313 registry_generator_fn = (
280314 gradient_clipping_utils .get_registry_generator_fn (
281315 tape = tape ,
282316 layer_registry = self ._layer_registry ,
283- sparse_noise_layer_registry = None ,
317+ sparse_noise_layer_registry = sparse_noise_layer_registry ,
284318 num_microbatches = num_microbatches ,
285319 )
286320 )
@@ -310,14 +344,53 @@ def train_step(self, data):
310344 )
311345 )
312346 output_metrics [_PRIVATIZED_LOSS_NAME ] = clipping_loss
313- if self ._noise_multiplier > 0 :
347+ noise_multiplier , noise_multiplier_sparse = self ._noise_multiplier , None
348+ contribution_counts = None
349+ if self ._sparsity_preserving_dpsgd_config is not None :
350+ logging .info ('Using sparse noise.' )
351+
352+ varname_to_contribution_counts_fns = (
353+ sparse_noise_utils .extract_varname_to_contribution_counts_fns (
354+ registry_fn_outputs_list ,
355+ self .trainable_variables ,
356+ )
357+ )
358+ contribution_counts = sparse_noise_utils .get_contribution_counts (
359+ self .trainable_variables ,
360+ clipped_grads ,
361+ varname_to_contribution_counts_fns ,
362+ )
363+
364+ noise_multiplier_sparse , noise_multiplier = (
365+ sparse_noise_utils .split_noise_multiplier (
366+ noise_multiplier ,
367+ self ._sparsity_preserving_dpsgd_config .sparse_selection_ratio ,
368+ contribution_counts ,
369+ )
370+ )
371+ logging .info (
372+ 'Split noise multiplier for gradient noise: %s and partition'
373+ ' selection: %s' ,
374+ noise_multiplier ,
375+ noise_multiplier_sparse ,
376+ )
377+
378+ if noise_multiplier > 0 :
379+ sparse_noise_config = None
380+ if self ._sparsity_preserving_dpsgd_config is not None :
381+ sparse_noise_config = noise_utils .SparsityPreservingNoiseConfig (
382+ sparse_noise_multiplier = noise_multiplier_sparse ,
383+ sparse_selection_threshold = self ._sparsity_preserving_dpsgd_config .sparse_selection_threshold ,
384+ sparse_contribution_counts = contribution_counts ,
385+ )
314386 grads = noise_utils .add_aggregate_noise (
315387 clipped_grads ,
316388 num_microbatches ,
317389 self ._l2_norm_clip ,
318- self . _noise_multiplier ,
390+ noise_multiplier ,
319391 loss_reduction = None ,
320392 loss_model = self ,
393+ sparse_noise_config = sparse_noise_config ,
321394 )
322395 else :
323396 grads = clipped_grads
0 commit comments