@@ -611,14 +611,24 @@ def set_attention_backend(self, backend: str) -> None:
611611 from .attention_processor import Attention , MochiAttention
612612
613613 logger .warning ("Attention backends are an experimental feature and the API may be subject to change." )
614+ attention_classes = (Attention , MochiAttention , AttentionModuleMixin )
615+
616+ parallel_config_set = False
617+ for module in self .modules ():
618+ if not isinstance (module , attention_classes ):
619+ continue
620+ processor = module .processor
621+ if getattr (processor , "_parallel_config" , None ) is not None :
622+ parallel_config_set = True
623+ break
614624
615625 backend = backend .lower ()
616626 available_backends = {x .value for x in AttentionBackendName .__members__ .values ()}
617627 if backend not in available_backends :
618628 raise ValueError (f"`{ backend = } ` must be one of the following: " + ", " .join (available_backends ))
619629
620630 backend = AttentionBackendName (backend )
621- if not _AttentionBackendRegistry ._is_context_parallel_available (backend ):
631+ if parallel_config_set and not _AttentionBackendRegistry ._is_context_parallel_available (backend ):
622632 compatible_backends = sorted (_AttentionBackendRegistry ._supports_context_parallel )
623633 raise ValueError (
624634 f"Context parallelism is enabled but backend '{ backend .value } ' "
@@ -630,7 +640,6 @@ def set_attention_backend(self, backend: str) -> None:
630640 _check_attention_backend_requirements (backend )
631641 _maybe_download_kernel_for_backend (backend )
632642
633- attention_classes = (Attention , MochiAttention , AttentionModuleMixin )
634643 for module in self .modules ():
635644 if not isinstance (module , attention_classes ):
636645 continue
0 commit comments