Skip to content

Commit b1dfd30

Browse files
authored
[TRTLLM-12653][feat] LTX-2 Ulysses cross-attention for v2a with audio padding (#14044)
Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
1 parent ecb1b44 commit b1dfd30

9 files changed

Lines changed: 924 additions & 52 deletions

File tree

tensorrt_llm/_torch/visual_gen/attention_backend/flash_attn4.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,14 @@ def _fwd(
7171
k: torch.Tensor,
7272
v: torch.Tensor,
7373
causal: bool,
74+
seqused_k: Optional[torch.Tensor] = None,
7475
) -> Tuple[torch.Tensor, torch.Tensor]:
7576
"""Calls _flash_attn_fwd with torch.compile disabled. Returns (output, lse)."""
7677
output, lse = _flash_attn_fwd(
7778
q,
7879
k,
7980
v,
81+
seqused_k=seqused_k,
8082
softmax_scale=self.scale,
8183
causal=causal,
8284
window_size_left=None,
@@ -120,6 +122,7 @@ def forward(
120122
v: torch.Tensor,
121123
*,
122124
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.FULL,
125+
key_padding_mask: Optional[torch.Tensor] = None,
123126
**kwargs,
124127
) -> torch.Tensor:
125128
"""
@@ -132,11 +135,21 @@ def forward(
132135
k: Key tensor [batch_size, seq_len_kv, num_kv_heads, head_dim]
133136
v: Value tensor [batch_size, seq_len_kv, num_kv_heads, head_dim]
134137
attention_mask: Attention mask type (CAUSAL or FULL)
138+
key_padding_mask: Optional ``[B, S_kv]`` bool tensor; True = valid,
139+
False = pad. Translated to FA4's ``seqused_k = mask.sum(dim=1)``
140+
(assumes True-prefix layout). Non-causal only.
135141
136142
Returns:
137143
Output tensor [batch_size, seq_len, num_heads, head_dim]
138144
"""
139-
output, _ = self.forward_with_lse(q, k, v, attention_mask=attention_mask, **kwargs)
145+
output, _ = self.forward_with_lse(
146+
q,
147+
k,
148+
v,
149+
attention_mask=attention_mask,
150+
key_padding_mask=key_padding_mask,
151+
**kwargs,
152+
)
140153
return output
141154

142155
def forward_with_lse(
@@ -145,6 +158,7 @@ def forward_with_lse(
145158
k: torch.Tensor,
146159
v: torch.Tensor,
147160
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.FULL,
161+
key_padding_mask: Optional[torch.Tensor] = None,
148162
**kwargs,
149163
) -> Tuple[torch.Tensor, torch.Tensor]:
150164
"""
@@ -157,7 +171,20 @@ def forward_with_lse(
157171
partial attention results in Attention2D parallelism.
158172
"""
159173
q, k, v, is_causal, origin_dtype = self._prepare_inputs(q, k, v, attention_mask)
160-
output, lse = self._fwd(q, k, v, is_causal)
174+
seqused_k = None
175+
if key_padding_mask is not None:
176+
assert not is_causal, "key_padding_mask is not supported with causal attention"
177+
assert key_padding_mask.dim() == 2 and key_padding_mask.shape == (
178+
q.shape[0],
179+
k.shape[1],
180+
), (
181+
f"Invalid key_padding_mask shape: expected [B={q.shape[0]}, "
182+
f"S_kv={k.shape[1]}], got {tuple(key_padding_mask.shape)}"
183+
)
184+
# FA4 seqused_k assumes a True-prefix layout: positions [0, valid)
185+
# are kept, [valid, S_kv) are masked. mask.sum gives the prefix length.
186+
seqused_k = key_padding_mask.sum(dim=1).to(torch.int32)
187+
output, lse = self._fwd(q, k, v, is_causal, seqused_k=seqused_k)
161188
if output.dtype != origin_dtype:
162189
output = output.to(origin_dtype)
163190
return output, lse

tensorrt_llm/_torch/visual_gen/attention_backend/vanilla.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def forward(
7171
v: torch.Tensor,
7272
*,
7373
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.FULL,
74+
key_padding_mask: Optional[torch.Tensor] = None,
7475
**kwargs,
7576
) -> torch.Tensor:
7677
"""
@@ -83,6 +84,9 @@ def forward(
8384
k: Key tensor [batch_size, num_kv_heads, seq_len_kv, head_dim]
8485
v: Value tensor [batch_size, num_kv_heads, seq_len_kv, head_dim]
8586
attention_mask: Attention mask type (CAUSAL or FULL)
87+
key_padding_mask: Optional ``[B, S_kv]`` bool tensor; True = valid,
88+
False = pad. Expanded internally to ``[B, 1, 1, S_kv]`` and
89+
passed as ``attn_mask`` to SDPA. Non-causal only.
8690
8791
Returns:
8892
Output tensor [batch_size, num_heads, seq_len, head_dim]
@@ -99,13 +103,24 @@ def forward(
99103
f"Invalid v shape: expected [B={q.shape[0]}, H_kv, S_kv, D={self.head_dim}], got {v.shape}"
100104
)
101105

106+
enable_gqa = self.num_heads != self.num_kv_heads
107+
if key_padding_mask is not None:
108+
assert not is_causal, "key_padding_mask is not supported with causal attention"
109+
assert key_padding_mask.dim() == 2 and key_padding_mask.shape == (
110+
q.shape[0],
111+
k.shape[2],
112+
), (
113+
f"Invalid key_padding_mask shape: expected [B={q.shape[0]}, "
114+
f"S_kv={k.shape[2]}], got {tuple(key_padding_mask.shape)}"
115+
)
116+
# [B, S_kv] -> [B, 1, 1, S_kv] so SDPA broadcasts over H and S_q.
117+
attn_mask = key_padding_mask[:, None, None, :]
118+
return F.scaled_dot_product_attention(
119+
q, k, v, attn_mask=attn_mask, scale=self.scale, enable_gqa=enable_gqa
120+
)
121+
102122
return F.scaled_dot_product_attention(
103-
q,
104-
k,
105-
v,
106-
is_causal=is_causal,
107-
scale=self.scale,
108-
enable_gqa=self.num_heads != self.num_kv_heads,
123+
q, k, v, is_causal=is_causal, scale=self.scale, enable_gqa=enable_gqa
109124
)
110125

111126
@property

tensorrt_llm/_torch/visual_gen/models/ltx2/ltx2_core/transformer_args.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ class TransformerArgs:
3232
cross_scale_shift_timestep: torch.Tensor | None
3333
cross_gate_timestep: torch.Tensor | None
3434
enabled: bool
35+
# Optional [B, S_full_padded] bool mask (True=valid, False=pad) for the
36+
# audio modality when Ulysses padding is engaged (T_a padded to be
37+
# divisible by ulysses_size). Identical across Ulysses ranks (full-seq).
38+
# None when no padding is applied.
39+
audio_padding_mask: torch.Tensor | None = None
3540

3641

3742
class TransformerArgsPreprocessor:

0 commit comments

Comments
 (0)