Skip to content

Commit f18d18d

Browse files
authored
[TRTLLM-12963][refactor] LTX-2 attention: drop dead k_pe parameter; require cached cross-attn (#14555)
Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
1 parent 85d5e6e commit f18d18d

2 files changed

Lines changed: 43 additions & 33 deletions

File tree

tensorrt_llm/_torch/visual_gen/models/ltx2/transformer_ltx2.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class LTX2Attention(Attention):
7272
- Output projection (to_out)
7373
7474
Adds LTX-2 specifics:
75-
- LTX 3D RoPE (INTERLEAVED / SPLIT) with separate k_pe support
75+
- LTX 3D RoPE (INTERLEAVED / SPLIT)
7676
- Gated attention (to_gate_logits)
7777
- Cross-attention with different context_dim for K/V input
7878
"""
@@ -274,8 +274,8 @@ def project_kv(
274274
before all-gather. RoPE is per-token element-wise so it commutes with
275275
seq-dim concat — bit-identical to the post-gather rope while saving
276276
the cos/sin all-gather collective and reducing K-rope compute by U×.
277-
The forward() consumer should pass ``k_pe=None`` to signal that K is
278-
already rotated.
277+
After this, K is already rotated, and the forward() consumer passes
278+
``pre_projected_kv=(k, v)`` to skip re-rotation.
279279
"""
280280
k = self.to_k(context)
281281
v = self.to_v(context)
@@ -298,19 +298,23 @@ def forward(
298298
x: torch.Tensor,
299299
context: torch.Tensor | None = None,
300300
pe: tuple[torch.Tensor, torch.Tensor] | None = None,
301-
k_pe: tuple[torch.Tensor, torch.Tensor] | None = None,
302301
pre_projected_kv: tuple[torch.Tensor, torch.Tensor] | None = None,
303302
key_padding_mask: torch.Tensor | None = None,
304303
) -> torch.Tensor:
305304
"""Forward pass.
306305
307306
Caller contract:
308-
- FUSE_QKV (self-attn): pe must be set; k_pe and pre_projected_kv unused.
309-
- SEPARATE_QKV (cross-attn): cached path requires pre_projected_kv;
310-
uncached path uses ``context`` (may be None when the async-Ulysses
311-
inner backend was swapped to a non-async one — falls back to
312-
self-attn via kv_source=x). pe optional (None = norm-only).
313-
k_pe overrides pe for K (e.g. AV cross-attn) when provided.
307+
- FUSE_QKV (self-attn): pe must be set; pre_projected_kv unused.
308+
- SEPARATE_QKV self-attn (async-Ulysses): pre_projected_kv=None,
309+
context=None — routed to ``forward_async`` (V/Q/K rolling A2A).
310+
Falls through to the sync SEPARATE_QKV self-attn path when the
311+
inner backend lacks ``forward_async`` (Ulysses-inactive swap).
312+
- SEPARATE_QKV cross-attn: pre_projected_kv must be set (K already
313+
norm+rope'd by ``project_kv`` upstream — text cache or AV
314+
project-before-gather). pe optional (None = norm-only on Q).
315+
Uncached cross-attn (context != None without pre_projected_kv) is
316+
rejected: Q/K may have different lengths so sharing pe would
317+
mis-rotate K. Caller must use project_kv + pre_projected_kv.
314318
315319
Args:
316320
key_padding_mask: Optional ``[B, S_kv]`` bool tensor; True = valid,
@@ -324,13 +328,13 @@ def forward(
324328
Routing:
325329
1. Async-Ulysses self-attn → ``forward_async`` (V/Q/K rolling A2A).
326330
2. FUSE_QKV self-attn → packed fused kernel (or naive mini-config).
327-
3. SEPARATE_QKV cross-attn → split fused kernel (or naive mini-config).
331+
3. SEPARATE_QKV cross-attn (cached) → split fused kernel.
332+
4. SEPARATE_QKV self-attn (sync fallback) → split fused kernel on x.
328333
"""
329334
# Async-Ulysses self-attn dispatch. ``hasattr`` guard: audio_attn1 may
330335
# have ``set_ulysses_active(False)`` swap ``self.attn`` to a plain
331336
# backend that lacks ``forward_async`` — fall through to the sync
332-
# uncached SEPARATE_QKV branch, which handles context=None via
333-
# kv_source=x.
337+
# uncached SEPARATE_QKV branch (self-attn on x).
334338
if (
335339
self.qkv_mode == QKVMode.SEPARATE_QKV
336340
and self._use_async_ulysses
@@ -378,34 +382,36 @@ def forward(
378382
if pe is not None:
379383
q = apply_rotary_emb(q, pe, self.rope_type)
380384
else:
381-
# ─── uncached cross-attn / async self-attn fallback ───
382-
# LTX-2 prod doesn't use uncached cross-attn (always pre-projects
383-
# K/V). This branch also catches async self-attn when the inner
384-
# backend lacks forward_async (audio Ulysses-inactive swap):
385-
# context=None then, fall back to self-attn via kv_source=x.
386-
kv_source = context if context is not None else x
385+
# ─── uncached SEPARATE_QKV ───
386+
# Two valid cases:
387+
# (a) async-Ulysses self-attn fallback (context=None) when
388+
# the inner backend lacks forward_async (e.g. audio
389+
# Ulysses-inactive swap). Use x for K/V (self-attn).
390+
# (b) (forbidden) uncached cross-attn (context != None) —
391+
# Q/K may have different lengths so sharing pe would
392+
# mis-rotate K. Caller must use project_kv + pre_projected_kv.
393+
if context is not None:
394+
raise ValueError(
395+
"uncached SEPARATE_QKV cross-attn is forbidden; "
396+
"pass pre_projected_kv from project_kv(context, pe=...)."
397+
)
387398
q = self.to_q(x)
388-
k = self.to_k(kv_source)
389-
v = self.to_v(kv_source)
399+
k = self.to_k(x)
400+
v = self.to_v(x)
390401
if use_fused:
391402
self.apply_split_norm_or_norm_rope(
392403
q, self.norm_q.weight, self.num_attention_heads, pe
393404
)
394405
self.apply_split_norm_or_norm_rope(
395-
k,
396-
self.norm_k.weight,
397-
self.num_key_value_heads,
398-
k_pe if k_pe is not None else pe,
406+
k, self.norm_k.weight, self.num_key_value_heads, pe
399407
)
400408
else:
401409
if self.qk_norm:
402410
q = self.norm_q(q)
403411
k = self.norm_k(k)
404412
if pe is not None:
405413
q = apply_rotary_emb(q, pe, self.rope_type)
406-
k_pe_use = k_pe if k_pe is not None else pe
407-
if k_pe_use is not None:
408-
k = apply_rotary_emb(k, k_pe_use, self.rope_type)
414+
k = apply_rotary_emb(k, pe, self.rope_type)
409415

410416
attn_kwargs = {}
411417
if key_padding_mask is not None:
@@ -746,7 +752,8 @@ def forward(
746752
perturbations: Optional ``BatchedPerturbationConfig`` that masks
747753
attention outputs for selected blocks/modalities.
748754
text_kv_video: Pre-projected (K, V) for video text cross-attention.
749-
Falls back to inline computation if ``None``.
755+
Required when the video stream runs cross-attn — built by
756+
``LTXModel.prepare_text_cache``.
750757
text_kv_audio: Pre-projected (K, V) for audio text cross-attention.
751758
"""
752759
if video is None and audio is None:
@@ -885,7 +892,6 @@ def forward(
885892
vx_scaled,
886893
pre_projected_kv=(k_a2v, v_a2v),
887894
pe=video.cross_positional_embeddings,
888-
k_pe=None, # K already rotated in project_kv
889895
key_padding_mask=audio.audio_padding_mask,
890896
)
891897
* gate_out_a2v
@@ -927,7 +933,6 @@ def forward(
927933
ax_scaled,
928934
pre_projected_kv=(k_v2a, v_v2a),
929935
pe=audio.cross_positional_embeddings,
930-
k_pe=None, # K already rotated in project_kv
931936
)
932937
* gate_out_v2a
933938
)

tests/unittest/_torch/visual_gen/test_ltx2_attention.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,12 @@ def test_cross_attention_sanity(self):
190190
x = torch.randn(batch_size, q_seq, query_dim, device=self.DEVICE, dtype=dtype) * 0.02
191191
ctx = torch.randn(batch_size, kv_seq, context_dim, device=self.DEVICE, dtype=dtype) * 0.02
192192

193+
# Cross-attn must use the cached pattern: project K/V upstream, then
194+
# call forward with pre_projected_kv (production text cross-attn does
195+
# this via prepare_text_cache; AV cross-attn does it inline).
193196
with torch.no_grad():
194-
output = attn(x, context=ctx, pe=None)
197+
k, v = attn.project_kv(ctx, pe=None)
198+
output = attn(x, pre_projected_kv=(k, v), pe=None)
195199

196200
self.assertEqual(output.shape, (batch_size, q_seq, query_dim))
197201

@@ -229,7 +233,8 @@ def test_cross_attention_different_dims(self):
229233
ctx = torch.randn(batch_size, kv_seq, context_dim, device=self.DEVICE, dtype=dtype) * 0.02
230234

231235
with torch.no_grad():
232-
output = attn(x, context=ctx, pe=None)
236+
k, v = attn.project_kv(ctx, pe=None)
237+
output = attn(x, pre_projected_kv=(k, v), pe=None)
233238

234239
self.assertEqual(output.shape, (batch_size, q_seq, query_dim))
235240

0 commit comments

Comments
 (0)