Skip to content
20 changes: 9 additions & 11 deletions src/diffusers/pipelines/ltx2/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
Expand Down Expand Up @@ -295,22 +294,21 @@ def forward(
)

num_register_repeats = seq_len // self.num_learnable_registers
registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim]
registers = (
self.learnable_registers.unsqueeze(0).expand(num_register_repeats, -1, -1).reshape(seq_len, -1)
) # [seq_len, inner_dim]

binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int()
if binary_attn_mask.ndim == 4:
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]

hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)]
valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]
pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]
padded_hidden_states = [
F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths)
]
padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D]
# Replace padding positions with learned registers using vectorized masking
mask = binary_attn_mask.unsqueeze(-1) # [B, L, 1]
registers_expanded = registers.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, D]
hidden_states = mask * hidden_states + (1 - mask) * registers_expanded

flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1]
hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers
# Flip sequence: embeddings move to front, registers to back (from left padding layout)
hidden_states = torch.flip(hidden_states, dims=[1])

# Overwrite attention_mask with an all-zeros mask if using registers.
attention_mask = torch.zeros_like(attention_mask)
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,10 @@ def __call__(
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)

# Set begin index to skip nonzero().item() call in scheduler initialization, which triggers GPU sync
self.scheduler.set_begin_index(0)
audio_scheduler.set_begin_index(0)

# 6. Prepare micro-conditions
# Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
video_coords = self.transformer.rope.prepare_video_coords(
Expand Down
Loading