Skip to content

Commit 7796369

Browse files
Support gradient norm computation with respect to a subset of variables.
PiperOrigin-RevId: 519245638
1 parent d5d60e2 commit 7796369

3 files changed

Lines changed: 61 additions & 11 deletions

File tree

tensorflow_privacy/privacy/fast_gradient_clipping/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ py_library(
3535

3636
py_test(
3737
name = "clip_grads_test",
38+
size = "large",
3839
srcs = ["clip_grads_test.py"],
3940
python_version = "PY3",
41+
shard_count = 8,
4042
srcs_version = "PY3",
4143
deps = [
4244
":clip_grads",

tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,11 @@ def registry_generator_fn(layer_instance, args, kwargs):
5454
(layer_vars, layer_outputs, layer_sqr_norm_fn) = registry_fn(
5555
layer_instance, args, kwargs, tape, num_microbatches
5656
)
57-
return layer_outputs, (layer_vars, layer_sqr_norm_fn)
57+
return layer_outputs, (
58+
layer_vars,
59+
layer_sqr_norm_fn,
60+
layer_instance.trainable_weights,
61+
)
5862
else:
5963
# Non-trainable layer.
6064
return layer_instance(*args, **kwargs), None
@@ -69,6 +73,7 @@ def compute_gradient_norms(
6973
layer_registry: lr.LayerRegistry,
7074
per_example_loss_fn: Optional[Callable[[tf.Tensor, Any], tf.Tensor]] = None,
7175
num_microbatches: Optional[lr.BatchSize] = None,
76+
trainable_vars: Optional[List[tf.Variable]] = None,
7277
):
7378
"""Computes the per-example loss gradient norms for given data.
7479
@@ -96,6 +101,10 @@ def compute_gradient_norms(
96101
of num_microbatches). When there is microbatches, we always assume the
97102
loss is the mean over a microbatch. And the gradient norm is computed for
98103
each microbatch.
104+
trainable_vars: The list of variables included in computing the gradient
105+
norm. When a layer has multiple variables, we include all the variables if
106+
any of the variables is in the list. If `trainable_vars` is None, all the
107+
variables are included.
99108
100109
Returns:
101110
A 1D `tf.Tensor` whose i-th entry is the norm of the gradient of the i-th
@@ -126,8 +135,19 @@ def compute_gradient_norms(
126135
# Unwrap the generator outputs so that the next loop avoids duplicating
127136
# backprop ops.
128137
filtered_outputs = [t for t in generator_outputs_list if t is not None]
129-
vars_list = [a for (a, b) in filtered_outputs]
130-
sqr_norm_fns_list = [b for (a, b) in filtered_outputs]
138+
vars_list = []
139+
sqr_norm_fns_list = []
140+
if trainable_vars is not None:
141+
# Create a set using `ref()` for fast set membership check. tf.Variable
142+
# itself is not hashable.
143+
trainable_vars = set([v.ref() for v in trainable_vars])
144+
for v, f, weights_list in filtered_outputs:
145+
if trainable_vars is None or any(
146+
w.ref() in trainable_vars for w in weights_list
147+
):
148+
# Include only those variables in trainable_vars.
149+
vars_list.append(v)
150+
sqr_norm_fns_list.append(f)
131151
# Second loop evaluates the squared L2 norm functions and appends the results.
132152
grads_list = tape.gradient(
133153
summed_loss,
@@ -218,6 +238,7 @@ def compute_clipped_gradients_and_outputs(
218238
y_batch,
219239
layer_registry,
220240
num_microbatches=num_microbatches,
241+
trainable_vars=input_model.trainable_variables,
221242
)
222243
loss_weights = compute_clip_weights(l2_norm_clip, gradient_norms)
223244
with tf.GradientTape() as tape:

tensorflow_privacy/privacy/fast_gradient_clipping/clip_grads_test.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def compute_true_gradient_norms(
8484
y_batch: tf.Tensor,
8585
per_example_loss_fn: Optional[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]],
8686
num_microbatches: Optional[int],
87+
trainable_vars: Optional[tf.Variable] = None,
8788
) -> layer_registry.OutputTensor:
8889
"""Computes the real gradient norms for an input `(model, x, y)`."""
8990
if per_example_loss_fn is None:
@@ -104,7 +105,8 @@ def compute_true_gradient_norms(
104105
if isinstance(loss, tf.RaggedTensor):
105106
loss = loss.to_tensor()
106107
sqr_norms = []
107-
for var in input_model.trainable_variables:
108+
trainable_vars = trainable_vars or input_model.trainable_variables
109+
for var in trainable_vars:
108110
jacobian = tape.jacobian(loss, var, experimental_use_pfor=False)
109111
reduction_axes = tf.range(1, len(jacobian.shape))
110112
sqr_norm = tf.reduce_sum(tf.square(jacobian), axis=reduction_axes)
@@ -124,6 +126,7 @@ def get_computed_and_true_norms(
124126
x_input: tf.Tensor,
125127
rng_seed: int = 777,
126128
registry: layer_registry.LayerRegistry = None,
129+
partial: bool = False,
127130
) -> Tuple[tf.Tensor, tf.Tensor]:
128131
"""Obtains the true and computed gradient norms for a model and batch input.
129132
@@ -147,6 +150,8 @@ def get_computed_and_true_norms(
147150
x_input: `tf.Tensor` inputs to be tested.
148151
rng_seed: An `int` used to initialize model weights.
149152
registry: A `layer_registry.LayerRegistry` instance.
153+
partial: Whether to compute the gradient norm with respect to a partial set
154+
of varibles. If True, only consider the variables in the first layer.
150155
151156
Returns:
152157
A `tuple` `(computed_norm, true_norms)`. The first element contains the
@@ -163,6 +168,13 @@ def get_computed_and_true_norms(
163168
),
164169
run_eagerly=is_eager,
165170
)
171+
trainable_vars = None
172+
if partial:
173+
# Gets the first layer with variables.
174+
for l in model.layers:
175+
trainable_vars = l.trainable_variables
176+
if trainable_vars:
177+
break
166178
y_pred = model(x_input)
167179
y_batch = tf.ones_like(y_pred)
168180
tf.keras.utils.set_random_seed(rng_seed)
@@ -173,10 +185,16 @@ def get_computed_and_true_norms(
173185
layer_registry=registry,
174186
per_example_loss_fn=per_example_loss_fn,
175187
num_microbatches=num_microbatches,
188+
trainable_vars=trainable_vars,
176189
)
177190
tf.keras.utils.set_random_seed(rng_seed)
178191
true_norms = compute_true_gradient_norms(
179-
model, x_input, y_batch, per_example_loss_fn, num_microbatches
192+
model,
193+
x_input,
194+
y_batch,
195+
per_example_loss_fn,
196+
num_microbatches,
197+
trainable_vars=trainable_vars,
180198
)
181199
return (computed_norms, true_norms)
182200

@@ -360,10 +378,11 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
360378
model_name=list(get_dense_model_generators().keys()),
361379
layer_name=list(get_dense_layer_generators().keys()),
362380
input_dim=[4],
363-
output_dim=[1, 2],
381+
output_dim=[2],
364382
per_example_loss_fn=[None, test_loss_fn],
365383
num_microbatches=[None, 1, 2],
366384
is_eager=[True, False],
385+
partial=[True, False],
367386
)
368387
def test_gradient_norms_on_various_models(
369388
self,
@@ -374,6 +393,7 @@ def test_gradient_norms_on_various_models(
374393
per_example_loss_fn,
375394
num_microbatches,
376395
is_eager,
396+
partial,
377397
):
378398
model_generator = get_dense_model_generators()[model_name]
379399
layer_generator = get_dense_layer_generators()[layer_name]
@@ -399,6 +419,7 @@ def test_gradient_norms_on_various_models(
399419
is_eager,
400420
x_input,
401421
registry=default_registry,
422+
partial=partial,
402423
)
403424
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
404425

@@ -436,8 +457,9 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
436457
model_name=list(get_embedding_model_generators().keys()),
437458
output_dim=[2],
438459
per_example_loss_fn=[None, test_loss_fn],
439-
num_microbatches=[None, 1, 2],
440-
is_eager=[True],
460+
num_microbatches=[None, 2],
461+
is_eager=[True, False],
462+
partial=[True, False],
441463
)
442464
def test_gradient_norms_on_various_models(
443465
self,
@@ -447,6 +469,7 @@ def test_gradient_norms_on_various_models(
447469
per_example_loss_fn,
448470
num_microbatches,
449471
is_eager,
472+
partial,
450473
):
451474
if (
452475
num_microbatches is not None
@@ -470,18 +493,20 @@ def test_gradient_norms_on_various_models(
470493
is_eager=is_eager,
471494
x_input=x_batch,
472495
registry=default_registry,
496+
partial=partial,
473497
)
474498
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
475499

476500

477501
class ClipGradsCustomLayerTest(tf.test.TestCase, parameterized.TestCase):
478502

479503
@parameterized.product(
480-
input_dim=[1, 2],
481-
output_dim=[1, 2],
504+
input_dim=[3],
505+
output_dim=[2],
482506
per_example_loss_fn=[None, test_loss_fn],
483-
num_microbatches=[None, 1, 2],
507+
num_microbatches=[None, 2],
484508
is_eager=[True, False],
509+
partial=[True, False],
485510
)
486511
def test_gradient_norms_on_various_models(
487512
self,
@@ -490,6 +515,7 @@ def test_gradient_norms_on_various_models(
490515
per_example_loss_fn,
491516
num_microbatches,
492517
is_eager,
518+
partial,
493519
):
494520
registry = layer_registry.make_default_layer_registry()
495521
registry.insert(DoubleDense, double_dense_layer_computation)
@@ -510,6 +536,7 @@ def test_gradient_norms_on_various_models(
510536
is_eager=is_eager,
511537
x_input=x_batch,
512538
registry=registry,
539+
partial=partial,
513540
)
514541
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
515542

0 commit comments

Comments
 (0)