@@ -222,6 +222,7 @@ def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
222222
223223 # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
224224 self .scale_rope = scale_rope
225+ self ._device_freq_cache : dict [torch .device , tuple [torch .Tensor , torch .Tensor ]] = {}
225226
226227 def rope_params (self , index , dim , theta = 10000 ):
227228 """
@@ -233,6 +234,12 @@ def rope_params(self, index, dim, theta=10000):
233234 freqs = torch .polar (torch .ones_like (freqs ), freqs )
234235 return freqs
235236
237+ def _get_device_freqs (self , device : torch .device ) -> tuple [torch .Tensor , torch .Tensor ]:
238+ """Return pos_freqs and neg_freqs on the given device, caching the transfer."""
239+ if device not in self ._device_freq_cache :
240+ self ._device_freq_cache [device ] = (self .pos_freqs .to (device ), self .neg_freqs .to (device ))
241+ return self ._device_freq_cache [device ]
242+
236243 def forward (
237244 self ,
238245 video_fhw : tuple [int , int , int , list [tuple [int , int , int ]]],
@@ -300,8 +307,9 @@ def forward(
300307 max_vid_index = max (height , width , max_vid_index )
301308
302309 max_txt_seq_len_int = int (max_txt_seq_len )
303- # Create device-specific copy for text freqs without modifying self.pos_freqs
304- txt_freqs = self .pos_freqs .to (device )[max_vid_index : max_vid_index + max_txt_seq_len_int , ...]
310+ # Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
311+ pos_freqs_device , _ = self ._get_device_freqs (device )
312+ txt_freqs = pos_freqs_device [max_vid_index : max_vid_index + max_txt_seq_len_int , ...]
305313 vid_freqs = torch .cat (vid_freqs , dim = 0 )
306314
307315 return vid_freqs , txt_freqs
@@ -311,8 +319,7 @@ def _compute_video_freqs(
311319 self , frame : int , height : int , width : int , idx : int = 0 , device : torch .device = None
312320 ) -> torch .Tensor :
313321 seq_lens = frame * height * width
314- pos_freqs = self .pos_freqs .to (device ) if device is not None else self .pos_freqs
315- neg_freqs = self .neg_freqs .to (device ) if device is not None else self .neg_freqs
322+ pos_freqs , neg_freqs = self ._get_device_freqs (device ) if device is not None else (self .pos_freqs , self .neg_freqs )
316323
317324 freqs_pos = pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
318325 freqs_neg = neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
@@ -356,6 +363,7 @@ def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
356363 )
357364
358365 self .scale_rope = scale_rope
366+ self ._device_freq_cache : dict [torch .device , tuple [torch .Tensor , torch .Tensor ]] = {}
359367
360368 def rope_params (self , index , dim , theta = 10000 ):
361369 """
@@ -367,6 +375,12 @@ def rope_params(self, index, dim, theta=10000):
367375 freqs = torch .polar (torch .ones_like (freqs ), freqs )
368376 return freqs
369377
378+ def _get_device_freqs (self , device : torch .device ) -> tuple [torch .Tensor , torch .Tensor ]:
379+ """Return pos_freqs and neg_freqs on the given device, caching the transfer."""
380+ if device not in self ._device_freq_cache :
381+ self ._device_freq_cache [device ] = (self .pos_freqs .to (device ), self .neg_freqs .to (device ))
382+ return self ._device_freq_cache [device ]
383+
370384 def forward (
371385 self ,
372386 video_fhw : tuple [int , int , int , list [tuple [int , int , int ]]],
@@ -421,17 +435,17 @@ def forward(
421435
422436 max_vid_index = max (max_vid_index , layer_num )
423437 max_txt_seq_len_int = int (max_txt_seq_len )
424- # Create device-specific copy for text freqs without modifying self.pos_freqs
425- txt_freqs = self .pos_freqs .to (device )[max_vid_index : max_vid_index + max_txt_seq_len_int , ...]
438+ # Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
439+ pos_freqs_device , _ = self ._get_device_freqs (device )
440+ txt_freqs = pos_freqs_device [max_vid_index : max_vid_index + max_txt_seq_len_int , ...]
426441 vid_freqs = torch .cat (vid_freqs , dim = 0 )
427442
428443 return vid_freqs , txt_freqs
429444
430445 @lru_cache_unless_export (maxsize = None )
431446 def _compute_video_freqs (self , frame , height , width , idx = 0 , device : torch .device = None ):
432447 seq_lens = frame * height * width
433- pos_freqs = self .pos_freqs .to (device ) if device is not None else self .pos_freqs
434- neg_freqs = self .neg_freqs .to (device ) if device is not None else self .neg_freqs
448+ pos_freqs , neg_freqs = self ._get_device_freqs (device ) if device is not None else (self .pos_freqs , self .neg_freqs )
435449
436450 freqs_pos = pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
437451 freqs_neg = neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
@@ -452,8 +466,7 @@ def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device
452466 @lru_cache_unless_export (maxsize = None )
453467 def _compute_condition_freqs (self , frame , height , width , device : torch .device = None ):
454468 seq_lens = frame * height * width
455- pos_freqs = self .pos_freqs .to (device ) if device is not None else self .pos_freqs
456- neg_freqs = self .neg_freqs .to (device ) if device is not None else self .neg_freqs
469+ pos_freqs , neg_freqs = self ._get_device_freqs (device ) if device is not None else (self .pos_freqs , self .neg_freqs )
457470
458471 freqs_pos = pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
459472 freqs_neg = neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
0 commit comments