Skip to content

Commit 7a8f85b

Browse files
committed
up
1 parent 82d20e6 commit 7a8f85b

1 file changed

Lines changed: 24 additions & 11 deletions

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ class _HubKernelConfig:
277277
repo_id="kernels-community/flash-attn2",
278278
function_attr="flash_attn_func",
279279
revision=None,
280-
wrapped_forward_attr="_wrapped_flash_attn_forward",
281-
wrapped_backward_attr="_wrapped_flash_attn_backward",
280+
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
281+
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
282282
),
283283
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
284284
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
@@ -602,27 +602,39 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
602602

603603

604604
# ===== Helpers for downloading kernels =====
605+
def _resolve_kernel_attr(module, attr_path: str):
606+
target = module
607+
for attr in attr_path.split("."):
608+
if not hasattr(target, attr):
609+
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
610+
target = getattr(target, attr)
611+
return target
612+
613+
605614
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
606615
if backend not in _HUB_KERNELS_REGISTRY:
607616
return
608617
config = _HUB_KERNELS_REGISTRY[backend]
609618

610-
if config.kernel_fn is not None:
619+
needs_kernel = config.kernel_fn is None
620+
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
621+
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
622+
623+
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
611624
return
612625

613626
try:
614627
from kernels import get_kernel
615628

616629
kernel_module = get_kernel(config.repo_id, revision=config.revision)
617-
kernel_func = getattr(kernel_module, config.function_attr)
618-
# Cache the downloaded kernel function in the config object
619-
config.kernel_fn = kernel_func
630+
if needs_kernel:
631+
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
620632

621-
if config.wrapped_forward_attr is not None and config.wrapped_forward_attr is not None:
622-
wrapped_forward_fn = getattr(kernel_module, config.wrapped_forward_attr)
623-
wrapped_backward_fn = getattr(kernel_module, config.wrapped_backward_attr)
624-
config.wrapped_forward_fn = wrapped_forward_fn
625-
config.wrapped_backward_fn = wrapped_backward_fn
633+
if needs_wrapped_forward:
634+
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
635+
636+
if needs_wrapped_backward:
637+
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
626638

627639
except Exception as e:
628640
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
@@ -1307,6 +1319,7 @@ def _sage_attention_hub_forward_op(
13071319

13081320
return (out, lse) if return_lse else out
13091321

1322+
13101323
# ===== Context parallel =====
13111324

13121325

0 commit comments

Comments
 (0)