Skip to content

Commit 5fa3204

Browse files
committed
up
1 parent 75c61d1 commit 5fa3204

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

src/diffusers/models/modeling_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,14 +611,24 @@ def set_attention_backend(self, backend: str) -> None:
611611
from .attention_processor import Attention, MochiAttention
612612

613613
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
614+
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
615+
616+
parallel_config_set = False
617+
for module in self.modules():
618+
if not isinstance(module, attention_classes):
619+
continue
620+
processor = module.processor
621+
if getattr(processor, "_parallel_config", None) is not None:
622+
parallel_config_set = True
623+
break
614624

615625
backend = backend.lower()
616626
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
617627
if backend not in available_backends:
618628
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
619629

620630
backend = AttentionBackendName(backend)
621-
if not _AttentionBackendRegistry._is_context_parallel_available(backend):
631+
if parallel_config_set and not _AttentionBackendRegistry._is_context_parallel_available(backend):
622632
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
623633
raise ValueError(
624634
f"Context parallelism is enabled but backend '{backend.value}' "
@@ -630,7 +640,6 @@ def set_attention_backend(self, backend: str) -> None:
630640
_check_attention_backend_requirements(backend)
631641
_maybe_download_kernel_for_backend(backend)
632642

633-
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
634643
for module in self.modules():
635644
if not isinstance(module, attention_classes):
636645
continue

0 commit comments

Comments
 (0)