@@ -27,54 +27,47 @@ TRTLLM_NAMESPACE_BEGIN
2727
2828namespace kernels
2929{
30- // / fp32 indexer TopK decode — L2-aware BS-threshold dispatcher with four
31- // / fallback tiers:
32- // / - GVR Heuristic (preIdx provided, kSeqSmall ≤ N < splitWork, BS < kBsLarge, K ∈ {512,1024,2048})
33- // / - Insertion sort (N < kSortingAlgorithmThreshold)
34- // / - Radix sort (kSortingAlgorithmThreshold ≤ N < splitWork)
35- // / - Radix split-work (N ≥ splitWork — uses outLogitsAux / outIndicesAux)
36- void invokeIndexerTopKDecode (float const * logits, int const * seqLens, int * indices, float * outLogitsAux,
37- int * outIndicesAux, int const splitWorkThreshold, int const numRows, int const numColumns, int const stride0,
38- int const stride1, int const next_n, int const topK = 2048 , int const * preIdx = nullptr , int const preIdxStride = 0 ,
39- int const preIdxCount = 0 , float * heuristicScratch = nullptr , cudaStream_t const stream = 0 );
40-
41- // / bf16 indexer TopK decode — same dispatch axes as the fp32 entry, except
42- // / kBsL2 uses sizeof(__nv_bfloat16) bytes/elem (L2 footprint is half) and
43- // / the split-work tier is unsupported (the bf16/fp16 entry does not expose
44- // / the float aux buffers required for split-work). Insertion + radix tiers
45- // / share topKPerRowDecode with fp32 — histogram and sort run on float keys
46- // / after a static_cast<float>(InputT) at HBM-read sites.
30+ // / Indexer TopK decode. Three tiers:
31+ // / - GVR Heuristic (preIdx provided, K in {512,1024,2048}, numColumns in
32+ // / [kSeqSmall, splitWorkThreshold), numRows below the
33+ // / architecture-derived wave/L2 bound).
34+ // / - Single-block (numColumns < split-work threshold)
35+ // / - Multi-pass radix (numColumns >= split-work threshold; requires
36+ // / `scratch` sized via indexerTopKDecodeScratchBytes,
37+ // / zero-init on first call and may be reused).
4738// /
48- // / Aborts with TLLM_CHECK if numColumns ≥ splitWorkThreshold; callers in
49- // / that regime must use the fp32 entry.
39+ // / `is_prefill = true` forces single-block (split-work suppressed).
40+ void invokeIndexerTopKDecode (float const * logits, int const * seqLens, int * indices, int const splitWorkThreshold,
41+ int const numRows, int const numColumns, int const stride0, int const stride1, int const next_n,
42+ int const topK = 2048 , int const * preIdx = nullptr , int const preIdxStride = 0 , int const preIdxCount = 0 ,
43+ float * heuristicScratch = nullptr , cudaStream_t const stream = 0 , void * scratch = nullptr , size_t scratchBytes = 0 ,
44+ bool is_prefill = false );
45+
46+ // / Size of the multi-pass radix `scratch` buffer for these shapes.
47+ size_t indexerTopKDecodeScratchBytes (int numRows, int numColumns, int topK);
48+
49+ // / bf16 overload; same contract.
5050void invokeIndexerTopKDecode (__nv_bfloat16 const * logits, int const * seqLens, int * indices,
5151 int const splitWorkThreshold, int const numRows, int const numColumns, int const stride0, int const stride1,
5252 int const next_n, int const topK = 2048 , int const * preIdx = nullptr , int const preIdxStride = 0 ,
53- int const preIdxCount = 0 , __nv_bfloat16* heuristicScratch = nullptr , cudaStream_t const stream = 0 );
53+ int const preIdxCount = 0 , __nv_bfloat16* heuristicScratch = nullptr , cudaStream_t const stream = 0 ,
54+ void * scratch = nullptr , size_t scratchBytes = 0 , bool is_prefill = false );
5455
55- // / fp16 indexer TopK decode — see bf16 overload for dispatcher contract.
56+ // / fp16 overload; same contract.
5657void invokeIndexerTopKDecode (__half const * logits, int const * seqLens, int * indices, int const splitWorkThreshold,
5758 int const numRows, int const numColumns, int const stride0, int const stride1, int const next_n,
5859 int const topK = 2048 , int const * preIdx = nullptr , int const preIdxStride = 0 , int const preIdxCount = 0 ,
59- __half* heuristicScratch = nullptr , cudaStream_t const stream = 0 );
60+ __half* heuristicScratch = nullptr , cudaStream_t const stream = 0 , void * scratch = nullptr , size_t scratchBytes = 0 ,
61+ bool is_prefill = false );
6062
6163void invokeIndexerTopKPrefill (float const * logits, int const * rowStarts, int const * rowEnds, int * indices,
6264 int const numRows, int const numColumns, int const stride0, int const stride1, int const topK = 2048 ,
6365 cudaStream_t const stream = 0 );
6466
65- // / Returns true iff invokeIndexerTopKDecode would route to the GVR Heuristic
66- // / kernel for this (numRows, numColumns, topK) triple, assuming valid preIdx
67- // / is provided and stride1 == 1. Useful for callers that need to provision a
68- // / preIdx tensor or heuristicScratch buffer only when GVR will be selected.
69- // /
70- // / Mirrors the gating logic of the dispatcher: K ∈ {512, 1024, 2048},
71- // / numColumns ∈ [kSeqSmall, splitWorkThreshold), numRows < kBsLarge, where
72- // / kBsLarge = min(kBsWave, kBsL2) and kBsL2 scales with bytesPerElem.
73- // /
74- // / @param numRows logits rows (batch · next_n)
75- // / @param numColumns logits columns (max sequence length)
76- // / @param topK requested output size
77- // / @param bytesPerElem element size of logits (4 for fp32, 2 for bf16/fp16)
67+ // / True iff invokeIndexerTopKDecode would pick the GVR tier for this shape:
68+ // / K in {512,1024,2048}, numColumns in [kSeqSmall, splitWorkThreshold), and
69+ // / numRows below the architecture-derived wave/L2 bound. Lets callers
70+ // / provision preIdx / heuristicScratch only when needed.
7871bool canIndexerTopKDecodeUseGvr (int numRows, int numColumns, int topK, int bytesPerElem = 4 );
7972
8073} // namespace kernels
0 commit comments