Skip to content

Commit db7c09e

Browse files
committed
initial effort of adding fa4
Signed-off-by: Xin Yao <xiny@nvidia.com>
1 parent 3ff0b8d commit db7c09e

2 files changed

Lines changed: 225 additions & 84 deletions

File tree

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,17 @@
8080
from transformer_engine.pytorch.export import is_in_onnx_export_mode
8181
from transformer_engine.pytorch.graph import is_graph_capturing
8282

83-
# Global vars for flash attn v2 and v3 imports
84-
flash_attn_cuda_bwd = None
85-
flash_attn_func = None
86-
flash_attn_varlen_func = None
87-
_flash_attn_fwd = None
88-
_flash_attn_bwd = None
89-
_flash_attn_varlen_fwd = None
90-
_flash_attn_varlen_bwd = None
83+
# Try to import Flash Attention v2
9184
try:
9285
fa_utils.version = PkgVersion(get_pkg_version("flash-attn"))
9386
except PackageNotFoundError:
87+
flash_attn_cuda_bwd = None
88+
flash_attn_func = None
89+
flash_attn_varlen_func = None
90+
_flash_attn_fwd = None
91+
_flash_attn_bwd = None
92+
_flash_attn_varlen_fwd = None
93+
_flash_attn_varlen_bwd = None
9494
pass # only print warning if use_flash_attention_2 = True in get_attention_backend
9595
else:
9696
if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0):
@@ -130,12 +130,16 @@
130130
),
131131
fa_utils.version,
132132
)
133+
134+
# Try to import Flash Attention v3
133135
try:
134136
fa_utils.fa3_version = PkgVersion(get_pkg_version("flash-attn-3"))
135137
except PackageNotFoundError:
136138
flash_attn_func_v3 = None
137139
flash_attn_varlen_func_v3 = None
138140
flash_attn_with_kvcache_v3 = None
141+
_flash_attn_fwd_v3 = None
142+
_flash_attn_bwd_v3 = None
139143
# pass # only print warning if use_flash_attention_3 = True in get_attention_backend
140144
else:
141145
from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3
@@ -150,6 +154,24 @@
150154

151155
fa_utils.set_flash_attention_3_params()
152156

157+
# Try to import Flash Attention v4
158+
try:
159+
fa_utils.fa4_version = PkgVersion(get_pkg_version("flash-attn-cute"))
160+
except PackageNotFoundError:
161+
flash_attn_func_v4 = None
162+
flash_attn_varlen_func_v4 = None
163+
flash_attn_with_kvcache_v4 = None
164+
_flash_attn_fwd_v4 = None
165+
_flash_attn_bwd_v4 = None
166+
# pass # only print warning if use_flash_attention_4 = True in get_attention_backend
167+
else:
168+
from flash_attn.cute.interface import flash_attn_func as flash_attn_func_v4
169+
from flash_attn.cute.interface import flash_attn_varlen_func as flash_attn_varlen_func_v4
170+
from flash_attn.cute.interface import _flash_attn_fwd as _flash_attn_fwd_v4
171+
from flash_attn.cute.interface import _flash_attn_bwd as _flash_attn_bwd_v4
172+
# flash_attn_with_kvcache_v4 = None # FA4 does not support kvcache yet
173+
fa_utils.set_flash_attention_4_params()
174+
153175
# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16
154176
_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1"
155177

@@ -859,6 +881,9 @@ def forward(
859881
use_flash_attn_3 = False
860882
if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"):
861883
use_flash_attn_3 = True
884+
use_flash_attn_4 = False
885+
if flash_attention_backend is not None and str(flash_attention_backend).endswith("cute"):
886+
use_flash_attn_4 = True
862887
if context_parallel and all(
863888
not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]
864889
):
@@ -913,9 +938,13 @@ def forward(
913938
# | | bshd/sbhd/thd + padding
914939
fa_optional_forward_args_thd = []
915940
if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type:
916-
func = (
917-
flash_attn_func if not use_flash_attn_3 else flash_attn_func_v3
918-
) # pylint: disable=possibly-used-before-assignment
941+
func = None
942+
if use_flash_attn_4:
943+
func = flash_attn_func_v4
944+
elif use_flash_attn_3:
945+
func = flash_attn_func_v3
946+
else:
947+
func = flash_attn_func
919948
else:
920949
if not use_flash_attn_3:
921950
func = flash_attn_varlen_func
@@ -928,7 +957,24 @@ def forward(
928957
fa_optional_forward_args_thd.append(cu_seqlens_kv)
929958
fa_optional_forward_args_thd.append(max_seqlen_q)
930959
fa_optional_forward_args_thd.append(max_seqlen_kv)
931-
if not use_flash_attn_3:
960+
if use_flash_attn_4:
961+
fa_4_optional_forward_kwargs = {
962+
# "window_size": window_size,
963+
"num_splits": num_splits,
964+
}
965+
if inference_params is None:
966+
fa_4_optional_forward_kwargs["deterministic"] = self.deterministic
967+
output = func(
968+
query_layer,
969+
key_layer,
970+
value_layer,
971+
softmax_scale=self.softmax_scale,
972+
causal="causal" in attn_mask_type,
973+
**fa_4_optional_forward_kwargs,
974+
)
975+
if isinstance(output, (List, Tuple)):
976+
output = output[0]
977+
elif not use_flash_attn_3:
932978
fa_optional_forward_kwargs = {}
933979
if fa_utils.v2_3_plus:
934980
fa_optional_forward_kwargs["window_size"] = window_size

0 commit comments

Comments
 (0)