Skip to content

Commit 096c8e7

Browse files
Simplify the structure of embedding, layer_normalization, and multi_head_attention tests.
PiperOrigin-RevId: 908775823
1 parent a7fd013 commit 096c8e7

11 files changed

Lines changed: 62 additions & 152 deletions

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ py_library(
8585
py_test(
8686
name = "embedding_test",
8787
srcs = ["embedding_test.py"],
88+
env = {
89+
"TESTBRIDGE_TEST_ONLY": "GradNormTest.",
90+
},
8891
shard_count = 12,
8992
deps = [
9093
":dense",
@@ -107,6 +110,9 @@ py_library(
107110
py_test(
108111
name = "nlp_on_device_embedding_test",
109112
srcs = ["nlp_on_device_embedding_test.py"],
113+
env = {
114+
"TESTBRIDGE_TEST_ONLY": "GradNormTest.",
115+
},
110116
shard_count = 6,
111117
deps = [
112118
":dense",
@@ -129,6 +135,9 @@ py_library(
129135
py_test(
130136
name = "nlp_position_embedding_test",
131137
srcs = ["nlp_position_embedding_test.py"],
138+
env = {
139+
"TESTBRIDGE_TEST_ONLY": "GradNormTest.",
140+
},
132141
shard_count = 6,
133142
deps = [
134143
":dense",
@@ -151,6 +160,9 @@ py_library(
151160
py_test(
152161
name = "layer_normalization_test",
153162
srcs = ["layer_normalization_test.py"],
163+
env = {
164+
"TESTBRIDGE_TEST_ONLY": "GradNormTest.",
165+
},
154166
shard_count = 8,
155167
deps = [
156168
":dense",
@@ -174,6 +186,9 @@ py_library(
174186
py_test(
175187
name = "multi_head_attention_test",
176188
srcs = ["multi_head_attention_test.py"],
189+
env = {
190+
"TESTBRIDGE_TEST_ONLY": "GradNormTest.",
191+
},
177192
shard_count = 8,
178193
deps = [
179194
":dense",

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,5 +151,15 @@ def test_op(x_batch):
151151
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
152152

153153

154+
class GradNormTpuTest(GradNormTest):
155+
156+
def setUp(self):
157+
tf.config.experimental.disable_mlir_bridge()
158+
super().setUp()
159+
self.strategy = common_test_utils.create_tpu_strategy()
160+
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
161+
self.using_tpu = True
162+
163+
154164
if __name__ == '__main__':
155165
tf.test.main()

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/embedding_tpu_test.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,5 +153,14 @@ def test_op(x_batch):
153153
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
154154

155155

156+
class GradNormTpuTest(GradNormTest):
157+
158+
def setUp(self):
159+
super().setUp()
160+
self.strategy = common_test_utils.create_tpu_strategy()
161+
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
162+
self.using_tpu = True
163+
164+
156165
if __name__ == '__main__':
157166
tf.test.main()

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_tpu_test.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/multi_head_attention_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,5 +371,14 @@ def test_op(query_batch, value_batch, key_batch, mask_batch):
371371
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)
372372

373373

374+
class GradNormTpuTest(GradNormTest):
375+
376+
def setUp(self):
377+
super().setUp()
378+
self.strategy = common_test_utils.create_tpu_strategy()
379+
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
380+
self.using_tpu = True
381+
382+
374383
if __name__ == '__main__':
375384
tf.test.main()

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/multi_head_attention_tpu_test.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,5 +145,15 @@ def test_op(x_batch):
145145
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
146146

147147

148+
class GradNormTpuTest(GradNormTest):
149+
150+
def setUp(self):
151+
tf.config.experimental.disable_mlir_bridge()
152+
super().setUp()
153+
self.strategy = common_test_utils.create_tpu_strategy()
154+
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
155+
self.using_tpu = True
156+
157+
148158
if __name__ == '__main__':
149159
tf.test.main()

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_on_device_embedding_tpu_test.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/nlp_position_embedding_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,5 +136,14 @@ def test_op():
136136
self.assertAllClose(computed_norms, true_norms, rtol=1e-3, atol=1e-2)
137137

138138

139+
class GradNormTpuTest(GradNormTest):
140+
141+
def setUp(self):
142+
super().setUp()
143+
self.strategy = common_test_utils.create_tpu_strategy()
144+
self.assertIn('TPU', self.strategy.extended.worker_devices[0])
145+
self.using_tpu = True
146+
147+
139148
if __name__ == '__main__':
140149
tf.test.main()

0 commit comments

Comments
 (0)