Skip to content

Commit 0b35834

Browse files
authored
[core] fa4 support. (#13280)
* start fa4 support. * up * specify minimum version
1 parent 522b523 commit 0b35834

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

docs/source/en/optimization/attention_backends.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ Refer to the table below for a complete list of available attention backends and
143143
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
144144
| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
145145
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
146+
| `flash_4_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-4 |
146147
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
147148
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
148149
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |

src/diffusers/models/attention_dispatch.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ class AttentionBackendName(str, Enum):
229229
FLASH_HUB = "flash_hub"
230230
FLASH_VARLEN = "flash_varlen"
231231
FLASH_VARLEN_HUB = "flash_varlen_hub"
232+
FLASH_4_HUB = "flash_4_hub"
232233
_FLASH_3 = "_flash_3"
233234
_FLASH_VARLEN_3 = "_flash_varlen_3"
234235
_FLASH_3_HUB = "_flash_3_hub"
@@ -358,6 +359,11 @@ class _HubKernelConfig:
358359
function_attr="sageattn",
359360
version=1,
360361
),
362+
AttentionBackendName.FLASH_4_HUB: _HubKernelConfig(
363+
repo_id="kernels-staging/flash-attn4",
364+
function_attr="flash_attn_func",
365+
version=0,
366+
),
361367
}
362368

363369

@@ -521,6 +527,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
521527
AttentionBackendName._FLASH_3_HUB,
522528
AttentionBackendName._FLASH_3_VARLEN_HUB,
523529
AttentionBackendName.SAGE_HUB,
530+
AttentionBackendName.FLASH_4_HUB,
524531
]:
525532
if not is_kernels_available():
526533
raise RuntimeError(
@@ -531,6 +538,11 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
531538
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
532539
)
533540

541+
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
542+
raise RuntimeError(
543+
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
544+
)
545+
534546
elif backend == AttentionBackendName.AITER:
535547
if not _CAN_USE_AITER_ATTN:
536548
raise RuntimeError(
@@ -2676,6 +2688,37 @@ def _flash_attention_3_varlen_hub(
26762688
return (out, lse) if return_lse else out
26772689

26782690

2691+
@_AttentionBackendRegistry.register(
2692+
AttentionBackendName.FLASH_4_HUB,
2693+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
2694+
supports_context_parallel=False,
2695+
)
2696+
def _flash_attention_4_hub(
2697+
query: torch.Tensor,
2698+
key: torch.Tensor,
2699+
value: torch.Tensor,
2700+
attn_mask: torch.Tensor | None = None,
2701+
scale: float | None = None,
2702+
is_causal: bool = False,
2703+
return_lse: bool = False,
2704+
_parallel_config: "ParallelConfig" | None = None,
2705+
) -> torch.Tensor:
2706+
if attn_mask is not None:
2707+
raise ValueError("`attn_mask` is not supported for flash-attn 4.")
2708+
2709+
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn
2710+
out = func(
2711+
q=query,
2712+
k=key,
2713+
v=value,
2714+
softmax_scale=scale,
2715+
causal=is_causal,
2716+
)
2717+
if isinstance(out, tuple):
2718+
return (out[0], out[1]) if return_lse else out[0]
2719+
return out
2720+
2721+
26792722
@_AttentionBackendRegistry.register(
26802723
AttentionBackendName._FLASH_VARLEN_3,
26812724
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],

0 commit comments

Comments
 (0)