Skip to content

Commit fea90fa

Browse files
nv-guomingzsuyoggupta
authored andcommitted
[None][feat] retune causalConv1d fwd dispatch for varlen and short sequences (NVIDIA#12739)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
1 parent 0bcc709 commit fea90fa

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
* and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
55
* Copyright (c) 2024, Tri Dao.
66
*
7-
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
7+
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
88
*
99
* Licensed under the Apache License, Version 2.0 (the "License");
1010
* you may not use this file except in compliance with the License.
@@ -349,20 +349,45 @@ void causal_conv1d_fwd_launch(ConvParamsBase& params, cudaStream_t stream)
349349
});
350350
}
351351

352+
template <int kWidth, typename input_t, typename weight_t>
353+
void causal_conv1d_fwd_dispatch(ConvParamsBase& params, cudaStream_t stream)
354+
{
355+
bool const isVarlen = params.query_start_loc_ptr != nullptr;
356+
constexpr int kNarrowThreads = 64;
357+
constexpr int kWideThreads = 128;
358+
constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
359+
constexpr int kShortSeqThreshold = kNarrowThreads * kNElts;
360+
// Varlen prefill launches one block per sequence/channel pair, so the per-sequence
361+
// work is usually much smaller than params.seqlen suggests. That path also disables
362+
// the wide vector-load specialization, so the 128-thread kernel tends to overprovision
363+
// threads for many short chunks. Prefer the narrower launch for varlen and for short
364+
// fixed-length inputs; keep the wider launch for long dense sequences.
365+
bool const preferNarrowKernel = isVarlen || params.seqlen <= kShortSeqThreshold;
366+
367+
if (preferNarrowKernel)
368+
{
369+
causal_conv1d_fwd_launch<kNarrowThreads, kWidth, input_t, weight_t>(params, stream);
370+
}
371+
else
372+
{
373+
causal_conv1d_fwd_launch<kWideThreads, kWidth, input_t, weight_t>(params, stream);
374+
}
375+
}
376+
352377
template <typename input_t, typename weight_t>
353378
void causal_conv1d_fwd_cuda(ConvParamsBase& params, cudaStream_t stream)
354379
{
355380
if (params.width == 2)
356381
{
357-
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
382+
causal_conv1d_fwd_dispatch<2, input_t, weight_t>(params, stream);
358383
}
359384
else if (params.width == 3)
360385
{
361-
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
386+
causal_conv1d_fwd_dispatch<3, input_t, weight_t>(params, stream);
362387
}
363388
else if (params.width == 4)
364389
{
365-
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
390+
causal_conv1d_fwd_dispatch<4, input_t, weight_t>(params, stream);
366391
}
367392
}
368393

0 commit comments

Comments
 (0)