Skip to content

Commit 540b375

Browse files
authored
[PyTorch] Add FA4 Support (NVIDIA#2432)
* add fa4 support Signed-off-by: Xin Yao <xiny@nvidia.com> * comment out unused import for cp Signed-off-by: Xin Yao <xiny@nvidia.com> * fix lint Signed-off-by: Xin Yao <xiny@nvidia.com> * install fa4 in L3 test Signed-off-by: Xin Yao <xiny@nvidia.com> * fix sm90 Signed-off-by: Xin Yao <xiny@nvidia.com> --------- Signed-off-by: Xin Yao <xiny@nvidia.com>
1 parent 9432c5a commit 540b375

4 files changed

Lines changed: 404 additions & 61 deletions

File tree

qa/L3_pytorch_FA_versions_test/test.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); pri
1818
export FLASH_ATTN_CUDA_ARCHS=$sm_arch
1919
if [ $sm_arch -gt 90 ]
2020
then
21-
FA_versions=(2.8.3)
21+
FA_versions=(2.8.3 4.0.0b8)
2222
elif [ $sm_arch -eq 90 ]
2323
then
24-
FA_versions=(2.7.3 2.8.3 3.0.0b1)
24+
FA_versions=(2.7.3 2.8.3 3.0.0b1 4.0.0b8)
2525
fi
2626

2727
for fa_version in "${FA_versions[@]}"
@@ -31,6 +31,9 @@ do
3131
if [ "${fa_version}" \< "3.0.0" ]
3232
then
3333
pip3 install flash-attn==${fa_version} --no-build-isolation
34+
elif [[ "${fa_version}" == 4.* ]]
35+
then
36+
pip3 install flash-attn-4==${fa_version} nvidia-cutlass-dsl[cu13]==4.4.2 --no-build-isolation
3437
else
3538
git clone https://github.com/Dao-AILab/flash-attention.git
3639
cd flash-attention/hopper && python setup.py install

tests/pytorch/attention/test_attention.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
)
5454

5555
_current_file = pathlib.Path(__file__).resolve()
56-
sys.path.append(str(_current_file.parent.parent))
56+
sys.path = [str(_current_file.parent.parent)] + sys.path
5757
from utils import (
5858
reset_rng_states,
5959
compare_and_assert,
@@ -362,6 +362,139 @@ def test_dpa_num_splits(dtype, model_configs, model):
362362
)
363363

364364

365+
# ==============================
366+
# Flash Attention 4 (FA4) tests
367+
# ==============================
368+
369+
model_configs_fa4_base = {
370+
# test: ModelConfig(b, sq, hq, dqk)
371+
# Standard head dims
372+
"fa4_base_1": ModelConfig(4, 128, 16, 64),
373+
"fa4_base_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
374+
"fa4_base_3": ModelConfig(2, 1024, 8, 96, attn_mask_type="causal"),
375+
# GQA
376+
"fa4_gqa_1": ModelConfig(2, 1024, 32, 128, num_gqa_groups=8, attn_mask_type="causal"),
377+
"fa4_gqa_2": ModelConfig(2, 1024, 16, 128, num_gqa_groups=1, attn_mask_type="causal"),
378+
# num_splits
379+
"fa4_splits_1": ModelConfig(2, 2048, 24, 128, num_splits=2),
380+
"fa4_splits_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, num_splits=4),
381+
}
382+
383+
384+
@pytest.mark.skipif(
385+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
386+
)
387+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
388+
@pytest.mark.parametrize("dtype", param_types_lean)
389+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_base])
390+
@pytest.mark.parametrize("model", model_configs_fa4_base.keys())
391+
def test_dpa_fa4_base(dtype, model_configs, model):
392+
"""Test DotProductAttention with FA4: base configs, extended head dims, GQA, num_splits"""
393+
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
394+
395+
396+
model_configs_fa4_mla = {
397+
# test: ModelConfig(b, sq, hq, dqk, head_dim_v=dv)
398+
"fa4_mla_1": ModelConfig(4, 128, 16, 128, head_dim_v=64),
399+
"fa4_mla_2": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),
400+
"fa4_mla_3": ModelConfig(2, 1024, 16, 96, head_dim_v=64, attn_mask_type="causal"),
401+
# dqk=128, dv=96: FA4 SM100 backward has dK_reduce_ncol misalignment for dV;
402+
# the backend filter should reject FA4 and fall back to another backend.
403+
"fa4_mla_4": ModelConfig(2, 1024, 16, 128, head_dim_v=96, attn_mask_type="causal"),
404+
# DeepSeek-style MLA: dqk=192, dv=128 (supported on SM100 as special case)
405+
"fa4_mla_deepseek": ModelConfig(2, 1024, 16, 192, head_dim_v=128, attn_mask_type="causal"),
406+
}
407+
408+
409+
@pytest.mark.skipif(
410+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
411+
)
412+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
413+
@pytest.mark.parametrize("dtype", param_types_lean)
414+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_mla])
415+
@pytest.mark.parametrize("model", model_configs_fa4_mla.keys())
416+
def test_dpa_fa4_mla(dtype, model_configs, model):
417+
"""Test DotProductAttention with FA4: MLA (head_dim_qk != head_dim_v)"""
418+
test_dot_product_attention(
419+
dtype, model_configs, model, False, True, "bshd_bshd_bshd", False, False
420+
)
421+
422+
423+
model_configs_fa4_swa = {
424+
# test: ModelConfig(b, sq, hq, dqk, window_size=(left, right))
425+
"fa4_swa_1": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal", window_size=(128, 0)),
426+
"fa4_swa_2": ModelConfig(2, 2048, 24, 64, attn_mask_type="causal", window_size=(64, 0)),
427+
"fa4_swa_3": ModelConfig(
428+
2, 2048, 16, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(256, 0)
429+
),
430+
"fa4_swa_4": ModelConfig(
431+
2, 2048, 16, 128, attn_mask_type="padding_causal", window_size=(128, 0)
432+
),
433+
}
434+
435+
436+
@pytest.mark.skipif(
437+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
438+
)
439+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
440+
@pytest.mark.parametrize("dtype", param_types_lean)
441+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_swa])
442+
@pytest.mark.parametrize("model", model_configs_fa4_swa.keys())
443+
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "bshd_bshd_bshd"])
444+
def test_dpa_fa4_sliding_window(dtype, model_configs, model, qkv_layout):
445+
"""Test DotProductAttention with FA4: sliding window attention"""
446+
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False)
447+
448+
449+
model_configs_fa4_varlen = {
450+
# test: ModelConfig(b, sq, hq, dqk)
451+
"fa4_varlen_1": ModelConfig(4, 128, 16, 64, attn_mask_type="padding"),
452+
"fa4_varlen_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="padding_causal"),
453+
"fa4_varlen_3": ModelConfig(
454+
2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"
455+
),
456+
"fa4_varlen_4": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
457+
}
458+
459+
460+
@pytest.mark.skipif(
461+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
462+
)
463+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
464+
@pytest.mark.parametrize("dtype", param_types_lean)
465+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_varlen])
466+
@pytest.mark.parametrize("model", model_configs_fa4_varlen.keys())
467+
@pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "bshd_bshd_bshd"])
468+
def test_dpa_fa4_varlen(dtype, model_configs, model, qkv_layout):
469+
"""Test DotProductAttention with FA4: variable-length sequences (varlen/thd)"""
470+
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
471+
472+
473+
model_configs_fa4_mask = {
474+
# test: ModelConfig(b, sq, hq, dqk)
475+
"fa4_mask_no_mask": ModelConfig(2, 1024, 16, 128),
476+
"fa4_mask_causal": ModelConfig(2, 1024, 16, 128, attn_mask_type="causal"),
477+
"fa4_mask_causal_br": ModelConfig(2, 1024, 16, 128, attn_mask_type="causal_bottom_right"),
478+
"fa4_mask_padding": ModelConfig(2, 1024, 16, 128, attn_mask_type="padding"),
479+
"fa4_mask_padding_causal": ModelConfig(2, 1024, 16, 128, attn_mask_type="padding_causal"),
480+
"fa4_mask_padding_causal_br": ModelConfig(
481+
2, 1024, 16, 128, attn_mask_type="padding_causal_bottom_right"
482+
),
483+
}
484+
485+
486+
@pytest.mark.skipif(
487+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
488+
)
489+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
490+
@pytest.mark.parametrize("dtype", param_types_lean)
491+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_mask])
492+
@pytest.mark.parametrize("model", model_configs_fa4_mask.keys())
493+
def test_dpa_fa4_mask(dtype, model_configs, model):
494+
"""Test DotProductAttention with FA4: various attention mask types"""
495+
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
496+
497+
365498
model_configs_softmax = {
366499
# test: ModelConfig(b, sq, hq, dqk)
367500
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,16 @@
8080
from transformer_engine.pytorch.export import is_in_onnx_export_mode
8181
from transformer_engine.pytorch.graph import is_graph_capturing
8282

83-
# Global vars for flash attn v2 and v3 imports
83+
# Global vars for flash attn v2
8484
flash_attn_cuda_bwd = None
8585
flash_attn_func = None
8686
flash_attn_varlen_func = None
8787
_flash_attn_fwd = None
8888
_flash_attn_bwd = None
8989
_flash_attn_varlen_fwd = None
9090
_flash_attn_varlen_bwd = None
91+
92+
# Try to import Flash Attention v2
9193
try:
9294
fa_utils.version = PkgVersion(PkgVersion(get_pkg_version("flash-attn")).public)
9395
except PackageNotFoundError:
@@ -130,12 +132,16 @@
130132
),
131133
fa_utils.version,
132134
)
135+
136+
# Try to import Flash Attention v3
133137
try:
134138
fa_utils.fa3_version = PkgVersion(PkgVersion(get_pkg_version("flash-attn-3")).public)
135139
except PackageNotFoundError:
136140
flash_attn_func_v3 = None
137141
flash_attn_varlen_func_v3 = None
138142
flash_attn_with_kvcache_v3 = None
143+
_flash_attn_fwd_v3 = None
144+
_flash_attn_bwd_v3 = None
139145
# pass # only print warning if use_flash_attention_3 = True in get_attention_backend
140146
else:
141147
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
@@ -150,6 +156,20 @@
150156

151157
fa_utils.set_flash_attention_3_params()
152158

159+
# Try to import Flash Attention v4
160+
try:
161+
fa_utils.fa4_version = PkgVersion(get_pkg_version("flash-attn-4"))
162+
except PackageNotFoundError:
163+
flash_attn_func_v4 = None
164+
flash_attn_varlen_func_v4 = None
165+
else:
166+
from flash_attn.cute.interface import ( # pylint: disable=ungrouped-imports,no-name-in-module
167+
flash_attn_func as flash_attn_func_v4,
168+
flash_attn_varlen_func as flash_attn_varlen_func_v4,
169+
)
170+
171+
fa_utils.set_flash_attention_4_params()
172+
153173
# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16
154174
_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1"
155175

@@ -916,8 +936,13 @@ def forward(
916936
batch_size * context_len,
917937
)
918938

939+
use_flash_attn_4 = False
940+
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("4.0.0b"):
941+
use_flash_attn_4 = True
919942
use_flash_attn_3 = False
920-
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"):
943+
if flash_attention_backend is not None and PkgVersion(
944+
"3.0.0b"
945+
) < flash_attention_backend < PkgVersion("4.0.0"):
921946
use_flash_attn_3 = True
922947
if context_parallel and all(
923948
not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
@@ -971,24 +996,55 @@ def forward(
971996
# | | thd + padding
972997
# | flash_attn_with_kvcache | KV cache (not-paged/paged), i.e.
973998
# | | bshd/sbhd/thd + padding
999+
# FA v4 | flash_attn_func | bshd/sbhd + not padding
1000+
# | flash_attn_varlen_func | bshd/sbhd + padding
1001+
# | | thd + padding
9741002
fa_optional_forward_args_thd = []
9751003
if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
976-
func = (
977-
flash_attn_func if not use_flash_attn_3 else flash_attn_func_v3
978-
) # pylint: disable=possibly-used-before-assignment
1004+
func = None
1005+
if use_flash_attn_4:
1006+
func = flash_attn_func_v4
1007+
elif use_flash_attn_3:
1008+
func = flash_attn_func_v3
1009+
else:
1010+
func = flash_attn_func
9791011
else:
980-
if not use_flash_attn_3:
1012+
if use_flash_attn_4:
1013+
func = flash_attn_varlen_func_v4
1014+
elif not use_flash_attn_3:
9811015
func = flash_attn_varlen_func
9821016
elif inference_params is None:
9831017
func = flash_attn_varlen_func_v3 # pylint: disable=possibly-used-before-assignment
9841018
else:
9851019
func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment
986-
if not use_flash_attn_3 or inference_params is None:
1020+
if not use_flash_attn_4 and (not use_flash_attn_3 or inference_params is None):
9871021
fa_optional_forward_args_thd.append(cu_seqlens_q)
9881022
fa_optional_forward_args_thd.append(cu_seqlens_kv)
9891023
fa_optional_forward_args_thd.append(max_seqlen_q)
9901024
fa_optional_forward_args_thd.append(max_seqlen_kv)
991-
if not use_flash_attn_3:
1025+
if use_flash_attn_4:
1026+
fa_4_optional_forward_kwargs = {
1027+
"window_size": window_size,
1028+
"num_splits": num_splits,
1029+
}
1030+
if inference_params is None:
1031+
fa_4_optional_forward_kwargs["deterministic"] = self.deterministic
1032+
if func is flash_attn_varlen_func_v4:
1033+
fa_4_optional_forward_kwargs["cu_seqlens_q"] = cu_seqlens_q
1034+
fa_4_optional_forward_kwargs["cu_seqlens_k"] = cu_seqlens_kv
1035+
fa_4_optional_forward_kwargs["max_seqlen_q"] = max_seqlen_q
1036+
fa_4_optional_forward_kwargs["max_seqlen_k"] = max_seqlen_kv
1037+
output = func(
1038+
query_layer,
1039+
key_layer,
1040+
value_layer,
1041+
softmax_scale=self.softmax_scale,
1042+
causal="causal" in attn_mask_type,
1043+
**fa_4_optional_forward_kwargs,
1044+
)
1045+
if isinstance(output, (List, Tuple)):
1046+
output = output[0]
1047+
elif not use_flash_attn_3:
9921048
fa_optional_forward_kwargs = {}
9931049
if fa_utils.v2_3_plus:
9941050
fa_optional_forward_kwargs["window_size"] = window_size

0 commit comments

Comments
 (0)