3838 is_flash_attn_available ,
3939 is_flash_attn_version ,
4040 is_kernels_available ,
41+ is_kernels_version ,
4142 is_sageattention_available ,
4243 is_sageattention_version ,
4344 is_torch_npu_available ,
@@ -318,6 +319,7 @@ class _HubKernelConfig:
318319 repo_id : str
319320 function_attr : str
320321 revision : str | None = None
322+ version : int | None = None
321323 kernel_fn : Callable | None = None
322324 wrapped_forward_attr : str | None = None
323325 wrapped_backward_attr : str | None = None
@@ -327,31 +329,34 @@ class _HubKernelConfig:
327329
328330# Registry for hub-based attention kernels
329331_HUB_KERNELS_REGISTRY : dict ["AttentionBackendName" , _HubKernelConfig ] = {
330- # TODO: temporary revision for now. Remove when merged upstream into `main`.
331332 AttentionBackendName ._FLASH_3_HUB : _HubKernelConfig (
332333 repo_id = "kernels-community/flash-attn3" ,
333334 function_attr = "flash_attn_func" ,
334- revision = "fake-ops-return-probs" ,
335335 wrapped_forward_attr = "flash_attn_interface._flash_attn_forward" ,
336336 wrapped_backward_attr = "flash_attn_interface._flash_attn_backward" ,
337+ version = 1 ,
337338 ),
338339 AttentionBackendName ._FLASH_3_VARLEN_HUB : _HubKernelConfig (
339340 repo_id = "kernels-community/flash-attn3" ,
340341 function_attr = "flash_attn_varlen_func" ,
341- # revision="fake-ops-return-probs" ,
342+ version = 1 ,
342343 ),
343344 AttentionBackendName .FLASH_HUB : _HubKernelConfig (
344345 repo_id = "kernels-community/flash-attn2" ,
345346 function_attr = "flash_attn_func" ,
346- revision = None ,
347347 wrapped_forward_attr = "flash_attn_interface._wrapped_flash_attn_forward" ,
348348 wrapped_backward_attr = "flash_attn_interface._wrapped_flash_attn_backward" ,
349+ version = 1 ,
349350 ),
350351 AttentionBackendName .FLASH_VARLEN_HUB : _HubKernelConfig (
351- repo_id = "kernels-community/flash-attn2" , function_attr = "flash_attn_varlen_func" , revision = None
352+ repo_id = "kernels-community/flash-attn2" ,
353+ function_attr = "flash_attn_varlen_func" ,
354+ version = 1 ,
352355 ),
353356 AttentionBackendName .SAGE_HUB : _HubKernelConfig (
354- repo_id = "kernels-community/sage_attention" , function_attr = "sageattn" , revision = None
357+ repo_id = "kernels-community/sage-attention" ,
358+ function_attr = "sageattn" ,
359+ version = 1 ,
355360 ),
356361}
357362
@@ -521,6 +526,10 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
521526 raise RuntimeError (
522527 f"Backend '{ backend .value } ' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
523528 )
529+ if not is_kernels_version (">=" , "0.12" ):
530+ raise RuntimeError (
531+ f"Backend '{ backend .value } ' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
532+ )
524533
525534 elif backend == AttentionBackendName .AITER :
526535 if not _CAN_USE_AITER_ATTN :
@@ -694,7 +703,7 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
694703 try :
695704 from kernels import get_kernel
696705
697- kernel_module = get_kernel (config .repo_id , revision = config .revision )
706+ kernel_module = get_kernel (config .repo_id , revision = config .revision , version = config . version )
698707 if needs_kernel :
699708 config .kernel_fn = _resolve_kernel_attr (kernel_module , config .function_attr )
700709
0 commit comments