Skip to content

Commit c67685b

Browse files
committed
fix(group_offloading): gate default stream on transfer stream in _onload_from_memory
## Problem `ModuleGroup._onload_from_memory` schedules async CPU→GPU tensor copies on a dedicated transfer stream, but returns without making the default stream (on which the module's forward pass runs) wait for those copies to finish. On NVIDIA CUDA, implicit stream ordering and driver-level synchronization generally prevent this race from manifesting. On **AMD ROCm** (tested on gfx1101 / RX 7800 XT with ROCm 7.x), the race is reliable: the first matmul in the freshly onloaded module executes before the async copies complete, raising: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm) This affects any pipeline that uses `enable_group_offload(use_stream=True)`, including FLUX.1-dev with int8 group offloading on ROCm. ## Fix After the `with context:` block, call `default_stream.wait_stream(self.stream)` so the forward pass is gated on the completed transfers. A `stream.synchronize()` fallback is included for backends that do not expose `wait_stream`. On CUDA this call is a no-op when both streams are already synchronized, so existing behaviour is preserved. ## Reproduction (ROCm) ```python from diffusers import FluxPipeline from diffusers.hooks import apply_group_offloading pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) pipe.to("cuda") apply_group_offloading(pipe.transformer, offload_type="block_level", offload_device=torch.device("cpu"), onload_device=torch.device("cuda"), use_stream=True, num_blocks_per_group=1) pipe("test prompt", num_inference_steps=4) # → RuntimeError: Expected all tensors to be on the same device … cpu vs cuda # Fixed with this patch. ``` Tested on: 5× AMD RX 7800 XT (gfx1101), ROCm 7.1, PyTorch 2.7, diffusers main. CUDA regression: none (wait_stream is a no-op when streams are synchronized).
1 parent c8c8401 commit c67685b

1 file changed

Lines changed: 12 additions & 0 deletions

File tree

src/diffusers/hooks/group_offloading.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,18 @@ def _onload_from_memory(self):
292292
else:
293293
self._process_tensors_from_modules(None)
294294

295+
# Gate the default stream on the transfer stream completing before the forward pass runs.
296+
# On CUDA, implicit stream ordering often masks this race; on AMD ROCm (gfx1xxx) the
297+
# first matmul can race ahead of the async CPU→GPU copies and raise a device-mismatch
298+
# error ("mat2 is on cpu") inside the first matmul of the loaded module.
299+
# `wait_stream` is a no-op when both handles refer to the same stream.
300+
if self.stream is not None:
301+
current_default = self._torch_accelerator_module.current_stream()
302+
if hasattr(current_default, "wait_stream"):
303+
current_default.wait_stream(self.stream)
304+
else:
305+
self.stream.synchronize()
306+
295307
def _offload_to_disk(self):
296308
self._check_disk_offload_torchao()
297309

0 commit comments

Comments
 (0)