8080from transformer_engine .pytorch .export import is_in_onnx_export_mode
8181from transformer_engine .pytorch .graph import is_graph_capturing
8282
83- # Global vars for flash attn v2 and v3 imports
84- flash_attn_cuda_bwd = None
85- flash_attn_func = None
86- flash_attn_varlen_func = None
87- _flash_attn_fwd = None
88- _flash_attn_bwd = None
89- _flash_attn_varlen_fwd = None
90- _flash_attn_varlen_bwd = None
83+ # Try to import Flash Attention v2
9184try :
9285 fa_utils .version = PkgVersion (get_pkg_version ("flash-attn" ))
9386except PackageNotFoundError :
87+ flash_attn_cuda_bwd = None
88+ flash_attn_func = None
89+ flash_attn_varlen_func = None
90+ _flash_attn_fwd = None
91+ _flash_attn_bwd = None
92+ _flash_attn_varlen_fwd = None
93+ _flash_attn_varlen_bwd = None
9494 pass # only print warning if use_flash_attention_2 = True in get_attention_backend
9595else :
9696 if torch .cuda .is_available () and get_device_compute_capability () >= (10 , 0 ):
130130 ),
131131 fa_utils .version ,
132132 )
133+
134+ # Try to import Flash Attention v3
133135try :
134136 fa_utils .fa3_version = PkgVersion (get_pkg_version ("flash-attn-3" ))
135137except PackageNotFoundError :
136138 flash_attn_func_v3 = None
137139 flash_attn_varlen_func_v3 = None
138140 flash_attn_with_kvcache_v3 = None
141+ _flash_attn_fwd_v3 = None
142+ _flash_attn_bwd_v3 = None
139143 # pass # only print warning if use_flash_attention_3 = True in get_attention_backend
140144else :
141145 from flash_attn_3 .flash_attn_interface import flash_attn_func as flash_attn_func_v3
150154
151155 fa_utils .set_flash_attention_3_params ()
152156
157+ # Try to import Flash Attention v4
158+ try :
159+ fa_utils .fa4_version = PkgVersion (get_pkg_version ("flash-attn-cute" ))
160+ except PackageNotFoundError :
161+ flash_attn_func_v4 = None
162+ flash_attn_varlen_func_v4 = None
163+ flash_attn_with_kvcache_v4 = None
164+ _flash_attn_fwd_v4 = None
165+ _flash_attn_bwd_v4 = None
166+ # pass # only print warning if use_flash_attention_4 = True in get_attention_backend
167+ else :
168+ from flash_attn .cute .interface import flash_attn_func as flash_attn_func_v4
169+ from flash_attn .cute .interface import flash_attn_varlen_func as flash_attn_varlen_func_v4
170+ from flash_attn .cute .interface import _flash_attn_fwd as _flash_attn_fwd_v4
171+ from flash_attn .cute .interface import _flash_attn_bwd as _flash_attn_bwd_v4
172+ # flash_attn_with_kvcache_v4 = None # FA4 does not support kvcache yet
173+ fa_utils .set_flash_attention_4_params ()
174+
153175# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16
154176_dpa_fp8_cs_o_in_f16 = os .getenv ("NVTE_DPA_FP8CS_O_in_F16" , "1" ) == "1"
155177
@@ -859,6 +881,9 @@ def forward(
859881 use_flash_attn_3 = False
860882 if flash_attention_backend is not None and flash_attention_backend > PkgVersion ("3.0.0b" ):
861883 use_flash_attn_3 = True
884+ use_flash_attn_4 = False
885+ if flash_attention_backend is not None and str (flash_attention_backend ).endswith ("cute" ):
886+ use_flash_attn_4 = True
862887 if context_parallel and all (
863888 not isinstance (x , Float8Tensor ) for x in [query_layer , key_layer , value_layer ]
864889 ):
@@ -913,9 +938,13 @@ def forward(
913938 # | | bshd/sbhd/thd + padding
914939 fa_optional_forward_args_thd = []
915940 if qkv_format in ["bshd" , "sbhd" ] and "padding" not in attn_mask_type :
916- func = (
917- flash_attn_func if not use_flash_attn_3 else flash_attn_func_v3
918- ) # pylint: disable=possibly-used-before-assignment
941+ func = None
942+ if use_flash_attn_4 :
943+ func = flash_attn_func_v4
944+ elif use_flash_attn_3 :
945+ func = flash_attn_func_v3
946+ else :
947+ func = flash_attn_func
919948 else :
920949 if not use_flash_attn_3 :
921950 func = flash_attn_varlen_func
@@ -928,7 +957,24 @@ def forward(
928957 fa_optional_forward_args_thd .append (cu_seqlens_kv )
929958 fa_optional_forward_args_thd .append (max_seqlen_q )
930959 fa_optional_forward_args_thd .append (max_seqlen_kv )
931- if not use_flash_attn_3 :
960+ if use_flash_attn_4 :
961+ fa_4_optional_forward_kwargs = {
962+ # "window_size": window_size,
963+ "num_splits" : num_splits ,
964+ }
965+ if inference_params is None :
966+ fa_4_optional_forward_kwargs ["deterministic" ] = self .deterministic
967+ output = func (
968+ query_layer ,
969+ key_layer ,
970+ value_layer ,
971+ softmax_scale = self .softmax_scale ,
972+ causal = "causal" in attn_mask_type ,
973+ ** fa_4_optional_forward_kwargs ,
974+ )
975+ if isinstance (output , (List , Tuple )):
976+ output = output [0 ]
977+ elif not use_flash_attn_3 :
932978 fa_optional_forward_kwargs = {}
933979 if fa_utils .v2_3_plus :
934980 fa_optional_forward_kwargs ["window_size" ] = window_size
0 commit comments