Skip to content

Commit 79e408a

Browse files
sayakpauldg845
andauthored
[tests] migrate test_hooks.py to pytest (#13242)
* move test_hooks.py to pytest * Create `TestHooks` generator on `torch_device` (#13871) Create TestHooks generator on device to avoid device mismatch errors --------- Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
1 parent 924cfb2 commit 79e408a

1 file changed

Lines changed: 29 additions & 32 deletions

File tree

tests/hooks/test_hooks.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# limitations under the License.
1414

1515
import gc
16-
import unittest
1716

17+
import pytest
1818
import torch
1919

2020
from diffusers.hooks import HookRegistry, ModelHook
@@ -134,20 +134,18 @@ def post_forward(self, module, output):
134134
return output
135135

136136

137-
class HookTests(unittest.TestCase):
137+
class TestHooks:
138138
in_features = 4
139139
hidden_features = 8
140140
out_features = 4
141141
num_layers = 2
142142

143-
def setUp(self):
143+
def setup_method(self):
144144
params = self.get_module_parameters()
145145
self.model = DummyModel(**params)
146146
self.model.to(torch_device)
147147

148-
def tearDown(self):
149-
super().tearDown()
150-
148+
def teardown_method(self):
151149
del self.model
152150
gc.collect()
153151
free_memory()
@@ -161,7 +159,7 @@ def get_module_parameters(self):
161159
}
162160

163161
def get_generator(self):
164-
return torch.manual_seed(0)
162+
return torch.Generator(device=torch_device).manual_seed(0)
165163

166164
def test_hook_registry(self):
167165
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -171,20 +169,20 @@ def test_hook_registry(self):
171169
registry_repr = repr(registry)
172170
expected_repr = "HookRegistry(\n (0) add_hook - AddHook\n (1) multiply_hook - MultiplyHook(value=2)\n)"
173171

174-
self.assertEqual(len(registry.hooks), 2)
175-
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
176-
self.assertEqual(registry_repr, expected_repr)
172+
assert len(registry.hooks) == 2
173+
assert registry._hook_order == ["add_hook", "multiply_hook"]
174+
assert registry_repr == expected_repr
177175

178176
registry.remove_hook("add_hook")
179177

180-
self.assertEqual(len(registry.hooks), 1)
181-
self.assertEqual(registry._hook_order, ["multiply_hook"])
178+
assert len(registry.hooks) == 1
179+
assert registry._hook_order == ["multiply_hook"]
182180

183181
def test_stateful_hook(self):
184182
registry = HookRegistry.check_if_exists_or_initialize(self.model)
185183
registry.register_hook(StatefulAddHook(1), "stateful_add_hook")
186184

187-
self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0)
185+
assert registry.hooks["stateful_add_hook"].increment == 0
188186

189187
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
190188
num_repeats = 3
@@ -194,13 +192,13 @@ def test_stateful_hook(self):
194192
if i == 0:
195193
output1 = result
196194

197-
self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats)
195+
assert registry.get_hook("stateful_add_hook").increment == num_repeats
198196

199197
registry.reset_stateful_hooks()
200198
output2 = self.model(input)
201199

202-
self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1)
203-
self.assertTrue(torch.allclose(output1, output2))
200+
assert registry.get_hook("stateful_add_hook").increment == 1
201+
assert torch.allclose(output1, output2)
204202

205203
def test_inference(self):
206204
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -218,40 +216,39 @@ def test_inference(self):
218216
new_input = input * 2 + 1
219217
output3 = self.model(new_input).mean().detach().cpu().item()
220218

221-
self.assertAlmostEqual(output1, output2, places=5)
222-
self.assertAlmostEqual(output1, output3, places=5)
223-
self.assertAlmostEqual(output2, output3, places=5)
219+
assert output1 == pytest.approx(output2, abs=5e-6)
220+
assert output1 == pytest.approx(output3, abs=5e-6)
221+
assert output2 == pytest.approx(output3, abs=5e-6)
224222

225223
def test_skip_layer_hook(self):
226224
registry = HookRegistry.check_if_exists_or_initialize(self.model)
227225
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
228226

229227
input = torch.zeros(1, 4, device=torch_device)
230228
output = self.model(input).mean().detach().cpu().item()
231-
self.assertEqual(output, 0.0)
229+
assert output == 0.0
232230

233231
registry.remove_hook("skip_layer_hook")
234232
registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook")
235233
output = self.model(input).mean().detach().cpu().item()
236-
self.assertNotEqual(output, 0.0)
234+
assert output != 0.0
237235

238236
def test_skip_layer_internal_block(self):
239237
registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1)
240238
input = torch.zeros(1, 4, device=torch_device)
241239

242240
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
243-
with self.assertRaises(RuntimeError) as cm:
241+
with pytest.raises(RuntimeError, match="mat1 and mat2 shapes cannot be multiplied"):
244242
self.model(input).mean().detach().cpu().item()
245-
self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception))
246243

247244
registry.remove_hook("skip_layer_hook")
248245
output = self.model(input).mean().detach().cpu().item()
249-
self.assertNotEqual(output, 0.0)
246+
assert output != 0.0
250247

251248
registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1])
252249
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
253250
output = self.model(input).mean().detach().cpu().item()
254-
self.assertNotEqual(output, 0.0)
251+
assert output != 0.0
255252

256253
def test_invocation_order_stateful_first(self):
257254
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -278,7 +275,7 @@ def test_invocation_order_stateful_first(self):
278275
.replace(" ", "")
279276
.replace("\n", "")
280277
)
281-
self.assertEqual(output, expected_invocation_order_log)
278+
assert output == expected_invocation_order_log
282279

283280
registry.remove_hook("add_hook")
284281
with CaptureLogger(logger) as cap_logger:
@@ -289,7 +286,7 @@ def test_invocation_order_stateful_first(self):
289286
.replace(" ", "")
290287
.replace("\n", "")
291288
)
292-
self.assertEqual(output, expected_invocation_order_log)
289+
assert output == expected_invocation_order_log
293290

294291
def test_invocation_order_stateful_middle(self):
295292
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -316,7 +313,7 @@ def test_invocation_order_stateful_middle(self):
316313
.replace(" ", "")
317314
.replace("\n", "")
318315
)
319-
self.assertEqual(output, expected_invocation_order_log)
316+
assert output == expected_invocation_order_log
320317

321318
registry.remove_hook("add_hook")
322319
with CaptureLogger(logger) as cap_logger:
@@ -327,7 +324,7 @@ def test_invocation_order_stateful_middle(self):
327324
.replace(" ", "")
328325
.replace("\n", "")
329326
)
330-
self.assertEqual(output, expected_invocation_order_log)
327+
assert output == expected_invocation_order_log
331328

332329
registry.remove_hook("add_hook_2")
333330
with CaptureLogger(logger) as cap_logger:
@@ -336,7 +333,7 @@ def test_invocation_order_stateful_middle(self):
336333
expected_invocation_order_log = (
337334
("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "")
338335
)
339-
self.assertEqual(output, expected_invocation_order_log)
336+
assert output == expected_invocation_order_log
340337

341338
def test_invocation_order_stateful_last(self):
342339
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -363,7 +360,7 @@ def test_invocation_order_stateful_last(self):
363360
.replace(" ", "")
364361
.replace("\n", "")
365362
)
366-
self.assertEqual(output, expected_invocation_order_log)
363+
assert output == expected_invocation_order_log
367364

368365
registry.remove_hook("add_hook")
369366
with CaptureLogger(logger) as cap_logger:
@@ -374,4 +371,4 @@ def test_invocation_order_stateful_last(self):
374371
.replace(" ", "")
375372
.replace("\n", "")
376373
)
377-
self.assertEqual(output, expected_invocation_order_log)
374+
assert output == expected_invocation_order_log

0 commit comments

Comments
 (0)