@@ -72,7 +72,7 @@ class LTX2Attention(Attention):
7272 - Output projection (to_out)
7373
7474 Adds LTX-2 specifics:
75- - LTX 3D RoPE (INTERLEAVED / SPLIT) with separate k_pe support
75+ - LTX 3D RoPE (INTERLEAVED / SPLIT)
7676 - Gated attention (to_gate_logits)
7777 - Cross-attention with different context_dim for K/V input
7878 """
@@ -274,8 +274,8 @@ def project_kv(
274274 before all-gather. RoPE is per-token element-wise so it commutes with
275275 seq-dim concat — bit-identical to the post-gather rope while saving
276276 the cos/sin all-gather collective and reducing K-rope compute by U×.
277- The forward() consumer should pass ``k_pe=None`` to signal that K is
278- already rotated .
277+ After this, K is already rotated, and the forward() consumer passes
278+ ``pre_projected_kv=(k, v)`` to skip re-rotation .
279279 """
280280 k = self .to_k (context )
281281 v = self .to_v (context )
@@ -298,19 +298,23 @@ def forward(
298298 x : torch .Tensor ,
299299 context : torch .Tensor | None = None ,
300300 pe : tuple [torch .Tensor , torch .Tensor ] | None = None ,
301- k_pe : tuple [torch .Tensor , torch .Tensor ] | None = None ,
302301 pre_projected_kv : tuple [torch .Tensor , torch .Tensor ] | None = None ,
303302 key_padding_mask : torch .Tensor | None = None ,
304303 ) -> torch .Tensor :
305304 """Forward pass.
306305
307306 Caller contract:
308- - FUSE_QKV (self-attn): pe must be set; k_pe and pre_projected_kv unused.
309- - SEPARATE_QKV (cross-attn): cached path requires pre_projected_kv;
310- uncached path uses ``context`` (may be None when the async-Ulysses
311- inner backend was swapped to a non-async one — falls back to
312- self-attn via kv_source=x). pe optional (None = norm-only).
313- k_pe overrides pe for K (e.g. AV cross-attn) when provided.
307+ - FUSE_QKV (self-attn): pe must be set; pre_projected_kv unused.
308+ - SEPARATE_QKV self-attn (async-Ulysses): pre_projected_kv=None,
309+ context=None — routed to ``forward_async`` (V/Q/K rolling A2A).
310+ Falls through to the sync SEPARATE_QKV self-attn path when the
311+ inner backend lacks ``forward_async`` (Ulysses-inactive swap).
312+ - SEPARATE_QKV cross-attn: pre_projected_kv must be set (K already
313+ norm+rope'd by ``project_kv`` upstream — text cache or AV
314+ project-before-gather). pe optional (None = norm-only on Q).
315+ Uncached cross-attn (context != None without pre_projected_kv) is
316+ rejected: Q/K may have different lengths so sharing pe would
317+ mis-rotate K. Caller must use project_kv + pre_projected_kv.
314318
315319 Args:
316320 key_padding_mask: Optional ``[B, S_kv]`` bool tensor; True = valid,
@@ -324,13 +328,13 @@ def forward(
324328 Routing:
325329 1. Async-Ulysses self-attn → ``forward_async`` (V/Q/K rolling A2A).
326330 2. FUSE_QKV self-attn → packed fused kernel (or naive mini-config).
327- 3. SEPARATE_QKV cross-attn → split fused kernel (or naive mini-config).
331+ 3. SEPARATE_QKV cross-attn (cached) → split fused kernel.
332+ 4. SEPARATE_QKV self-attn (sync fallback) → split fused kernel on x.
328333 """
329334 # Async-Ulysses self-attn dispatch. ``hasattr`` guard: audio_attn1 may
330335 # have ``set_ulysses_active(False)`` swap ``self.attn`` to a plain
331336 # backend that lacks ``forward_async`` — fall through to the sync
332- # uncached SEPARATE_QKV branch, which handles context=None via
333- # kv_source=x.
337+ # uncached SEPARATE_QKV branch (self-attn on x).
334338 if (
335339 self .qkv_mode == QKVMode .SEPARATE_QKV
336340 and self ._use_async_ulysses
@@ -378,34 +382,36 @@ def forward(
378382 if pe is not None :
379383 q = apply_rotary_emb (q , pe , self .rope_type )
380384 else :
381- # ─── uncached cross-attn / async self-attn fallback ───
382- # LTX-2 prod doesn't use uncached cross-attn (always pre-projects
383- # K/V). This branch also catches async self-attn when the inner
384- # backend lacks forward_async (audio Ulysses-inactive swap):
385- # context=None then, fall back to self-attn via kv_source=x.
386- kv_source = context if context is not None else x
385+ # ─── uncached SEPARATE_QKV ───
386+ # Two valid cases:
387+ # (a) async-Ulysses self-attn fallback (context=None) when
388+ # the inner backend lacks forward_async (e.g. audio
389+ # Ulysses-inactive swap). Use x for K/V (self-attn).
390+ # (b) (forbidden) uncached cross-attn (context != None) —
391+ # Q/K may have different lengths so sharing pe would
392+ # mis-rotate K. Caller must use project_kv + pre_projected_kv.
393+ if context is not None :
394+ raise ValueError (
395+ "uncached SEPARATE_QKV cross-attn is forbidden; "
396+ "pass pre_projected_kv from project_kv(context, pe=...)."
397+ )
387398 q = self .to_q (x )
388- k = self .to_k (kv_source )
389- v = self .to_v (kv_source )
399+ k = self .to_k (x )
400+ v = self .to_v (x )
390401 if use_fused :
391402 self .apply_split_norm_or_norm_rope (
392403 q , self .norm_q .weight , self .num_attention_heads , pe
393404 )
394405 self .apply_split_norm_or_norm_rope (
395- k ,
396- self .norm_k .weight ,
397- self .num_key_value_heads ,
398- k_pe if k_pe is not None else pe ,
406+ k , self .norm_k .weight , self .num_key_value_heads , pe
399407 )
400408 else :
401409 if self .qk_norm :
402410 q = self .norm_q (q )
403411 k = self .norm_k (k )
404412 if pe is not None :
405413 q = apply_rotary_emb (q , pe , self .rope_type )
406- k_pe_use = k_pe if k_pe is not None else pe
407- if k_pe_use is not None :
408- k = apply_rotary_emb (k , k_pe_use , self .rope_type )
414+ k = apply_rotary_emb (k , pe , self .rope_type )
409415
410416 attn_kwargs = {}
411417 if key_padding_mask is not None :
@@ -746,7 +752,8 @@ def forward(
746752 perturbations: Optional ``BatchedPerturbationConfig`` that masks
747753 attention outputs for selected blocks/modalities.
748754 text_kv_video: Pre-projected (K, V) for video text cross-attention.
749- Falls back to inline computation if ``None``.
755+ Required when the video stream runs cross-attn — built by
756+ ``LTXModel.prepare_text_cache``.
750757 text_kv_audio: Pre-projected (K, V) for audio text cross-attention.
751758 """
752759 if video is None and audio is None :
@@ -885,7 +892,6 @@ def forward(
885892 vx_scaled ,
886893 pre_projected_kv = (k_a2v , v_a2v ),
887894 pe = video .cross_positional_embeddings ,
888- k_pe = None , # K already rotated in project_kv
889895 key_padding_mask = audio .audio_padding_mask ,
890896 )
891897 * gate_out_a2v
@@ -927,7 +933,6 @@ def forward(
927933 ax_scaled ,
928934 pre_projected_kv = (k_v2a , v_v2a ),
929935 pe = audio .cross_positional_embeddings ,
930- k_pe = None , # K already rotated in project_kv
931936 )
932937 * gate_out_v2a
933938 )
0 commit comments