1313# limitations under the License.
1414
1515import gc
16- import unittest
1716
17+ import pytest
1818import torch
1919
2020from diffusers .hooks import HookRegistry , ModelHook
@@ -134,20 +134,18 @@ def post_forward(self, module, output):
134134 return output
135135
136136
137- class HookTests ( unittest . TestCase ) :
137+ class TestHooks :
138138 in_features = 4
139139 hidden_features = 8
140140 out_features = 4
141141 num_layers = 2
142142
143- def setUp (self ):
143+ def setup_method (self ):
144144 params = self .get_module_parameters ()
145145 self .model = DummyModel (** params )
146146 self .model .to (torch_device )
147147
148- def tearDown (self ):
149- super ().tearDown ()
150-
148+ def teardown_method (self ):
151149 del self .model
152150 gc .collect ()
153151 free_memory ()
@@ -161,7 +159,7 @@ def get_module_parameters(self):
161159 }
162160
163161 def get_generator (self ):
164- return torch .manual_seed (0 )
162+ return torch .Generator ( device = torch_device ). manual_seed (0 )
165163
166164 def test_hook_registry (self ):
167165 registry = HookRegistry .check_if_exists_or_initialize (self .model )
@@ -171,20 +169,20 @@ def test_hook_registry(self):
171169 registry_repr = repr (registry )
172170 expected_repr = "HookRegistry(\n (0) add_hook - AddHook\n (1) multiply_hook - MultiplyHook(value=2)\n )"
173171
174- self . assertEqual ( len (registry .hooks ), 2 )
175- self . assertEqual ( registry ._hook_order , ["add_hook" , "multiply_hook" ])
176- self . assertEqual ( registry_repr , expected_repr )
172+ assert len (registry .hooks ) == 2
173+ assert registry ._hook_order == ["add_hook" , "multiply_hook" ]
174+ assert registry_repr == expected_repr
177175
178176 registry .remove_hook ("add_hook" )
179177
180- self . assertEqual ( len (registry .hooks ), 1 )
181- self . assertEqual ( registry ._hook_order , ["multiply_hook" ])
178+ assert len (registry .hooks ) == 1
179+ assert registry ._hook_order == ["multiply_hook" ]
182180
183181 def test_stateful_hook (self ):
184182 registry = HookRegistry .check_if_exists_or_initialize (self .model )
185183 registry .register_hook (StatefulAddHook (1 ), "stateful_add_hook" )
186184
187- self . assertEqual ( registry .hooks ["stateful_add_hook" ].increment , 0 )
185+ assert registry .hooks ["stateful_add_hook" ].increment == 0
188186
189187 input = torch .randn (1 , 4 , device = torch_device , generator = self .get_generator ())
190188 num_repeats = 3
@@ -194,13 +192,13 @@ def test_stateful_hook(self):
194192 if i == 0 :
195193 output1 = result
196194
197- self . assertEqual ( registry .get_hook ("stateful_add_hook" ).increment , num_repeats )
195+ assert registry .get_hook ("stateful_add_hook" ).increment == num_repeats
198196
199197 registry .reset_stateful_hooks ()
200198 output2 = self .model (input )
201199
202- self . assertEqual ( registry .get_hook ("stateful_add_hook" ).increment , 1 )
203- self . assertTrue ( torch .allclose (output1 , output2 ) )
200+ assert registry .get_hook ("stateful_add_hook" ).increment == 1
201+ assert torch .allclose (output1 , output2 )
204202
205203 def test_inference (self ):
206204 registry = HookRegistry .check_if_exists_or_initialize (self .model )
@@ -218,40 +216,39 @@ def test_inference(self):
218216 new_input = input * 2 + 1
219217 output3 = self .model (new_input ).mean ().detach ().cpu ().item ()
220218
221- self . assertAlmostEqual ( output1 , output2 , places = 5 )
222- self . assertAlmostEqual ( output1 , output3 , places = 5 )
223- self . assertAlmostEqual ( output2 , output3 , places = 5 )
219+ assert output1 == pytest . approx ( output2 , abs = 5e-6 )
220+ assert output1 == pytest . approx ( output3 , abs = 5e-6 )
221+ assert output2 == pytest . approx ( output3 , abs = 5e-6 )
224222
225223 def test_skip_layer_hook (self ):
226224 registry = HookRegistry .check_if_exists_or_initialize (self .model )
227225 registry .register_hook (SkipLayerHook (skip_layer = True ), "skip_layer_hook" )
228226
229227 input = torch .zeros (1 , 4 , device = torch_device )
230228 output = self .model (input ).mean ().detach ().cpu ().item ()
231- self . assertEqual ( output , 0.0 )
229+ assert output == 0.0
232230
233231 registry .remove_hook ("skip_layer_hook" )
234232 registry .register_hook (SkipLayerHook (skip_layer = False ), "skip_layer_hook" )
235233 output = self .model (input ).mean ().detach ().cpu ().item ()
236- self . assertNotEqual ( output , 0.0 )
234+ assert output != 0.0
237235
238236 def test_skip_layer_internal_block (self ):
239237 registry = HookRegistry .check_if_exists_or_initialize (self .model .linear_1 )
240238 input = torch .zeros (1 , 4 , device = torch_device )
241239
242240 registry .register_hook (SkipLayerHook (skip_layer = True ), "skip_layer_hook" )
243- with self . assertRaises (RuntimeError ) as cm :
241+ with pytest . raises (RuntimeError , match = "mat1 and mat2 shapes cannot be multiplied" ) :
244242 self .model (input ).mean ().detach ().cpu ().item ()
245- self .assertIn ("mat1 and mat2 shapes cannot be multiplied" , str (cm .exception ))
246243
247244 registry .remove_hook ("skip_layer_hook" )
248245 output = self .model (input ).mean ().detach ().cpu ().item ()
249- self . assertNotEqual ( output , 0.0 )
246+ assert output != 0.0
250247
251248 registry = HookRegistry .check_if_exists_or_initialize (self .model .blocks [1 ])
252249 registry .register_hook (SkipLayerHook (skip_layer = True ), "skip_layer_hook" )
253250 output = self .model (input ).mean ().detach ().cpu ().item ()
254- self . assertNotEqual ( output , 0.0 )
251+ assert output != 0.0
255252
256253 def test_invocation_order_stateful_first (self ):
257254 registry = HookRegistry .check_if_exists_or_initialize (self .model )
@@ -278,7 +275,7 @@ def test_invocation_order_stateful_first(self):
278275 .replace (" " , "" )
279276 .replace ("\n " , "" )
280277 )
281- self . assertEqual ( output , expected_invocation_order_log )
278+ assert output == expected_invocation_order_log
282279
283280 registry .remove_hook ("add_hook" )
284281 with CaptureLogger (logger ) as cap_logger :
@@ -289,7 +286,7 @@ def test_invocation_order_stateful_first(self):
289286 .replace (" " , "" )
290287 .replace ("\n " , "" )
291288 )
292- self . assertEqual ( output , expected_invocation_order_log )
289+ assert output == expected_invocation_order_log
293290
294291 def test_invocation_order_stateful_middle (self ):
295292 registry = HookRegistry .check_if_exists_or_initialize (self .model )
@@ -316,7 +313,7 @@ def test_invocation_order_stateful_middle(self):
316313 .replace (" " , "" )
317314 .replace ("\n " , "" )
318315 )
319- self . assertEqual ( output , expected_invocation_order_log )
316+ assert output == expected_invocation_order_log
320317
321318 registry .remove_hook ("add_hook" )
322319 with CaptureLogger (logger ) as cap_logger :
@@ -327,7 +324,7 @@ def test_invocation_order_stateful_middle(self):
327324 .replace (" " , "" )
328325 .replace ("\n " , "" )
329326 )
330- self . assertEqual ( output , expected_invocation_order_log )
327+ assert output == expected_invocation_order_log
331328
332329 registry .remove_hook ("add_hook_2" )
333330 with CaptureLogger (logger ) as cap_logger :
@@ -336,7 +333,7 @@ def test_invocation_order_stateful_middle(self):
336333 expected_invocation_order_log = (
337334 ("MultiplyHook pre_forward\n MultiplyHook post_forward\n " ).replace (" " , "" ).replace ("\n " , "" )
338335 )
339- self . assertEqual ( output , expected_invocation_order_log )
336+ assert output == expected_invocation_order_log
340337
341338 def test_invocation_order_stateful_last (self ):
342339 registry = HookRegistry .check_if_exists_or_initialize (self .model )
@@ -363,7 +360,7 @@ def test_invocation_order_stateful_last(self):
363360 .replace (" " , "" )
364361 .replace ("\n " , "" )
365362 )
366- self . assertEqual ( output , expected_invocation_order_log )
363+ assert output == expected_invocation_order_log
367364
368365 registry .remove_hook ("add_hook" )
369366 with CaptureLogger (logger ) as cap_logger :
@@ -374,4 +371,4 @@ def test_invocation_order_stateful_last(self):
374371 .replace (" " , "" )
375372 .replace ("\n " , "" )
376373 )
377- self . assertEqual ( output , expected_invocation_order_log )
374+ assert output == expected_invocation_order_log
0 commit comments