Skip to content

Commit b206f68

Browse files
authored
[None][feat] Indexer TopK: single-block / multi-pass radix (#14268)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
1 parent 31e730a commit b206f68

4 files changed

Lines changed: 765 additions & 364 deletions

File tree

cpp/tensorrt_llm/kernels/IndexerTopK.h

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -27,54 +27,47 @@ TRTLLM_NAMESPACE_BEGIN
2727

2828
namespace kernels
2929
{
30-
/// fp32 indexer TopK decode — L2-aware BS-threshold dispatcher with four
31-
/// fallback tiers:
32-
/// - GVR Heuristic (preIdx provided, kSeqSmall ≤ N < splitWork, BS < kBsLarge, K ∈ {512,1024,2048})
33-
/// - Insertion sort (N < kSortingAlgorithmThreshold)
34-
/// - Radix sort (kSortingAlgorithmThreshold ≤ N < splitWork)
35-
/// - Radix split-work (N ≥ splitWork — uses outLogitsAux / outIndicesAux)
36-
void invokeIndexerTopKDecode(float const* logits, int const* seqLens, int* indices, float* outLogitsAux,
37-
int* outIndicesAux, int const splitWorkThreshold, int const numRows, int const numColumns, int const stride0,
38-
int const stride1, int const next_n, int const topK = 2048, int const* preIdx = nullptr, int const preIdxStride = 0,
39-
int const preIdxCount = 0, float* heuristicScratch = nullptr, cudaStream_t const stream = 0);
40-
41-
/// bf16 indexer TopK decode — same dispatch axes as the fp32 entry, except
42-
/// kBsL2 uses sizeof(__nv_bfloat16) bytes/elem (L2 footprint is half) and
43-
/// the split-work tier is unsupported (the bf16/fp16 entry does not expose
44-
/// the float aux buffers required for split-work). Insertion + radix tiers
45-
/// share topKPerRowDecode with fp32 — histogram and sort run on float keys
46-
/// after a static_cast<float>(InputT) at HBM-read sites.
30+
/// Indexer TopK decode. Three tiers:
31+
/// - GVR Heuristic (preIdx provided, K in {512,1024,2048}, numColumns in
32+
/// [kSeqSmall, splitWorkThreshold), numRows below the
33+
/// architecture-derived wave/L2 bound).
34+
/// - Single-block (numColumns < split-work threshold)
35+
/// - Multi-pass radix (numColumns >= split-work threshold; requires
36+
/// `scratch` sized via indexerTopKDecodeScratchBytes,
37+
/// zero-init on first call and may be reused).
4738
///
48-
/// Aborts with TLLM_CHECK if numColumns ≥ splitWorkThreshold; callers in
49-
/// that regime must use the fp32 entry.
39+
/// `is_prefill = true` forces single-block (split-work suppressed).
40+
void invokeIndexerTopKDecode(float const* logits, int const* seqLens, int* indices, int const splitWorkThreshold,
41+
int const numRows, int const numColumns, int const stride0, int const stride1, int const next_n,
42+
int const topK = 2048, int const* preIdx = nullptr, int const preIdxStride = 0, int const preIdxCount = 0,
43+
float* heuristicScratch = nullptr, cudaStream_t const stream = 0, void* scratch = nullptr, size_t scratchBytes = 0,
44+
bool is_prefill = false);
45+
46+
/// Size of the multi-pass radix `scratch` buffer for these shapes.
47+
size_t indexerTopKDecodeScratchBytes(int numRows, int numColumns, int topK);
48+
49+
/// bf16 overload; same contract.
5050
void invokeIndexerTopKDecode(__nv_bfloat16 const* logits, int const* seqLens, int* indices,
5151
int const splitWorkThreshold, int const numRows, int const numColumns, int const stride0, int const stride1,
5252
int const next_n, int const topK = 2048, int const* preIdx = nullptr, int const preIdxStride = 0,
53-
int const preIdxCount = 0, __nv_bfloat16* heuristicScratch = nullptr, cudaStream_t const stream = 0);
53+
int const preIdxCount = 0, __nv_bfloat16* heuristicScratch = nullptr, cudaStream_t const stream = 0,
54+
void* scratch = nullptr, size_t scratchBytes = 0, bool is_prefill = false);
5455

55-
/// fp16 indexer TopK decode — see bf16 overload for dispatcher contract.
56+
/// fp16 overload; same contract.
5657
void invokeIndexerTopKDecode(__half const* logits, int const* seqLens, int* indices, int const splitWorkThreshold,
5758
int const numRows, int const numColumns, int const stride0, int const stride1, int const next_n,
5859
int const topK = 2048, int const* preIdx = nullptr, int const preIdxStride = 0, int const preIdxCount = 0,
59-
__half* heuristicScratch = nullptr, cudaStream_t const stream = 0);
60+
__half* heuristicScratch = nullptr, cudaStream_t const stream = 0, void* scratch = nullptr, size_t scratchBytes = 0,
61+
bool is_prefill = false);
6062

6163
void invokeIndexerTopKPrefill(float const* logits, int const* rowStarts, int const* rowEnds, int* indices,
6264
int const numRows, int const numColumns, int const stride0, int const stride1, int const topK = 2048,
6365
cudaStream_t const stream = 0);
6466

65-
/// Returns true iff invokeIndexerTopKDecode would route to the GVR Heuristic
66-
/// kernel for this (numRows, numColumns, topK) triple, assuming valid preIdx
67-
/// is provided and stride1 == 1. Useful for callers that need to provision a
68-
/// preIdx tensor or heuristicScratch buffer only when GVR will be selected.
69-
///
70-
/// Mirrors the gating logic of the dispatcher: K ∈ {512, 1024, 2048},
71-
/// numColumns ∈ [kSeqSmall, splitWorkThreshold), numRows < kBsLarge, where
72-
/// kBsLarge = min(kBsWave, kBsL2) and kBsL2 scales with bytesPerElem.
73-
///
74-
/// @param numRows logits rows (batch · next_n)
75-
/// @param numColumns logits columns (max sequence length)
76-
/// @param topK requested output size
77-
/// @param bytesPerElem element size of logits (4 for fp32, 2 for bf16/fp16)
67+
/// True iff invokeIndexerTopKDecode would pick the GVR tier for this shape:
68+
/// K in {512,1024,2048}, numColumns in [kSeqSmall, splitWorkThreshold), and
69+
/// numRows below the architecture-derived wave/L2 bound. Lets callers
70+
/// provision preIdx / heuristicScratch only when needed.
7871
bool canIndexerTopKDecodeUseGvr(int numRows, int numColumns, int topK, int bytesPerElem = 4);
7972

8073
} // namespace kernels

0 commit comments

Comments
 (0)