@@ -235,6 +235,10 @@ def decorator(func):
235235 def get_active_backend (cls ):
236236 return cls ._active_backend , cls ._backends [cls ._active_backend ]
237237
238+ @classmethod
239+ def set_active_backend (cls , backend : str ):
240+ cls ._active_backend = backend
241+
238242 @classmethod
239243 def list_backends (cls ):
240244 return list (cls ._backends .keys ())
@@ -294,12 +298,12 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
294298 _maybe_download_kernel_for_backend (backend )
295299
296300 old_backend = _AttentionBackendRegistry ._active_backend
297- _AttentionBackendRegistry ._active_backend = backend
301+ _AttentionBackendRegistry .set_active_backend ( backend )
298302
299303 try :
300304 yield
301305 finally :
302- _AttentionBackendRegistry ._active_backend = old_backend
306+ _AttentionBackendRegistry .set_active_backend ( old_backend )
303307
304308
305309def dispatch_attention_fn (
@@ -325,7 +329,7 @@ def dispatch_attention_fn(
325329 else :
326330 backend_name = AttentionBackendName (backend )
327331 backend_fn = _AttentionBackendRegistry ._backends .get (backend_name )
328-
332+
329333 kwargs = {
330334 "query" : query ,
331335 "key" : key ,
@@ -348,6 +352,18 @@ def dispatch_attention_fn(
348352 check (** kwargs )
349353
350354 kwargs = {k : v for k , v in kwargs .items () if k in _AttentionBackendRegistry ._supported_arg_names [backend_name ]}
355+
356+ if "_parallel_config" in kwargs and kwargs ["_parallel_config" ] is not None :
357+ attention_backend = AttentionBackendName (backend_name )
358+ if not _AttentionBackendRegistry ._is_context_parallel_available (attention_backend ):
359+ compatible_backends = sorted (_AttentionBackendRegistry ._supports_context_parallel )
360+ raise ValueError (
361+ f"Context parallelism is enabled but backend '{ attention_backend .value } ' "
362+ f"which does not support context parallelism. "
363+ f"Please set a compatible attention backend: { compatible_backends } using `model.set_attention_backend()` before "
364+ f"calling `model.enable_parallelism()`."
365+ )
366+
351367 return backend_fn (** kwargs )
352368
353369
0 commit comments