Skip to content

Commit a7fd013

Browse files
Simplify the structure of einsum_dense_test.
PiperOrigin-RevId: 908333752
1 parent 03035d5 commit a7fd013

3 files changed

Lines changed: 12 additions & 30 deletions

File tree

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ py_test(
3535
name = "einsum_dense_test",
3636
size = "large",
3737
srcs = ["einsum_dense_test.py"],
38+
env = {
39+
"TESTBRIDGE_TEST_ONLY": "GradNormTest.",
40+
},
3841
shard_count = 12,
3942
deps = [
4043
":dense",

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_dense_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,5 +167,14 @@ def test_op(x):
167167
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
168168

169169

170+
class GradNormTpuTest(GradNormTest):
171+
172+
def setUp(self):
173+
super().setUp()
174+
self.strategy = common_test_utils.create_tpu_strategy()
175+
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
176+
self.using_tpu = True
177+
178+
170179
if __name__ == '__main__':
171180
tf.test.main()

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/einsum_dense_tpu_test.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

0 commit comments

Comments
 (0)