Commit 43f54f3
committed
fix(attention_dispatch): use lse.ndim check instead of torch version for ROCm compat
## Problem
Two call sites in `TemplatedRingAttention.forward` and
`_ulysses_context_parallel_attention` condition the LSE unsqueeze on
`is_torch_version("<", "2.9.0")`, following the assumption introduced in
#12693 that torch>=2.9 always returns LSE with shape [B,H,S,1] (4D) from
`_scaled_dot_product_flash_attention`.
That assumption holds on **NVIDIA CUDA** but not on **AMD ROCm**: on ROCm 7.x
with torch>=2.9, `aten._scaled_dot_product_flash_attention` (backed by
AOTriton / hipBLASLt) still returns LSE as [B,H,S] (3D). The downstream ring
merge then broadcasts a 3D tensor against a 4D `out` tensor, raising:
RuntimeError: The size of tensor a (24) must match the size of tensor b
(128) at non-singleton dimension 3
This blocks ring / context-parallel attention entirely on AMD hardware with
any torch>=2.9 build.
## Fix
Replace the `is_torch_version` guard with `lse.ndim < out.ndim` (resp.
`lse.ndim == 3`). This is backend-agnostic: on CUDA torch>=2.9 where LSE is
already 4D the condition is False and behaviour is unchanged; on ROCm where
LSE is 3D the unsqueeze happens regardless of torch version.
The same logical fix is applied to both affected call sites:
- `TemplatedRingAttention.forward` (ring merge loop)
- `_ulysses_context_parallel_attention` (Ulysses all-to-all path)
## Tested on
- 5× AMD RX 7800 XT (gfx1101), ROCm 7.1, PyTorch 2.7, diffusers main
- Ring attention + context parallel with FLUX.1-dev, 4-GPU tensor parallel
- CUDA regression: none (ndim guard is equivalent to version guard on CUDA)1 parent c67685b commit 43f54f3
1 file changed
Lines changed: 7 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1916 | 1916 | | |
1917 | 1917 | | |
1918 | 1918 | | |
1919 | | - | |
| 1919 | + | |
| 1920 | + | |
| 1921 | + | |
| 1922 | + | |
1920 | 1923 | | |
1921 | 1924 | | |
1922 | 1925 | | |
| |||
2206 | 2209 | | |
2207 | 2210 | | |
2208 | 2211 | | |
2209 | | - | |
| 2212 | + | |
| 2213 | + | |
| 2214 | + | |
2210 | 2215 | | |
2211 | 2216 | | |
2212 | 2217 | | |
| |||
0 commit comments