@@ -172,7 +172,7 @@ class AttentionBackendName(str, Enum):
172172 _FLASH_3 = "_flash_3"
173173 _FLASH_VARLEN_3 = "_flash_varlen_3"
174174 _FLASH_3_HUB = "_flash_3_hub"
175- _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub "
175+ _FLASH_3_VARLEN_HUB = "_flash_3_varlen_hub "
176176
177177 # `aiter`
178178 AITER = "aiter"
@@ -264,10 +264,10 @@ class _HubKernelConfig:
264264 AttentionBackendName ._FLASH_3_HUB : _HubKernelConfig (
265265 repo_id = "kernels-community/flash-attn3" , function_attr = "flash_attn_func" , revision = "fake-ops-return-probs"
266266 ),
267- AttentionBackendName ._FLASH_VARLEN_3_HUB : _HubKernelConfig (
267+ AttentionBackendName ._FLASH_3_VARLEN_HUB : _HubKernelConfig (
268268 repo_id = "kernels-community/flash-attn3" ,
269269 function_attr = "flash_attn_varlen_func" ,
270- revision = "fake-ops-return-probs" ,
270+ # revision="fake-ops-return-probs",
271271 ),
272272 AttentionBackendName .FLASH_HUB : _HubKernelConfig (
273273 repo_id = "kernels-community/flash-attn2" , function_attr = "flash_attn_func" , revision = None
@@ -438,7 +438,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
438438 AttentionBackendName .FLASH_HUB ,
439439 AttentionBackendName .FLASH_VARLEN_HUB ,
440440 AttentionBackendName ._FLASH_3_HUB ,
441- AttentionBackendName ._FLASH_VARLEN_3_HUB ,
441+ AttentionBackendName ._FLASH_3_VARLEN_HUB ,
442442 AttentionBackendName .SAGE_HUB ,
443443 ]:
444444 if not is_kernels_available ():
@@ -1552,7 +1552,6 @@ def _flash_attention_3(
15521552 softmax_scale = scale ,
15531553 causal = is_causal ,
15541554 )
1555- out = _maybe_unflatten_attention_heads (out , query )
15561555 return (out , lse ) if return_lse else out
15571556
15581557
@@ -1597,17 +1596,15 @@ def _flash_attention_3_hub(
15971596 )
15981597 # When `return_attn_probs` is True, the above returns a tuple of
15991598 # actual outputs and lse.
1600- if return_attn_probs :
1601- return (_maybe_unflatten_attention_heads (out [0 ], query ), out [1 ])
1602- return _maybe_unflatten_attention_heads (out , query )
1599+ return (out [0 ], out [1 ]) if return_attn_probs else out
16031600
16041601
16051602@_AttentionBackendRegistry .register (
1606- AttentionBackendName ._FLASH_VARLEN_3_HUB ,
1603+ AttentionBackendName ._FLASH_3_VARLEN_HUB ,
16071604 constraints = [_check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ],
16081605 supports_context_parallel = False ,
16091606)
1610- def _flash_varlen_attention_3_hub (
1607+ def _flash_attention_3_varlen_hub (
16111608 query : torch .Tensor ,
16121609 key : torch .Tensor ,
16131610 value : torch .Tensor ,
@@ -1639,7 +1636,7 @@ def _flash_varlen_attention_3_hub(
16391636 key_packed = torch .cat (key_valid , dim = 0 )
16401637 value_packed = torch .cat (value_valid , dim = 0 )
16411638
1642- func = _HUB_KERNELS_REGISTRY [AttentionBackendName ._FLASH_VARLEN_3_HUB ].kernel_fn
1639+ func = _HUB_KERNELS_REGISTRY [AttentionBackendName ._FLASH_3_VARLEN_HUB ].kernel_fn
16431640 out , lse , * _ = func (
16441641 q = query_packed ,
16451642 k = key_packed ,
@@ -1652,7 +1649,6 @@ def _flash_varlen_attention_3_hub(
16521649 causal = is_causal ,
16531650 )
16541651 out = out .unflatten (0 , (batch_size , - 1 ))
1655- out = _maybe_unflatten_attention_heads (out , query )
16561652
16571653 return (out , lse ) if return_lse else out
16581654
0 commit comments