Skip to content

Commit 06ccde9

Browse files
Fix group-offloading bug (#13211)
* Implement synchronous onload for offloaded parameters Add fallback synchronous onload for conditionally-executed modules. * add test for new code path about group-offloading * Update tests/hooks/test_group_offloading.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * use unittest.skipIf and update the comment --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 8879824 commit 06ccde9

File tree

2 files changed

+135
-0
lines changed

2 files changed

+135
-0
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,17 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
307307
if self.group.onload_leader == module:
308308
if self.group.onload_self:
309309
self.group.onload_()
310+
else:
311+
# onload_self=False means this group relies on prefetching from a previous group.
312+
# However, for conditionally-executed modules (e.g. patch_short/patch_mid/patch_long in Helios),
313+
# the prefetch chain may not cover them if they were absent during the first forward pass
314+
# when the execution order was traced. In that case, their weights remain on offload_device,
315+
# so we fall back to a synchronous onload here.
316+
params = [p for m in self.group.modules for p in m.parameters()] + list(self.group.parameters)
317+
if params and params[0].device == self.group.offload_device:
318+
self.group.onload_()
319+
if self.group.stream is not None:
320+
self.group.stream.synchronize()
310321

311322
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
312323
if should_onload_next_group:

tests/hooks/test_group_offloading.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,3 +566,127 @@ def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=Non
566566
"layers_per_block": 1,
567567
}
568568
return init_dict
569+
570+
571+
# Model with conditionally-executed modules, simulating Helios patch_short/patch_mid/patch_long behavior.
572+
# These modules are only called when optional inputs are provided, which means the lazy prefetch
573+
# execution order tracer may not see them on the first forward pass. This can cause a device mismatch
574+
# on subsequent calls when the modules ARE invoked but their weights were never onloaded.
575+
# See: https://github.com/huggingface/diffusers/pull/13211
576+
class DummyModelWithConditionalModules(ModelMixin):
577+
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
578+
super().__init__()
579+
580+
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
581+
self.activation = torch.nn.ReLU()
582+
self.blocks = torch.nn.ModuleList(
583+
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
584+
)
585+
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
586+
587+
# These modules are only invoked when optional_input is not None.
588+
# Output dimension matches hidden_features so they can be added after linear_1.
589+
self.optional_proj_1 = torch.nn.Linear(in_features, hidden_features)
590+
self.optional_proj_2 = torch.nn.Linear(in_features, hidden_features)
591+
592+
def forward(self, x: torch.Tensor, optional_input: torch.Tensor | None = None) -> torch.Tensor:
593+
x = self.linear_1(x)
594+
x = self.activation(x)
595+
if optional_input is not None:
596+
# Add optional projections after linear_1 so dimensions match (both hidden_features)
597+
x = x + self.optional_proj_1(optional_input)
598+
x = x + self.optional_proj_2(optional_input)
599+
for block in self.blocks:
600+
x = block(x)
601+
x = self.linear_2(x)
602+
return x
603+
604+
605+
class ConditionalModuleGroupOffloadTests(GroupOffloadTests):
606+
"""Tests for conditionally-executed modules under group offloading with streams.
607+
608+
Regression tests for the case where a module is not executed during the first forward pass
609+
(when the lazy prefetch execution order is traced), but IS executed on subsequent passes.
610+
Without the fix, the weights of such modules remain on CPU while the input is on GPU,
611+
causing a RuntimeError about tensor device mismatch.
612+
"""
613+
614+
def get_model(self):
615+
torch.manual_seed(0)
616+
return DummyModelWithConditionalModules(
617+
in_features=self.in_features,
618+
hidden_features=self.hidden_features,
619+
out_features=self.out_features,
620+
num_layers=self.num_layers,
621+
)
622+
623+
@parameterized.expand([("leaf_level",), ("block_level",)])
624+
@unittest.skipIf(
625+
torch.device(torch_device).type not in ["cuda", "xpu"],
626+
"Test requires a CUDA or XPU device.",
627+
)
628+
def test_conditional_modules_with_stream(self, offload_type: str):
629+
"""Regression test: conditionally-executed modules must not cause device mismatch when using streams.
630+
631+
The model contains two optional Linear layers (optional_proj_1, optional_proj_2) that are only
632+
executed when `optional_input` is provided. This simulates modules like patch_short/patch_mid/
633+
patch_long in HeliosTransformer3DModel, which are only called when history latents are present.
634+
635+
When using streams, `LazyPrefetchGroupOffloadingHook` traces the execution order on the first
636+
forward pass and sets up a prefetch chain so each module pre-loads the next one's weights.
637+
Modules not executed during this tracing pass are excluded from the prefetch chain.
638+
639+
The bug: if a module was absent from the first (tracing) pass, its `onload_self` flag gets set
640+
to False (meaning "someone else will onload me"). But since it's not in the prefetch chain,
641+
nobody ever does — so its weights remain on CPU. When the module is eventually called in a
642+
subsequent pass, the input is on GPU but the weights are on CPU, causing a RuntimeError.
643+
644+
We therefore must invoke the model multiple times:
645+
1. First pass WITHOUT optional_input: triggers the lazy prefetch tracing. optional_proj_1/2
646+
are absent, so they are excluded from the prefetch chain.
647+
2. Second pass WITH optional_input: the regression case. Without the fix, this raises a
648+
RuntimeError because optional_proj_1/2 weights are still on CPU.
649+
3. Third pass WITHOUT optional_input: verifies the model remains stable after having seen
650+
both code paths.
651+
"""
652+
653+
model = self.get_model()
654+
model_ref = self.get_model()
655+
model_ref.load_state_dict(model.state_dict(), strict=True)
656+
model_ref.to(torch_device)
657+
658+
model.enable_group_offload(
659+
torch_device,
660+
offload_type=offload_type,
661+
num_blocks_per_group=1,
662+
use_stream=True,
663+
)
664+
665+
x = torch.randn(4, self.in_features).to(torch_device)
666+
optional_input = torch.randn(4, self.in_features).to(torch_device)
667+
668+
with torch.no_grad():
669+
# First forward pass WITHOUT optional_input — this is when the lazy prefetch
670+
# execution order is traced. optional_proj_1/2 are NOT in the traced order.
671+
out_ref_no_opt = model_ref(x, optional_input=None)
672+
out_no_opt = model(x, optional_input=None)
673+
self.assertTrue(
674+
torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5),
675+
f"[{offload_type}] Outputs do not match on first pass (no optional_input).",
676+
)
677+
678+
# Second forward pass WITH optional_input — optional_proj_1/2 ARE now called.
679+
out_ref_with_opt = model_ref(x, optional_input=optional_input)
680+
out_with_opt = model(x, optional_input=optional_input)
681+
self.assertTrue(
682+
torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5),
683+
f"[{offload_type}] Outputs do not match on second pass (with optional_input).",
684+
)
685+
686+
# Third pass again without optional_input — verify stable behavior.
687+
out_ref_no_opt2 = model_ref(x, optional_input=None)
688+
out_no_opt2 = model(x, optional_input=None)
689+
self.assertTrue(
690+
torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5),
691+
f"[{offload_type}] Outputs do not match on third pass (back to no optional_input).",
692+
)

0 commit comments

Comments
 (0)