-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Expand file tree
/
Copy pathdynamicTreeKernels.cu
More file actions
1412 lines (1258 loc) · 56.8 KB
/
Copy pathdynamicTreeKernels.cu
File metadata and controls
1412 lines (1258 loc) · 56.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Copyright (c) 2024-2026, NVIDIA CORPORATION. All rights reserved.
* Portions Copyright (c) 2025 by SGLang team (original implementation).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif
#include "dynamicTreeKernels.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
#include "tensorrt_llm/common/vec_dtypes.cuh"
#include "tensorrt_llm/kernels/decodingCommon.h"
#include <ATen/cuda/CUDAContext.h>
#include <algorithm>
#include <cfloat>
#include <cstdint>
#include <limits>
#include <torch/extension.h>
TRTLLM_NAMESPACE_BEGIN
using namespace tensorrt_llm::common;
using namespace tensorrt_llm::runtime;
namespace kernels::speculative_decoding
{
// ---------------------------------------------------------------------------
// Two-stage top-k / top-p masking kernels
// Mirrors the approach in invokeBatchTopKSampling (samplingTopKKernels.cu),
// but outputs a masked logits tensor instead of sampling a token.
// ---------------------------------------------------------------------------
// Stage 1: Parallel top-k reduction across BLOCKS_PER_BEAM_ blocks per row.
// Each block handles (vocabSize / BLOCKS_PER_BEAM_) elements and finds its
// local top-k, writing (global_index, logit_value) pairs into the tmp buffers.
template <typename T, int32_t BLOCK_SIZE_, int32_t BLOCKS_PER_BEAM_>
__global__ void topKProbStage1(T const* __restrict__ logits, T* tmpLogProbs, int32_t* topKTmpIdBuf, T* topKTmpValBuf,
int32_t maxTopK, int32_t const* topKs, int32_t vocabSize)
{
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce;
__shared__ typename BlockReduce::TempStorage tempStorage;
auto const tid = static_cast<int32_t>(threadIdx.x);
auto const bid = static_cast<int32_t>(blockIdx.x);
auto const rowId = bid / BLOCKS_PER_BEAM_;
auto const blockLane = bid % BLOCKS_PER_BEAM_; // chunk index within the row
auto const k = (topKs != nullptr) ? topKs[rowId] : maxTopK;
bool const IS_FP16 = std::is_same<T, half>::value;
T const MAX_T_VAL = IS_FP16 ? HALF_FLT_MAX : FLT_MAX;
// Base offset into the flat (nRows * vocabSize) logits array for this row.
auto const rowOffset = rowId * vocabSize;
// Base offset into the tmp buffers for this (row, blockLane).
auto const tmpIdxBase = rowId * BLOCKS_PER_BEAM_ * maxTopK + blockLane * k;
// Copy this block's chunk of logits into tmpLogProbs scratch space.
for (auto elemId = tid + blockLane * BLOCK_SIZE_; elemId < vocabSize; elemId += BLOCK_SIZE_ * BLOCKS_PER_BEAM_)
{
tmpLogProbs[rowOffset + elemId] = logits[rowOffset + elemId];
}
__syncthreads();
// Iteratively find the top-k values via max-reduction, zeroing each found max.
TopK_2<T> partial;
for (int32_t ite = 0; ite < k; ite++)
{
partial.init();
for (auto elemId = tid + blockLane * BLOCK_SIZE_; elemId < vocabSize; elemId += BLOCK_SIZE_ * BLOCKS_PER_BEAM_)
{
partial.insert(tmpLogProbs[rowOffset + elemId], rowOffset + elemId);
}
TopK_2<T> total = BlockReduce(tempStorage).Reduce(partial, reduce_topk_op_2<T>);
if (tid == 0)
{
topKTmpIdBuf[tmpIdxBase + ite] = total.p; // global index (rowOffset + vocabIdx)
topKTmpValBuf[tmpIdxBase + ite] = total.u; // logit value
if (total.p >= 0)
{
tmpLogProbs[total.p] = -MAX_T_VAL; // zero out so next iteration finds next-best
}
}
__syncthreads();
}
}
// Stage 2: Merge BLOCKS_PER_BEAM_ * k candidates per row, apply optional top-p,
// then scatter selected logit values back to an output logits tensor (all other
// positions are set to -inf so that a subsequent softmax produces 0 probability).
template <typename T, int32_t BLOCK_SIZE_, int32_t BLOCKS_PER_BEAM_>
__global__ void topKProbStage2ForLogits(int32_t const* __restrict__ topKTmpIdBuf, T* topKTmpValBuf, float* outputLogits,
int32_t maxTopK, int32_t const* topKs, float const* topPs, int32_t vocabSize)
{
bool const IS_FP16 = std::is_same<T, half>::value;
T const MAX_T_VAL = IS_FP16 ? HALF_FLT_MAX : FLT_MAX;
auto const tid = static_cast<int32_t>(threadIdx.x);
auto const rowId = static_cast<int32_t>(blockIdx.x);
auto const k = (topKs != nullptr) ? topKs[rowId] : maxTopK;
// size: number of valid candidates written by Stage 1 for this row.
// stride: row pitch in the tmp buffers (same as in invokeBatchTopKSampling).
auto const size = k * BLOCKS_PER_BEAM_;
auto const stride = maxTopK * BLOCKS_PER_BEAM_;
typedef cub::BlockReduce<TopK_2<float>, BLOCK_SIZE_> BlockReduce;
__shared__ typename BlockReduce::TempStorage tempStorage;
extern __shared__ char sharedArray[];
// Shared layout: sId[maxTopK] | sVal2[maxTopK]
auto* sId = reinterpret_cast<int32_t*>(sharedArray);
auto* sVal2 = reinterpret_cast<float*>(sId + maxTopK);
// Pointer to this row's candidates in the tmp value buffer (modified in-place during reduction).
T* sVal = topKTmpValBuf + rowId * stride;
// Step 1: Initialize output row to -inf (all threads cooperate for bandwidth).
float* outRow = outputLogits + rowId * vocabSize;
float const negInf = -std::numeric_limits<float>::infinity();
for (int32_t i = tid; i < vocabSize; i += BLOCK_SIZE_)
{
outRow[i] = negInf;
}
__syncthreads();
// Step 2: k-round block-reduction over the k * BLOCKS_PER_BEAM_ valid candidates.
// (Only the first 'size' entries of the row's tmp buffer were written by Stage 1.)
TopK_2<float> partial;
__shared__ float sMaxLogit;
for (int32_t ite = 0; ite < k; ite++)
{
partial.init();
for (int32_t i = tid; i < size; i += BLOCK_SIZE_)
{
partial.insert(static_cast<float>(sVal[i]), i);
}
TopK_2<float> total = BlockReduce(tempStorage).Reduce(partial, reduce_topk_op_2<float>);
if (tid == 0)
{
if (ite == 0)
{
sMaxLogit = total.u;
}
sId[ite] = total.p;
sVal[total.p] = -MAX_T_VAL; // zero out so next iteration finds next-best
sVal2[ite] = total.u; // store raw logit value (not exponentiated)
}
__syncthreads();
}
// Step 3: Determine top-p cutoff (tid=0 only).
// sVal2 contains logit values in descending order; we exponentiate to get unnormalized probs.
if (tid == 0)
{
int32_t cutoff = k;
if (topPs != nullptr)
{
float const topP = topPs[rowId];
if (topP < 1.0f)
{
// Compute unnormalized probabilities and their sum.
float sSum = 0.0f;
for (int32_t ki = 0; ki < k; ki++)
{
sVal2[ki] = __expf(sVal2[ki] - sMaxLogit); // reuse sVal2 to hold exp probs
sSum += sVal2[ki];
}
// Walk in descending-probability order; stop as soon as cumulative prob >= topP.
float cumProb = 0.0f;
for (int32_t ki = 0; ki < k; ki++)
{
cumProb += sVal2[ki] / sSum;
if (cumProb >= topP)
{
cutoff = ki + 1; // always keep at least this token
break;
}
}
}
}
// Step 4: Scatter selected logit values back to output.
// topKTmpIdBuf stores (rowOffset + vocabIdx); recover vocabIdx with % vocabSize.
auto const rowStride = rowId * stride;
for (int32_t ki = 0; ki < cutoff; ki++)
{
auto const candidateIdx = sId[ki];
auto const globalIdx = topKTmpIdBuf[rowStride + candidateIdx];
if (globalIdx >= 0)
{
auto const vocabIdx = globalIdx % vocabSize;
// sVal2 was overwritten with exp probs when topP < 1; we need the original logit.
// Re-read from the original tmp buffer — the stored value IS the logit (set in Stage 1).
// However sVal[candidateIdx] was zeroed during Stage 2 reduction; but
// topKTmpValBuf still holds the original value at that index (sVal points there).
// We stored the logit as sVal2[ite] = total.u BEFORE any exp, so if topP was not
// applied we can use sVal2[ki] directly. If topP was applied, sVal2[ki] now holds
// the exp prob — we cannot recover the logit. To handle both cases cleanly we use
// log(sVal2[ki]) + sMaxLogit when topPs was applied, otherwise sVal2[ki] directly.
float logitVal;
if (topPs != nullptr && topPs[rowId] < 1.0f)
{
// sVal2[ki] = exp(logit - sMaxLogit), so logit = log(sVal2[ki]) + sMaxLogit
logitVal = __logf(sVal2[ki]) + sMaxLogit;
}
else
{
logitVal = sVal2[ki]; // still the raw logit
}
outRow[vocabIdx] = logitVal;
}
}
}
}
#define CASE_K_PROB(K_MAX, BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_) \
do \
{ \
topKProbStage1<T, BLOCK_SIZE_1_, BLOCKS_PER_BEAM_> \
<<<dim3(nRows* BLOCKS_PER_BEAM_, 1), BLOCK_SIZE_1_, 0, stream>>>( \
logits, tmpLogProbs, topKTmpIdBuf, topKTmpValBuf, maxTopK, topKs, vocabSize); \
topKProbStage2ForLogits<T, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_> \
<<<dim3(nRows, 1), BLOCK_SIZE_2_, K_MAX * (sizeof(int32_t) + sizeof(float)), stream>>>( \
topKTmpIdBuf, topKTmpValBuf, outputLogits, maxTopK, topKs, topPs, vocabSize); \
} while (0)
// Host launcher: allocates workspace tensors internally and dispatches the two-stage kernels.
// logits [nRows, vocabSize] – temperature-scaled input (float or half)
// outputLogits [nRows, vocabSize] – output: -inf everywhere except selected top-k-p positions
// topKs [nRows] – per-row k values (int32, on device)
// topPs [nRows] or nullptr – per-row p values (float, on device)
// maxTopK – maximum k across all rows (CPU scalar, 1–1024)
template <typename T>
void invokeTopKTopPMaskingForProbs(T const* logits, float* outputLogits, int32_t const* topKs, float const* topPs,
int32_t maxTopK, int32_t nRows, int32_t vocabSize, cudaStream_t stream)
{
constexpr int32_t BLOCKS_PER_BEAM = 8;
// Workspace buffers (allocated as CUDA device tensors via ATen).
auto opts = at::TensorOptions().dtype(torch::kFloat32).device(at::kCUDA);
auto tmpLogProbsTensor = torch::empty({nRows * vocabSize}, opts);
auto topKTmpIdBufTensor
= torch::empty({nRows * BLOCKS_PER_BEAM * maxTopK}, at::TensorOptions().dtype(torch::kInt32).device(at::kCUDA));
// topKTmpValBuf uses the same dtype as T; we allocate as float and reinterpret for half if needed.
auto topKTmpValBufTensor = torch::empty({nRows * BLOCKS_PER_BEAM * maxTopK}, opts);
T* tmpLogProbs = reinterpret_cast<T*>(tmpLogProbsTensor.data_ptr<float>());
int32_t* topKTmpIdBuf = topKTmpIdBufTensor.data_ptr<int32_t>();
T* topKTmpValBuf = reinterpret_cast<T*>(topKTmpValBufTensor.data_ptr<float>());
int32_t logMaxTopK = 0;
int32_t recursor = maxTopK - 1;
while (recursor >>= 1)
{
++logMaxTopK;
}
switch (logMaxTopK)
{
case 0:
case 1:
case 2:
case 3: // 0 < maxTopK <= 16
CASE_K_PROB(16, 128, 128, 8);
break;
case 4: // 16 < maxTopK <= 32
CASE_K_PROB(32, 256, 128, 8);
break;
case 5: // 32 < maxTopK <= 64
CASE_K_PROB(64, 256, 256, 8);
break;
case 6:
case 7:
case 8:
case 9: // 64 < maxTopK <= 1024
CASE_K_PROB(1024, 256, 256, 8);
break;
default: TLLM_CHECK_WITH_INFO(false, "topKProbMasking supports 1 <= k <= 1024 but got k=%d", maxTopK);
}
}
#undef CASE_K_PROB
namespace
{
constexpr double kGreedyTempThreshold = 1e-4;
torch::Tensor computeSoftmaxForProbOp(torch::Tensor logits)
{
TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor");
TORCH_CHECK(logits.dim() == 2, "logits must be a 2D tensor");
auto probs = logits.contiguous().to(torch::kFloat32);
auto stream = at::cuda::getCurrentCUDAStream(probs.device().index());
BiasSoftmaxParams<float> biasSoftmaxParams;
biasSoftmaxParams.logits = probs.data_ptr<float>();
biasSoftmaxParams.probs = probs.data_ptr<float>();
biasSoftmaxParams.batchSize = static_cast<SizeType32>(probs.size(0));
biasSoftmaxParams.maxBatchSize = static_cast<SizeType32>(probs.size(0));
biasSoftmaxParams.maxBeamWidth = 1;
biasSoftmaxParams.vocabSize = static_cast<SizeType32>(probs.size(1));
biasSoftmaxParams.vocabSizePadded = static_cast<SizeType32>(probs.size(1));
biasSoftmaxParams.skipSoftMax = false;
biasSoftmaxParams.batchSlotsLogits = false;
biasSoftmaxParams.checkParams();
invokeAddBiasSoftMax(biasSoftmaxParams, stream);
return probs;
}
// Fast path for top-K (and optional top-P) filtering using torch::topk instead of a
// full vocab-size sort. kMax must be provided as a CPU integer (the caller computes it
// via topK.max().item() on the Python side). When kMax == 0 or kMax >= vocabSize the
// function falls back to the original sort-based path.
//
// Key advantages over the full-sort path:
// 1. torch::topk with small kMax is O(V * log kMax) vs O(V * log V) for full sort.
// 2. The topk index tensor is [nRows, kMax] instead of [nRows, V] — much smaller.
// 3. No scatter-back of sorted indices needed; masking is done directly on logits.
// 4. For combined top-K + top-P, softmax/cumsum are computed on kMax values (not V).
torch::Tensor applyTopKTopPForProbOp(torch::Tensor logits, torch::optional<torch::Tensor> const& topK,
torch::optional<torch::Tensor> const& topP, int32_t kMax)
{
int64_t const vocabSize = logits.size(1);
// Host-only checks: the caller is expected to pass nullopt when filtering is fully
// disabled (see SpecMetadata.skip_top_k / skip_top_p). Probing the tensor contents
// via `.item<bool>()` here would force a host-device sync and break CUDA graph
// capture; the per-row `effectiveTopK` formula below already handles disabled rows.
bool const hasTopK = topK.has_value() && topK->defined();
bool const hasTopP = topP.has_value() && topP->defined();
if (!hasTopK && !hasTopP)
{
return logits;
}
torch::Tensor effectiveTopK;
if (hasTopK)
{
auto topKLong = topK->to(torch::kLong);
effectiveTopK
= torch::where(topKLong > 0, topKLong, torch::full_like(topKLong, vocabSize)).clamp_max(vocabSize);
}
// Fast path uses `topk(kMax)` which is unsafe when any row has effective top-k > kMax
// (i.e. disabled rows expand to the full vocab). Detecting this requires a tensor
// reduction + `.item<bool>()`, which is incompatible with CUDA graph capture. Only
// probe when the caller explicitly opted into the fast path via kMax > 0 (today only
// the dynamic-tree caller, which is not graph-captured).
bool hasDisabledTopKRows = false;
if (hasTopK && kMax > 0 && kMax < vocabSize)
{
auto topKLong = topK->to(torch::kLong);
hasDisabledTopKRows = topKLong.le(0).any().item<bool>();
}
if (hasTopK && !hasDisabledTopKRows && kMax > 0 && kMax < vocabSize)
{
// Fast topk path ─────────────────────────────────────────────────────────────
// topKValues/topKIdx: [nRows, kMax], values in descending order
auto [topKValues, topKIdx] = logits.topk(kMax, /*dim=*/-1, /*largest=*/true, /*sorted=*/true);
// validTopK[i, j]: True when position j falls within top-K[i] for row i
auto kArange = torch::arange(kMax, torch::TensorOptions().dtype(torch::kInt64).device(logits.device()))
.unsqueeze(0); // [1, kMax]
auto kVals = effectiveTopK.to(torch::kInt64).unsqueeze(1); // [nRows, 1]
auto validTopK = kArange < kVals; // [nRows, kMax]
// Start with everything masked; scatter will unmark the kept positions.
auto mask = torch::ones(
{logits.size(0), vocabSize}, torch::TensorOptions().dtype(torch::kBool).device(logits.device()));
if (hasTopP)
{
// Compute top-P on the kMax descending-sorted values only (much cheaper).
// Positions beyond K[i] are treated as -inf so their probability ≈ 0.
auto validTopKValues = topKValues.masked_fill(~validTopK, -std::numeric_limits<float>::infinity());
auto sortedProbs = validTopKValues.softmax(/*dim=*/-1); // [nRows, kMax]
auto cumsum = sortedProbs.cumsum(/*dim=*/-1); // [nRows, kMax]
// Mask positions where the cumulative probability *before* this token
// already reaches topP — i.e. we have enough probability mass already.
auto topPMask = (cumsum - sortedProbs) >= topP->unsqueeze(1); // [nRows, kMax]
topPMask.select(/*dim=*/1, /*index=*/0).fill_(false); // always keep the top-1 token
// combinedMask: True → mask this vocab position
// False → keep this vocab position
auto combinedMask = topPMask | (~validTopK); // [nRows, kMax]
mask.scatter_(/*dim=*/1, /*index=*/topKIdx, /*src=*/combinedMask);
}
else
{
// Top-K only: unmark the first K[i] positions (those within validTopK).
// ~validTopK is True for positions j >= K[i] → they should stay masked.
mask.scatter_(/*dim=*/1, /*index=*/topKIdx, /*src=*/(~validTopK));
}
return logits.masked_fill(mask, -std::numeric_limits<float>::infinity());
}
// Fallback: full-sort path (used for top-P only, or when kMax == 0) ────────────
auto sortResult = logits.sort(/*dim=*/-1, /*descending=*/false);
auto logitsSort = std::get<0>(sortResult);
auto logitsIdx = std::get<1>(sortResult);
if (hasTopK)
{
auto topKMask = logitsSort.size(1) - effectiveTopK;
topKMask = topKMask.clamp_min(0);
auto topKThreshold = logitsSort.gather(1, topKMask.unsqueeze(1));
auto mask = logitsSort < topKThreshold;
logitsSort.masked_fill_(mask, -std::numeric_limits<float>::infinity());
}
if (hasTopP)
{
auto probsSort = logitsSort.softmax(/*dim=*/-1);
auto probsSum = probsSort.cumsum(/*dim=*/-1, /*dtype=*/probsSort.scalar_type());
auto topPMask = probsSum <= (1.0 - topP->unsqueeze(1));
topPMask.select(/*dim=*/1, /*index=*/logitsSort.size(1) - 1).fill_(false);
logitsSort.masked_fill_(topPMask, -std::numeric_limits<float>::infinity());
}
return logitsSort.scatter(/*dim=*/-1, /*index=*/logitsIdx, /*src=*/logitsSort);
}
} // namespace
torch::Tensor computeProbsFromLogits(torch::Tensor const& logits, torch::Tensor const& temperatures,
torch::optional<torch::Tensor> const& topK, torch::optional<torch::Tensor> const& topP, bool skipTemperature,
int32_t kMax)
{
TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor");
TORCH_CHECK(temperatures.is_cuda(), "temperatures must be a CUDA tensor");
TORCH_CHECK(logits.dim() == 2, "logits must be a 2D tensor");
TORCH_CHECK(temperatures.dim() == 1, "temperatures must be a 1D tensor");
TORCH_CHECK(logits.size(0) == temperatures.size(0), "logits and temperatures size mismatch");
if (topK.has_value() && topK->defined())
{
TORCH_CHECK(topK->is_cuda(), "top_k must be a CUDA tensor");
TORCH_CHECK(topK->dim() == 1, "top_k must be a 1D tensor");
TORCH_CHECK(topK->size(0) == logits.size(0), "top_k and logits size mismatch");
}
if (topP.has_value() && topP->defined())
{
TORCH_CHECK(topP->is_cuda(), "top_p must be a CUDA tensor");
TORCH_CHECK(topP->dim() == 1, "top_p must be a 1D tensor");
TORCH_CHECK(topP->size(0) == logits.size(0), "top_p and logits size mismatch");
}
auto const isGreedy = temperatures <= kGreedyTempThreshold;
auto const safeTemperatures = torch::where(isGreedy, torch::ones_like(temperatures), temperatures);
auto scaledLogits
= (skipTemperature ? logits : logits.div(safeTemperatures.unsqueeze(1))).contiguous().to(torch::kFloat32);
int64_t const vocabSize = scaledLogits.size(1);
int64_t const nRows = scaledLogits.size(0);
// Host-only presence checks; see comment in applyTopKTopPForProbOp() for why we
// avoid probing tensor contents (would sync and break CUDA graph capture).
bool const hasTopKPresence = topK.has_value() && topK->defined();
bool const hasTopPPresence = topP.has_value() && topP->defined();
// The kernel path produces -inf for rows whose top_k value is 0, so it is only
// safe when every row has an active top_k filter. Determining that requires a
// host-device sync, so only probe when the caller has opted into the kernel
// path (kMax > 0). The kMax > 0 callers (dynamic-tree) are not graph-captured.
bool useKernelPath = false;
if (hasTopKPresence && kMax > 0 && kMax < vocabSize)
{
useKernelPath = torch::logical_and(topK->gt(0), topK->lt(vocabSize)).any().item<bool>();
}
torch::Tensor maskedLogits;
if (useKernelPath)
{
// Two-stage CUDA top-k/top-p masking (mirrors invokeBatchTopKSampling).
maskedLogits = torch::empty_like(scaledLogits);
auto topKForKernel = topK->to(torch::kInt32).contiguous();
auto topPForKernel = hasTopPPresence ? topP->to(torch::kFloat32).contiguous() : torch::Tensor();
auto stream = at::cuda::getCurrentCUDAStream(scaledLogits.device().index());
invokeTopKTopPMaskingForProbs<float>(scaledLogits.data_ptr<float>(), maskedLogits.data_ptr<float>(),
topKForKernel.data_ptr<int32_t>(), hasTopPPresence ? topPForKernel.data_ptr<float>() : nullptr, kMax,
static_cast<int32_t>(nRows), static_cast<int32_t>(vocabSize), stream);
}
else
{
// Fallback: PyTorch-based sort path (top-P only or kMax == 0).
maskedLogits = applyTopKTopPForProbOp(scaledLogits, topK, topP, kMax);
}
auto probs = computeSoftmaxForProbOp(maskedLogits);
auto argmaxIds = maskedLogits.argmax(/*dim=*/-1, /*keepdim=*/true);
auto oneHot = torch::zeros_like(probs).scatter_(1, argmaxIds, 1.0);
return torch::where(isGreedy.unsqueeze(1), oneHot, probs);
}
//! \param parentList [in] layer-wise parent indices [bs, topK*(depth-1)+1]
//! \param selectedIndex [in] resampled history buffer indices [bs, draftTokenNum-1]
//! \param treeMask [out] attention mask (which nodes each node can see)
//! \param positions [out] position id per node [bs, draftTokenNum]
//! \param retrieveIndex [out] tree node -> local index mapping [bs, draftTokenNum]
//! \param retrieveNextToken [out] first-child pointer [bs, draftTokenNum], -1=none
//! \param retrieveNextSibling [out] next-sibling pointer [bs, draftTokenNum], -1=none
//! \param topK top-K value per layer
//! \param depth max tree depth (number of draft layers)
//! \param draftTokenNum total tree nodes per batch (including root)
__global__ void buildDynamicTreeKernel(int64_t const* parentList, int64_t const* selectedIndex, int32_t* treeMask,
int32_t* positions, int32_t* retrieveIndex, int32_t* retrieveNextToken, int32_t* retrieveNextSibling,
SizeType32 topK, SizeType32 depth, SizeType32 draftTokenNum)
{
int32_t bid = blockIdx.x;
int32_t tid = threadIdx.x;
if (tid >= draftTokenNum)
{
return;
}
// treeMask layout: [batchSize, draftTokenNum, draftTokenNum] (QLEN_ONLY mode)
int32_t tokenTreeIdx = draftTokenNum * draftTokenNum * bid + draftTokenNum * tid + 1;
treeMask[tokenTreeIdx - 1] = 1; // self-attention diagonal
for (int32_t i = 0; i < draftTokenNum - 1; i++)
{
treeMask[tokenTreeIdx + i] = 0;
}
int32_t position = 0;
if (tid == 0)
{
positions[bid * draftTokenNum] = 0;
// Reverse iteration: inserting at list head produces forward sibling order
for (int32_t i = draftTokenNum - 1; i > 0; --i)
{
retrieveIndex[bid * draftTokenNum + i] = i;
int64_t parentTbIdx = selectedIndex[bid * (draftTokenNum - 1) + i - 1] / topK;
int32_t parentPosition = 0;
if (parentTbIdx > 0)
{
int64_t parentTokenIdx = parentList[bid * (topK * (depth - 1) + 1) + parentTbIdx];
for (; parentPosition < draftTokenNum; ++parentPosition)
{
if (selectedIndex[bid * (draftTokenNum - 1) + parentPosition] == parentTokenIdx)
{
++parentPosition; // +1 because position 0 is root
break;
}
}
}
if (parentPosition == draftTokenNum)
{
printf(
"WARNING: Invalid dynamic tree! Detected a token with no parent token selected. "
"Please check if the logprob has nan. The token will be ignored.\n");
continue;
}
if (retrieveNextToken[bid * draftTokenNum + parentPosition] == -1)
{
retrieveNextToken[bid * draftTokenNum + parentPosition] = i;
}
else
{
int32_t originNextToken = retrieveNextToken[bid * draftTokenNum + parentPosition];
retrieveNextToken[bid * draftTokenNum + parentPosition] = i;
retrieveNextSibling[bid * draftTokenNum + i] = originNextToken;
}
}
retrieveIndex[bid * draftTokenNum] = 0;
}
else
{
// Walk up to root, setting treeMask ancestor bits and counting depth
int32_t curPosition = tid - 1;
while (position < depth + 1)
{
position += 1;
treeMask[tokenTreeIdx + curPosition] = 1;
int64_t parentTbIdx = selectedIndex[bid * (draftTokenNum - 1) + curPosition] / topK;
if (parentTbIdx == 0)
{
break;
}
int64_t tokenIdx = parentList[bid * (topK * (depth - 1) + 1) + parentTbIdx];
for (curPosition = 0; curPosition < draftTokenNum; ++curPosition)
{
if (selectedIndex[bid * (draftTokenNum - 1) + curPosition] == tokenIdx)
{
break;
}
}
if (curPosition == draftTokenNum)
{
break;
}
}
positions[bid * draftTokenNum + tid] = position;
}
}
//! Bit-packed variant of buildDynamicTreeKernel.
//! \param numInt32PerRow int32 count per treeMask row (buffer stride; >= ceil(draftTokenNum/32) if padded)
__global__ void buildDynamicTreeKernelPacked(int64_t const* parentList, int64_t const* selectedIndex, int32_t* treeMask,
int32_t* positions, int32_t* retrieveIndex, int32_t* retrieveNextToken, int32_t* retrieveNextSibling,
SizeType32 topK, SizeType32 depth, SizeType32 draftTokenNum, SizeType32 numInt32PerRow)
{
int32_t bid = blockIdx.x;
int32_t tid = threadIdx.x;
if (tid >= draftTokenNum)
{
return;
}
int32_t rowBaseIdx = (bid * draftTokenNum + tid) * numInt32PerRow;
treeMask[rowBaseIdx] = 1; // bit 0 = root, always visible
int32_t position = 0;
if (tid == 0)
{
positions[bid * draftTokenNum] = 0;
for (int32_t i = draftTokenNum - 1; i > 0; --i)
{
retrieveIndex[bid * draftTokenNum + i] = i;
int64_t parentTbIdx = selectedIndex[bid * (draftTokenNum - 1) + i - 1] / topK;
int32_t parentPosition = 0;
if (parentTbIdx > 0)
{
int64_t parentTokenIdx = parentList[bid * (topK * (depth - 1) + 1) + parentTbIdx];
for (; parentPosition < draftTokenNum; ++parentPosition)
{
if (selectedIndex[bid * (draftTokenNum - 1) + parentPosition] == parentTokenIdx)
{
++parentPosition;
break;
}
}
}
if (parentPosition == draftTokenNum)
{
printf("WARNING: Invalid dynamic tree! Detected a token with no parent token selected.\n");
continue;
}
if (retrieveNextToken[bid * draftTokenNum + parentPosition] == -1)
{
retrieveNextToken[bid * draftTokenNum + parentPosition] = i;
}
else
{
int32_t originNextToken = retrieveNextToken[bid * draftTokenNum + parentPosition];
retrieveNextToken[bid * draftTokenNum + parentPosition] = i;
retrieveNextSibling[bid * draftTokenNum + i] = originNextToken;
}
}
retrieveIndex[bid * draftTokenNum] = 0;
}
else
{
int32_t curPosition = tid - 1;
while (position < depth + 1)
{
position += 1;
int32_t bitPosition = curPosition + 1; // +1 because bit 0 is root
int32_t int32Idx = bitPosition / 32;
int32_t bitIdx = bitPosition % 32;
if (int32Idx < numInt32PerRow)
{
atomicOr(&treeMask[rowBaseIdx + int32Idx], 1 << bitIdx);
}
int64_t parentTbIdx = selectedIndex[bid * (draftTokenNum - 1) + curPosition] / topK;
if (parentTbIdx == 0)
{
break;
}
int64_t tokenIdx = parentList[bid * (topK * (depth - 1) + 1) + parentTbIdx];
for (curPosition = 0; curPosition < draftTokenNum; ++curPosition)
{
if (selectedIndex[bid * (draftTokenNum - 1) + curPosition] == tokenIdx)
{
break;
}
}
if (curPosition == draftTokenNum)
{
break;
}
}
positions[bid * draftTokenNum + tid] = position;
}
}
void invokeBuildDynamicTree(int64_t const* parentList, int64_t const* selectedIndex, void* treeMask, int32_t* positions,
int32_t* retrieveIndex, int32_t* retrieveNextToken, int32_t* retrieveNextSibling, SizeType32 batchSize,
SizeType32 topK, SizeType32 depth, SizeType32 numDraftTokens, TreeMaskMode treeMaskMode, cudaStream_t stream,
SizeType32 numInt32PerRow)
{
dim3 grid(batchSize);
dim3 block(numDraftTokens);
if (treeMaskMode == TreeMaskMode::QLEN_ONLY_BITPACKING)
{
TLLM_CHECK_WITH_INFO(
numInt32PerRow > 0, "numInt32PerRow must be the packed treeMask row stride in int32s (from buffer shape).");
buildDynamicTreeKernelPacked<<<grid, block, 0, stream>>>(parentList, selectedIndex,
static_cast<int32_t*>(treeMask), positions, retrieveIndex, retrieveNextToken, retrieveNextSibling, topK,
depth, numDraftTokens, numInt32PerRow);
}
else
{
buildDynamicTreeKernel<<<grid, block, 0, stream>>>(parentList, selectedIndex, static_cast<int32_t*>(treeMask),
positions, retrieveIndex, retrieveNextToken, retrieveNextSibling, topK, depth, numDraftTokens);
}
sync_check_cuda_error(stream);
}
//! retrievePacked layout [bs, numDraftTokens, 3] int32 row-major:
//! [b,n,0]=retrieveIndex, [b,n,1]=retrieveNextToken, [b,n,2]=retrieveNextSibling
__global__ void verifyDynamicTreeGreedyPackedKernel(int32_t* acceptIndex, int32_t* acceptTokenNum, int32_t* acceptToken,
int32_t const* candidates, int32_t const* retrievePacked, int32_t const* targetPredict, bool const* treeValid,
uint32_t numSpeculativeTokens, uint32_t numDraftTokens)
{
uint32_t bx = blockIdx.x;
uint32_t batchOffset = bx * numDraftTokens;
int32_t const* row = retrievePacked + static_cast<size_t>(bx) * numDraftTokens * 3;
if (treeValid != nullptr && !treeValid[bx])
{
acceptTokenNum[bx] = 0;
acceptIndex[bx * numSpeculativeTokens] = 0;
acceptToken[bx * numSpeculativeTokens] = targetPredict[batchOffset];
return;
}
int32_t lastAcceptedLocalIdx = row[0];
acceptIndex[bx * numSpeculativeTokens] = lastAcceptedLocalIdx;
uint32_t numAcceptedTokens = 0;
int32_t curIndex = 0;
acceptToken[bx * numSpeculativeTokens] = targetPredict[batchOffset + lastAcceptedLocalIdx];
for (uint32_t j = 1; j < numSpeculativeTokens; ++j)
{
curIndex = row[static_cast<size_t>(curIndex) * 3 + 1];
while (curIndex >= 0 && static_cast<uint32_t>(curIndex) < numDraftTokens)
{
int32_t draftLocalIdx = row[static_cast<size_t>(curIndex) * 3];
int32_t draftTokenId = candidates[batchOffset + curIndex];
int32_t targetTokenId = targetPredict[batchOffset + lastAcceptedLocalIdx];
if (draftTokenId == targetTokenId)
{
++numAcceptedTokens;
acceptIndex[bx * numSpeculativeTokens + numAcceptedTokens] = draftLocalIdx;
acceptToken[bx * numSpeculativeTokens + numAcceptedTokens] = targetPredict[batchOffset + draftLocalIdx];
lastAcceptedLocalIdx = draftLocalIdx;
break;
}
curIndex = row[static_cast<size_t>(curIndex) * 3 + 2];
}
if (curIndex < 0 || static_cast<uint32_t>(curIndex) >= numDraftTokens)
{
break;
}
}
acceptTokenNum[bx] = numAcceptedTokens;
}
void invokeVerifyDynamicTreeGreedyPacked(int32_t* acceptIndex, int32_t* acceptTokenNum, int32_t* acceptToken,
int32_t const* candidates, int32_t const* retrievePacked, int32_t const* targetPredict, bool const* treeValid,
SizeType32 batchSize, SizeType32 numDraftTokens, SizeType32 numSpecStep, cudaStream_t stream)
{
dim3 grid(batchSize);
dim3 block(1);
verifyDynamicTreeGreedyPackedKernel<<<grid, block, 0, stream>>>(acceptIndex, acceptTokenNum, acceptToken,
candidates, retrievePacked, targetPredict, treeValid, numSpecStep, numDraftTokens);
sync_check_cuda_error(stream);
}
// ------------------------------------------------------------
// Background: Speculative Sampling Theory
// ------------------------------------------------------------
//
// Goal: reuse draft model samples to speed up generation while keeping the
// final output distribution strictly equal to the target distribution q.
//
// For a given token x:
// p(x) = draft_probs[x] (draft model probability)
// q(x) = target_probs[x] (target model probability)
//
// Step 1 - The draft model proposes token x sampled from p.
// Step 2 - Accept x with probability min(1, q(x)/p(x)).
// Equivalently: accept when u * p(x) < q(x), where u ~ Uniform(0,1).
//
// Why does this work?
// x is proposed with probability p(x) and then accepted with probability
// min(1, q(x)/p(x)), so its total probability mass reaching the output is:
// p(x) * min(1, q(x)/p(x)) = min(p(x), q(x))
//
// This covers only the min(p, q) portion of the target mass.
// The remaining portion q - min(p, q) = relu(q - p) is not yet covered.
//
// Therefore, if the draft token is rejected, we must resample from the
// residual distribution relu(q - p) (normalised) to fill the gap and
// restore the full target distribution.
//
// Example:
// p = [0.6, 0.3, 0.1] tokens [A, B, C]
// q = [0.2, 0.5, 0.3]
//
// Accept probabilities:
// A: min(1, 0.2/0.6) = 1/3 B: min(1, 0.5/0.3) = 1 C: min(1, 0.3/0.1) = 1
//
// Case 1 - draft proposes A (prob 0.6):
// Accept (1/3): contributes 0.6 * 1/3 = 0.2 to output A.
// Reject (2/3): total rejected mass = 0.6 * 2/3 = 0.4.
// relu(q-p) = [0, 0.2, 0.2] -> normalised [0, 0.5, 0.5]
// contributes 0.4*0.5 = 0.2 to B and 0.4*0.5 = 0.2 to C.
// Case 2 - draft proposes B (prob 0.3): always accepted -> 0.3 to B.
// Case 3 - draft proposes C (prob 0.1): always accepted -> 0.1 to C.
//
// Final output distribution:
// A = 0.2, B = 0.3 + 0.2 = 0.5, C = 0.1 + 0.2 = 0.3 -> exactly q.
//
// Tree extension:
// The same logic applies depth-by-depth along the draft tree. At each
// depth the kernel tries siblings in score order; the first accepted
// sibling extends the current path. If every sibling at a depth is
// rejected the kernel samples a correction token from relu(q-p) and
// terminates traversal for that request.
// ------------------------------------------------------------
#include <curand_kernel.h>
/// Map curand_uniform (0, 1] to [0, 1) so that cumulative-sum sampling
/// never falls off the end of a probability distribution due to float32
/// rounding. 1.0 is mapped to 0.0 (probability mass epsilon).
__device__ __forceinline__ float curand_uniform_open_right(curandStatePhilox4_32_10_t& state)
{
float u = curand_uniform(&state); // (0, 1]
return u < 1.0f ? u : 0.0f; // [0, 1)
}
__device__ int64_t sampleFromDistribution(curandStatePhilox4_32_10_t& state, float const* probs, uint32_t vocabSize)
{
float r = curand_uniform_open_right(state); // [0, 1)
float cumsum = 0.0f;
int64_t sampledTok = 0;
for (uint32_t v = 0; v < vocabSize; ++v)
{
cumsum += probs[v];
if (r < cumsum)
{
sampledTok = static_cast<int64_t>(v);
return sampledTok;
}
}
// Float32 cumsum may not reach 1.0 for large vocabs.
// Fall back to the last token with positive probability.
for (int64_t v = static_cast<int64_t>(vocabSize) - 1; v >= 0; --v)
{
if (probs[v] > 0.0f)
{
return v;
}
}
return static_cast<int64_t>(vocabSize) - 1;
}
__device__ int64_t sampleFromIndexedDistribution(curandStatePhilox4_32_10_t& state, float const* probs,
int32_t const* supportIndices, uint32_t supportSize, uint32_t vocabSize)
{
float r = curand_uniform_open_right(state); // [0, 1)
float cumsum = 0.0f;
int64_t sampledTok = static_cast<int64_t>(vocabSize) - 1;
for (uint32_t i = 0; i < supportSize; ++i)
{
int32_t const tok = supportIndices[i];
cumsum += probs[tok];
if (r < cumsum)
{
return static_cast<int64_t>(tok);
}
}
// Fallback: last support token with positive probability.
for (int64_t i = static_cast<int64_t>(supportSize) - 1; i >= 0; --i)
{
if (probs[supportIndices[i]] > 0.0f)
{
return static_cast<int64_t>(supportIndices[i]);
}
}
return sampledTok;
}
struct MinInt32Op
{
__device__ __forceinline__ int32_t operator()(int32_t a, int32_t b) const
{
return a < b ? a : b;
}
};
struct MaxInt32Op
{
__device__ __forceinline__ int32_t operator()(int32_t a, int32_t b) const
{
return a > b ? a : b;
}
};
// ---------------------------------------------------------------------------
// Target-only dynamic tree rejection sampling kernel
//
// Acceptance algorithm:
// - For each depth, accumulate cumulative target probability across siblings.
// - Accept the first sibling whose cumulative prob exceeds the random coin.
// - When all siblings are rejected, sample a correction token from the
// residual target mass (target prob for tokens NOT tried as siblings).
//
// This matches the mathematical guarantee of speculative sampling with the
// draft treated as a uniform empirical prior over the K candidate siblings.
// ---------------------------------------------------------------------------
// Maximum siblings we track per level for the correction step.
// Matches the maximum supported K branching factor (dynamic_tree_max_topK).
constexpr int32_t kMaxTriedPerLevel = 32;
//! \param acceptIndex output [batchSize, numSpecStep] int64 — tree positions of accepted tokens.
//! \param acceptTokenNum output [batchSize] int64 — # accepted draft tokens (excl. root).
//! \param acceptToken output [batchSize, numSpecStep] int64 — accepted/correction token ids.
//! \param candidates [batchSize, numDraftTokens] int64; col 0 = root (target sample).
//! \param targetProbs [batchSize, numDraftTokens, vocabSize] float32; full-vocab target probs.
//! \param retrieveNextToken [batchSize, numDraftTokens] int32 first-child pointer, -1=none.
//! \param retrieveNextSibling [batchSize, numDraftTokens] int32 next-sibling pointer, -1=none.
//! \param treeValid [batchSize] bool; false means no valid tree exists for this request.
//! \param batchSize batch size.
//! \param numSpeculativeTokens second dim of acceptIndex/acceptToken (= max_draft_len + 1).
//! \param numDraftTokens total tree nodes per request (including root).
//! \param vocabSize vocabulary size.
//! \param seed [1] int64 on GPU. Philox RNG seed.
//! \param offset [1] int64 on GPU. Philox RNG offset.
template <int32_t BLOCK_SIZE>
__global__ void verifyDynamicTreeRejectionKernel(int64_t* acceptIndex, int64_t* acceptTokenNum, int64_t* acceptToken,
int64_t const* draftTokens, float const* targetProbs, int32_t const* retrieveNextToken,
int32_t const* retrieveNextSibling, bool const* treeValid, uint32_t batchSize, uint32_t numSpeculativeTokens,
uint32_t numDraftTokens, uint32_t vocabSize, int64_t const* seed, int64_t const* offset)
{
uint32_t const bx = blockIdx.x;
int32_t const tid = static_cast<int32_t>(threadIdx.x);