Skip to content

Commit 03035d5

Browse files
Simplify the structure of the dense_test.
PiperOrigin-RevId: 895460903
1 parent 96aa81e commit 03035d5

3 files changed

Lines changed: 12 additions & 31 deletions

File tree

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,12 @@ py_test(
5959
name = "dense_test",
6060
size = "large",
6161
srcs = ["dense_test.py"],
62+
env = {
63+
"TESTBRIDGE_TEST_ONLY": "GradNormTest.",
64+
},
6265
shard_count = 12,
6366
deps = [
6467
":dense",
65-
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
6668
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
6769
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
6870
],

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,5 +142,14 @@ def test_op(x_batch, weight_batch):
142142
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
143143

144144

145+
class GradNormTpuTest(GradNormTest):
146+
147+
def setUp(self):
148+
super(GradNormTest, self).setUp()
149+
self.strategy = common_test_utils.create_tpu_strategy()
150+
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
151+
self.using_tpu = True
152+
153+
145154
if __name__ == '__main__':
146155
tf.test.main()

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/dense_tpu_test.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

0 commit comments

Comments
 (0)