|
4 | 4 | * and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu |
5 | 5 | * Copyright (c) 2024, Tri Dao. |
6 | 6 | * |
7 | | - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. |
| 7 | + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. |
8 | 8 | * |
9 | 9 | * Licensed under the Apache License, Version 2.0 (the "License"); |
10 | 10 | * you may not use this file except in compliance with the License. |
@@ -349,20 +349,45 @@ void causal_conv1d_fwd_launch(ConvParamsBase& params, cudaStream_t stream) |
349 | 349 | }); |
350 | 350 | } |
351 | 351 |
|
| 352 | +template <int kWidth, typename input_t, typename weight_t> |
| 353 | +void causal_conv1d_fwd_dispatch(ConvParamsBase& params, cudaStream_t stream) |
| 354 | +{ |
| 355 | + bool const isVarlen = params.query_start_loc_ptr != nullptr; |
| 356 | + constexpr int kNarrowThreads = 64; |
| 357 | + constexpr int kWideThreads = 128; |
| 358 | + constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; |
| 359 | + constexpr int kShortSeqThreshold = kNarrowThreads * kNElts; |
| 360 | + // Varlen prefill launches one block per sequence/channel pair, so the per-sequence |
| 361 | + // work is usually much smaller than params.seqlen suggests. That path also disables |
| 362 | + // the wide vector-load specialization, so the 128-thread kernel tends to overprovision |
| 363 | + // threads for many short chunks. Prefer the narrower launch for varlen and for short |
| 364 | + // fixed-length inputs; keep the wider launch for long dense sequences. |
| 365 | + bool const preferNarrowKernel = isVarlen || params.seqlen <= kShortSeqThreshold; |
| 366 | + |
| 367 | + if (preferNarrowKernel) |
| 368 | + { |
| 369 | + causal_conv1d_fwd_launch<kNarrowThreads, kWidth, input_t, weight_t>(params, stream); |
| 370 | + } |
| 371 | + else |
| 372 | + { |
| 373 | + causal_conv1d_fwd_launch<kWideThreads, kWidth, input_t, weight_t>(params, stream); |
| 374 | + } |
| 375 | +} |
| 376 | + |
352 | 377 | template <typename input_t, typename weight_t> |
353 | 378 | void causal_conv1d_fwd_cuda(ConvParamsBase& params, cudaStream_t stream) |
354 | 379 | { |
355 | 380 | if (params.width == 2) |
356 | 381 | { |
357 | | - causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); |
| 382 | + causal_conv1d_fwd_dispatch<2, input_t, weight_t>(params, stream); |
358 | 383 | } |
359 | 384 | else if (params.width == 3) |
360 | 385 | { |
361 | | - causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); |
| 386 | + causal_conv1d_fwd_dispatch<3, input_t, weight_t>(params, stream); |
362 | 387 | } |
363 | 388 | else if (params.width == 4) |
364 | 389 | { |
365 | | - causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); |
| 390 | + causal_conv1d_fwd_dispatch<4, input_t, weight_t>(params, stream); |
366 | 391 | } |
367 | 392 | } |
368 | 393 |
|
|
0 commit comments