Skip to content

Commit 6875e95

Browse files
committed
use lru_cache_unless_export
1 parent 073b5b1 commit 6875e95

1 file changed

Lines changed: 6 additions & 10 deletions

File tree

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,6 @@ def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
222222

223223
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
224224
self.scale_rope = scale_rope
225-
self._device_freq_cache: dict[torch.device, tuple[torch.Tensor, torch.Tensor]] = {}
226225

227226
def rope_params(self, index, dim, theta=10000):
228227
"""
@@ -234,11 +233,10 @@ def rope_params(self, index, dim, theta=10000):
234233
freqs = torch.polar(torch.ones_like(freqs), freqs)
235234
return freqs
236235

236+
@lru_cache_unless_export(maxsize=None)
237237
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
238-
"""Return pos_freqs and neg_freqs on the given device, caching the transfer."""
239-
if device not in self._device_freq_cache:
240-
self._device_freq_cache[device] = (self.pos_freqs.to(device), self.neg_freqs.to(device))
241-
return self._device_freq_cache[device]
238+
"""Return pos_freqs and neg_freqs on the given device."""
239+
return self.pos_freqs.to(device), self.neg_freqs.to(device)
242240

243241
def forward(
244242
self,
@@ -365,7 +363,6 @@ def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
365363
)
366364

367365
self.scale_rope = scale_rope
368-
self._device_freq_cache: dict[torch.device, tuple[torch.Tensor, torch.Tensor]] = {}
369366

370367
def rope_params(self, index, dim, theta=10000):
371368
"""
@@ -377,11 +374,10 @@ def rope_params(self, index, dim, theta=10000):
377374
freqs = torch.polar(torch.ones_like(freqs), freqs)
378375
return freqs
379376

377+
@lru_cache_unless_export(maxsize=None)
380378
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
381-
"""Return pos_freqs and neg_freqs on the given device, caching the transfer."""
382-
if device not in self._device_freq_cache:
383-
self._device_freq_cache[device] = (self.pos_freqs.to(device), self.neg_freqs.to(device))
384-
return self._device_freq_cache[device]
379+
"""Return pos_freqs and neg_freqs on the given device."""
380+
return self.pos_freqs.to(device), self.neg_freqs.to(device)
385381

386382
def forward(
387383
self,

0 commit comments

Comments
 (0)