Skip to content

Commit a05c8e9

Browse files
Fix Dynamo lru_cache warnings during torch.compile (#13384)
* fix compile issue Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * compile friendly Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add comments Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 8070f6e commit a05c8e9

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,9 @@ def dispatch_attention_fn(
423423
**attention_kwargs,
424424
"_parallel_config": parallel_config,
425425
}
426-
if is_torch_version(">=", "2.5.0"):
426+
# Equivalent to `is_torch_version(">=", "2.5.0")` — use module-level constant to avoid
427+
# Dynamo tracing into the lru_cache-wrapped `is_torch_version` during torch.compile.
428+
if _CAN_USE_FLEX_ATTN:
427429
kwargs["enable_gqa"] = enable_gqa
428430

429431
if _AttentionBackendRegistry._checks_enabled:

src/diffusers/utils/torch_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,17 @@ def outer_wrapper(fn: Callable[P, T]):
347347

348348
@functools.wraps(fn)
349349
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
350-
if torch.compiler.is_exporting():
350+
compiler = getattr(torch, "compiler", None)
351+
is_exporting = bool(compiler and hasattr(compiler, "is_exporting") and compiler.is_exporting())
352+
is_compiling = bool(compiler and hasattr(compiler, "is_compiling") and compiler.is_compiling())
353+
354+
# Fallback for older builds where compiler.is_compiling is unavailable.
355+
if not is_compiling:
356+
dynamo = getattr(torch, "_dynamo", None)
357+
if dynamo is not None and hasattr(dynamo, "is_compiling"):
358+
is_compiling = dynamo.is_compiling()
359+
360+
if is_exporting or is_compiling:
351361
return fn(*args, **kwargs)
352362
return cached(*args, **kwargs)
353363

0 commit comments

Comments
 (0)