Commit c0b37b1
committed
fix(group_offloading): gate default stream on transfer stream in _onload_from_memory
## Problem
`ModuleGroup._onload_from_memory` schedules async CPU→GPU tensor copies on a
dedicated transfer stream, but returns without making the default stream (on
which the module's forward pass runs) wait for those copies to finish.
On NVIDIA CUDA, implicit stream ordering and driver-level synchronization
generally prevent this race from manifesting. On **AMD ROCm** (tested on
gfx1101 / RX 7800 XT with ROCm 7.x), the race is reliable: the first matmul
in the freshly onloaded module executes before the async copies complete,
raising:
RuntimeError: Expected all tensors to be on the same device, but found at
least two devices, cuda:0 and cpu! (when checking argument for argument
mat2 in method wrapper_CUDA_mm)
This affects any pipeline that uses `enable_group_offload(use_stream=True)`,
including FLUX.1-dev with int8 group offloading on ROCm.
## Fix
After the `with context:` block, call `default_stream.wait_stream(self.stream)`
so the forward pass is gated on the completed transfers. A `stream.synchronize()`
fallback is included for backends that do not expose `wait_stream`.
On CUDA this call is a no-op when both streams are already synchronized,
so existing behaviour is preserved.
## Reproduction (ROCm)
```python
from diffusers import FluxPipeline
from diffusers.hooks import apply_group_offloading
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.to("cuda")
apply_group_offloading(pipe.transformer, offload_type="block_level",
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
use_stream=True, num_blocks_per_group=1)
pipe("test prompt", num_inference_steps=4)
# → RuntimeError: Expected all tensors to be on the same device … cpu vs cuda
# Fixed with this patch.
```
Tested on: 5× AMD RX 7800 XT (gfx1101), ROCm 7.1, PyTorch 2.7, diffusers main.
CUDA regression: none (wait_stream is a no-op when streams are synchronized).1 parent 2173c55 commit c0b37b1
1 file changed
Lines changed: 12 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
292 | 292 | | |
293 | 293 | | |
294 | 294 | | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
295 | 307 | | |
296 | 308 | | |
297 | 309 | | |
| |||
0 commit comments