Skip to content

Commit ebaa187

Browse files
ViktoriiaRomanovadg845github-actions[bot]
authored
Eliminate GPU sync overhead and CPU→GPU transfers across LTX2 pipeline (#13564)
* Remove unnecessary CUDA synchronization points and avoid CPU→GPU tensor creation across the LTX2 pipeline, transformer, scheduler, and connector logic. - Add set_begin_index(0) to schedulers to eliminate DtoH sync in _init_step_index - Replace torch.tensor(..., device=...) with on-device tensor construction for decode scaling - Move RoPE-related tensor creation to GPU to avoid memcpy overhead - Refactor connector padding logic using vectorized masking instead of list-based ops * Apply style fixes * Revert low-impact CUDA synchronization changes and remove redundant `hasattr` check --------- Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 4ca8633 commit ebaa187

2 files changed

Lines changed: 13 additions & 11 deletions

File tree

src/diffusers/pipelines/ltx2/connectors.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import torch
44
import torch.nn as nn
5-
import torch.nn.functional as F
65

76
from ...configuration_utils import ConfigMixin, register_to_config
87
from ...loaders import PeftAdapterMixin
@@ -295,22 +294,21 @@ def forward(
295294
)
296295

297296
num_register_repeats = seq_len // self.num_learnable_registers
298-
registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim]
297+
registers = (
298+
self.learnable_registers.unsqueeze(0).expand(num_register_repeats, -1, -1).reshape(seq_len, -1)
299+
) # [seq_len, inner_dim]
299300

300301
binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int()
301302
if binary_attn_mask.ndim == 4:
302303
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]
303304

304-
hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)]
305-
valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]
306-
pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]
307-
padded_hidden_states = [
308-
F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths)
309-
]
310-
padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D]
305+
# Replace padding positions with learned registers using vectorized masking
306+
mask = binary_attn_mask.unsqueeze(-1) # [B, L, 1]
307+
registers_expanded = registers.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, D]
308+
hidden_states = mask * hidden_states + (1 - mask) * registers_expanded
311309

312-
flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1]
313-
hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers
310+
# Flip sequence: embeddings move to front, registers to back (from left padding layout)
311+
hidden_states = torch.flip(hidden_states, dims=[1])
314312

315313
# Overwrite attention_mask with an all-zeros mask if using registers.
316314
attention_mask = torch.zeros_like(attention_mask)

src/diffusers/pipelines/ltx2/pipeline_ltx2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,10 @@ def __call__(
11891189
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
11901190
self._num_timesteps = len(timesteps)
11911191

1192+
# Set begin index to skip nonzero().item() call in scheduler initialization, which triggers GPU sync
1193+
self.scheduler.set_begin_index(0)
1194+
audio_scheduler.set_begin_index(0)
1195+
11921196
# 6. Prepare micro-conditions
11931197
# Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
11941198
video_coords = self.transformer.rope.prepare_video_coords(

0 commit comments

Comments
 (0)