|
1 | 1 | /* |
2 | | - * Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | + * Copyright (c) 2019-2026, NVIDIA CORPORATION. All rights reserved. |
3 | 3 | * Copyright (c) 2021, NAVER Corp. Authored by CLOVA. |
4 | 4 | * |
5 | 5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
@@ -27,47 +27,63 @@ TRTLLM_NAMESPACE_BEGIN |
27 | 27 |
|
28 | 28 | namespace kernels |
29 | 29 | { |
30 | | -/// Indexer TopK decode. Three tiers: |
31 | | -/// - GVR Heuristic (preIdx provided, K in {512,1024,2048}, numColumns in |
32 | | -/// [kSeqSmall, splitWorkThreshold), numRows below the |
33 | | -/// architecture-derived wave/L2 bound). |
34 | | -/// - Single-block (numColumns < split-work threshold) |
35 | | -/// - Multi-pass radix (numColumns >= split-work threshold; requires |
36 | | -/// `scratch` sized via indexerTopKDecodeScratchBytes, |
37 | | -/// zero-init on first call and may be reused). |
38 | | -/// |
39 | | -/// `is_prefill = true` forces single-block (split-work suppressed). |
40 | | -void invokeIndexerTopKDecode(float const* logits, int const* seqLens, int* indices, int const splitWorkThreshold, |
41 | | - int const numRows, int const numColumns, int const stride0, int const stride1, int const next_n, |
42 | | - int const topK = 2048, int const* preIdx = nullptr, int const preIdxStride = 0, int const preIdxCount = 0, |
43 | | - float* heuristicScratch = nullptr, cudaStream_t const stream = 0, void* scratch = nullptr, size_t scratchBytes = 0, |
44 | | - bool is_prefill = false); |
| 30 | +// Number of blocks-per-row used by the multi-block split + merge dispatch path of |
| 31 | +// invokeIndexerTopKDecode. Returns 1 when the single-block path is preferred. |
| 32 | +// Callers that allocate aux buffers must use this same helper to size them, and |
| 33 | +// must pass the same splitWorkThreshold they will pass to invokeIndexerTopKDecode |
| 34 | +// (a value <= 0 selects the internal default). |
| 35 | +int computeIndexerTopKDecodeBlocksPerRow(int numRows, int numColumns, int splitWorkThreshold = 0); |
45 | 36 |
|
46 | | -/// Size of the multi-pass radix `scratch` buffer for these shapes. |
47 | | -size_t indexerTopKDecodeScratchBytes(int numRows, int numColumns, int topK); |
| 37 | +/// fp32 indexer TopK decode — L2-aware BS-threshold dispatcher with four |
| 38 | +/// fallback tiers: |
| 39 | +/// - GVR Heuristic (preIdx provided, kSeqSmall ≤ N < splitWork, BS < kBsLarge, K ∈ {512,1024,2048}) |
| 40 | +/// - Insertion sort (N < kSortingAlgorithmThreshold) |
| 41 | +/// - Radix sort (kSortingAlgorithmThreshold ≤ N < splitWork) |
| 42 | +/// - Radix split-work (N ≥ splitWork — uses outLogitsAux / outIndicesAux) |
| 43 | +void invokeIndexerTopKDecode(float const* logits, int const* seqLens, int* indices, float* outLogitsAux, |
| 44 | + int* outIndicesAux, int const splitWorkThreshold, int const numRows, int const numColumns, int const stride0, |
| 45 | + int const stride1, int const next_n, int const topK = 2048, int const* preIdx = nullptr, int const preIdxStride = 0, |
| 46 | + int const preIdxCount = 0, float* heuristicScratch = nullptr, int const compressRatio = 1, |
| 47 | + cudaStream_t const stream = 0); |
48 | 48 |
|
49 | | -/// bf16 overload; same contract. |
| 49 | +/// bf16 indexer TopK decode — same dispatch axes as the fp32 entry, except |
| 50 | +/// kBsL2 uses sizeof(__nv_bfloat16) bytes/elem (L2 footprint is half) and |
| 51 | +/// the split-work tier is unsupported (the bf16/fp16 entry does not expose |
| 52 | +/// the float aux buffers required for split-work). Insertion + radix tiers |
| 53 | +/// share topKPerRowDecode with fp32 — histogram and sort run on float keys |
| 54 | +/// after a static_cast<float>(InputT) at HBM-read sites. |
| 55 | +/// |
| 56 | +/// Aborts with TLLM_CHECK if numColumns ≥ splitWorkThreshold; callers in |
| 57 | +/// that regime must use the fp32 entry. |
50 | 58 | void invokeIndexerTopKDecode(__nv_bfloat16 const* logits, int const* seqLens, int* indices, |
51 | 59 | int const splitWorkThreshold, int const numRows, int const numColumns, int const stride0, int const stride1, |
52 | 60 | int const next_n, int const topK = 2048, int const* preIdx = nullptr, int const preIdxStride = 0, |
53 | | - int const preIdxCount = 0, __nv_bfloat16* heuristicScratch = nullptr, cudaStream_t const stream = 0, |
54 | | - void* scratch = nullptr, size_t scratchBytes = 0, bool is_prefill = false); |
| 61 | + int const preIdxCount = 0, __nv_bfloat16* heuristicScratch = nullptr, int const compressRatio = 1, |
| 62 | + cudaStream_t const stream = 0); |
55 | 63 |
|
56 | | -/// fp16 overload; same contract. |
| 64 | +/// fp16 indexer TopK decode — see bf16 overload for dispatcher contract. |
57 | 65 | void invokeIndexerTopKDecode(__half const* logits, int const* seqLens, int* indices, int const splitWorkThreshold, |
58 | 66 | int const numRows, int const numColumns, int const stride0, int const stride1, int const next_n, |
59 | 67 | int const topK = 2048, int const* preIdx = nullptr, int const preIdxStride = 0, int const preIdxCount = 0, |
60 | | - __half* heuristicScratch = nullptr, cudaStream_t const stream = 0, void* scratch = nullptr, size_t scratchBytes = 0, |
61 | | - bool is_prefill = false); |
| 68 | + __half* heuristicScratch = nullptr, int const compressRatio = 1, cudaStream_t const stream = 0); |
62 | 69 |
|
63 | 70 | void invokeIndexerTopKPrefill(float const* logits, int const* rowStarts, int const* rowEnds, int* indices, |
64 | 71 | int const numRows, int const numColumns, int const stride0, int const stride1, int const topK = 2048, |
65 | 72 | cudaStream_t const stream = 0); |
66 | 73 |
|
67 | | -/// True iff invokeIndexerTopKDecode would pick the GVR tier for this shape: |
68 | | -/// K in {512,1024,2048}, numColumns in [kSeqSmall, splitWorkThreshold), and |
69 | | -/// numRows below the architecture-derived wave/L2 bound. Lets callers |
70 | | -/// provision preIdx / heuristicScratch only when needed. |
| 74 | +/// Returns true iff invokeIndexerTopKDecode would route to the GVR Heuristic |
| 75 | +/// kernel for this (numRows, numColumns, topK) triple, assuming valid preIdx |
| 76 | +/// is provided and stride1 == 1. Useful for callers that need to provision a |
| 77 | +/// preIdx tensor or heuristicScratch buffer only when GVR will be selected. |
| 78 | +/// |
| 79 | +/// Mirrors the gating logic of the dispatcher: K ∈ {512, 1024, 2048}, |
| 80 | +/// numColumns ∈ [kSeqSmall, splitWorkThreshold), numRows < kBsLarge, where |
| 81 | +/// kBsLarge = min(kBsWave, kBsL2) and kBsL2 scales with bytesPerElem. |
| 82 | +/// |
| 83 | +/// @param numRows logits rows (batch · next_n) |
| 84 | +/// @param numColumns logits columns (max sequence length) |
| 85 | +/// @param topK requested output size |
| 86 | +/// @param bytesPerElem element size of logits (4 for fp32, 2 for bf16/fp16) |
71 | 87 | bool canIndexerTopKDecodeUseGvr(int numRows, int numColumns, int topK, int bytesPerElem = 4); |
72 | 88 |
|
73 | 89 | } // namespace kernels |
|
0 commit comments