Skip to content

Commit 6b6a92c

Browse files
committed
fa4 support
Signed-off-by: Xin Yao <xiny@nvidia.com>
1 parent 38f62a0 commit 6b6a92c

3 files changed

Lines changed: 217 additions & 44 deletions

File tree

tests/pytorch/attention/test_attention.py

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
)
5454

5555
_current_file = pathlib.Path(__file__).resolve()
56-
sys.path.append(str(_current_file.parent.parent))
56+
sys.path = [str(_current_file.parent.parent)] + sys.path
5757
from utils import (
5858
reset_rng_states,
5959
compare_and_assert,
@@ -362,6 +362,141 @@ def test_dpa_num_splits(dtype, model_configs, model):
362362
)
363363

364364

365+
# ==============================
366+
# Flash Attention 4 (FA4) tests
367+
# ==============================
368+
369+
model_configs_fa4_base = {
370+
# test: ModelConfig(b, sq, hq, dqk)
371+
# Standard head dims
372+
"fa4_base_1": ModelConfig(4, 128, 16, 64),
373+
"fa4_base_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
374+
"fa4_base_3": ModelConfig(2, 1024, 8, 96, attn_mask_type="causal"),
375+
# GQA
376+
"fa4_gqa_1": ModelConfig(2, 1024, 32, 128, num_gqa_groups=8, attn_mask_type="causal"),
377+
"fa4_gqa_2": ModelConfig(2, 1024, 16, 128, num_gqa_groups=1, attn_mask_type="causal"),
378+
# num_splits
379+
"fa4_splits_1": ModelConfig(2, 2048, 24, 128, num_splits=2),
380+
"fa4_splits_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, num_splits=4),
381+
}
382+
383+
384+
@pytest.mark.skipif(
385+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
386+
)
387+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
388+
@pytest.mark.parametrize("dtype", param_types_lean)
389+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_base])
390+
@pytest.mark.parametrize("model", model_configs_fa4_base.keys())
391+
def test_dpa_fa4_base(dtype, model_configs, model):
392+
"""Test DotProductAttention with FA4: base configs, extended head dims, GQA, num_splits"""
393+
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
394+
395+
396+
model_configs_fa4_mla = {
397+
# test: ModelConfig(b, sq, hq, dqk, head_dim_v=dv)
398+
"fa4_mla_1": ModelConfig(4, 128, 16, 128, head_dim_v=64),
399+
"fa4_mla_2": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),
400+
"fa4_mla_3": ModelConfig(2, 1024, 16, 96, head_dim_v=64, attn_mask_type="causal"),
401+
# dqk=128, dv=96: FA4 SM100 backward has dK_reduce_ncol misalignment for dV;
402+
# the backend filter should reject FA4 and fall back to another backend.
403+
"fa4_mla_4": ModelConfig(2, 1024, 16, 128, head_dim_v=96, attn_mask_type="causal"),
404+
# DeepSeek-style MLA: dqk=192, dv=128 (supported on SM100 as special case)
405+
"fa4_mla_deepseek": ModelConfig(2, 1024, 16, 192, head_dim_v=128, attn_mask_type="causal"),
406+
}
407+
408+
409+
@pytest.mark.skipif(
410+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
411+
)
412+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
413+
@pytest.mark.parametrize("dtype", param_types_lean)
414+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_mla])
415+
@pytest.mark.parametrize("model", model_configs_fa4_mla.keys())
416+
def test_dpa_fa4_mla(dtype, model_configs, model):
417+
"""Test DotProductAttention with FA4: MLA (head_dim_qk != head_dim_v)"""
418+
test_dot_product_attention(
419+
dtype, model_configs, model, False, True, "bshd_bshd_bshd", False, False
420+
)
421+
422+
423+
model_configs_fa4_swa = {
424+
# test: ModelConfig(b, sq, hq, dqk, window_size=(left, right))
425+
"fa4_swa_1": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal", window_size=(128, 0)),
426+
"fa4_swa_2": ModelConfig(2, 2048, 24, 64, attn_mask_type="causal", window_size=(64, 0)),
427+
"fa4_swa_3": ModelConfig(
428+
2, 2048, 16, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(256, 0)
429+
),
430+
"fa4_swa_4": ModelConfig(
431+
2, 2048, 16, 128, attn_mask_type="padding_causal", window_size=(128, 0)
432+
),
433+
}
434+
435+
436+
@pytest.mark.skipif(
437+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
438+
)
439+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
440+
@pytest.mark.parametrize("dtype", param_types_lean)
441+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_swa])
442+
@pytest.mark.parametrize("model", model_configs_fa4_swa.keys())
443+
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "bshd_bshd_bshd"])
444+
def test_dpa_fa4_sliding_window(dtype, model_configs, model, qkv_layout):
445+
"""Test DotProductAttention with FA4: sliding window attention"""
446+
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False)
447+
448+
449+
model_configs_fa4_varlen = {
450+
# test: ModelConfig(b, sq, hq, dqk)
451+
"fa4_varlen_1": ModelConfig(4, 128, 16, 64, attn_mask_type="padding"),
452+
"fa4_varlen_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="padding_causal"),
453+
"fa4_varlen_3": ModelConfig(
454+
2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"
455+
),
456+
"fa4_varlen_4": ModelConfig(
457+
2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"
458+
),
459+
}
460+
461+
462+
@pytest.mark.skipif(
463+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
464+
)
465+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
466+
@pytest.mark.parametrize("dtype", param_types_lean)
467+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_varlen])
468+
@pytest.mark.parametrize("model", model_configs_fa4_varlen.keys())
469+
@pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "bshd_bshd_bshd"])
470+
def test_dpa_fa4_varlen(dtype, model_configs, model, qkv_layout):
471+
"""Test DotProductAttention with FA4: variable-length sequences (varlen/thd)"""
472+
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
473+
474+
475+
model_configs_fa4_mask = {
476+
# test: ModelConfig(b, sq, hq, dqk)
477+
"fa4_mask_no_mask": ModelConfig(2, 1024, 16, 128),
478+
"fa4_mask_causal": ModelConfig(2, 1024, 16, 128, attn_mask_type="causal"),
479+
"fa4_mask_causal_br": ModelConfig(2, 1024, 16, 128, attn_mask_type="causal_bottom_right"),
480+
"fa4_mask_padding": ModelConfig(2, 1024, 16, 128, attn_mask_type="padding"),
481+
"fa4_mask_padding_causal": ModelConfig(2, 1024, 16, 128, attn_mask_type="padding_causal"),
482+
"fa4_mask_padding_causal_br": ModelConfig(
483+
2, 1024, 16, 128, attn_mask_type="padding_causal_bottom_right"
484+
),
485+
}
486+
487+
488+
@pytest.mark.skipif(
489+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
490+
)
491+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
492+
@pytest.mark.parametrize("dtype", param_types_lean)
493+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_mask])
494+
@pytest.mark.parametrize("model", model_configs_fa4_mask.keys())
495+
def test_dpa_fa4_mask(dtype, model_configs, model):
496+
"""Test DotProductAttention with FA4: various attention mask types"""
497+
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
498+
499+
365500
model_configs_softmax = {
366501
# test: ModelConfig(b, sq, hq, dqk)
367502
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,19 @@
156156

157157
# Try to import Flash Attention v4
158158
try:
159-
fa_utils.fa4_version = PkgVersion(get_pkg_version("flash-attn-cute"))
159+
fa_utils.fa4_version = PkgVersion(get_pkg_version("flash-attn-4"))
160160
except PackageNotFoundError:
161161
flash_attn_func_v4 = None
162162
flash_attn_varlen_func_v4 = None
163-
flash_attn_with_kvcache_v4 = None
163+
flash_attn_combine_v4 = None
164164
_flash_attn_fwd_v4 = None
165165
_flash_attn_bwd_v4 = None
166-
# pass # only print warning if use_flash_attention_4 = True in get_attention_backend
167166
else:
168167
from flash_attn.cute.interface import flash_attn_func as flash_attn_func_v4
169168
from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_v4
169+
from flash_attn.cute.interface import flash_attn_combine as flash_attn_combine_v4
170170
from flash_attn.cute.interface import _flash_attn_fwd as _flash_attn_fwd_v4
171171
from flash_attn.cute.interface import _flash_attn_bwd as _flash_attn_bwd_v4
172-
# flash_attn_with_kvcache_v4 = None # FA4 does not support kvcache yet
173172
fa_utils.set_flash_attention_4_params()
174173

175174
# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16
@@ -942,7 +941,7 @@ def forward(
942941
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"):
943942
use_flash_attn_3 = True
944943
use_flash_attn_4 = False
945-
if flash_attention_backend is not None and str(flash_attention_backend).endswith("cute"):
944+
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("4.0.0b"):
946945
use_flash_attn_4 = True
947946
if context_parallel and all(
948947
not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
@@ -996,6 +995,9 @@ def forward(
996995
# | | thd + padding
997996
# | flash_attn_with_kvcache | KV cache (not-paged/paged), i.e.
998997
# | | bshd/sbhd/thd + padding
998+
# FA v4 | flash_attn_func | bshd/sbhd + not padding
999+
# | flash_attn_varlen_func | bshd/sbhd + padding
1000+
# | | thd + padding
9991001
fa_optional_forward_args_thd = []
10001002
if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
10011003
func = None
@@ -1006,24 +1008,33 @@ def forward(
10061008
else:
10071009
func = flash_attn_func
10081010
else:
1009-
if not use_flash_attn_3:
1011+
if use_flash_attn_4:
1012+
func = flash_attn_varlen_func_v4
1013+
elif not use_flash_attn_3:
10101014
func = flash_attn_varlen_func
10111015
elif inference_params is None:
10121016
func = flash_attn_varlen_func_v3 # pylint: disable=possibly-used-before-assignment
10131017
else:
10141018
func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment
1015-
if not use_flash_attn_3 or inference_params is None:
1019+
if not use_flash_attn_4 and (
1020+
not use_flash_attn_3 or inference_params is None
1021+
):
10161022
fa_optional_forward_args_thd.append(cu_seqlens_q)
10171023
fa_optional_forward_args_thd.append(cu_seqlens_kv)
10181024
fa_optional_forward_args_thd.append(max_seqlen_q)
10191025
fa_optional_forward_args_thd.append(max_seqlen_kv)
10201026
if use_flash_attn_4:
10211027
fa_4_optional_forward_kwargs = {
1022-
# "window_size": window_size,
1028+
"window_size": window_size,
10231029
"num_splits": num_splits,
10241030
}
10251031
if inference_params is None:
10261032
fa_4_optional_forward_kwargs["deterministic"] = self.deterministic
1033+
if func is flash_attn_varlen_func_v4:
1034+
fa_4_optional_forward_kwargs["cu_seqlens_q"] = cu_seqlens_q
1035+
fa_4_optional_forward_kwargs["cu_seqlens_k"] = cu_seqlens_kv
1036+
fa_4_optional_forward_kwargs["max_seqlen_q"] = max_seqlen_q
1037+
fa_4_optional_forward_kwargs["max_seqlen_k"] = max_seqlen_kv
10271038
output = func(
10281039
query_layer,
10291040
key_layer,

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,7 @@ class FlashAttentionUtils:
145145
fa4_version = PkgVersion("0")
146146
use_v4 = False
147147
v4_installation_steps = """\
148-
(1) git clone https://github.com/Dao-AILab/flash-attention.git
149-
(2) pip install flash-attention/flash_attn/cute"""
148+
(1) pip install flash-attn-4"""
150149
v4_warning_printed = False
151150

152151
@staticmethod
@@ -460,13 +459,10 @@ def get_attention_backend(
460459
if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed:
461460
logger.debug("Disabling FlashAttention 3 for compute capability != sm90")
462461
use_flash_attention_3 = False
463-
# TODO: Other compute capabilities support:
464-
# SM80: not enabled
465-
# SM90: has bugs
466-
# SM120: WIP
467-
if device_compute_capability != (10, 0):
462+
# FA4 supports SM80, SM90, SM100, SM120
463+
if device_compute_capability < (8, 0):
468464
if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed:
469-
logger.debug("Disabling FlashAttention 4 for compute capability != sm100")
465+
logger.debug("Disabling FlashAttention 4 for compute capability < sm80")
470466
use_flash_attention_4 = False
471467

472468
# Filter: Data type
@@ -588,7 +584,7 @@ def get_attention_backend(
588584
# Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 256
589585
# Flash v3 | FP16/BF16 | non-paged/paged | sm90 | bshd,sbhd,thd | >= 1
590586
# | FP8 | non-paged/paged | sm90 | thd | >= 1
591-
# Flash v4 | N/A | N/A | N/A | N/A | N/A
587+
# Flash v4 | FP16/BF16 | TODO | sm80+ | bshd,sbhd,thd | TODO
592588
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
593589
if inference_params is not None:
594590
# Temporarily disabling fused attention for kv caching for sm89 irrespective of cuDNN version
@@ -642,9 +638,6 @@ def get_attention_backend(
642638
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
643639
logger.debug("Disabling FlashAttention 2 as it does not support MLA.")
644640
use_flash_attention_2 = False
645-
if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed:
646-
logger.debug("Disabling FlashAttention 4 as it does not support MLA.")
647-
use_flash_attention_4 = False
648641

649642
qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "")
650643
if use_fused_attention and qkv_layout_group != "hd_hd_hd":
@@ -717,17 +710,50 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
717710
)
718711
use_flash_attention_3 = False
719712

720-
if (
721-
use_flash_attention_4
722-
and FlashAttentionUtils.v4_is_installed
723-
and (head_dim_qk != head_dim_v or head_dim_qk not in [64, 96, 128])
724-
):
725-
logger.debug(
726-
"Disabling FlashAttention 4 due to unsupported head_dim_qk and head_dim_v. "
727-
"Supported: head_dim_qk == head_dim_v, head_dim_qk in [64, 96, 128]. "
728-
f"Found: head_dim_qk = {head_dim_qk}, head_dim_v = {head_dim_v}."
729-
)
730-
use_flash_attention_4 = False
713+
if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed:
714+
# FA4 head dimension support is architecture-dependent
715+
# (matches _validate_head_dims in flash_attn.cute.interface):
716+
# SM90: head_dim <= 256 and head_dim_v <= 256
717+
# SM100/110: head_dim <= 128 and head_dim_v <= 128,
718+
# OR DeepSeek MLA shape (head_dim=192, head_dim_v=128)
719+
# SM80/120: constrained by shared memory (~256 max in practice)
720+
_fa4_hdim_ok = True
721+
if device_compute_capability >= (10, 0) and device_compute_capability < (12, 0):
722+
_is_standard = head_dim_qk <= 128 and head_dim_v <= 128
723+
_is_deepseek = head_dim_qk == 192 and head_dim_v == 128
724+
_fa4_hdim_ok = _is_standard or _is_deepseek
725+
else:
726+
_fa4_hdim_ok = head_dim_qk <= 256 and head_dim_v <= 256
727+
if not _fa4_hdim_ok:
728+
logger.debug(
729+
"Disabling FlashAttention 4 due to unsupported head dimensions. "
730+
f"Found: head_dim_qk = {head_dim_qk}, head_dim_v = {head_dim_v}, "
731+
f"on sm{device_compute_capability[0] * 10 + device_compute_capability[1]}."
732+
)
733+
use_flash_attention_4 = False
734+
# Workaround: SM100 backward kernel bug when MLA + 2CTA (head_dim_qk >= 128).
735+
# FlashAttentionBackwardSm100 computes dK_reduce_ncol = gcd(32, tile_hdim // 2)
736+
# based on Q/K head_dim but reuses it for dV TMEM load atoms. When
737+
# (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are misaligned.
738+
# See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890.
739+
elif (
740+
_fa4_hdim_ok
741+
and is_training
742+
and head_dim_qk != head_dim_v
743+
and head_dim_qk >= 128
744+
and device_compute_capability >= (10, 0)
745+
and device_compute_capability < (12, 0)
746+
):
747+
_tile_hdim = math.ceil(head_dim_qk / 16) * 16
748+
_tile_hdimv = math.ceil(head_dim_v / 16) * 16
749+
_dk_reduce_ncol = math.gcd(32, _tile_hdim // 2)
750+
if (_tile_hdimv // 2) % _dk_reduce_ncol != 0:
751+
logger.debug(
752+
"Disabling FlashAttention 4 for training due to SM100 backward kernel "
753+
"bug with MLA head dimensions (dK_reduce_ncol misalignment for dV). "
754+
f"Found: head_dim_qk = {head_dim_qk}, head_dim_v = {head_dim_v}."
755+
)
756+
use_flash_attention_4 = False
731757

732758
# Filter: QKV layout
733759
if qkv_format == "thd":
@@ -749,10 +775,6 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
749775
" not supported for compute capability = sm120"
750776
)
751777
use_fused_attention = False
752-
if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed:
753-
logger.debug("Disabling FlashAttention 4 for qkv_format = thd")
754-
use_flash_attention_4 = False
755-
756778
# Filter: Dropout
757779
if attention_dropout != 0.0:
758780
if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed:
@@ -816,6 +838,11 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
816838
"Disabling UnfusedDotProductAttention as it does not support context parallelism"
817839
)
818840
use_unfused_attention = False
841+
if context_parallel and use_flash_attention_4 and FlashAttentionUtils.v4_is_installed:
842+
logger.debug(
843+
"Disabling FlashAttention 4 as it does not support context parallelism yet"
844+
)
845+
use_flash_attention_4 = False
819846
if context_parallel and (
820847
use_flash_attention_2 or use_flash_attention_3 or use_flash_attention_4
821848
):
@@ -1228,10 +1255,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
12281255
flash_attention_backend = FlashAttentionUtils.version
12291256
if use_flash_attention_3:
12301257
flash_attention_backend = FlashAttentionUtils.fa3_version
1231-
# FA4 is released with the package name "flash-attn-cute" and version starting from 0.1.0
1232-
# We need to add the ".cute" suffix to the version number to distinguish.
12331258
if use_flash_attention_4:
1234-
flash_attention_backend = PkgVersion(f"{str(FlashAttentionUtils.fa4_version)}+cute")
1259+
flash_attention_backend = FlashAttentionUtils.fa4_version
12351260

12361261
logger.debug(
12371262
"Available backends = {FlashAttention=%s%s, FusedAttention=%s%s,"
@@ -1248,12 +1273,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
12481273
)
12491274

12501275
# Select FusedAttention for performance
1276+
# FA4 is preferred over FusedAttention when available
12511277
if use_flash_attention and use_fused_attention and device_compute_capability >= (9, 0):
1252-
logger.debug(
1253-
"Disabling FlashAttention to give FusedAttention preference on Hopper+ "
1254-
"for performance reasons"
1255-
)
1256-
use_flash_attention = False
1278+
if not use_flash_attention_4:
1279+
logger.debug(
1280+
"Disabling FlashAttention to give FusedAttention preference on Hopper+ "
1281+
"for performance reasons"
1282+
)
1283+
use_flash_attention = False
12571284

12581285
# Selected backend
12591286
if use_flash_attention:

0 commit comments

Comments
 (0)