Skip to content

Commit 7bf0000

Browse files
Boffeeclaudedg845
authored
Fix LTX2 connector token/register layout (regression from #13564) (#13931)
* Fix LTX2 connector register layout to match the original LTX implementation The connector replaced left-padding positions with the tiled registers and then flipped the whole sequence, which put the prompt tokens at the front in reversed order and the register tile reversed within each block. The original LTX implementation (ltx-core _replace_padded_with_learnable_registers, also matched by ComfyUI) front-aligns the valid tokens in their original order and fills the tail with registers indexed by absolute position. Since the connector blocks apply RoPE, the reversed layout produces off-distribution embeddings; short prompts (e.g. negative prompts, whose context is mostly registers) are hit hardest, which manifests as overblown CFG: at cfg > 1 (or CFG++ samplers at cfg 1) the unconditional branch is computed from a mostly-register context with scrambled positions. Replace the fill+flip with a stable-argsort gather (valid tokens to the front, order preserved, per batch row) and fill the tail with the absolute-position register tile. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> * Add register-layout regression tests for the LTX2 text connectors Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> --------- Co-authored-by: Claude Fable 5 <noreply@anthropic.com> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
1 parent b596c83 commit 7bf0000

2 files changed

Lines changed: 109 additions & 6 deletions

File tree

src/diffusers/pipelines/ltx2/connectors.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,13 +302,17 @@ def forward(
302302
if binary_attn_mask.ndim == 4:
303303
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]
304304

305-
# Replace padding positions with learned registers using vectorized masking
306-
mask = binary_attn_mask.unsqueeze(-1) # [B, L, 1]
305+
# Move the valid tokens to the front in their original order and fill the tail
306+
# with registers indexed by absolute position, matching the original LTX
307+
# implementation (`_replace_padded_with_learnable_registers`). A stable argsort
308+
# of the inverted mask gathers valid tokens first while preserving their order.
309+
order = torch.argsort(1 - binary_attn_mask, dim=1, stable=True) # [B, L]
310+
front_aligned = torch.gather(hidden_states, 1, order.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]))
311+
num_valid = binary_attn_mask.sum(dim=1, keepdim=True) # [B, 1]
312+
positions = torch.arange(seq_len, device=hidden_states.device).unsqueeze(0) # [1, L]
313+
front_mask = (positions < num_valid).unsqueeze(-1) # [B, L, 1]
307314
registers_expanded = registers.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, D]
308-
hidden_states = mask * hidden_states + (1 - mask) * registers_expanded
309-
310-
# Flip sequence: embeddings move to front, registers to back (from left padding layout)
311-
hidden_states = torch.flip(hidden_states, dims=[1])
315+
hidden_states = torch.where(front_mask, front_aligned, registers_expanded.to(hidden_states.dtype))
312316

313317
# Overwrite attention_mask with an all-zeros mask if using registers.
314318
attention_mask = torch.zeros_like(attention_mask)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2026 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import torch
18+
19+
from diffusers.pipelines.ltx2.connectors import LTX2ConnectorTransformer1d
20+
21+
from ...testing_utils import enable_full_determinism
22+
23+
24+
enable_full_determinism()
25+
26+
27+
class LTX2ConnectorRegisterLayoutTests(unittest.TestCase):
28+
"""The connector must lay out its sequence exactly like the original LTX
29+
implementation (``ltx_core`` ``_replace_padded_with_learnable_registers``,
30+
also matched by ComfyUI): the valid tokens move to the front *in their
31+
original order*, and the tail is filled with the tiled learnable registers
32+
indexed by *absolute position*. The connector blocks apply RoPE, so any
33+
deviation (e.g. reversed token order) produces embeddings the DiT was
34+
never trained on.
35+
"""
36+
37+
num_registers = 4
38+
seq_len = 12
39+
num_heads = 2
40+
head_dim = 4
41+
42+
def get_connector(self):
43+
# num_layers=0 keeps the forward to layout + final RMSNorm, so the
44+
# register layout can be checked exactly.
45+
return LTX2ConnectorTransformer1d(
46+
num_attention_heads=self.num_heads,
47+
attention_head_dim=self.head_dim,
48+
num_layers=0,
49+
num_learnable_registers=self.num_registers,
50+
).eval()
51+
52+
def get_inputs(self, valid_lengths):
53+
dim = self.num_heads * self.head_dim
54+
batch_size = len(valid_lengths)
55+
hidden_states = torch.randn(batch_size, self.seq_len, dim)
56+
# Left padding, like the Gemma tokenization in the LTX2 pipelines.
57+
binary_mask = torch.zeros(batch_size, self.seq_len, dtype=torch.int64)
58+
for i, n in enumerate(valid_lengths):
59+
binary_mask[i, self.seq_len - n :] = 1
60+
additive_mask = (binary_mask - 1).to(hidden_states.dtype)
61+
additive_mask = additive_mask.reshape(batch_size, 1, 1, self.seq_len)
62+
additive_mask = additive_mask * torch.finfo(hidden_states.dtype).max
63+
return hidden_states, binary_mask, additive_mask
64+
65+
def reference_layout(self, connector, hidden_states, binary_mask):
66+
# Reference semantics: front-align valid tokens (order preserved),
67+
# fill the tail with the register tile by absolute position.
68+
batch_size, seq_len, _ = hidden_states.shape
69+
registers = connector.learnable_registers.detach()
70+
tiled = registers.repeat(seq_len // self.num_registers, 1)
71+
expected = torch.empty_like(hidden_states)
72+
for i in range(batch_size):
73+
valid = hidden_states[i, binary_mask[i].bool()]
74+
expected[i, : valid.shape[0]] = valid
75+
expected[i, valid.shape[0] :] = tiled[valid.shape[0] :]
76+
# The forward ends with a non-affine RMSNorm.
77+
return expected * torch.rsqrt(expected.pow(2).mean(-1, keepdim=True) + 1e-6)
78+
79+
def check_layout(self, valid_lengths):
80+
connector = self.get_connector()
81+
hidden_states, binary_mask, additive_mask = self.get_inputs(valid_lengths)
82+
with torch.no_grad():
83+
output, _ = connector(hidden_states, additive_mask)
84+
expected = self.reference_layout(connector, hidden_states, binary_mask)
85+
self.assertTrue(torch.allclose(output, expected, atol=1e-5))
86+
87+
def test_register_layout_left_padded(self):
88+
self.check_layout([5])
89+
90+
def test_register_layout_mixed_lengths_batch(self):
91+
# The pipelines concatenate negative and positive prompts of different
92+
# lengths into one batch; the layout must be computed per row.
93+
self.check_layout([5, 2])
94+
95+
def test_register_layout_fully_valid(self):
96+
self.check_layout([self.seq_len])
97+
98+
def test_register_layout_single_token(self):
99+
self.check_layout([1])

0 commit comments

Comments
 (0)