Skip to content

Commit dc3a468

Browse files
committed
Optimize qwen3.5 prefill kernels
- update chunked Gated Delta Rule prefill to use indexed in-kernel state updates - remove explicit Qwen3Next prefill state gather/scatter in forward_extend - retune causalConv1d forward launch selection for varlen and short sequences Signed-off-by: Anthony Chang <27950904+rosenrodt@users.noreply.github.com>
1 parent 9a2c39d commit dc3a468

4 files changed

Lines changed: 91 additions & 19 deletions

File tree

cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
* and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
55
* Copyright (c) 2024, Tri Dao.
66
*
7-
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
7+
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
88
*
99
* Licensed under the Apache License, Version 2.0 (the "License");
1010
* you may not use this file except in compliance with the License.
@@ -349,20 +349,45 @@ void causal_conv1d_fwd_launch(ConvParamsBase& params, cudaStream_t stream)
349349
});
350350
}
351351

352+
template <int kWidth, typename input_t, typename weight_t>
353+
void causal_conv1d_fwd_dispatch(ConvParamsBase& params, cudaStream_t stream)
354+
{
355+
bool const isVarlen = params.query_start_loc_ptr != nullptr;
356+
constexpr int kNarrowThreads = 64;
357+
constexpr int kWideThreads = 128;
358+
constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
359+
constexpr int kShortSeqThreshold = kNarrowThreads * kNElts;
360+
// Varlen prefill launches one block per sequence/channel pair, so the per-sequence
361+
// work is usually much smaller than params.seqlen suggests. That path also disables
362+
// the wide vector-load specialization, so the 128-thread kernel tends to overprovision
363+
// threads for many short chunks. Prefer the narrower launch for varlen and for short
364+
// fixed-length inputs; keep the wider launch for long dense sequences.
365+
bool const preferNarrowKernel = isVarlen || params.seqlen <= kShortSeqThreshold;
366+
367+
if (preferNarrowKernel)
368+
{
369+
causal_conv1d_fwd_launch<kNarrowThreads, kWidth, input_t, weight_t>(params, stream);
370+
}
371+
else
372+
{
373+
causal_conv1d_fwd_launch<kWideThreads, kWidth, input_t, weight_t>(params, stream);
374+
}
375+
}
376+
352377
template <typename input_t, typename weight_t>
353378
void causal_conv1d_fwd_cuda(ConvParamsBase& params, cudaStream_t stream)
354379
{
355380
if (params.width == 2)
356381
{
357-
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
382+
causal_conv1d_fwd_dispatch<2, input_t, weight_t>(params, stream);
358383
}
359384
else if (params.width == 3)
360385
{
361-
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
386+
causal_conv1d_fwd_dispatch<3, input_t, weight_t>(params, stream);
362387
}
363388
else if (params.width == 4)
364389
{
365-
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
390+
causal_conv1d_fwd_dispatch<4, input_t, weight_t>(params, stream);
366391
}
367392
}
368393

tensorrt_llm/_torch/models/modeling_qwen3_next.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -747,23 +747,22 @@ def forward_extend(
747747
g = g.unsqueeze(0)
748748
beta = beta.unsqueeze(0)
749749

750-
recurrent_state = ssm_states[cache_indices]
751-
752-
core_attn_out, last_recurrent_state = chunk_gated_delta_rule(
750+
core_attn_out, _ = chunk_gated_delta_rule(
753751
q=query,
754752
k=key,
755753
v=value,
756754
g=g,
757755
beta=beta,
758-
initial_state=recurrent_state,
759-
output_final_state=True,
756+
initial_state=ssm_states,
757+
initial_state_indices=cache_indices,
758+
# This path writes recurrent state directly back into the shared
759+
# pool; callers **must** ensure cache_indices do not alias live slots.
760+
inplace_indexed_state_update=True,
761+
output_final_state=False,
760762
cu_seqlens=query_start_loc_long,
761763
head_first=False,
762764
use_qk_l2norm_in_kernel=True,
763765
)
764-
last_recurrent_state = last_recurrent_state.to(ssm_states.dtype,
765-
copy=False)
766-
ssm_states[cache_indices] = last_recurrent_state
767766

768767
return core_attn_out
769768

tensorrt_llm/_torch/modules/fla/chunk.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def chunk_gated_delta_rule_fwd(
2929
beta: torch.Tensor,
3030
scale: float,
3131
initial_state: torch.Tensor,
32+
initial_state_indices: Optional[torch.Tensor],
33+
inplace_indexed_state_update: bool,
3234
output_final_state: bool,
3335
cu_seqlens: Optional[torch.LongTensor] = None,
3436
):
@@ -54,7 +56,9 @@ def chunk_gated_delta_rule_fwd(
5456
u=u,
5557
g=g,
5658
initial_state=initial_state,
59+
initial_state_indices=initial_state_indices,
5760
output_final_state=output_final_state,
61+
inplace_indexed_state_update=inplace_indexed_state_update,
5862
cu_seqlens=cu_seqlens,
5963
)
6064
o = chunk_fwd_o(
@@ -86,6 +90,8 @@ def forward(
8690
beta: torch.Tensor,
8791
scale: float,
8892
initial_state: torch.Tensor,
93+
initial_state_indices: Optional[torch.Tensor],
94+
inplace_indexed_state_update: bool,
8995
output_final_state: bool,
9096
cu_seqlens: Optional[torch.LongTensor] = None,
9197
use_qk_l2norm_in_kernel: bool = False,
@@ -102,6 +108,8 @@ def forward(
102108
beta=beta,
103109
scale=scale,
104110
initial_state=initial_state,
111+
initial_state_indices=initial_state_indices,
112+
inplace_indexed_state_update=inplace_indexed_state_update,
105113
output_final_state=output_final_state,
106114
cu_seqlens=cu_seqlens,
107115
)
@@ -117,6 +125,8 @@ def chunk_gated_delta_rule(
117125
beta: torch.Tensor,
118126
scale: float = None,
119127
initial_state: torch.Tensor = None,
128+
initial_state_indices: Optional[torch.Tensor] = None,
129+
inplace_indexed_state_update: bool = False,
120130
output_final_state: bool = False,
121131
cu_seqlens: Optional[torch.LongTensor] = None,
122132
head_first: bool = False,
@@ -141,6 +151,13 @@ def chunk_gated_delta_rule(
141151
Initial state of shape `[N, H, K, V]` for `N` input sequences.
142152
For equal-length input sequences, `N` equals the batch size `B`.
143153
Default: `None`.
154+
initial_state_indices (Optional[torch.Tensor]):
155+
Optional state-pool indices of shape `[N]` selecting the slots to
156+
read from `initial_state`.
157+
inplace_indexed_state_update (Optional[bool]):
158+
Explicit opt-in for writing indexed final states back into
159+
`initial_state` in-place. Callers are responsible for ensuring the
160+
selected slots are safe to update without aliasing races.
144161
output_final_state (Optional[bool]):
145162
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
146163
cu_seqlens (torch.LongTensor):
@@ -211,12 +228,18 @@ def chunk_gated_delta_rule(
211228
raise ValueError(
212229
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
213230
f"Please flatten variable-length inputs before processing.")
214-
if initial_state is not None and initial_state.shape[0] != len(
215-
cu_seqlens) - 1:
231+
num_sequences = len(cu_seqlens) - 1
232+
if initial_state_indices is not None:
233+
if initial_state_indices.shape[0] != num_sequences:
234+
raise ValueError(
235+
f"The number of initial-state indices is expected to be equal to the number of input "
236+
f"sequences, i.e., {num_sequences} rather than {initial_state_indices.shape[0]}."
237+
)
238+
elif initial_state is not None and initial_state.shape[
239+
0] != num_sequences:
216240
raise ValueError(
217241
f"The number of initial states is expected to be equal to the number of input sequences, "
218-
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
219-
)
242+
f"i.e., {num_sequences} rather than {initial_state.shape[0]}.")
220243
if scale is None:
221244
scale = k.shape[-1]**-0.5
222245
o, final_state = ChunkGatedDeltaRuleFunction.apply(
@@ -227,6 +250,8 @@ def chunk_gated_delta_rule(
227250
beta,
228251
scale,
229252
initial_state,
253+
initial_state_indices,
254+
inplace_indexed_state_update,
230255
output_final_state,
231256
cu_seqlens,
232257
use_qk_l2norm_in_kernel,

tensorrt_llm/_torch/modules/fla/chunk_delta_h.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
@triton.heuristics({
2020
"USE_G": lambda args: args["g"] is not None,
2121
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
22+
"USE_INDEXED_STATE": lambda args: args["h0_i"] is not None,
2223
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
2324
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
2425
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
@@ -42,6 +43,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
4243
g,
4344
h,
4445
h0,
46+
h0_i,
4547
ht,
4648
cu_seqlens,
4749
chunk_offsets,
@@ -54,6 +56,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
5456
BV: tl.constexpr,
5557
USE_G: tl.constexpr,
5658
USE_INITIAL_STATE: tl.constexpr,
59+
USE_INDEXED_STATE: tl.constexpr,
5760
STORE_FINAL_STATE: tl.constexpr,
5861
SAVE_NEW_VALUE: tl.constexpr,
5962
IS_VARLEN: tl.constexpr,
@@ -91,10 +94,16 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
9194
stride_h = H * K * V
9295
stride_k = Hg * K
9396
stride_w = H * K
97+
if USE_INDEXED_STATE:
98+
state_index = tl.load(h0_i + i_n).to(tl.int64)
99+
h0 = h0 + state_index * stride_h
100+
ht = h0
94101
if USE_INITIAL_STATE:
95-
h0 = h0 + i_nh * K * V
102+
h0 = h0 + ((i_h if USE_INDEXED_STATE else i_nh) * K * V)
96103
if STORE_FINAL_STATE:
97104
ht = ht + i_nh * K * V
105+
elif USE_INDEXED_STATE:
106+
ht = ht + i_h * K * V
98107

99108
# load initial state
100109
if USE_INITIAL_STATE:
@@ -209,7 +218,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
209218
b_h4 += tl.dot(b_k, b_v_new)
210219

211220
# epilogue
212-
if STORE_FINAL_STATE:
221+
if STORE_FINAL_STATE or USE_INDEXED_STATE:
213222
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV),
214223
(1, 0))
215224
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
@@ -239,7 +248,9 @@ def chunk_gated_delta_rule_fwd_h(
239248
u: torch.Tensor,
240249
g: Optional[torch.Tensor] = None,
241250
initial_state: Optional[torch.Tensor] = None,
251+
initial_state_indices: Optional[torch.Tensor] = None,
242252
output_final_state: bool = False,
253+
inplace_indexed_state_update: bool = False,
243254
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
244255
save_new_value: bool = True,
245256
cu_seqlens: Optional[torch.LongTensor] = None,
@@ -262,8 +273,14 @@ def chunk_gated_delta_rule_fwd_h(
262273
assert K <= 256, "current kernel does not support head dimension larger than 256."
263274

264275
h = k.new_empty(B, NT, H, K, V)
276+
use_indexed_state = initial_state is not None and initial_state_indices is not None
277+
if use_indexed_state and not inplace_indexed_state_update:
278+
raise ValueError(
279+
"Indexed chunk state updates require inplace_indexed_state_update=True."
280+
)
281+
store_final_state_in_kernel = output_final_state and not use_indexed_state
265282
final_state = (k.new_empty(N, H, K, V, dtype=torch.float32)
266-
if output_final_state else None)
283+
if store_final_state_in_kernel else None)
267284

268285
v_new = torch.empty_like(u) if save_new_value else None
269286

@@ -278,6 +295,7 @@ def grid(meta):
278295
g=g,
279296
h=h,
280297
h0=initial_state,
298+
h0_i=initial_state_indices,
281299
ht=final_state,
282300
cu_seqlens=cu_seqlens,
283301
chunk_offsets=chunk_offsets,
@@ -291,4 +309,9 @@ def grid(meta):
291309
num_warps=4,
292310
num_stages=2,
293311
)
312+
if output_final_state and use_indexed_state:
313+
# The indexed kernel path updates h0 in-place, so returning
314+
# the final state means gathering those updated slots back out.
315+
final_state = initial_state.index_select(
316+
0, initial_state_indices.to(torch.long))
294317
return h, v_new, final_state

0 commit comments

Comments
 (0)