1616import functools
1717import inspect
1818import math
19+ from dataclasses import dataclass
1920from enum import Enum
20- from typing import TYPE_CHECKING , Any , Callable , Dict , List , Literal , Optional , Set , Tuple , Union
21+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Literal , Optional , Tuple , Union
2122
2223import torch
2324
4041 is_xformers_available ,
4142 is_xformers_version ,
4243)
43- from ..utils .constants import DIFFUSERS_ATTN_BACKEND , DIFFUSERS_ATTN_CHECKS , DIFFUSERS_ENABLE_HUB_KERNELS
44+ from ..utils .constants import DIFFUSERS_ATTN_BACKEND , DIFFUSERS_ATTN_CHECKS
4445
4546
4647if TYPE_CHECKING :
7879 flash_attn_3_func = None
7980 flash_attn_3_varlen_func = None
8081
81- _BACKEND_HANDLES : Dict ["AttentionBackendName" , Callable ] = {}
82- _PREPARED_BACKENDS : Set ["AttentionBackendName" ] = set ()
83-
8482if _CAN_USE_SAGE_ATTN :
8583 from sageattention import (
8684 sageattn ,
@@ -222,9 +220,7 @@ def decorator(func):
222220
223221 @classmethod
224222 def get_active_backend (cls ):
225- backend = cls ._active_backend
226- _ensure_attention_backend_ready (backend )
227- return backend , cls ._backends [backend ]
223+ return cls ._active_backend , cls ._backends [cls ._active_backend ]
228224
229225 @classmethod
230226 def list_backends (cls ):
@@ -242,6 +238,25 @@ def _is_context_parallel_enabled(
242238 return supports_context_parallel and is_degree_greater_than_1
243239
244240
241+ @dataclass
242+ class _HubKernelConfig :
243+ """Configuration for downloading and using a hub-based attention kernel."""
244+
245+ repo_id : str
246+ function_attr : str
247+ revision : Optional [str ] = None
248+ kernel_fn : Optional [Callable ] = None
249+
250+
251+ # Registry for hub-based attention kernels
252+ _HUB_KERNELS_REGISTRY : Dict ["AttentionBackendName" , _HubKernelConfig ] = {
253+ # TODO: temporary revision for now. Remove when merged upstream into `main`.
254+ AttentionBackendName ._FLASH_3_HUB : _HubKernelConfig (
255+ repo_id = "kernels-community/flash-attn3" , function_attr = "flash_attn_func" , revision = "fake-ops-return-probs"
256+ )
257+ }
258+
259+
245260@contextlib .contextmanager
246261def attention_backend (backend : Union [str , AttentionBackendName ] = AttentionBackendName .NATIVE ):
247262 """
@@ -251,7 +266,7 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
251266 raise ValueError (f"Backend { backend } is not registered." )
252267
253268 backend = AttentionBackendName (backend )
254- _ensure_attention_backend_ready (backend )
269+ _check_attention_backend_requirements (backend )
255270
256271 old_backend = _AttentionBackendRegistry ._active_backend
257272 _AttentionBackendRegistry ._active_backend = backend
@@ -398,13 +413,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
398413
399414 # TODO: add support Hub variant of FA3 varlen later
400415 elif backend in [AttentionBackendName ._FLASH_3_HUB ]:
401- if not DIFFUSERS_ENABLE_HUB_KERNELS :
402- raise RuntimeError (
403- f"Flash Attention 3 Hub backend '{ backend .value } ' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
404- )
405416 if not is_kernels_available ():
406417 raise RuntimeError (
407- f"Flash Attention 3 Hub backend '{ backend .value } ' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
418+ f"Backend '{ backend .value } ' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
408419 )
409420
410421 elif backend in [
@@ -445,39 +456,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
445456 )
446457
447458
448- def _ensure_flash_attn_3_func_hub_loaded ():
449- cached = _BACKEND_HANDLES .get (AttentionBackendName ._FLASH_3_HUB )
450- if cached is not None :
451- return cached
452-
453- from ..utils .kernels_utils import _get_fa3_from_hub
454-
455- flash_attn_interface_hub = _get_fa3_from_hub ()
456- func = flash_attn_interface_hub .flash_attn_func
457- _BACKEND_HANDLES [AttentionBackendName ._FLASH_3_HUB ] = func
458-
459- return func
460-
461-
462- _BACKEND_PREPARERS : Dict [AttentionBackendName , Callable [[], None ]] = {
463- AttentionBackendName ._FLASH_3_HUB : _ensure_flash_attn_3_func_hub_loaded ,
464- }
465-
466-
467- def _prepare_attention_backend (backend : AttentionBackendName ) -> None :
468- preparer = _BACKEND_PREPARERS .get (backend )
469- if preparer is not None :
470- preparer ()
471-
472-
473- def _ensure_attention_backend_ready (backend : AttentionBackendName ) -> None :
474- if backend in _PREPARED_BACKENDS :
475- return
476- _check_attention_backend_requirements (backend )
477- _prepare_attention_backend (backend )
478- _PREPARED_BACKENDS .add (backend )
479-
480-
481459@functools .lru_cache (maxsize = 128 )
482460def _prepare_for_flash_attn_or_sage_varlen_without_mask (
483461 batch_size : int ,
@@ -581,6 +559,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
581559 return q_idx >= kv_idx
582560
583561
562+ # ===== Helpers for downloading kernels =====
563+ def _maybe_download_kernel_for_backend (backend : AttentionBackendName ) -> None :
564+ if backend not in _HUB_KERNELS_REGISTRY :
565+ return
566+ config = _HUB_KERNELS_REGISTRY [backend ]
567+
568+ if config ._kernel_fn is not None :
569+ return
570+
571+ try :
572+ from kernels import get_kernel
573+
574+ kernel_module = get_kernel (config .repo_id , revision = config .revision )
575+ kernel_func = getattr (kernel_module , config .function_attr )
576+
577+ # Cache the downloaded kernel function in the config object
578+ config ._kernel_fn = kernel_func
579+
580+ except Exception as e :
581+ logger .error (f"An error occurred while fetching kernel '{ config .repo_id } ' from the Hub: { e } " )
582+ raise
583+
584+
584585# ===== torch op registrations =====
585586# Registrations are required for fullgraph tracing compatibility
586587# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
@@ -1348,9 +1349,7 @@ def _flash_attention_3_hub(
13481349 return_attn_probs : bool = False ,
13491350 _parallel_config : Optional ["ParallelConfig" ] = None ,
13501351) -> torch .Tensor :
1351- func = _BACKEND_HANDLES .get (AttentionBackendName ._FLASH_3_HUB )
1352- if func is None :
1353- func = _ensure_flash_attn_3_func_hub_loaded ()
1352+ func = _HUB_KERNELS_REGISTRY [AttentionBackendName ._FLASH_3_HUB ]._kernel_fn
13541353 out = func (
13551354 q = query ,
13561355 k = key ,
0 commit comments