Skip to content

Commit b9761ce

Browse files
cbensimonsayakpaul
andauthored
[export] Add export-safe LRU cache helper (#13290)
* [core] Add export-safe LRU cache helper * torch version check! --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 52558b4 commit b9761ce

File tree

4 files changed

+31
-10
lines changed

4 files changed

+31
-10
lines changed

src/diffusers/hooks/context_parallel.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import copy
15-
import functools
1615
import inspect
1716
from dataclasses import dataclass
1817
from typing import Type
@@ -32,7 +31,7 @@
3231
gather_size_by_comm,
3332
)
3433
from ..utils import get_logger
35-
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
34+
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module
3635
from .hooks import HookRegistry, ModelHook
3736

3837

@@ -327,7 +326,7 @@ def unshard_anything(
327326
return tensor
328327

329328

330-
@functools.lru_cache(maxsize=64)
329+
@lru_cache_unless_export(maxsize=64)
331330
def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
332331
gather_shapes = []
333332
for i in range(world_size):

src/diffusers/models/attention_dispatch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
is_xformers_version,
5050
)
5151
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
52-
from ..utils.torch_utils import maybe_allow_in_graph
52+
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
5353
from ._modeling_parallel import gather_size_by_comm
5454

5555

@@ -587,7 +587,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
587587
)
588588

589589

590-
@functools.lru_cache(maxsize=128)
590+
@lru_cache_unless_export(maxsize=128)
591591
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
592592
batch_size: int,
593593
seq_len_q: int,

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import functools
1615
import math
1716
from math import prod
1817
from typing import Any
@@ -25,7 +24,7 @@
2524
from ...configuration_utils import ConfigMixin, register_to_config
2625
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
2726
from ...utils import apply_lora_scale, deprecate, logging
28-
from ...utils.torch_utils import maybe_allow_in_graph
27+
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
2928
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
3029
from ..attention import AttentionMixin, FeedForward
3130
from ..attention_dispatch import dispatch_attention_fn
@@ -307,7 +306,7 @@ def forward(
307306

308307
return vid_freqs, txt_freqs
309308

310-
@functools.lru_cache(maxsize=128)
309+
@lru_cache_unless_export(maxsize=128)
311310
def _compute_video_freqs(
312311
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
313312
) -> torch.Tensor:
@@ -428,7 +427,7 @@ def forward(
428427

429428
return vid_freqs, txt_freqs
430429

431-
@functools.lru_cache(maxsize=None)
430+
@lru_cache_unless_export(maxsize=None)
432431
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
433432
seq_lens = frame * height * width
434433
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
@@ -450,7 +449,7 @@ def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device
450449
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
451450
return freqs.clone().contiguous()
452451

453-
@functools.lru_cache(maxsize=None)
452+
@lru_cache_unless_export(maxsize=None)
454453
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
455454
seq_lens = frame * height * width
456455
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs

src/diffusers/utils/torch_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,16 @@
1919

2020
import functools
2121
import os
22+
from typing import Callable, ParamSpec, TypeVar
2223

2324
from . import logging
2425
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
2526

2627

28+
T = TypeVar("T")
29+
P = ParamSpec("P")
30+
31+
2732
if is_torch_available():
2833
import torch
2934
from torch.fft import fftn, fftshift, ifftn, ifftshift
@@ -333,5 +338,23 @@ def disable_full_determinism():
333338
torch.use_deterministic_algorithms(False)
334339

335340

341+
@functools.wraps(functools.lru_cache)
342+
def lru_cache_unless_export(maxsize=128, typed=False):
343+
def outer_wrapper(fn: Callable[P, T]):
344+
cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn)
345+
if is_torch_version("<", "2.7.0"):
346+
return cached
347+
348+
@functools.wraps(fn)
349+
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
350+
if torch.compiler.is_exporting():
351+
return fn(*args, **kwargs)
352+
return cached(*args, **kwargs)
353+
354+
return inner_wrapper
355+
356+
return outer_wrapper
357+
358+
336359
if is_torch_available():
337360
torch_device = get_device()

0 commit comments

Comments
 (0)