1414
1515import contextlib
1616import gc
17- import unittest
17+ import logging
1818
19+ import pytest
1920import torch
20- from parameterized import parameterized
2121
2222from diffusers import AutoencoderKL
2323from diffusers .hooks import HookRegistry , ModelHook
2424from diffusers .models import ModelMixin
2525from diffusers .pipelines .pipeline_utils import DiffusionPipeline
26- from diffusers .utils import get_logger
26+ from diffusers .utils import logging as diffusers_logging
2727from diffusers .utils .import_utils import compare_versions
2828
29- from ..testing_utils import (
29+
30+ @contextlib .contextmanager
31+ def _propagate_diffusers_logs ():
32+ # The diffusers root logger sets `propagate = False`, so its records never reach the root
33+ # logger that pytest's `caplog` fixture hooks into. Temporarily enable propagation so the
34+ # emitted warnings can be captured, then restore the default.
35+ diffusers_logging .enable_propagation ()
36+ try :
37+ yield
38+ finally :
39+ diffusers_logging .disable_propagation ()
40+
41+
42+ from ..testing_utils import ( # noqa: E402
3043 backend_empty_cache ,
3144 backend_max_memory_allocated ,
3245 backend_reset_peak_memory_stats ,
@@ -219,20 +232,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
219232
220233
221234@require_torch_accelerator
222- class GroupOffloadTests ( unittest . TestCase ) :
235+ class TestGroupOffload :
223236 in_features = 64
224237 hidden_features = 256
225238 out_features = 64
226239 num_layers = 4
227240
228- def setUp (self ):
241+ def setup_method (self ):
229242 with torch .no_grad ():
230243 self .model = self .get_model ()
231244 self .input = torch .randn ((4 , self .in_features )).to (torch_device )
232245
233- def tearDown (self ):
234- super ().tearDown ()
235-
246+ def teardown_method (self ):
236247 del self .model
237248 del self .input
238249 gc .collect ()
@@ -248,18 +259,20 @@ def get_model(self):
248259 num_layers = self .num_layers ,
249260 )
250261
262+ @pytest .mark .skipif (
263+ torch .device (torch_device ).type not in ["cuda" , "xpu" ],
264+ reason = "Test requires a CUDA or XPU device." ,
265+ )
251266 def test_offloading_forward_pass (self ):
252267 @torch .no_grad ()
253268 def run_forward (model ):
254269 gc .collect ()
255270 backend_empty_cache (torch_device )
256271 backend_reset_peak_memory_stats (torch_device )
257- self .assertTrue (
258- all (
259- module ._diffusers_hook .get_hook ("group_offloading" ) is not None
260- for module in model .modules ()
261- if hasattr (module , "_diffusers_hook" )
262- )
272+ assert all (
273+ module ._diffusers_hook .get_hook ("group_offloading" ) is not None
274+ for module in model .modules ()
275+ if hasattr (module , "_diffusers_hook" )
263276 )
264277 model .eval ()
265278 output = model (self .input )[0 ].cpu ()
@@ -291,73 +304,72 @@ def run_forward(model):
291304 output_with_group_offloading5 , mem5 = run_forward (model )
292305
293306 # Precision assertions - offloading should not impact the output
294- self . assertTrue ( torch .allclose (output_without_group_offloading , output_with_group_offloading1 , atol = 1e-5 ) )
295- self . assertTrue ( torch .allclose (output_without_group_offloading , output_with_group_offloading2 , atol = 1e-5 ) )
296- self . assertTrue ( torch .allclose (output_without_group_offloading , output_with_group_offloading3 , atol = 1e-5 ) )
297- self . assertTrue ( torch .allclose (output_without_group_offloading , output_with_group_offloading4 , atol = 1e-5 ) )
298- self . assertTrue ( torch .allclose (output_without_group_offloading , output_with_group_offloading5 , atol = 1e-5 ) )
307+ assert torch .allclose (output_without_group_offloading , output_with_group_offloading1 , atol = 1e-5 )
308+ assert torch .allclose (output_without_group_offloading , output_with_group_offloading2 , atol = 1e-5 )
309+ assert torch .allclose (output_without_group_offloading , output_with_group_offloading3 , atol = 1e-5 )
310+ assert torch .allclose (output_without_group_offloading , output_with_group_offloading4 , atol = 1e-5 )
311+ assert torch .allclose (output_without_group_offloading , output_with_group_offloading5 , atol = 1e-5 )
299312
300313 # Memory assertions - offloading should reduce memory usage
301- self . assertTrue ( mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline )
314+ assert mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline
302315
303- def test_warning_logged_if_group_offloaded_module_moved_to_accelerator (self ):
316+ def test_warning_logged_if_group_offloaded_module_moved_to_accelerator (self , caplog ):
304317 if torch .device (torch_device ).type not in ["cuda" , "xpu" ]:
305318 return
306319 self .model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 3 )
307- logger = get_logger ("diffusers.models.modeling_utils" )
308- logger .setLevel ("INFO" )
309- with self .assertLogs (logger , level = "WARNING" ) as cm :
320+ with _propagate_diffusers_logs (), caplog .at_level (logging .WARNING , logger = "diffusers.models.modeling_utils" ):
310321 self .model .to (torch_device )
311- self . assertIn ( f"The module '{ self .model .__class__ .__name__ } ' is group offloaded" , cm . output [ 0 ])
322+ assert f"The module '{ self .model .__class__ .__name__ } ' is group offloaded" in caplog . text
312323
313- def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator (self ):
324+ def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator (self , caplog ):
314325 if torch .device (torch_device ).type not in ["cuda" , "xpu" ]:
315326 return
316327 pipe = DummyPipeline (self .model )
317328 self .model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 3 )
318- logger = get_logger ("diffusers.pipelines.pipeline_utils" )
319- logger .setLevel ("INFO" )
320- with self .assertLogs (logger , level = "WARNING" ) as cm :
329+ with (
330+ _propagate_diffusers_logs (),
331+ caplog .at_level (logging .WARNING , logger = "diffusers.pipelines.pipeline_utils" ),
332+ ):
321333 pipe .to (torch_device )
322- self . assertIn ( f"The module '{ self .model .__class__ .__name__ } ' is group offloaded" , cm . output [ 0 ])
334+ assert f"The module '{ self .model .__class__ .__name__ } ' is group offloaded" in caplog . text
323335
324336 def test_error_raised_if_streams_used_and_no_accelerator_device (self ):
325337 torch_accelerator_module = getattr (torch , torch_device , torch .cuda )
326338 original_is_available = torch_accelerator_module .is_available
327339 torch_accelerator_module .is_available = lambda : False
328- with self . assertRaises (ValueError ):
340+ with pytest . raises (ValueError ):
329341 self .model .enable_group_offload (
330342 onload_device = torch .device (torch_device ), offload_type = "leaf_level" , use_stream = True
331343 )
332344 torch_accelerator_module .is_available = original_is_available
333345
334346 def test_error_raised_if_supports_group_offloading_false (self ):
335347 self .model ._supports_group_offloading = False
336- with self . assertRaisesRegex (ValueError , "does not support group offloading" ):
348+ with pytest . raises (ValueError , match = "does not support group offloading" ):
337349 self .model .enable_group_offload (onload_device = torch .device (torch_device ))
338350
339351 def test_error_raised_if_model_offloading_applied_on_group_offloaded_module (self ):
340352 pipe = DummyPipeline (self .model )
341353 pipe .model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 3 )
342- with self . assertRaisesRegex (ValueError , "You are trying to apply model/sequential CPU offloading" ):
354+ with pytest . raises (ValueError , match = "You are trying to apply model/sequential CPU offloading" ):
343355 pipe .enable_model_cpu_offload ()
344356
345357 def test_error_raised_if_sequential_offloading_applied_on_group_offloaded_module (self ):
346358 pipe = DummyPipeline (self .model )
347359 pipe .model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 3 )
348- with self . assertRaisesRegex (ValueError , "You are trying to apply model/sequential CPU offloading" ):
360+ with pytest . raises (ValueError , match = "You are trying to apply model/sequential CPU offloading" ):
349361 pipe .enable_sequential_cpu_offload ()
350362
351363 def test_error_raised_if_group_offloading_applied_on_model_offloaded_module (self ):
352364 pipe = DummyPipeline (self .model )
353365 pipe .enable_model_cpu_offload ()
354- with self . assertRaisesRegex (ValueError , "Cannot apply group offloading" ):
366+ with pytest . raises (ValueError , match = "Cannot apply group offloading" ):
355367 pipe .model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 3 )
356368
357369 def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module (self ):
358370 pipe = DummyPipeline (self .model )
359371 pipe .enable_sequential_cpu_offload ()
360- with self . assertRaisesRegex (ValueError , "Cannot apply group offloading" ):
372+ with pytest . raises (ValueError , match = "Cannot apply group offloading" ):
361373 pipe .model .enable_group_offload (torch_device , offload_type = "block_level" , num_blocks_per_group = 3 )
362374
363375 def test_block_level_stream_with_invocation_order_different_from_initialization_order (self ):
@@ -376,12 +388,12 @@ def test_block_level_stream_with_invocation_order_different_from_initialization_
376388 context = contextlib .nullcontext ()
377389 if compare_versions ("diffusers" , "<=" , "0.33.0" ):
378390 # Will raise a device mismatch RuntimeError mentioning weights are on CPU but input is on device
379- context = self . assertRaisesRegex (RuntimeError , "Expected all tensors to be on the same device" )
391+ context = pytest . raises (RuntimeError , match = "Expected all tensors to be on the same device" )
380392
381393 with context :
382394 model (self .input )
383395
384- @parameterized . expand ([( "block_level" ,), ( "leaf_level" ,) ])
396+ @pytest . mark . parametrize ( "offload_type" , [ "block_level" , "leaf_level" ])
385397 def test_block_level_offloading_with_parameter_only_module_group (self , offload_type : str ):
386398 if torch .device (torch_device ).type not in ["cuda" , "xpu" ]:
387399 return
@@ -407,14 +419,14 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
407419
408420 out_ref = model_ref (x )
409421 out = model (x )
410- self . assertTrue ( torch .allclose (out_ref , out , atol = 1e-5 ), "Outputs do not match." )
422+ assert torch .allclose (out_ref , out , atol = 1e-5 ), "Outputs do not match."
411423
412424 num_repeats = 2
413425 for i in range (num_repeats ):
414426 out_ref = model_ref (x )
415427 out = model (x )
416428
417- self . assertTrue ( torch .allclose (out_ref , out , atol = 1e-5 ), "Outputs do not match after multiple invocations." )
429+ assert torch .allclose (out_ref , out , atol = 1e-5 ), "Outputs do not match after multiple invocations."
418430
419431 for (ref_name , ref_module ), (name , module ) in zip (model_ref .named_modules (), model .named_modules ()):
420432 assert ref_name == name
@@ -428,9 +440,7 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
428440 absdiff = diff .abs ()
429441 absmax = absdiff .max ().item ()
430442 cumulated_absmax += absmax
431- self .assertLess (
432- cumulated_absmax , 1e-5 , f"Output differences for { name } exceeded threshold: { cumulated_absmax :.5f} "
433- )
443+ assert cumulated_absmax < 1e-5 , f"Output differences for { name } exceeded threshold: { cumulated_absmax :.5f} "
434444
435445 def test_vae_like_model_without_streams (self ):
436446 """Test VAE-like model with block-level offloading but without streams."""
@@ -452,9 +462,7 @@ def test_vae_like_model_without_streams(self):
452462 out_ref = model_ref (x ).sample
453463 out = model (x ).sample
454464
455- self .assertTrue (
456- torch .allclose (out_ref , out , atol = 1e-5 ), "Outputs do not match for VAE-like model without streams."
457- )
465+ assert torch .allclose (out_ref , out , atol = 1e-5 ), "Outputs do not match for VAE-like model without streams."
458466
459467 def test_model_with_only_standalone_layers (self ):
460468 """Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading."""
@@ -475,12 +483,11 @@ def test_model_with_only_standalone_layers(self):
475483 for i in range (2 ):
476484 out_ref = model_ref (x )
477485 out = model (x )
478- self .assertTrue (
479- torch .allclose (out_ref , out , atol = 1e-5 ),
480- f"Outputs do not match at iteration { i } for model with standalone layers." ,
486+ assert torch .allclose (out_ref , out , atol = 1e-5 ), (
487+ f"Outputs do not match at iteration { i } for model with standalone layers."
481488 )
482489
483- @parameterized . expand ([( "block_level" ,), ( "leaf_level" ,) ])
490+ @pytest . mark . parametrize ( "offload_type" , [ "block_level" , "leaf_level" ])
484491 def test_standalone_conv_layers_with_both_offload_types (self , offload_type : str ):
485492 """Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading."""
486493 if torch .device (torch_device ).type not in ["cuda" , "xpu" ]:
@@ -501,9 +508,8 @@ def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str)
501508 out_ref = model_ref (x ).sample
502509 out = model (x ).sample
503510
504- self .assertTrue (
505- torch .allclose (out_ref , out , atol = 1e-5 ),
506- f"Outputs do not match for standalone Conv layers with { offload_type } ." ,
511+ assert torch .allclose (out_ref , out , atol = 1e-5 ), (
512+ f"Outputs do not match for standalone Conv layers with { offload_type } ."
507513 )
508514
509515 def test_multiple_invocations_with_vae_like_model (self ):
@@ -526,7 +532,7 @@ def test_multiple_invocations_with_vae_like_model(self):
526532 for i in range (2 ):
527533 out_ref = model_ref (x ).sample
528534 out = model (x ).sample
529- self . assertTrue ( torch .allclose (out_ref , out , atol = 1e-5 ), f"Outputs do not match at iteration { i } ." )
535+ assert torch .allclose (out_ref , out , atol = 1e-5 ), f"Outputs do not match at iteration { i } ."
530536
531537 def test_nested_container_parameters_offloading (self ):
532538 """Test that parameters from non-computational layers in nested containers are handled correctly."""
@@ -547,9 +553,8 @@ def test_nested_container_parameters_offloading(self):
547553 for i in range (2 ):
548554 out_ref = model_ref (x )
549555 out = model (x )
550- self .assertTrue (
551- torch .allclose (out_ref , out , atol = 1e-5 ),
552- f"Outputs do not match at iteration { i } for nested parameters." ,
556+ assert torch .allclose (out_ref , out , atol = 1e-5 ), (
557+ f"Outputs do not match at iteration { i } for nested parameters."
553558 )
554559
555560 def get_autoencoder_kl_config (self , block_out_channels = None , norm_num_groups = None ):
@@ -602,7 +607,7 @@ def forward(self, x: torch.Tensor, optional_input: torch.Tensor | None = None) -
602607 return x
603608
604609
605- class ConditionalModuleGroupOffloadTests ( GroupOffloadTests ):
610+ class TestConditionalModuleGroupOffload ( TestGroupOffload ):
606611 """Tests for conditionally-executed modules under group offloading with streams.
607612
608613 Regression tests for the case where a module is not executed during the first forward pass
@@ -620,10 +625,10 @@ def get_model(self):
620625 num_layers = self .num_layers ,
621626 )
622627
623- @parameterized . expand ([( "leaf_level" ,), ( "block_level" ,) ])
624- @unittest . skipIf (
628+ @pytest . mark . parametrize ( "offload_type" , [ "leaf_level" , "block_level" ])
629+ @pytest . mark . skipif (
625630 torch .device (torch_device ).type not in ["cuda" , "xpu" ],
626- "Test requires a CUDA or XPU device." ,
631+ reason = "Test requires a CUDA or XPU device." ,
627632 )
628633 def test_conditional_modules_with_stream (self , offload_type : str ):
629634 """Regression test: conditionally-executed modules must not cause device mismatch when using streams.
@@ -670,23 +675,20 @@ def test_conditional_modules_with_stream(self, offload_type: str):
670675 # execution order is traced. optional_proj_1/2 are NOT in the traced order.
671676 out_ref_no_opt = model_ref (x , optional_input = None )
672677 out_no_opt = model (x , optional_input = None )
673- self .assertTrue (
674- torch .allclose (out_ref_no_opt , out_no_opt , atol = 1e-5 ),
675- f"[{ offload_type } ] Outputs do not match on first pass (no optional_input)." ,
678+ assert torch .allclose (out_ref_no_opt , out_no_opt , atol = 1e-5 ), (
679+ f"[{ offload_type } ] Outputs do not match on first pass (no optional_input)."
676680 )
677681
678682 # Second forward pass WITH optional_input — optional_proj_1/2 ARE now called.
679683 out_ref_with_opt = model_ref (x , optional_input = optional_input )
680684 out_with_opt = model (x , optional_input = optional_input )
681- self .assertTrue (
682- torch .allclose (out_ref_with_opt , out_with_opt , atol = 1e-5 ),
683- f"[{ offload_type } ] Outputs do not match on second pass (with optional_input)." ,
685+ assert torch .allclose (out_ref_with_opt , out_with_opt , atol = 1e-5 ), (
686+ f"[{ offload_type } ] Outputs do not match on second pass (with optional_input)."
684687 )
685688
686689 # Third pass again without optional_input — verify stable behavior.
687690 out_ref_no_opt2 = model_ref (x , optional_input = None )
688691 out_no_opt2 = model (x , optional_input = None )
689- self .assertTrue (
690- torch .allclose (out_ref_no_opt2 , out_no_opt2 , atol = 1e-5 ),
691- f"[{ offload_type } ] Outputs do not match on third pass (back to no optional_input)." ,
692+ assert torch .allclose (out_ref_no_opt2 , out_no_opt2 , atol = 1e-5 ), (
693+ f"[{ offload_type } ] Outputs do not match on third pass (back to no optional_input)."
692694 )
0 commit comments