From b6e6ae09f83c63e6e8567af05ee89ead2c430b46 Mon Sep 17 00:00:00 2001 From: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> Date: Fri, 3 Apr 2026 14:39:36 +0000 Subject: [PATCH] retune causalConv1d fwd dispatch for varlen and short sequences Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com> --- .../kernels/causalConv1d/causalConv1d.cu | 33 ++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu b/cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu index a5f22858ac6..faa1f2d9fca 100644 --- a/cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu +++ b/cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu @@ -4,7 +4,7 @@ * and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu * Copyright (c) 2024, Tri Dao. * - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -349,20 +349,45 @@ void causal_conv1d_fwd_launch(ConvParamsBase& params, cudaStream_t stream) }); } +template +void causal_conv1d_fwd_dispatch(ConvParamsBase& params, cudaStream_t stream) +{ + bool const isVarlen = params.query_start_loc_ptr != nullptr; + constexpr int kNarrowThreads = 64; + constexpr int kWideThreads = 128; + constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; + constexpr int kShortSeqThreshold = kNarrowThreads * kNElts; + // Varlen prefill launches one block per sequence/channel pair, so the per-sequence + // work is usually much smaller than params.seqlen suggests. That path also disables + // the wide vector-load specialization, so the 128-thread kernel tends to overprovision + // threads for many short chunks. Prefer the narrower launch for varlen and for short + // fixed-length inputs; keep the wider launch for long dense sequences. + bool const preferNarrowKernel = isVarlen || params.seqlen <= kShortSeqThreshold; + + if (preferNarrowKernel) + { + causal_conv1d_fwd_launch(params, stream); + } + else + { + causal_conv1d_fwd_launch(params, stream); + } +} + template void causal_conv1d_fwd_cuda(ConvParamsBase& params, cudaStream_t stream) { if (params.width == 2) { - causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); + causal_conv1d_fwd_dispatch<2, input_t, weight_t>(params, stream); } else if (params.width == 3) { - causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); + causal_conv1d_fwd_dispatch<3, input_t, weight_t>(params, stream); } else if (params.width == 4) { - causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); + causal_conv1d_fwd_dispatch<4, input_t, weight_t>(params, stream); } }