Skip to content

Commit 4760264

Browse files
committed
fa4 support
Signed-off-by: Xin Yao <xiny@nvidia.com>
1 parent 38f62a0 commit 4760264

3 files changed

Lines changed: 273 additions & 82 deletions

File tree

tests/pytorch/attention/test_attention.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
)
5454

5555
_current_file = pathlib.Path(__file__).resolve()
56-
sys.path.append(str(_current_file.parent.parent))
56+
sys.path = [str(_current_file.parent.parent)] + sys.path
5757
from utils import (
5858
reset_rng_states,
5959
compare_and_assert,
@@ -362,6 +362,139 @@ def test_dpa_num_splits(dtype, model_configs, model):
362362
)
363363

364364

365+
# ==============================
366+
# Flash Attention 4 (FA4) tests
367+
# ==============================
368+
369+
model_configs_fa4_base = {
370+
# test: ModelConfig(b, sq, hq, dqk)
371+
# Standard head dims
372+
"fa4_base_1": ModelConfig(4, 128, 16, 64),
373+
"fa4_base_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
374+
"fa4_base_3": ModelConfig(2, 1024, 8, 96, attn_mask_type="causal"),
375+
# GQA
376+
"fa4_gqa_1": ModelConfig(2, 1024, 32, 128, num_gqa_groups=8, attn_mask_type="causal"),
377+
"fa4_gqa_2": ModelConfig(2, 1024, 16, 128, num_gqa_groups=1, attn_mask_type="causal"),
378+
# num_splits
379+
"fa4_splits_1": ModelConfig(2, 2048, 24, 128, num_splits=2),
380+
"fa4_splits_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, num_splits=4),
381+
}
382+
383+
384+
@pytest.mark.skipif(
385+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
386+
)
387+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
388+
@pytest.mark.parametrize("dtype", param_types_lean)
389+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_base])
390+
@pytest.mark.parametrize("model", model_configs_fa4_base.keys())
391+
def test_dpa_fa4_base(dtype, model_configs, model):
392+
"""Test DotProductAttention with FA4: base configs, extended head dims, GQA, num_splits"""
393+
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
394+
395+
396+
model_configs_fa4_mla = {
397+
# test: ModelConfig(b, sq, hq, dqk, head_dim_v=dv)
398+
"fa4_mla_1": ModelConfig(4, 128, 16, 128, head_dim_v=64),
399+
"fa4_mla_2": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),
400+
"fa4_mla_3": ModelConfig(2, 1024, 16, 96, head_dim_v=64, attn_mask_type="causal"),
401+
# dqk=128, dv=96: FA4 SM100 backward has dK_reduce_ncol misalignment for dV;
402+
# the backend filter should reject FA4 and fall back to another backend.
403+
"fa4_mla_4": ModelConfig(2, 1024, 16, 128, head_dim_v=96, attn_mask_type="causal"),
404+
# DeepSeek-style MLA: dqk=192, dv=128 (supported on SM100 as special case)
405+
"fa4_mla_deepseek": ModelConfig(2, 1024, 16, 192, head_dim_v=128, attn_mask_type="causal"),
406+
}
407+
408+
409+
@pytest.mark.skipif(
410+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
411+
)
412+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
413+
@pytest.mark.parametrize("dtype", param_types_lean)
414+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_mla])
415+
@pytest.mark.parametrize("model", model_configs_fa4_mla.keys())
416+
def test_dpa_fa4_mla(dtype, model_configs, model):
417+
"""Test DotProductAttention with FA4: MLA (head_dim_qk != head_dim_v)"""
418+
test_dot_product_attention(
419+
dtype, model_configs, model, False, True, "bshd_bshd_bshd", False, False
420+
)
421+
422+
423+
model_configs_fa4_swa = {
424+
# test: ModelConfig(b, sq, hq, dqk, window_size=(left, right))
425+
"fa4_swa_1": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal", window_size=(128, 0)),
426+
"fa4_swa_2": ModelConfig(2, 2048, 24, 64, attn_mask_type="causal", window_size=(64, 0)),
427+
"fa4_swa_3": ModelConfig(
428+
2, 2048, 16, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(256, 0)
429+
),
430+
"fa4_swa_4": ModelConfig(
431+
2, 2048, 16, 128, attn_mask_type="padding_causal", window_size=(128, 0)
432+
),
433+
}
434+
435+
436+
@pytest.mark.skipif(
437+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
438+
)
439+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
440+
@pytest.mark.parametrize("dtype", param_types_lean)
441+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_swa])
442+
@pytest.mark.parametrize("model", model_configs_fa4_swa.keys())
443+
@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "bshd_bshd_bshd"])
444+
def test_dpa_fa4_sliding_window(dtype, model_configs, model, qkv_layout):
445+
"""Test DotProductAttention with FA4: sliding window attention"""
446+
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, True, False)
447+
448+
449+
model_configs_fa4_varlen = {
450+
# test: ModelConfig(b, sq, hq, dqk)
451+
"fa4_varlen_1": ModelConfig(4, 128, 16, 64, attn_mask_type="padding"),
452+
"fa4_varlen_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="padding_causal"),
453+
"fa4_varlen_3": ModelConfig(
454+
2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"
455+
),
456+
"fa4_varlen_4": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"),
457+
}
458+
459+
460+
@pytest.mark.skipif(
461+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
462+
)
463+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
464+
@pytest.mark.parametrize("dtype", param_types_lean)
465+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_varlen])
466+
@pytest.mark.parametrize("model", model_configs_fa4_varlen.keys())
467+
@pytest.mark.parametrize("qkv_layout", ["thd_thd_thd", "bshd_bshd_bshd"])
468+
def test_dpa_fa4_varlen(dtype, model_configs, model, qkv_layout):
469+
"""Test DotProductAttention with FA4: variable-length sequences (varlen/thd)"""
470+
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
471+
472+
473+
model_configs_fa4_mask = {
474+
# test: ModelConfig(b, sq, hq, dqk)
475+
"fa4_mask_no_mask": ModelConfig(2, 1024, 16, 128),
476+
"fa4_mask_causal": ModelConfig(2, 1024, 16, 128, attn_mask_type="causal"),
477+
"fa4_mask_causal_br": ModelConfig(2, 1024, 16, 128, attn_mask_type="causal_bottom_right"),
478+
"fa4_mask_padding": ModelConfig(2, 1024, 16, 128, attn_mask_type="padding"),
479+
"fa4_mask_padding_causal": ModelConfig(2, 1024, 16, 128, attn_mask_type="padding_causal"),
480+
"fa4_mask_padding_causal_br": ModelConfig(
481+
2, 1024, 16, 128, attn_mask_type="padding_causal_bottom_right"
482+
),
483+
}
484+
485+
486+
@pytest.mark.skipif(
487+
not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required."
488+
)
489+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
490+
@pytest.mark.parametrize("dtype", param_types_lean)
491+
@pytest.mark.parametrize("model_configs", [model_configs_fa4_mask])
492+
@pytest.mark.parametrize("model", model_configs_fa4_mask.keys())
493+
def test_dpa_fa4_mask(dtype, model_configs, model):
494+
"""Test DotProductAttention with FA4: various attention mask types"""
495+
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)
496+
497+
365498
model_configs_softmax = {
366499
# test: ModelConfig(b, sq, hq, dqk)
367500
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,19 @@
8080
from transformer_engine.pytorch.export import is_in_onnx_export_mode
8181
from transformer_engine.pytorch.graph import is_graph_capturing
8282

83+
# Global vars for flash attn v2
84+
flash_attn_cuda_bwd = None
85+
flash_attn_func = None
86+
flash_attn_varlen_func = None
87+
_flash_attn_fwd = None
88+
_flash_attn_bwd = None
89+
_flash_attn_varlen_fwd = None
90+
_flash_attn_varlen_bwd = None
91+
8392
# Try to import Flash Attention v2
8493
try:
8594
fa_utils.version = PkgVersion(get_pkg_version("flash-attn"))
8695
except PackageNotFoundError:
87-
flash_attn_cuda_bwd = None
88-
flash_attn_func = None
89-
flash_attn_varlen_func = None
90-
_flash_attn_fwd = None
91-
_flash_attn_bwd = None
92-
_flash_attn_varlen_fwd = None
93-
_flash_attn_varlen_bwd = None
9496
pass # only print warning if use_flash_attention_2 = True in get_attention_backend
9597
else:
9698
if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0):
@@ -156,20 +158,21 @@
156158

157159
# Try to import Flash Attention v4
158160
try:
159-
fa_utils.fa4_version = PkgVersion(get_pkg_version("flash-attn-cute"))
161+
fa_utils.fa4_version = PkgVersion(get_pkg_version("flash-attn-4"))
160162
except PackageNotFoundError:
161163
flash_attn_func_v4 = None
162164
flash_attn_varlen_func_v4 = None
163-
flash_attn_with_kvcache_v4 = None
165+
# flash_attn_combine_v4 = None
164166
_flash_attn_fwd_v4 = None
165167
_flash_attn_bwd_v4 = None
166-
# pass # only print warning if use_flash_attention_4 = True in get_attention_backend
167168
else:
168-
from flash_attn.cute.interface import flash_attn_func as flash_attn_func_v4
169-
from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_v4
170-
from flash_attn.cute.interface import _flash_attn_fwd as _flash_attn_fwd_v4
171-
from flash_attn.cute.interface import _flash_attn_bwd as _flash_attn_bwd_v4
172-
# flash_attn_with_kvcache_v4 = None # FA4 does not support kvcache yet
169+
from flash_attn.cute.interface import ( # pylint: disable=ungrouped-imports
170+
flash_attn_func as flash_attn_func_v4,
171+
flash_attn_varlen_func as flash_attn_varlen_func_v4,
172+
_flash_attn_fwd as _flash_attn_fwd_v4,
173+
_flash_attn_bwd as _flash_attn_bwd_v4,
174+
)
175+
173176
fa_utils.set_flash_attention_4_params()
174177

175178
# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16
@@ -938,12 +941,14 @@ def forward(
938941
batch_size * context_len,
939942
)
940943

941-
use_flash_attn_3 = False
942-
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"):
943-
use_flash_attn_3 = True
944944
use_flash_attn_4 = False
945-
if flash_attention_backend is not None and str(flash_attention_backend).endswith("cute"):
945+
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("4.0.0b"):
946946
use_flash_attn_4 = True
947+
use_flash_attn_3 = False
948+
if flash_attention_backend is not None and PkgVersion(
949+
"3.0.0b"
950+
) < flash_attention_backend < PkgVersion("4.0.0"):
951+
use_flash_attn_3 = True
947952
if context_parallel and all(
948953
not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
949954
):
@@ -996,6 +1001,9 @@ def forward(
9961001
# | | thd + padding
9971002
# | flash_attn_with_kvcache | KV cache (not-paged/paged), i.e.
9981003
# | | bshd/sbhd/thd + padding
1004+
# FA v4 | flash_attn_func | bshd/sbhd + not padding
1005+
# | flash_attn_varlen_func | bshd/sbhd + padding
1006+
# | | thd + padding
9991007
fa_optional_forward_args_thd = []
10001008
if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
10011009
func = None
@@ -1006,24 +1014,31 @@ def forward(
10061014
else:
10071015
func = flash_attn_func
10081016
else:
1009-
if not use_flash_attn_3:
1017+
if use_flash_attn_4:
1018+
func = flash_attn_varlen_func_v4
1019+
elif not use_flash_attn_3:
10101020
func = flash_attn_varlen_func
10111021
elif inference_params is None:
10121022
func = flash_attn_varlen_func_v3 # pylint: disable=possibly-used-before-assignment
10131023
else:
10141024
func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment
1015-
if not use_flash_attn_3 or inference_params is None:
1025+
if not use_flash_attn_4 and (not use_flash_attn_3 or inference_params is None):
10161026
fa_optional_forward_args_thd.append(cu_seqlens_q)
10171027
fa_optional_forward_args_thd.append(cu_seqlens_kv)
10181028
fa_optional_forward_args_thd.append(max_seqlen_q)
10191029
fa_optional_forward_args_thd.append(max_seqlen_kv)
10201030
if use_flash_attn_4:
10211031
fa_4_optional_forward_kwargs = {
1022-
# "window_size": window_size,
1032+
"window_size": window_size,
10231033
"num_splits": num_splits,
10241034
}
10251035
if inference_params is None:
10261036
fa_4_optional_forward_kwargs["deterministic"] = self.deterministic
1037+
if func is flash_attn_varlen_func_v4:
1038+
fa_4_optional_forward_kwargs["cu_seqlens_q"] = cu_seqlens_q
1039+
fa_4_optional_forward_kwargs["cu_seqlens_k"] = cu_seqlens_kv
1040+
fa_4_optional_forward_kwargs["max_seqlen_q"] = max_seqlen_q
1041+
fa_4_optional_forward_kwargs["max_seqlen_k"] = max_seqlen_kv
10271042
output = func(
10281043
query_layer,
10291044
key_layer,

0 commit comments

Comments
 (0)