Skip to content

Commit b87e5da

Browse files
committed
add fa4 support
Signed-off-by: Xin Yao <xiny@nvidia.com>
1 parent 53fefa4 commit b87e5da

3 files changed

Lines changed: 392 additions & 59 deletions

File tree

tests/pytorch/attention/test_attention.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
)
5454

5555
_current_file = pathlib.Path(__file__).resolve()
56-
sys.path.append(str(_current_file.parent.parent))
56+
sys.path = [str(_current_file.parent.parent)] + sys.path
5757
from utils import (
5858
reset_rng_states,
5959
compare_and_assert,
@@ -362,6 +362,139 @@ def test_dpa_num_splits(dtype, model_configs, model):
362362
)
363363

364364

365+
# ==============================
366+
# Flash Attention 4 (FA4) tests
367+
# ==============================
368+
369+
model_configs_fa4_base = {
370+
# test: ModelConfig(b, sq, hq, dqk)
371+
# Standard head dims
372+
"fa4_base_1": ModelConfig(4, 128, 16, 64),
373+
"fa4_base_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
374+
"fa4_base_3": ModelConfig(2, 1024, 8, 96, attn_mask_type="causal"),
375+
# GQA
376+
"fa4_gqa_1": ModelConfig(2, 1024, 32, 128, num_gqa_groups=8, attn_mask_type="causal"),
377+
"fa4_gqa_2": ModelConfig(2, 1024, 16, 128, num_gqa_groups=1, attn_mask_type="causal"),
378+
# num_splits
379+
"fa4_splits_1": ModelConfig(2, 2048, 24, 128, num_splits=2),
380+
"fa4_splits_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, num_splits=4),
381+
}
382+
383+
384+
@pytest.mark.skipif(
385+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
386+
)
387+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
388+
@pytest.mark.parametrize("dtype", param_types_lean)
389+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_base])
390+
@pytest.mark.parametrize("model", model_configs_fa4_base.keys())
391+
def test_dpa_fa4_base(dtype, model_configs, model):
392+
"""Test DotProductAttention with FA4: base configs, extended head dims, GQA, num_splits"""
393+
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
394+
395+
396+
model_configs_fa4_mla = {
397+
# test: ModelConfig(b, sq, hq, dqk, head_dim_v=dv)
398+
"fa4_mla_1": ModelConfig(4, 128, 16, 128, head_dim_v=64),
399+
"fa4_mla_2": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),
400+
"fa4_mla_3": ModelConfig(2, 1024, 16, 96, head_dim_v=64, attn_mask_type="causal"),
401+
# dqk=128, dv=96: FA4 SM100 backward has dK_reduce_ncol misalignment for dV;
402+
# the backend filter should reject FA4 and fall back to another backend.
403+
"fa4_mla_4": ModelConfig(2, 1024, 16, 128, head_dim_v=96, attn_mask_type="causal"),
404+
# DeepSeek-style MLA: dqk=192, dv=128 (supported on SM100 as special case)
405+
"fa4_mla_deepseek": ModelConfig(2, 1024, 16, 192, head_dim_v=128, attn_mask_type="causal"),
406+
}
407+
408+
409+
@pytest.mark.skipif(
410+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
411+
)
412+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
413+
@pytest.mark.parametrize("dtype", param_types_lean)
414+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_mla])
415+
@pytest.mark.parametrize("model", model_configs_fa4_mla.keys())
416+
def test_dpa_fa4_mla(dtype, model_configs, model):
417+
"""Test DotProductAttention with FA4: MLA (head_dim_qk != head_dim_v)"""
418+
test_dot_product_attention(
419+
dtype, model_configs, model, False, True, "bshd_bshd_bshd", False, False
420+
)
421+
422+
423+
model_configs_fa4_swa = {
424+
# test: ModelConfig(b, sq, hq, dqk, window_size=(left, right))
425+
"fa4_swa_1": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal", window_size=(128, 0)),
426+
"fa4_swa_2": ModelConfig(2, 2048, 24, 64, attn_mask_type="causal", window_size=(64, 0)),
427+
"fa4_swa_3": ModelConfig(
428+
2, 2048, 16, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(256, 0)
429+
),
430+
"fa4_swa_4": ModelConfig(
431+
2, 2048, 16, 128, attn_mask_type="padding_causal", window_size=(128, 0)
432+
),
433+
}
434+
435+
436+
@pytest.mark.skipif(
437+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
438+
)
439+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
440+
@pytest.mark.parametrize("dtype", param_types_lean)
441+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_swa])
442+
@pytest.mark.parametrize("model", model_configs_fa4_swa.keys())
443+
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "bshd_bshd_bshd"])
444+
def test_dpa_fa4_sliding_window(dtype, model_configs, model, qkv_layout):
445+
"""Test DotProductAttention with FA4: sliding window attention"""
446+
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False)
447+
448+
449+
model_configs_fa4_varlen = {
450+
# test: ModelConfig(b, sq, hq, dqk)
451+
"fa4_varlen_1": ModelConfig(4, 128, 16, 64, attn_mask_type="padding"),
452+
"fa4_varlen_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="padding_causal"),
453+
"fa4_varlen_3": ModelConfig(
454+
2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"
455+
),
456+
"fa4_varlen_4": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
457+
}
458+
459+
460+
@pytest.mark.skipif(
461+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
462+
)
463+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
464+
@pytest.mark.parametrize("dtype", param_types_lean)
465+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_varlen])
466+
@pytest.mark.parametrize("model", model_configs_fa4_varlen.keys())
467+
@pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "bshd_bshd_bshd"])
468+
def test_dpa_fa4_varlen(dtype, model_configs, model, qkv_layout):
469+
"""Test DotProductAttention with FA4: variable-length sequences (varlen/thd)"""
470+
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
471+
472+
473+
model_configs_fa4_mask = {
474+
# test: ModelConfig(b, sq, hq, dqk)
475+
"fa4_mask_no_mask": ModelConfig(2, 1024, 16, 128),
476+
"fa4_mask_causal": ModelConfig(2, 1024, 16, 128, attn_mask_type="causal"),
477+
"fa4_mask_causal_br": ModelConfig(2, 1024, 16, 128, attn_mask_type="causal_bottom_right"),
478+
"fa4_mask_padding": ModelConfig(2, 1024, 16, 128, attn_mask_type="padding"),
479+
"fa4_mask_padding_causal": ModelConfig(2, 1024, 16, 128, attn_mask_type="padding_causal"),
480+
"fa4_mask_padding_causal_br": ModelConfig(
481+
2, 1024, 16, 128, attn_mask_type="padding_causal_bottom_right"
482+
),
483+
}
484+
485+
486+
@pytest.mark.skipif(
487+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
488+
)
489+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
490+
@pytest.mark.parametrize("dtype", param_types_lean)
491+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_mask])
492+
@pytest.mark.parametrize("model", model_configs_fa4_mask.keys())
493+
def test_dpa_fa4_mask(dtype, model_configs, model):
494+
"""Test DotProductAttention with FA4: various attention mask types"""
495+
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
496+
497+
365498
model_configs_softmax = {
366499
# test: ModelConfig(b, sq, hq, dqk)
367500
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,16 @@
8080
from transformer_engine.pytorch.export import is_in_onnx_export_mode
8181
from transformer_engine.pytorch.graph import is_graph_capturing
8282

83-
# Global vars for flash attn v2 and v3 imports
83+
# Global vars for flash attn v2
8484
flash_attn_cuda_bwd = None
8585
flash_attn_func = None
8686
flash_attn_varlen_func = None
8787
_flash_attn_fwd = None
8888
_flash_attn_bwd = None
8989
_flash_attn_varlen_fwd = None
9090
_flash_attn_varlen_bwd = None
91+
92+
# Try to import Flash Attention v2
9193
try:
9294
fa_utils.version = PkgVersion(get_pkg_version("flash-attn"))
9395
except PackageNotFoundError:
@@ -130,12 +132,16 @@
130132
),
131133
fa_utils.version,
132134
)
135+
136+
# Try to import Flash Attention v3
133137
try:
134138
fa_utils.fa3_version = PkgVersion(get_pkg_version("flash-attn-3"))
135139
except PackageNotFoundError:
136140
flash_attn_func_v3 = None
137141
flash_attn_varlen_func_v3 = None
138142
flash_attn_with_kvcache_v3 = None
143+
_flash_attn_fwd_v3 = None
144+
_flash_attn_bwd_v3 = None
139145
# pass # only print warning if use_flash_attention_3 = True in get_attention_backend
140146
else:
141147
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
@@ -150,6 +156,25 @@
150156

151157
fa_utils.set_flash_attention_3_params()
152158

159+
# Try to import Flash Attention v4
160+
try:
161+
fa_utils.fa4_version = PkgVersion(get_pkg_version("flash-attn-4"))
162+
except PackageNotFoundError:
163+
flash_attn_func_v4 = None
164+
flash_attn_varlen_func_v4 = None
165+
# flash_attn_combine_v4 = None
166+
_flash_attn_fwd_v4 = None
167+
_flash_attn_bwd_v4 = None
168+
else:
169+
from flash_attn.cute.interface import ( # pylint: disable=ungrouped-imports
170+
flash_attn_func as flash_attn_func_v4,
171+
flash_attn_varlen_func as flash_attn_varlen_func_v4,
172+
_flash_attn_fwd as _flash_attn_fwd_v4,
173+
_flash_attn_bwd as _flash_attn_bwd_v4,
174+
)
175+
176+
fa_utils.set_flash_attention_4_params()
177+
153178
# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16
154179
_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1"
155180

@@ -916,8 +941,13 @@ def forward(
916941
batch_size * context_len,
917942
)
918943

944+
use_flash_attn_4 = False
945+
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("4.0.0b"):
946+
use_flash_attn_4 = True
919947
use_flash_attn_3 = False
920-
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"):
948+
if flash_attention_backend is not None and PkgVersion(
949+
"3.0.0b"
950+
) < flash_attention_backend < PkgVersion("4.0.0"):
921951
use_flash_attn_3 = True
922952
if context_parallel and all(
923953
not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
@@ -971,24 +1001,55 @@ def forward(
9711001
# | | thd + padding
9721002
# | flash_attn_with_kvcache | KV cache (not-paged/paged), i.e.
9731003
# | | bshd/sbhd/thd + padding
1004+
# FA v4 | flash_attn_func | bshd/sbhd + not padding
1005+
# | flash_attn_varlen_func | bshd/sbhd + padding
1006+
# | | thd + padding
9741007
fa_optional_forward_args_thd = []
9751008
if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
976-
func = (
977-
flash_attn_func if not use_flash_attn_3 else flash_attn_func_v3
978-
) # pylint: disable=possibly-used-before-assignment
1009+
func = None
1010+
if use_flash_attn_4:
1011+
func = flash_attn_func_v4
1012+
elif use_flash_attn_3:
1013+
func = flash_attn_func_v3
1014+
else:
1015+
func = flash_attn_func
9791016
else:
980-
if not use_flash_attn_3:
1017+
if use_flash_attn_4:
1018+
func = flash_attn_varlen_func_v4
1019+
elif not use_flash_attn_3:
9811020
func = flash_attn_varlen_func
9821021
elif inference_params is None:
9831022
func = flash_attn_varlen_func_v3 # pylint: disable=possibly-used-before-assignment
9841023
else:
9851024
func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment
986-
if not use_flash_attn_3 or inference_params is None:
1025+
if not use_flash_attn_4 and (not use_flash_attn_3 or inference_params is None):
9871026
fa_optional_forward_args_thd.append(cu_seqlens_q)
9881027
fa_optional_forward_args_thd.append(cu_seqlens_kv)
9891028
fa_optional_forward_args_thd.append(max_seqlen_q)
9901029
fa_optional_forward_args_thd.append(max_seqlen_kv)
991-
if not use_flash_attn_3:
1030+
if use_flash_attn_4:
1031+
fa_4_optional_forward_kwargs = {
1032+
"window_size": window_size,
1033+
"num_splits": num_splits,
1034+
}
1035+
if inference_params is None:
1036+
fa_4_optional_forward_kwargs["deterministic"] = self.deterministic
1037+
if func is flash_attn_varlen_func_v4:
1038+
fa_4_optional_forward_kwargs["cu_seqlens_q"] = cu_seqlens_q
1039+
fa_4_optional_forward_kwargs["cu_seqlens_k"] = cu_seqlens_kv
1040+
fa_4_optional_forward_kwargs["max_seqlen_q"] = max_seqlen_q
1041+
fa_4_optional_forward_kwargs["max_seqlen_k"] = max_seqlen_kv
1042+
output = func(
1043+
query_layer,
1044+
key_layer,
1045+
value_layer,
1046+
softmax_scale=self.softmax_scale,
1047+
causal="causal" in attn_mask_type,
1048+
**fa_4_optional_forward_kwargs,
1049+
)
1050+
if isinstance(output, (List, Tuple)):
1051+
output = output[0]
1052+
elif not use_flash_attn_3:
9921053
fa_optional_forward_kwargs = {}
9931054
if fa_utils.v2_3_plus:
9941055
fa_optional_forward_kwargs["window_size"] = window_size

0 commit comments

Comments
 (0)