Skip to content

Commit 738f278

Browse files
committed
gracefully error out when attn-backend x cp combo isn't supported.
1 parent 23251d6 commit 738f278

2 files changed

Lines changed: 24 additions & 4 deletions

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,10 @@ def decorator(func):
235235
def get_active_backend(cls):
236236
return cls._active_backend, cls._backends[cls._active_backend]
237237

238+
@classmethod
239+
def set_active_backend(cls, backend: str):
240+
cls._active_backend = backend
241+
238242
@classmethod
239243
def list_backends(cls):
240244
return list(cls._backends.keys())
@@ -294,12 +298,12 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
294298
_maybe_download_kernel_for_backend(backend)
295299

296300
old_backend = _AttentionBackendRegistry._active_backend
297-
_AttentionBackendRegistry._active_backend = backend
301+
_AttentionBackendRegistry.set_active_backend(backend)
298302

299303
try:
300304
yield
301305
finally:
302-
_AttentionBackendRegistry._active_backend = old_backend
306+
_AttentionBackendRegistry.set_active_backend(old_backend)
303307

304308

305309
def dispatch_attention_fn(
@@ -325,7 +329,7 @@ def dispatch_attention_fn(
325329
else:
326330
backend_name = AttentionBackendName(backend)
327331
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
328-
332+
329333
kwargs = {
330334
"query": query,
331335
"key": key,
@@ -348,6 +352,18 @@ def dispatch_attention_fn(
348352
check(**kwargs)
349353

350354
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
355+
356+
if "_parallel_config" in kwargs and kwargs["_parallel_config"] is not None:
357+
attention_backend = AttentionBackendName(backend_name)
358+
if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend):
359+
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
360+
raise ValueError(
361+
f"Context parallelism is enabled but backend '{attention_backend.value}' "
362+
f"which does not support context parallelism. "
363+
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
364+
f"calling `model.enable_parallelism()`."
365+
)
366+
351367
return backend_fn(**kwargs)
352368

353369

src/diffusers/models/modeling_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ def set_attention_backend(self, backend: str) -> None:
601601
"""
602602
from .attention import AttentionModuleMixin
603603
from .attention_dispatch import (
604+
_AttentionBackendRegistry,
604605
AttentionBackendName,
605606
_check_attention_backend_requirements,
606607
_maybe_download_kernel_for_backend,
@@ -628,6 +629,9 @@ def set_attention_backend(self, backend: str) -> None:
628629
if processor is None or not hasattr(processor, "_attention_backend"):
629630
continue
630631
processor._attention_backend = backend
632+
633+
# Important to set the active backend so that it propagates gracefully throughout.
634+
_AttentionBackendRegistry.set_active_backend(backend)
631635

632636
def reset_attention_backend(self) -> None:
633637
"""
@@ -1541,7 +1545,7 @@ def enable_parallelism(
15411545
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
15421546
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
15431547
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
1544-
f"calling `enable_parallelism()`."
1548+
f"calling `model.enable_parallelism()`."
15451549
)
15461550

15471551
# All modules use the same attention processor and backend. We don't need to

0 commit comments

Comments
 (0)