Skip to content

Commit 1994ea3

Browse files
committed
Cache RoPE freqs on device to avoid repeated CPU-GPU copy in QwenImage
1 parent 5adc544 commit 1994ea3

1 file changed

Lines changed: 23 additions & 10 deletions

File tree

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
222222

223223
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
224224
self.scale_rope = scale_rope
225+
self._device_freq_cache: dict[torch.device, tuple[torch.Tensor, torch.Tensor]] = {}
225226

226227
def rope_params(self, index, dim, theta=10000):
227228
"""
@@ -233,6 +234,12 @@ def rope_params(self, index, dim, theta=10000):
233234
freqs = torch.polar(torch.ones_like(freqs), freqs)
234235
return freqs
235236

237+
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
238+
"""Return pos_freqs and neg_freqs on the given device, caching the transfer."""
239+
if device not in self._device_freq_cache:
240+
self._device_freq_cache[device] = (self.pos_freqs.to(device), self.neg_freqs.to(device))
241+
return self._device_freq_cache[device]
242+
236243
def forward(
237244
self,
238245
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
@@ -300,8 +307,9 @@ def forward(
300307
max_vid_index = max(height, width, max_vid_index)
301308

302309
max_txt_seq_len_int = int(max_txt_seq_len)
303-
# Create device-specific copy for text freqs without modifying self.pos_freqs
304-
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
310+
# Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
311+
pos_freqs_device, _ = self._get_device_freqs(device)
312+
txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
305313
vid_freqs = torch.cat(vid_freqs, dim=0)
306314

307315
return vid_freqs, txt_freqs
@@ -311,8 +319,7 @@ def _compute_video_freqs(
311319
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
312320
) -> torch.Tensor:
313321
seq_lens = frame * height * width
314-
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
315-
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
322+
pos_freqs, neg_freqs = self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
316323

317324
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
318325
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
@@ -356,6 +363,7 @@ def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
356363
)
357364

358365
self.scale_rope = scale_rope
366+
self._device_freq_cache: dict[torch.device, tuple[torch.Tensor, torch.Tensor]] = {}
359367

360368
def rope_params(self, index, dim, theta=10000):
361369
"""
@@ -367,6 +375,12 @@ def rope_params(self, index, dim, theta=10000):
367375
freqs = torch.polar(torch.ones_like(freqs), freqs)
368376
return freqs
369377

378+
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
379+
"""Return pos_freqs and neg_freqs on the given device, caching the transfer."""
380+
if device not in self._device_freq_cache:
381+
self._device_freq_cache[device] = (self.pos_freqs.to(device), self.neg_freqs.to(device))
382+
return self._device_freq_cache[device]
383+
370384
def forward(
371385
self,
372386
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
@@ -421,17 +435,17 @@ def forward(
421435

422436
max_vid_index = max(max_vid_index, layer_num)
423437
max_txt_seq_len_int = int(max_txt_seq_len)
424-
# Create device-specific copy for text freqs without modifying self.pos_freqs
425-
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
438+
# Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
439+
pos_freqs_device, _ = self._get_device_freqs(device)
440+
txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
426441
vid_freqs = torch.cat(vid_freqs, dim=0)
427442

428443
return vid_freqs, txt_freqs
429444

430445
@lru_cache_unless_export(maxsize=None)
431446
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
432447
seq_lens = frame * height * width
433-
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
434-
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
448+
pos_freqs, neg_freqs = self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
435449

436450
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
437451
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
@@ -452,8 +466,7 @@ def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device
452466
@lru_cache_unless_export(maxsize=None)
453467
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
454468
seq_lens = frame * height * width
455-
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
456-
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
469+
pos_freqs, neg_freqs = self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
457470

458471
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
459472
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)

0 commit comments

Comments
 (0)