Skip to content

Commit f48ec46

Browse files
sayakpaulDhruv Nair
andcommitted
refactor according to Dhruv's ideas.
Co-authored-by: Dhruv Nair <dhruv@huggingface.co>
1 parent 75a8046 commit f48ec46

4 files changed

Lines changed: 57 additions & 54 deletions

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 49 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
import functools
1717
import inspect
1818
import math
19+
from dataclasses import dataclass
1920
from enum import Enum
20-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union
21+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
2122

2223
import torch
2324

@@ -40,7 +41,7 @@
4041
is_xformers_available,
4142
is_xformers_version,
4243
)
43-
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS
44+
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
4445

4546

4647
if TYPE_CHECKING:
@@ -78,9 +79,6 @@
7879
flash_attn_3_func = None
7980
flash_attn_3_varlen_func = None
8081

81-
_BACKEND_HANDLES: Dict["AttentionBackendName", Callable] = {}
82-
_PREPARED_BACKENDS: Set["AttentionBackendName"] = set()
83-
8482
if _CAN_USE_SAGE_ATTN:
8583
from sageattention import (
8684
sageattn,
@@ -222,9 +220,7 @@ def decorator(func):
222220

223221
@classmethod
224222
def get_active_backend(cls):
225-
backend = cls._active_backend
226-
_ensure_attention_backend_ready(backend)
227-
return backend, cls._backends[backend]
223+
return cls._active_backend, cls._backends[cls._active_backend]
228224

229225
@classmethod
230226
def list_backends(cls):
@@ -242,6 +238,25 @@ def _is_context_parallel_enabled(
242238
return supports_context_parallel and is_degree_greater_than_1
243239

244240

241+
@dataclass
242+
class _HubKernelConfig:
243+
"""Configuration for downloading and using a hub-based attention kernel."""
244+
245+
repo_id: str
246+
function_attr: str
247+
revision: Optional[str] = None
248+
kernel_fn: Optional[Callable] = None
249+
250+
251+
# Registry for hub-based attention kernels
252+
_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
253+
# TODO: temporary revision for now. Remove when merged upstream into `main`.
254+
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
255+
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
256+
)
257+
}
258+
259+
245260
@contextlib.contextmanager
246261
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
247262
"""
@@ -251,7 +266,7 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
251266
raise ValueError(f"Backend {backend} is not registered.")
252267

253268
backend = AttentionBackendName(backend)
254-
_ensure_attention_backend_ready(backend)
269+
_check_attention_backend_requirements(backend)
255270

256271
old_backend = _AttentionBackendRegistry._active_backend
257272
_AttentionBackendRegistry._active_backend = backend
@@ -398,13 +413,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
398413

399414
# TODO: add support Hub variant of FA3 varlen later
400415
elif backend in [AttentionBackendName._FLASH_3_HUB]:
401-
if not DIFFUSERS_ENABLE_HUB_KERNELS:
402-
raise RuntimeError(
403-
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
404-
)
405416
if not is_kernels_available():
406417
raise RuntimeError(
407-
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
418+
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
408419
)
409420

410421
elif backend in [
@@ -445,39 +456,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
445456
)
446457

447458

448-
def _ensure_flash_attn_3_func_hub_loaded():
449-
cached = _BACKEND_HANDLES.get(AttentionBackendName._FLASH_3_HUB)
450-
if cached is not None:
451-
return cached
452-
453-
from ..utils.kernels_utils import _get_fa3_from_hub
454-
455-
flash_attn_interface_hub = _get_fa3_from_hub()
456-
func = flash_attn_interface_hub.flash_attn_func
457-
_BACKEND_HANDLES[AttentionBackendName._FLASH_3_HUB] = func
458-
459-
return func
460-
461-
462-
_BACKEND_PREPARERS: Dict[AttentionBackendName, Callable[[], None]] = {
463-
AttentionBackendName._FLASH_3_HUB: _ensure_flash_attn_3_func_hub_loaded,
464-
}
465-
466-
467-
def _prepare_attention_backend(backend: AttentionBackendName) -> None:
468-
preparer = _BACKEND_PREPARERS.get(backend)
469-
if preparer is not None:
470-
preparer()
471-
472-
473-
def _ensure_attention_backend_ready(backend: AttentionBackendName) -> None:
474-
if backend in _PREPARED_BACKENDS:
475-
return
476-
_check_attention_backend_requirements(backend)
477-
_prepare_attention_backend(backend)
478-
_PREPARED_BACKENDS.add(backend)
479-
480-
481459
@functools.lru_cache(maxsize=128)
482460
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
483461
batch_size: int,
@@ -581,6 +559,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
581559
return q_idx >= kv_idx
582560

583561

562+
# ===== Helpers for downloading kernels =====
563+
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
564+
if backend not in _HUB_KERNELS_REGISTRY:
565+
return
566+
config = _HUB_KERNELS_REGISTRY[backend]
567+
568+
if config._kernel_fn is not None:
569+
return
570+
571+
try:
572+
from kernels import get_kernel
573+
574+
kernel_module = get_kernel(config.repo_id, revision=config.revision)
575+
kernel_func = getattr(kernel_module, config.function_attr)
576+
577+
# Cache the downloaded kernel function in the config object
578+
config._kernel_fn = kernel_func
579+
580+
except Exception as e:
581+
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
582+
raise
583+
584+
584585
# ===== torch op registrations =====
585586
# Registrations are required for fullgraph tracing compatibility
586587
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
@@ -1348,9 +1349,7 @@ def _flash_attention_3_hub(
13481349
return_attn_probs: bool = False,
13491350
_parallel_config: Optional["ParallelConfig"] = None,
13501351
) -> torch.Tensor:
1351-
func = _BACKEND_HANDLES.get(AttentionBackendName._FLASH_3_HUB)
1352-
if func is None:
1353-
func = _ensure_flash_attn_3_func_hub_loaded()
1352+
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]._kernel_fn
13541353
out = func(
13551354
q=query,
13561355
k=key,

src/diffusers/models/modeling_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,11 @@ def set_attention_backend(self, backend: str) -> None:
595595
attention as backend.
596596
"""
597597
from .attention import AttentionModuleMixin
598-
from .attention_dispatch import AttentionBackendName, _ensure_attention_backend_ready
598+
from .attention_dispatch import (
599+
AttentionBackendName,
600+
_check_attention_backend_requirements,
601+
_maybe_download_kernel_for_backend,
602+
)
599603

600604
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
601605
from .attention_processor import Attention, MochiAttention
@@ -606,8 +610,10 @@ def set_attention_backend(self, backend: str) -> None:
606610
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
607611
if backend not in available_backends:
608612
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
613+
609614
backend = AttentionBackendName(backend)
610-
_ensure_attention_backend_ready(backend)
615+
_check_attention_backend_requirements(backend)
616+
_maybe_download_kernel_for_backend(backend)
611617

612618
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
613619
for module in self.modules():

src/diffusers/utils/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
4747
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
4848
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES
49-
DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES
5049

5150
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
5251
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are

tests/others/test_attention_backends.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
88
```bash
99
export RUN_ATTENTION_BACKEND_TESTS=yes
10-
export DIFFUSERS_ENABLE_HUB_KERNELS=yes
1110
1211
pytest tests/others/test_attention_backends.py
1312
```

0 commit comments

Comments
 (0)