Skip to content

Commit 29b9109

Browse files
authored
[attention backends] change to updated repo and version. (#13161)
* change to updated repo and version. * fix version and force updated kernels. * propagate version.
1 parent ae5881b commit 29b9109

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
is_flash_attn_available,
3939
is_flash_attn_version,
4040
is_kernels_available,
41+
is_kernels_version,
4142
is_sageattention_available,
4243
is_sageattention_version,
4344
is_torch_npu_available,
@@ -318,6 +319,7 @@ class _HubKernelConfig:
318319
repo_id: str
319320
function_attr: str
320321
revision: str | None = None
322+
version: int | None = None
321323
kernel_fn: Callable | None = None
322324
wrapped_forward_attr: str | None = None
323325
wrapped_backward_attr: str | None = None
@@ -327,31 +329,34 @@ class _HubKernelConfig:
327329

328330
# Registry for hub-based attention kernels
329331
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
330-
# TODO: temporary revision for now. Remove when merged upstream into `main`.
331332
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
332333
repo_id="kernels-community/flash-attn3",
333334
function_attr="flash_attn_func",
334-
revision="fake-ops-return-probs",
335335
wrapped_forward_attr="flash_attn_interface._flash_attn_forward",
336336
wrapped_backward_attr="flash_attn_interface._flash_attn_backward",
337+
version=1,
337338
),
338339
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
339340
repo_id="kernels-community/flash-attn3",
340341
function_attr="flash_attn_varlen_func",
341-
# revision="fake-ops-return-probs",
342+
version=1,
342343
),
343344
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
344345
repo_id="kernels-community/flash-attn2",
345346
function_attr="flash_attn_func",
346-
revision=None,
347347
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
348348
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
349+
version=1,
349350
),
350351
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
351-
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
352+
repo_id="kernels-community/flash-attn2",
353+
function_attr="flash_attn_varlen_func",
354+
version=1,
352355
),
353356
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
354-
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
357+
repo_id="kernels-community/sage-attention",
358+
function_attr="sageattn",
359+
version=1,
355360
),
356361
}
357362

@@ -521,6 +526,10 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
521526
raise RuntimeError(
522527
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
523528
)
529+
if not is_kernels_version(">=", "0.12"):
530+
raise RuntimeError(
531+
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
532+
)
524533

525534
elif backend == AttentionBackendName.AITER:
526535
if not _CAN_USE_AITER_ATTN:
@@ -694,7 +703,7 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
694703
try:
695704
from kernels import get_kernel
696705

697-
kernel_module = get_kernel(config.repo_id, revision=config.revision)
706+
kernel_module = get_kernel(config.repo_id, revision=config.revision, version=config.version)
698707
if needs_kernel:
699708
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
700709

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
is_inflect_available,
8787
is_invisible_watermark_available,
8888
is_kernels_available,
89+
is_kernels_version,
8990
is_kornia_available,
9091
is_librosa_available,
9192
is_matplotlib_available,

src/diffusers/utils/import_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,22 @@ def is_transformers_version(operation: str, version: str):
724724
return compare_versions(parse(_transformers_version), operation, version)
725725

726726

727+
@cache
728+
def is_kernels_version(operation: str, version: str):
729+
"""
730+
Compares the current Kernels version to a given reference with an operation.
731+
732+
Args:
733+
operation (`str`):
734+
A string representation of an operator, such as `">"` or `"<="`
735+
version (`str`):
736+
A version string
737+
"""
738+
if not _kernels_available:
739+
return False
740+
return compare_versions(parse(_kernels_version), operation, version)
741+
742+
727743
@cache
728744
def is_hf_hub_version(operation: str, version: str):
729745
"""

0 commit comments

Comments
 (0)