@@ -84,6 +84,7 @@ def compute_true_gradient_norms(
8484 y_batch : tf .Tensor ,
8585 per_example_loss_fn : Optional [Callable [[tf .Tensor , tf .Tensor ], tf .Tensor ]],
8686 num_microbatches : Optional [int ],
87+ trainable_vars : Optional [tf .Variable ] = None ,
8788) -> layer_registry .OutputTensor :
8889 """Computes the real gradient norms for an input `(model, x, y)`."""
8990 if per_example_loss_fn is None :
@@ -104,7 +105,8 @@ def compute_true_gradient_norms(
104105 if isinstance (loss , tf .RaggedTensor ):
105106 loss = loss .to_tensor ()
106107 sqr_norms = []
107- for var in input_model .trainable_variables :
108+ trainable_vars = trainable_vars or input_model .trainable_variables
109+ for var in trainable_vars :
108110 jacobian = tape .jacobian (loss , var , experimental_use_pfor = False )
109111 reduction_axes = tf .range (1 , len (jacobian .shape ))
110112 sqr_norm = tf .reduce_sum (tf .square (jacobian ), axis = reduction_axes )
@@ -124,6 +126,7 @@ def get_computed_and_true_norms(
124126 x_input : tf .Tensor ,
125127 rng_seed : int = 777 ,
126128 registry : layer_registry .LayerRegistry = None ,
129+ partial : bool = False ,
127130) -> Tuple [tf .Tensor , tf .Tensor ]:
128131 """Obtains the true and computed gradient norms for a model and batch input.
129132
@@ -147,6 +150,8 @@ def get_computed_and_true_norms(
147150 x_input: `tf.Tensor` inputs to be tested.
148151 rng_seed: An `int` used to initialize model weights.
149152 registry: A `layer_registry.LayerRegistry` instance.
153+ partial: Whether to compute the gradient norm with respect to a partial set
154+ of varibles. If True, only consider the variables in the first layer.
150155
151156 Returns:
152157 A `tuple` `(computed_norm, true_norms)`. The first element contains the
@@ -163,6 +168,13 @@ def get_computed_and_true_norms(
163168 ),
164169 run_eagerly = is_eager ,
165170 )
171+ trainable_vars = None
172+ if partial :
173+ # Gets the first layer with variables.
174+ for l in model .layers :
175+ trainable_vars = l .trainable_variables
176+ if trainable_vars :
177+ break
166178 y_pred = model (x_input )
167179 y_batch = tf .ones_like (y_pred )
168180 tf .keras .utils .set_random_seed (rng_seed )
@@ -173,10 +185,16 @@ def get_computed_and_true_norms(
173185 layer_registry = registry ,
174186 per_example_loss_fn = per_example_loss_fn ,
175187 num_microbatches = num_microbatches ,
188+ trainable_vars = trainable_vars ,
176189 )
177190 tf .keras .utils .set_random_seed (rng_seed )
178191 true_norms = compute_true_gradient_norms (
179- model , x_input , y_batch , per_example_loss_fn , num_microbatches
192+ model ,
193+ x_input ,
194+ y_batch ,
195+ per_example_loss_fn ,
196+ num_microbatches ,
197+ trainable_vars = trainable_vars ,
180198 )
181199 return (computed_norms , true_norms )
182200
@@ -360,10 +378,11 @@ class ClipGradsDenseLayerTest(tf.test.TestCase, parameterized.TestCase):
360378 model_name = list (get_dense_model_generators ().keys ()),
361379 layer_name = list (get_dense_layer_generators ().keys ()),
362380 input_dim = [4 ],
363- output_dim = [1 , 2 ],
381+ output_dim = [2 ],
364382 per_example_loss_fn = [None , test_loss_fn ],
365383 num_microbatches = [None , 1 , 2 ],
366384 is_eager = [True , False ],
385+ partial = [True , False ],
367386 )
368387 def test_gradient_norms_on_various_models (
369388 self ,
@@ -374,6 +393,7 @@ def test_gradient_norms_on_various_models(
374393 per_example_loss_fn ,
375394 num_microbatches ,
376395 is_eager ,
396+ partial ,
377397 ):
378398 model_generator = get_dense_model_generators ()[model_name ]
379399 layer_generator = get_dense_layer_generators ()[layer_name ]
@@ -399,6 +419,7 @@ def test_gradient_norms_on_various_models(
399419 is_eager ,
400420 x_input ,
401421 registry = default_registry ,
422+ partial = partial ,
402423 )
403424 self .assertAllClose (computed_norms , true_norms , rtol = 1e-3 , atol = 1e-2 )
404425
@@ -436,8 +457,9 @@ class ClipGradsEmbeddingLayerTest(tf.test.TestCase, parameterized.TestCase):
436457 model_name = list (get_embedding_model_generators ().keys ()),
437458 output_dim = [2 ],
438459 per_example_loss_fn = [None , test_loss_fn ],
439- num_microbatches = [None , 1 , 2 ],
440- is_eager = [True ],
460+ num_microbatches = [None , 2 ],
461+ is_eager = [True , False ],
462+ partial = [True , False ],
441463 )
442464 def test_gradient_norms_on_various_models (
443465 self ,
@@ -447,6 +469,7 @@ def test_gradient_norms_on_various_models(
447469 per_example_loss_fn ,
448470 num_microbatches ,
449471 is_eager ,
472+ partial ,
450473 ):
451474 if (
452475 num_microbatches is not None
@@ -470,18 +493,20 @@ def test_gradient_norms_on_various_models(
470493 is_eager = is_eager ,
471494 x_input = x_batch ,
472495 registry = default_registry ,
496+ partial = partial ,
473497 )
474498 self .assertAllClose (computed_norms , true_norms , rtol = 1e-3 , atol = 1e-2 )
475499
476500
477501class ClipGradsCustomLayerTest (tf .test .TestCase , parameterized .TestCase ):
478502
479503 @parameterized .product (
480- input_dim = [1 , 2 ],
481- output_dim = [1 , 2 ],
504+ input_dim = [3 ],
505+ output_dim = [2 ],
482506 per_example_loss_fn = [None , test_loss_fn ],
483- num_microbatches = [None , 1 , 2 ],
507+ num_microbatches = [None , 2 ],
484508 is_eager = [True , False ],
509+ partial = [True , False ],
485510 )
486511 def test_gradient_norms_on_various_models (
487512 self ,
@@ -490,6 +515,7 @@ def test_gradient_norms_on_various_models(
490515 per_example_loss_fn ,
491516 num_microbatches ,
492517 is_eager ,
518+ partial ,
493519 ):
494520 registry = layer_registry .make_default_layer_registry ()
495521 registry .insert (DoubleDense , double_dense_layer_computation )
@@ -510,6 +536,7 @@ def test_gradient_norms_on_various_models(
510536 is_eager = is_eager ,
511537 x_input = x_batch ,
512538 registry = registry ,
539+ partial = partial ,
513540 )
514541 self .assertAllClose (computed_norms , true_norms , rtol = 1e-3 , atol = 1e-2 )
515542
0 commit comments