Skip to content

Commit 38f62a0

Browse files
committed
initial effort of adding fa4
Signed-off-by: Xin Yao <xiny@nvidia.com>
1 parent b7598aa commit 38f62a0

2 files changed

Lines changed: 225 additions & 84 deletions

File tree

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,17 @@
8080
from transformer_engine.pytorch.export import is_in_onnx_export_mode
8181
from transformer_engine.pytorch.graph import is_graph_capturing
8282

83-
# Global vars for flash attn v2 and v3 imports
84-
flash_attn_cuda_bwd = None
85-
flash_attn_func = None
86-
flash_attn_varlen_func = None
87-
_flash_attn_fwd = None
88-
_flash_attn_bwd = None
89-
_flash_attn_varlen_fwd = None
90-
_flash_attn_varlen_bwd = None
83+
# Try to import Flash Attention v2
9184
try:
9285
fa_utils.version = PkgVersion(get_pkg_version("flash-attn"))
9386
except PackageNotFoundError:
87+
flash_attn_cuda_bwd = None
88+
flash_attn_func = None
89+
flash_attn_varlen_func = None
90+
_flash_attn_fwd = None
91+
_flash_attn_bwd = None
92+
_flash_attn_varlen_fwd = None
93+
_flash_attn_varlen_bwd = None
9494
pass # only print warning if use_flash_attention_2 = True in get_attention_backend
9595
else:
9696
if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0):
@@ -130,12 +130,16 @@
130130
),
131131
fa_utils.version,
132132
)
133+
134+
# Try to import Flash Attention v3
133135
try:
134136
fa_utils.fa3_version = PkgVersion(get_pkg_version("flash-attn-3"))
135137
except PackageNotFoundError:
136138
flash_attn_func_v3 = None
137139
flash_attn_varlen_func_v3 = None
138140
flash_attn_with_kvcache_v3 = None
141+
_flash_attn_fwd_v3 = None
142+
_flash_attn_bwd_v3 = None
139143
# pass # only print warning if use_flash_attention_3 = True in get_attention_backend
140144
else:
141145
from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3
@@ -150,6 +154,24 @@
150154

151155
fa_utils.set_flash_attention_3_params()
152156

157+
# Try to import Flash Attention v4
158+
try:
159+
fa_utils.fa4_version = PkgVersion(get_pkg_version("flash-attn-cute"))
160+
except PackageNotFoundError:
161+
flash_attn_func_v4 = None
162+
flash_attn_varlen_func_v4 = None
163+
flash_attn_with_kvcache_v4 = None
164+
_flash_attn_fwd_v4 = None
165+
_flash_attn_bwd_v4 = None
166+
# pass # only print warning if use_flash_attention_4 = True in get_attention_backend
167+
else:
168+
from flash_attn.cute.interface import flash_attn_func as flash_attn_func_v4
169+
from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_v4
170+
from flash_attn.cute.interface import _flash_attn_fwd as _flash_attn_fwd_v4
171+
from flash_attn.cute.interface import _flash_attn_bwd as _flash_attn_bwd_v4
172+
# flash_attn_with_kvcache_v4 = None # FA4 does not support kvcache yet
173+
fa_utils.set_flash_attention_4_params()
174+
153175
# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16
154176
_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1"
155177

@@ -919,6 +941,9 @@ def forward(
919941
use_flash_attn_3 = False
920942
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"):
921943
use_flash_attn_3 = True
944+
use_flash_attn_4 = False
945+
if flash_attention_backend is not None and str(flash_attention_backend).endswith("cute"):
946+
use_flash_attn_4 = True
922947
if context_parallel and all(
923948
not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
924949
):
@@ -973,9 +998,13 @@ def forward(
973998
# | | bshd/sbhd/thd + padding
974999
fa_optional_forward_args_thd = []
9751000
if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
976-
func = (
977-
flash_attn_func if not use_flash_attn_3 else flash_attn_func_v3
978-
) # pylint: disable=possibly-used-before-assignment
1001+
func = None
1002+
if use_flash_attn_4:
1003+
func = flash_attn_func_v4
1004+
elif use_flash_attn_3:
1005+
func = flash_attn_func_v3
1006+
else:
1007+
func = flash_attn_func
9791008
else:
9801009
if not use_flash_attn_3:
9811010
func = flash_attn_varlen_func
@@ -988,7 +1017,24 @@ def forward(
9881017
fa_optional_forward_args_thd.append(cu_seqlens_kv)
9891018
fa_optional_forward_args_thd.append(max_seqlen_q)
9901019
fa_optional_forward_args_thd.append(max_seqlen_kv)
991-
if not use_flash_attn_3:
1020+
if use_flash_attn_4:
1021+
fa_4_optional_forward_kwargs = {
1022+
# "window_size": window_size,
1023+
"num_splits": num_splits,
1024+
}
1025+
if inference_params is None:
1026+
fa_4_optional_forward_kwargs["deterministic"] = self.deterministic
1027+
output = func(
1028+
query_layer,
1029+
key_layer,
1030+
value_layer,
1031+
softmax_scale=self.softmax_scale,
1032+
causal="causal" in attn_mask_type,
1033+
**fa_4_optional_forward_kwargs,
1034+
)
1035+
if isinstance(output, (List, Tuple)):
1036+
output = output[0]
1037+
elif not use_flash_attn_3:
9921038
fa_optional_forward_kwargs = {}
9931039
if fa_utils.v2_3_plus:
9941040
fa_optional_forward_kwargs["window_size"] = window_size

0 commit comments

Comments
 (0)