Skip to content

Commit 2fa9b93

Browse files
sayakpauldg845
andauthored
[tests] migrate group offloading tests to pytest (#13234)
* migrate group offloading tests to pytest * fix tests. --------- Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
1 parent b95637a commit 2fa9b93

1 file changed

Lines changed: 73 additions & 71 deletions

File tree

tests/hooks/test_group_offloading.py

Lines changed: 73 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,32 @@
1414

1515
import contextlib
1616
import gc
17-
import unittest
17+
import logging
1818

19+
import pytest
1920
import torch
20-
from parameterized import parameterized
2121

2222
from diffusers import AutoencoderKL
2323
from diffusers.hooks import HookRegistry, ModelHook
2424
from diffusers.models import ModelMixin
2525
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26-
from diffusers.utils import get_logger
26+
from diffusers.utils import logging as diffusers_logging
2727
from diffusers.utils.import_utils import compare_versions
2828

29-
from ..testing_utils import (
29+
30+
@contextlib.contextmanager
31+
def _propagate_diffusers_logs():
32+
# The diffusers root logger sets `propagate = False`, so its records never reach the root
33+
# logger that pytest's `caplog` fixture hooks into. Temporarily enable propagation so the
34+
# emitted warnings can be captured, then restore the default.
35+
diffusers_logging.enable_propagation()
36+
try:
37+
yield
38+
finally:
39+
diffusers_logging.disable_propagation()
40+
41+
42+
from ..testing_utils import ( # noqa: E402
3043
backend_empty_cache,
3144
backend_max_memory_allocated,
3245
backend_reset_peak_memory_stats,
@@ -219,20 +232,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
219232

220233

221234
@require_torch_accelerator
222-
class GroupOffloadTests(unittest.TestCase):
235+
class TestGroupOffload:
223236
in_features = 64
224237
hidden_features = 256
225238
out_features = 64
226239
num_layers = 4
227240

228-
def setUp(self):
241+
def setup_method(self):
229242
with torch.no_grad():
230243
self.model = self.get_model()
231244
self.input = torch.randn((4, self.in_features)).to(torch_device)
232245

233-
def tearDown(self):
234-
super().tearDown()
235-
246+
def teardown_method(self):
236247
del self.model
237248
del self.input
238249
gc.collect()
@@ -248,18 +259,20 @@ def get_model(self):
248259
num_layers=self.num_layers,
249260
)
250261

262+
@pytest.mark.skipif(
263+
torch.device(torch_device).type not in ["cuda", "xpu"],
264+
reason="Test requires a CUDA or XPU device.",
265+
)
251266
def test_offloading_forward_pass(self):
252267
@torch.no_grad()
253268
def run_forward(model):
254269
gc.collect()
255270
backend_empty_cache(torch_device)
256271
backend_reset_peak_memory_stats(torch_device)
257-
self.assertTrue(
258-
all(
259-
module._diffusers_hook.get_hook("group_offloading") is not None
260-
for module in model.modules()
261-
if hasattr(module, "_diffusers_hook")
262-
)
272+
assert all(
273+
module._diffusers_hook.get_hook("group_offloading") is not None
274+
for module in model.modules()
275+
if hasattr(module, "_diffusers_hook")
263276
)
264277
model.eval()
265278
output = model(self.input)[0].cpu()
@@ -291,73 +304,72 @@ def run_forward(model):
291304
output_with_group_offloading5, mem5 = run_forward(model)
292305

293306
# Precision assertions - offloading should not impact the output
294-
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
295-
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
296-
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
297-
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
298-
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5))
307+
assert torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)
308+
assert torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)
309+
assert torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)
310+
assert torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)
311+
assert torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5)
299312

300313
# Memory assertions - offloading should reduce memory usage
301-
self.assertTrue(mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline)
314+
assert mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline
302315

303-
def test_warning_logged_if_group_offloaded_module_moved_to_accelerator(self):
316+
def test_warning_logged_if_group_offloaded_module_moved_to_accelerator(self, caplog):
304317
if torch.device(torch_device).type not in ["cuda", "xpu"]:
305318
return
306319
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
307-
logger = get_logger("diffusers.models.modeling_utils")
308-
logger.setLevel("INFO")
309-
with self.assertLogs(logger, level="WARNING") as cm:
320+
with _propagate_diffusers_logs(), caplog.at_level(logging.WARNING, logger="diffusers.models.modeling_utils"):
310321
self.model.to(torch_device)
311-
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
322+
assert f"The module '{self.model.__class__.__name__}' is group offloaded" in caplog.text
312323

313-
def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator(self):
324+
def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator(self, caplog):
314325
if torch.device(torch_device).type not in ["cuda", "xpu"]:
315326
return
316327
pipe = DummyPipeline(self.model)
317328
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
318-
logger = get_logger("diffusers.pipelines.pipeline_utils")
319-
logger.setLevel("INFO")
320-
with self.assertLogs(logger, level="WARNING") as cm:
329+
with (
330+
_propagate_diffusers_logs(),
331+
caplog.at_level(logging.WARNING, logger="diffusers.pipelines.pipeline_utils"),
332+
):
321333
pipe.to(torch_device)
322-
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
334+
assert f"The module '{self.model.__class__.__name__}' is group offloaded" in caplog.text
323335

324336
def test_error_raised_if_streams_used_and_no_accelerator_device(self):
325337
torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
326338
original_is_available = torch_accelerator_module.is_available
327339
torch_accelerator_module.is_available = lambda: False
328-
with self.assertRaises(ValueError):
340+
with pytest.raises(ValueError):
329341
self.model.enable_group_offload(
330342
onload_device=torch.device(torch_device), offload_type="leaf_level", use_stream=True
331343
)
332344
torch_accelerator_module.is_available = original_is_available
333345

334346
def test_error_raised_if_supports_group_offloading_false(self):
335347
self.model._supports_group_offloading = False
336-
with self.assertRaisesRegex(ValueError, "does not support group offloading"):
348+
with pytest.raises(ValueError, match="does not support group offloading"):
337349
self.model.enable_group_offload(onload_device=torch.device(torch_device))
338350

339351
def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self):
340352
pipe = DummyPipeline(self.model)
341353
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
342-
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
354+
with pytest.raises(ValueError, match="You are trying to apply model/sequential CPU offloading"):
343355
pipe.enable_model_cpu_offload()
344356

345357
def test_error_raised_if_sequential_offloading_applied_on_group_offloaded_module(self):
346358
pipe = DummyPipeline(self.model)
347359
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
348-
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
360+
with pytest.raises(ValueError, match="You are trying to apply model/sequential CPU offloading"):
349361
pipe.enable_sequential_cpu_offload()
350362

351363
def test_error_raised_if_group_offloading_applied_on_model_offloaded_module(self):
352364
pipe = DummyPipeline(self.model)
353365
pipe.enable_model_cpu_offload()
354-
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
366+
with pytest.raises(ValueError, match="Cannot apply group offloading"):
355367
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
356368

357369
def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module(self):
358370
pipe = DummyPipeline(self.model)
359371
pipe.enable_sequential_cpu_offload()
360-
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
372+
with pytest.raises(ValueError, match="Cannot apply group offloading"):
361373
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
362374

363375
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
@@ -376,12 +388,12 @@ def test_block_level_stream_with_invocation_order_different_from_initialization_
376388
context = contextlib.nullcontext()
377389
if compare_versions("diffusers", "<=", "0.33.0"):
378390
# Will raise a device mismatch RuntimeError mentioning weights are on CPU but input is on device
379-
context = self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device")
391+
context = pytest.raises(RuntimeError, match="Expected all tensors to be on the same device")
380392

381393
with context:
382394
model(self.input)
383395

384-
@parameterized.expand([("block_level",), ("leaf_level",)])
396+
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
385397
def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str):
386398
if torch.device(torch_device).type not in ["cuda", "xpu"]:
387399
return
@@ -407,14 +419,14 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
407419

408420
out_ref = model_ref(x)
409421
out = model(x)
410-
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
422+
assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match."
411423

412424
num_repeats = 2
413425
for i in range(num_repeats):
414426
out_ref = model_ref(x)
415427
out = model(x)
416428

417-
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations.")
429+
assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations."
418430

419431
for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()):
420432
assert ref_name == name
@@ -428,9 +440,7 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
428440
absdiff = diff.abs()
429441
absmax = absdiff.max().item()
430442
cumulated_absmax += absmax
431-
self.assertLess(
432-
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
433-
)
443+
assert cumulated_absmax < 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
434444

435445
def test_vae_like_model_without_streams(self):
436446
"""Test VAE-like model with block-level offloading but without streams."""
@@ -452,9 +462,7 @@ def test_vae_like_model_without_streams(self):
452462
out_ref = model_ref(x).sample
453463
out = model(x).sample
454464

455-
self.assertTrue(
456-
torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
457-
)
465+
assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
458466

459467
def test_model_with_only_standalone_layers(self):
460468
"""Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading."""
@@ -475,12 +483,11 @@ def test_model_with_only_standalone_layers(self):
475483
for i in range(2):
476484
out_ref = model_ref(x)
477485
out = model(x)
478-
self.assertTrue(
479-
torch.allclose(out_ref, out, atol=1e-5),
480-
f"Outputs do not match at iteration {i} for model with standalone layers.",
486+
assert torch.allclose(out_ref, out, atol=1e-5), (
487+
f"Outputs do not match at iteration {i} for model with standalone layers."
481488
)
482489

483-
@parameterized.expand([("block_level",), ("leaf_level",)])
490+
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
484491
def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str):
485492
"""Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading."""
486493
if torch.device(torch_device).type not in ["cuda", "xpu"]:
@@ -501,9 +508,8 @@ def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str)
501508
out_ref = model_ref(x).sample
502509
out = model(x).sample
503510

504-
self.assertTrue(
505-
torch.allclose(out_ref, out, atol=1e-5),
506-
f"Outputs do not match for standalone Conv layers with {offload_type}.",
511+
assert torch.allclose(out_ref, out, atol=1e-5), (
512+
f"Outputs do not match for standalone Conv layers with {offload_type}."
507513
)
508514

509515
def test_multiple_invocations_with_vae_like_model(self):
@@ -526,7 +532,7 @@ def test_multiple_invocations_with_vae_like_model(self):
526532
for i in range(2):
527533
out_ref = model_ref(x).sample
528534
out = model(x).sample
529-
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.")
535+
assert torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}."
530536

531537
def test_nested_container_parameters_offloading(self):
532538
"""Test that parameters from non-computational layers in nested containers are handled correctly."""
@@ -547,9 +553,8 @@ def test_nested_container_parameters_offloading(self):
547553
for i in range(2):
548554
out_ref = model_ref(x)
549555
out = model(x)
550-
self.assertTrue(
551-
torch.allclose(out_ref, out, atol=1e-5),
552-
f"Outputs do not match at iteration {i} for nested parameters.",
556+
assert torch.allclose(out_ref, out, atol=1e-5), (
557+
f"Outputs do not match at iteration {i} for nested parameters."
553558
)
554559

555560
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
@@ -602,7 +607,7 @@ def forward(self, x: torch.Tensor, optional_input: torch.Tensor | None = None) -
602607
return x
603608

604609

605-
class ConditionalModuleGroupOffloadTests(GroupOffloadTests):
610+
class TestConditionalModuleGroupOffload(TestGroupOffload):
606611
"""Tests for conditionally-executed modules under group offloading with streams.
607612
608613
Regression tests for the case where a module is not executed during the first forward pass
@@ -620,10 +625,10 @@ def get_model(self):
620625
num_layers=self.num_layers,
621626
)
622627

623-
@parameterized.expand([("leaf_level",), ("block_level",)])
624-
@unittest.skipIf(
628+
@pytest.mark.parametrize("offload_type", ["leaf_level", "block_level"])
629+
@pytest.mark.skipif(
625630
torch.device(torch_device).type not in ["cuda", "xpu"],
626-
"Test requires a CUDA or XPU device.",
631+
reason="Test requires a CUDA or XPU device.",
627632
)
628633
def test_conditional_modules_with_stream(self, offload_type: str):
629634
"""Regression test: conditionally-executed modules must not cause device mismatch when using streams.
@@ -670,23 +675,20 @@ def test_conditional_modules_with_stream(self, offload_type: str):
670675
# execution order is traced. optional_proj_1/2 are NOT in the traced order.
671676
out_ref_no_opt = model_ref(x, optional_input=None)
672677
out_no_opt = model(x, optional_input=None)
673-
self.assertTrue(
674-
torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5),
675-
f"[{offload_type}] Outputs do not match on first pass (no optional_input).",
678+
assert torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5), (
679+
f"[{offload_type}] Outputs do not match on first pass (no optional_input)."
676680
)
677681

678682
# Second forward pass WITH optional_input — optional_proj_1/2 ARE now called.
679683
out_ref_with_opt = model_ref(x, optional_input=optional_input)
680684
out_with_opt = model(x, optional_input=optional_input)
681-
self.assertTrue(
682-
torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5),
683-
f"[{offload_type}] Outputs do not match on second pass (with optional_input).",
685+
assert torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5), (
686+
f"[{offload_type}] Outputs do not match on second pass (with optional_input)."
684687
)
685688

686689
# Third pass again without optional_input — verify stable behavior.
687690
out_ref_no_opt2 = model_ref(x, optional_input=None)
688691
out_no_opt2 = model(x, optional_input=None)
689-
self.assertTrue(
690-
torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5),
691-
f"[{offload_type}] Outputs do not match on third pass (back to no optional_input).",
692+
assert torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5), (
693+
f"[{offload_type}] Outputs do not match on third pass (back to no optional_input)."
692694
)

0 commit comments

Comments
 (0)