@@ -229,6 +229,7 @@ class AttentionBackendName(str, Enum):
229229 FLASH_HUB = "flash_hub"
230230 FLASH_VARLEN = "flash_varlen"
231231 FLASH_VARLEN_HUB = "flash_varlen_hub"
232+ FLASH_4_HUB = "flash_4_hub"
232233 _FLASH_3 = "_flash_3"
233234 _FLASH_VARLEN_3 = "_flash_varlen_3"
234235 _FLASH_3_HUB = "_flash_3_hub"
@@ -358,6 +359,11 @@ class _HubKernelConfig:
358359 function_attr = "sageattn" ,
359360 version = 1 ,
360361 ),
362+ AttentionBackendName .FLASH_4_HUB : _HubKernelConfig (
363+ repo_id = "kernels-staging/flash-attn4" ,
364+ function_attr = "flash_attn_func" ,
365+ version = 0 ,
366+ ),
361367}
362368
363369
@@ -521,6 +527,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
521527 AttentionBackendName ._FLASH_3_HUB ,
522528 AttentionBackendName ._FLASH_3_VARLEN_HUB ,
523529 AttentionBackendName .SAGE_HUB ,
530+ AttentionBackendName .FLASH_4_HUB ,
524531 ]:
525532 if not is_kernels_available ():
526533 raise RuntimeError (
@@ -531,6 +538,11 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
531538 f"Backend '{ backend .value } ' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
532539 )
533540
541+ if backend == AttentionBackendName .FLASH_4_HUB and not is_kernels_available (">=" , "0.12.3" ):
542+ raise RuntimeError (
543+ f"Backend '{ backend .value } ' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
544+ )
545+
534546 elif backend == AttentionBackendName .AITER :
535547 if not _CAN_USE_AITER_ATTN :
536548 raise RuntimeError (
@@ -2676,6 +2688,37 @@ def _flash_attention_3_varlen_hub(
26762688 return (out , lse ) if return_lse else out
26772689
26782690
2691+ @_AttentionBackendRegistry .register (
2692+ AttentionBackendName .FLASH_4_HUB ,
2693+ constraints = [_check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ],
2694+ supports_context_parallel = False ,
2695+ )
2696+ def _flash_attention_4_hub (
2697+ query : torch .Tensor ,
2698+ key : torch .Tensor ,
2699+ value : torch .Tensor ,
2700+ attn_mask : torch .Tensor | None = None ,
2701+ scale : float | None = None ,
2702+ is_causal : bool = False ,
2703+ return_lse : bool = False ,
2704+ _parallel_config : "ParallelConfig" | None = None ,
2705+ ) -> torch .Tensor :
2706+ if attn_mask is not None :
2707+ raise ValueError ("`attn_mask` is not supported for flash-attn 4." )
2708+
2709+ func = _HUB_KERNELS_REGISTRY [AttentionBackendName .FLASH_4_HUB ].kernel_fn
2710+ out = func (
2711+ q = query ,
2712+ k = key ,
2713+ v = value ,
2714+ softmax_scale = scale ,
2715+ causal = is_causal ,
2716+ )
2717+ if isinstance (out , tuple ):
2718+ return (out [0 ], out [1 ]) if return_lse else out [0 ]
2719+ return out
2720+
2721+
26792722@_AttentionBackendRegistry .register (
26802723 AttentionBackendName ._FLASH_VARLEN_3 ,
26812724 constraints = [_check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ],
0 commit comments