@@ -222,7 +222,6 @@ def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
222222
223223 # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
224224 self .scale_rope = scale_rope
225- self ._device_freq_cache : dict [torch .device , tuple [torch .Tensor , torch .Tensor ]] = {}
226225
227226 def rope_params (self , index , dim , theta = 10000 ):
228227 """
@@ -234,11 +233,10 @@ def rope_params(self, index, dim, theta=10000):
234233 freqs = torch .polar (torch .ones_like (freqs ), freqs )
235234 return freqs
236235
236+ @lru_cache_unless_export (maxsize = None )
237237 def _get_device_freqs (self , device : torch .device ) -> tuple [torch .Tensor , torch .Tensor ]:
238- """Return pos_freqs and neg_freqs on the given device, caching the transfer."""
239- if device not in self ._device_freq_cache :
240- self ._device_freq_cache [device ] = (self .pos_freqs .to (device ), self .neg_freqs .to (device ))
241- return self ._device_freq_cache [device ]
238+ """Return pos_freqs and neg_freqs on the given device."""
239+ return self .pos_freqs .to (device ), self .neg_freqs .to (device )
242240
243241 def forward (
244242 self ,
@@ -365,7 +363,6 @@ def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
365363 )
366364
367365 self .scale_rope = scale_rope
368- self ._device_freq_cache : dict [torch .device , tuple [torch .Tensor , torch .Tensor ]] = {}
369366
370367 def rope_params (self , index , dim , theta = 10000 ):
371368 """
@@ -377,11 +374,10 @@ def rope_params(self, index, dim, theta=10000):
377374 freqs = torch .polar (torch .ones_like (freqs ), freqs )
378375 return freqs
379376
377+ @lru_cache_unless_export (maxsize = None )
380378 def _get_device_freqs (self , device : torch .device ) -> tuple [torch .Tensor , torch .Tensor ]:
381- """Return pos_freqs and neg_freqs on the given device, caching the transfer."""
382- if device not in self ._device_freq_cache :
383- self ._device_freq_cache [device ] = (self .pos_freqs .to (device ), self .neg_freqs .to (device ))
384- return self ._device_freq_cache [device ]
379+ """Return pos_freqs and neg_freqs on the given device."""
380+ return self .pos_freqs .to (device ), self .neg_freqs .to (device )
385381
386382 def forward (
387383 self ,
0 commit comments