diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu b/cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu index 721dbfb95621..3d167660e521 100644 --- a/cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu +++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu @@ -308,11 +308,6 @@ namespace { constexpr double kGreedyTempThreshold = 1e-4; -bool isTopPEnabled(torch::optional const& topP) -{ - return topP.has_value() && topP->defined() && topP->lt(1.0).any().item(); -} - torch::Tensor computeSoftmaxForProbOp(torch::Tensor logits) { TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor"); @@ -337,110 +332,6 @@ torch::Tensor computeSoftmaxForProbOp(torch::Tensor logits) return probs; } -struct DraftProbMaxFloatOp -{ - __device__ __forceinline__ float operator()(float a, float b) const - { - return a > b ? a : b; - } -}; - -template -__global__ void computeDraftProbsSkipAllKernel(float const* draftLogits, int32_t const* d2t, float* draftProbs, - int32_t nRows, int32_t draftVocabSize, int32_t targetVocabSize) -{ - int32_t const rowId = static_cast(blockIdx.x); - int32_t const tid = static_cast(threadIdx.x); - if (rowId >= nRows) - { - return; - } - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tempStorage; - __shared__ float sMaxLogit; - __shared__ float sExpSum; - - float const* rowLogits = draftLogits + static_cast(rowId) * draftVocabSize; - float* rowProbs = draftProbs + static_cast(rowId) * targetVocabSize; - - for (int32_t v = tid; v < targetVocabSize; v += BLOCK_SIZE) - { - rowProbs[v] = 0.0f; - } - - float localMax = -FLT_MAX; - for (int32_t v = tid; v < draftVocabSize; v += BLOCK_SIZE) - { - localMax = fmaxf(localMax, rowLogits[v]); - } - - float const blockMax = BlockReduce(tempStorage).Reduce(localMax, DraftProbMaxFloatOp{}); - if (tid == 0) - { - sMaxLogit = blockMax; - } - __syncthreads(); - - float localSum = 0.0f; - for (int32_t v = tid; v < draftVocabSize; v += BLOCK_SIZE) - { - localSum += __expf(rowLogits[v] - sMaxLogit); - } - - float const blockSum = BlockReduce(tempStorage).Sum(localSum); - if (tid == 0) - { - constexpr float kFloatSoftmaxEpsilon = 1e-6f; - sExpSum = blockSum + kFloatSoftmaxEpsilon; - } - __syncthreads(); - - for (int32_t v = tid; v < draftVocabSize; v += BLOCK_SIZE) - { - int64_t const targetIdx - = d2t != nullptr ? static_cast(v) + static_cast(d2t[v]) : static_cast(v); - if (targetIdx >= 0 && targetIdx < targetVocabSize) - { - rowProbs[targetIdx] = __expf(rowLogits[v] - sMaxLogit) / sExpSum; - } - } -} - -torch::Tensor computeDraftProbsSkipAllForDynamicTreeRejection(torch::Tensor const& draftLogits, int64_t batchSize, - SizeType32 const numDraftProbRows, SizeType32 const targetVocabSize, torch::optional const& d2t) -{ - auto const draftVocabSize = draftLogits.size(1); - bool const hasD2T = d2t.has_value() && d2t->defined(); - - auto draftLogitsFloat = draftLogits.contiguous().to(torch::kFloat32); - if (!hasD2T && draftVocabSize == targetVocabSize) - { - return computeSoftmaxForProbOp(draftLogitsFloat).reshape({batchSize, numDraftProbRows, targetVocabSize}); - } - - auto fullDraftProbs = torch::empty({draftLogitsFloat.size(0), targetVocabSize}, - torch::TensorOptions().dtype(torch::kFloat32).device(draftLogitsFloat.device())); - torch::Tensor d2tInt; - int32_t const* d2tPtr = nullptr; - if (hasD2T) - { - d2tInt = d2t->contiguous().to(torch::kInt32); - d2tPtr = d2tInt.data_ptr(); - } - - constexpr int32_t kBlockSize = 1024; - dim3 grid(draftLogitsFloat.size(0)); - dim3 block(kBlockSize); - auto stream = at::cuda::getCurrentCUDAStream(draftLogitsFloat.device().index()); - computeDraftProbsSkipAllKernel<<>>(draftLogitsFloat.data_ptr(), d2tPtr, - fullDraftProbs.data_ptr(), static_cast(draftLogitsFloat.size(0)), - static_cast(draftVocabSize), static_cast(targetVocabSize)); - sync_check_cuda_error(stream); - - return fullDraftProbs.reshape({batchSize, numDraftProbRows, targetVocabSize}); -} - // Fast path for top-K (and optional top-P) filtering using torch::topk instead of a // full vocab-size sort. kMax must be provided as a CPU integer (the caller computes it // via topK.max().item() on the Python side). When kMax == 0 or kMax >= vocabSize the @@ -626,252 +517,6 @@ torch::Tensor computeProbsFromLogits(torch::Tensor const& logits, torch::Tensor return torch::where(isGreedy.unsqueeze(1), oneHot, probs); } -torch::Tensor computeDraftProbsForDynamicTreeRejection(torch::Tensor const& draftLogits, - torch::Tensor const& temperatures, SizeType32 const numDraftProbRows, torch::optional const& topK, - torch::optional const& topP, SizeType32 const targetVocabSize, bool skipTemperature, - torch::optional const& d2t, SizeType32 const kMax, bool skipAllSamplingParams) -{ - TORCH_CHECK(draftLogits.is_cuda(), "draftLogits must be a CUDA tensor"); - TORCH_CHECK(temperatures.is_cuda(), "temperatures must be a CUDA tensor"); - TORCH_CHECK(draftLogits.dim() == 2, "draftLogits must be a 2D tensor"); - TORCH_CHECK(temperatures.dim() == 1, "temperatures must be a 1D tensor"); - TORCH_CHECK(numDraftProbRows > 0, "numDraftProbRows must be positive"); - - auto const batchSize = temperatures.size(0); - auto const draftVocabSize = draftLogits.size(1); - - TORCH_CHECK(batchSize > 0, "batchSize must be positive"); - TORCH_CHECK( - draftLogits.size(0) == batchSize * numDraftProbRows, "draftLogits row count does not match numDraftProbRows"); - TORCH_CHECK(targetVocabSize >= draftVocabSize, "targetVocabSize must be >= draft vocab size"); - - if (topK.has_value() && topK->defined()) - { - TORCH_CHECK(topK->is_cuda(), "top_k must be a CUDA tensor"); - TORCH_CHECK(topK->dim() == 1, "top_k must be a 1D tensor"); - TORCH_CHECK(topK->size(0) == batchSize, "top_k size mismatch"); - } - if (topP.has_value() && topP->defined()) - { - TORCH_CHECK(topP->is_cuda(), "top_p must be a CUDA tensor"); - TORCH_CHECK(topP->dim() == 1, "top_p must be a 1D tensor"); - TORCH_CHECK(topP->size(0) == batchSize, "top_p size mismatch"); - } - if (d2t.has_value() && d2t->defined()) - { - TORCH_CHECK(d2t->is_cuda(), "d2t must be a CUDA tensor"); - TORCH_CHECK(d2t->dim() == 1, "d2t must be a 1D tensor"); - TORCH_CHECK(d2t->size(0) >= draftVocabSize, "d2t size mismatch"); - } - - if (skipAllSamplingParams) - { - return computeDraftProbsSkipAllForDynamicTreeRejection( - draftLogits, batchSize, numDraftProbRows, targetVocabSize, d2t); - } - - auto draftTemps = temperatures.repeat_interleave(numDraftProbRows); - auto draftTopK = topK.has_value() && topK->defined() - ? torch::optional(topK->repeat_interleave(numDraftProbRows)) - : torch::optional(); - auto draftTopP = isTopPEnabled(topP) ? torch::optional(topP->repeat_interleave(numDraftProbRows)) - : torch::optional(); - - auto draftProbs = computeProbsFromLogits(draftLogits, draftTemps, draftTopK, draftTopP, skipTemperature, kMax) - .reshape({batchSize, numDraftProbRows, draftVocabSize}); - - if (draftVocabSize == targetVocabSize) - { - return draftProbs; - } - - auto fullDraftProbs = torch::zeros({batchSize, numDraftProbRows, targetVocabSize}, - torch::TensorOptions().dtype(torch::kFloat32).device(draftProbs.device())); - if (d2t.has_value() && d2t->defined()) - { - auto srcIdx - = torch::arange(draftVocabSize, torch::TensorOptions().dtype(torch::kInt64).device(draftProbs.device())); - auto targetIdx = srcIdx + d2t->slice(0, 0, draftVocabSize).to(torch::kInt64); - auto expandedTargetIdx - = targetIdx.view({1, 1, draftVocabSize}).expand({batchSize, numDraftProbRows, draftVocabSize}); - fullDraftProbs.scatter_(2, expandedTargetIdx, draftProbs); - } - else - { - fullDraftProbs.slice(/*dim=*/2, /*start=*/0, /*end=*/draftVocabSize).copy_(draftProbs); - } - - return fullDraftProbs; -} - -std::tuple computeTargetProbsForDynamicTreeRejection( - torch::Tensor const& targetLogits, torch::Tensor const& temperatures, SizeType32 const numDraftTokens, - torch::optional const& topK, torch::optional const& topP, bool skipTemperature, - SizeType32 const kMax, bool skipAllSamplingParams) -{ - TORCH_CHECK(targetLogits.is_cuda(), "targetLogits must be a CUDA tensor"); - TORCH_CHECK(temperatures.is_cuda(), "temperatures must be a CUDA tensor"); - TORCH_CHECK(targetLogits.dim() == 2, "targetLogits must be a 2D tensor"); - TORCH_CHECK(temperatures.dim() == 1, "temperatures must be a 1D tensor"); - TORCH_CHECK(numDraftTokens > 1, "numDraftTokens must be greater than 1"); - - auto const batchSize = temperatures.size(0); - auto const targetVocabSize = targetLogits.size(1); - auto const nRows = batchSize * numDraftTokens; - - TORCH_CHECK(batchSize > 0, "batchSize must be positive"); - TORCH_CHECK( - targetLogits.size(0) == batchSize * numDraftTokens, "targetLogits row count does not match numDraftTokens"); - - if (topK.has_value() && topK->defined()) - { - TORCH_CHECK(topK->is_cuda(), "top_k must be a CUDA tensor"); - TORCH_CHECK(topK->dim() == 1, "top_k must be a 1D tensor"); - TORCH_CHECK(topK->size(0) == batchSize, "top_k size mismatch"); - } - if (topP.has_value() && topP->defined()) - { - TORCH_CHECK(topP->is_cuda(), "top_p must be a CUDA tensor"); - TORCH_CHECK(topP->dim() == 1, "top_p must be a 1D tensor"); - TORCH_CHECK(topP->size(0) == batchSize, "top_p size mismatch"); - } - - if (skipAllSamplingParams) - { - auto targetSupportIndices - = torch::empty({0}, torch::TensorOptions().dtype(torch::kInt32).device(targetLogits.device())); - auto targetSupportLengths - = torch::empty({0}, torch::TensorOptions().dtype(torch::kInt32).device(targetLogits.device())); - auto targetProbs = computeSoftmaxForProbOp(targetLogits); - return std::make_tuple(targetProbs.reshape({batchSize, numDraftTokens, targetVocabSize}), targetSupportIndices, - targetSupportLengths); - } - - auto targetTemps = temperatures.repeat_interleave(numDraftTokens); - auto targetTopK = topK.has_value() && topK->defined() - ? torch::optional(topK->repeat_interleave(numDraftTokens)) - : torch::optional(); - auto targetTopP = isTopPEnabled(topP) ? torch::optional(topP->repeat_interleave(numDraftTokens)) - : torch::optional(); - - bool const hasTopK = targetTopK.has_value() && targetTopK->defined(); - bool const hasTopP = isTopPEnabled(targetTopP); - bool const hasFiltering = hasTopK || hasTopP; - torch::Tensor effectiveTargetTopK; - bool hasDisabledTopKRows = false; - - auto const isGreedy = targetTemps <= kGreedyTempThreshold; - auto const safeTargetTemps = torch::where(isGreedy, torch::ones_like(targetTemps), targetTemps); - auto scaledTargetLogits = (skipTemperature ? targetLogits : targetLogits.div(safeTargetTemps.unsqueeze(1))) - .contiguous() - .to(torch::kFloat32); - - if (hasTopK) - { - auto targetTopKLong = targetTopK->to(torch::kLong); - effectiveTargetTopK - = torch::where(targetTopKLong > 0, targetTopKLong, torch::full_like(targetTopKLong, targetVocabSize)) - .clamp_max(targetVocabSize); - hasDisabledTopKRows = targetTopKLong.le(0).any().item(); - } - - torch::Tensor maskedTargetLogits; - torch::Tensor targetSupportIndices; - torch::Tensor targetSupportLengths; - - if (!hasFiltering) - { - // No filtering: use full-vocab probs; sparse support is not applicable. - maskedTargetLogits = scaledTargetLogits; - targetSupportIndices - = torch::empty({0}, torch::TensorOptions().dtype(torch::kInt32).device(targetLogits.device())); - targetSupportLengths - = torch::empty({0}, torch::TensorOptions().dtype(torch::kInt32).device(targetLogits.device())); - } - else if (hasTopK && !hasDisabledTopKRows && kMax > 0 && static_cast(kMax) < targetVocabSize) - { - // Fast two-stage CUDA path for masked logits. - maskedTargetLogits = torch::empty_like(scaledTargetLogits); - auto topKForKernel = effectiveTargetTopK.to(torch::kInt32).contiguous(); - auto topPForKernel = hasTopP ? targetTopP->to(torch::kFloat32).contiguous() : torch::Tensor(); - auto stream = at::cuda::getCurrentCUDAStream(scaledTargetLogits.device().index()); - invokeTopKTopPMaskingForProbs(scaledTargetLogits.data_ptr(), maskedTargetLogits.data_ptr(), - topKForKernel.data_ptr(), hasTopP ? topPForKernel.data_ptr() : nullptr, kMax, - static_cast(nRows), static_cast(targetVocabSize), stream); - - // Extract support indices: the finite positions after masking (at most kMax per row). - auto [topKVals, topKIdx] = maskedTargetLogits.topk(kMax, /*dim=*/-1, /*largest=*/true, /*sorted=*/true); - auto validMask = topKVals.isfinite(); // [nRows, kMax] - auto supportLengthsLong = validMask.sum(/*dim=*/-1, /*keepdim=*/false, torch::kLong); // [nRows] - auto supportIndicesRaw - = torch::where(validMask, topKIdx.to(torch::kInt32), torch::full_like(topKIdx.to(torch::kInt32), -1)); - - targetSupportIndices = supportIndicesRaw.reshape({batchSize, numDraftTokens, kMax}); - targetSupportLengths = supportLengthsLong.to(torch::kInt32).reshape({batchSize, numDraftTokens}); - } - else - { - // Sort-based fallback: top-P only, or kMax == 0 / kMax >= vocabSize. - auto sortResult = scaledTargetLogits.sort(/*dim=*/-1, /*descending=*/false); - auto logitsSort = std::get<0>(sortResult); - auto logitsIdx = std::get<1>(sortResult); - - if (hasTopK) - { - auto topKMask = logitsSort.size(1) - effectiveTargetTopK; - topKMask = topKMask.clamp_min(0); - auto topKThreshold = logitsSort.gather(1, topKMask.unsqueeze(1)); - auto mask = logitsSort < topKThreshold; - logitsSort.masked_fill_(mask, -std::numeric_limits::infinity()); - } - - if (hasTopP) - { - auto probsSort = logitsSort.softmax(/*dim=*/-1); - auto probsSum = probsSort.cumsum(/*dim=*/-1, /*dtype=*/probsSort.scalar_type()); - auto topPMask = probsSum <= (1.0 - targetTopP->unsqueeze(1)); - topPMask.select(/*dim=*/1, /*index=*/logitsSort.size(1) - 1).fill_(false); - logitsSort.masked_fill_(topPMask, -std::numeric_limits::infinity()); - } - - maskedTargetLogits = logitsSort.scatter(/*dim=*/-1, /*index=*/logitsIdx, /*src=*/logitsSort); - - // Compact support indices: finite values are at the END of the ascending-sorted logitsSort. - auto supportLengthsLong - = logitsSort.isfinite().sum(/*dim=*/-1, /*keepdim=*/false, /*dtype=*/torch::kLong); // [nRows] - auto supportLengths1D = supportLengthsLong.to(torch::kInt32); - - int64_t maxSupportSize = targetVocabSize; - if (hasTopK && effectiveTargetTopK.defined()) - { - maxSupportSize = std::min(targetVocabSize, effectiveTargetTopK.max().item()); - } - - auto compactPositions - = torch::arange(maxSupportSize, torch::TensorOptions().dtype(torch::kLong).device(targetLogits.device())) - .unsqueeze(0) - .expand({nRows, maxSupportSize}); - auto supportStart = targetVocabSize - supportLengthsLong.unsqueeze(1); - auto gatherPositions = (supportStart + compactPositions).clamp_max(targetVocabSize - 1); - auto gatheredSupportIndices = logitsIdx.gather(1, gatherPositions); - auto validMask = compactPositions < supportLengthsLong.unsqueeze(1); - auto invalidFill = torch::full_like(gatheredSupportIndices, -1L); - targetSupportIndices = torch::where(validMask, gatheredSupportIndices, invalidFill) - .to(torch::kInt32) - .reshape({batchSize, numDraftTokens, maxSupportSize}); - targetSupportLengths = supportLengths1D.reshape({batchSize, numDraftTokens}); - } - - auto targetProbs = computeSoftmaxForProbOp(maskedTargetLogits); - - auto argmaxIds = maskedTargetLogits.argmax(/*dim=*/-1, /*keepdim=*/true); - auto oneHot = torch::zeros_like(targetProbs).scatter_(1, argmaxIds, 1.0); - targetProbs = torch::where(isGreedy.unsqueeze(1), oneHot, targetProbs); - - return std::make_tuple( - targetProbs.reshape({batchSize, numDraftTokens, targetVocabSize}), targetSupportIndices, targetSupportLengths); -} - //! \param parentList [in] layer-wise parent indices [bs, topK*(depth-1)+1] //! \param selectedIndex [in] resampled history buffer indices [bs, draftTokenNum-1] //! \param treeMask [out] attention mask (which nodes each node can see) @@ -1110,49 +755,6 @@ void invokeBuildDynamicTree(int64_t const* parentList, int64_t const* selectedIn sync_check_cuda_error(stream); } -__global__ void buildDraftProbIndicesKernel( - int64_t const* topkScoreIndices, int32_t* draftProbIndices, SizeType32 topK, SizeType32 numDraftTokens) -{ - int32_t const batchIdx = blockIdx.x; - int32_t const tokenIdx = threadIdx.x; - - if (tokenIdx > numDraftTokens) - { - return; - } - - int32_t* draftProbIndicesRow = draftProbIndices + batchIdx * (numDraftTokens + 1); - - if (tokenIdx == 0) - { - draftProbIndicesRow[0] = 0; - return; - } - - int64_t const histIdx = topkScoreIndices[batchIdx * numDraftTokens + (tokenIdx - 1)]; - int32_t draftProbRow = 0; - - if (histIdx >= topK) - { - int64_t const relative = histIdx - topK; - int64_t const depthBucket = relative / (topK * topK); - int64_t const parentK = (relative % (topK * topK)) / topK; - draftProbRow = static_cast(1 + depthBucket * topK + parentK); - } - - draftProbIndicesRow[tokenIdx] = draftProbRow; -} - -void invokeBuildDraftProbIndices(int64_t const* topkScoreIndices, int32_t* draftProbIndices, SizeType32 batchSize, - SizeType32 topK, SizeType32 numDraftTokens, cudaStream_t stream) -{ - dim3 const grid(batchSize); - dim3 const block(numDraftTokens + 1); - - buildDraftProbIndicesKernel<<>>(topkScoreIndices, draftProbIndices, topK, numDraftTokens); - sync_check_cuda_error(stream); -} - //! retrievePacked layout [bs, numDraftTokens, 3] int32 row-major: //! [b,n,0]=retrieveIndex, [b,n,1]=retrieveNextToken, [b,n,2]=retrieveNextSibling __global__ void verifyDynamicTreeGreedyPackedKernel(int32_t* acceptIndex, int32_t* acceptTokenNum, int32_t* acceptToken, @@ -1357,53 +959,44 @@ struct MaxInt32Op } }; -struct MaxFloatOp -{ - __device__ __forceinline__ float operator()(float a, float b) const - { - return a > b ? a : b; - } -}; - -struct SoftmaxStats -{ - float maxVal; - float sumVal; - int32_t argmax; -}; +// --------------------------------------------------------------------------- +// Target-only dynamic tree rejection sampling kernel +// +// Acceptance algorithm: +// - For each depth, accumulate cumulative target probability across siblings. +// - Accept the first sibling whose cumulative prob exceeds the random coin. +// - When all siblings are rejected, sample a correction token from the +// residual target mass (target prob for tokens NOT tried as siblings). +// +// This matches the mathematical guarantee of speculative sampling with the +// draft treated as a uniform empirical prior over the K candidate siblings. +// --------------------------------------------------------------------------- -//! \param acceptIndex [out] accepted path as tree positions [bs, numSpecStep]. int64. -//! \param acceptTokenNum [out] number of accepted draft tokens (excl. root) [bs]. int64. -//! \param acceptToken [out] emitted token ids [bs, numSpecStep]. int64. -//! \param candidates [in] candidate token ids [bs, numDraftTokens]; col 0 = root. int64. -//! \param draftProbs [in] unique draft probs [bs, numDraftProbRows, vocabSize]. float32. -//! \param targetProbs [in] target probs [bs, numDraftTokens, vocabSize]; index 0 = root. float32. -//! \param targetSupportIndices [in] compact target support per tree position -//! [bs, numDraftTokens, maxTargetSupportSize]. int32, or nullptr. -//! \param targetSupportLengths [in] support length per tree position [bs, numDraftTokens]. int32, or nullptr. -//! \param draftProbIndices [in] tree position -> draftProbs row [bs, numDraftTokens], root unused. int32. -//! \param retrieveNextToken [in] first-child pointer [bs, numDraftTokens], -1=none. int32. -//! \param retrieveNextSibling [in] next-sibling pointer [bs, numDraftTokens], -1=none. int32. -//! \param treeValid [in] per-request tree validity flag [bs]. bool. -//! \param batchSize batch size. -//! \param numDraftProbRows unique draft-prob rows per request. -//! \param maxTargetSupportSize support-array width. Zero when targetSupportIndices is null. -//! \param numSpecStep second dim of acceptIndex/acceptToken -//! (= max_path_len = max_draft_len + 1). -//! \param numDraftTokens total tree nodes per batch (including root). -//! \param vocabSize vocabulary size. -//! \param seed [1] int64 on GPU. Philox RNG seed. -//! \param offset [1] int64 on GPU. Philox RNG offset. -template +// Maximum siblings we track per level for the correction step. +// Matches the maximum supported K branching factor (dynamic_tree_max_topK). +constexpr int32_t kMaxTriedPerLevel = 32; + +//! \param acceptIndex output [batchSize, numSpecStep] int64 — tree positions of accepted tokens. +//! \param acceptTokenNum output [batchSize] int64 — # accepted draft tokens (excl. root). +//! \param acceptToken output [batchSize, numSpecStep] int64 — accepted/correction token ids. +//! \param candidates [batchSize, numDraftTokens] int64; col 0 = root (target sample). +//! \param targetProbs [batchSize, numDraftTokens, vocabSize] float32; full-vocab target probs. +//! \param retrieveNextToken [batchSize, numDraftTokens] int32 first-child pointer, -1=none. +//! \param retrieveNextSibling [batchSize, numDraftTokens] int32 next-sibling pointer, -1=none. +//! \param treeValid [batchSize] bool; false means no valid tree exists for this request. +//! \param batchSize batch size. +//! \param numSpeculativeTokens second dim of acceptIndex/acceptToken (= max_draft_len + 1). +//! \param numDraftTokens total tree nodes per request (including root). +//! \param vocabSize vocabulary size. +//! \param seed [1] int64 on GPU. Philox RNG seed. +//! \param offset [1] int64 on GPU. Philox RNG offset. +template __global__ void verifyDynamicTreeRejectionKernel(int64_t* acceptIndex, int64_t* acceptTokenNum, int64_t* acceptToken, - int64_t const* candidates, float const* draftInputs, float const* targetInputs, int32_t const* targetSupportIndices, - int32_t const* targetSupportLengths, int32_t const* draftProbIndices, int32_t const* retrieveNextToken, - int32_t const* retrieveNextSibling, bool const* treeValid, uint32_t batchSize, uint32_t numDraftProbRows, - uint32_t maxTargetSupportSize, uint32_t numSpeculativeTokens, uint32_t numDraftTokens, uint32_t vocabSize, - uint32_t draftVocabSize, int32_t const* targetToDraft, int64_t const* seed, int64_t const* offset, - float const* temperatures) + int64_t const* draftTokens, float const* targetProbs, int32_t const* retrieveNextToken, + int32_t const* retrieveNextSibling, bool const* treeValid, uint32_t batchSize, uint32_t numSpeculativeTokens, + uint32_t numDraftTokens, uint32_t vocabSize, int64_t const* seed, int64_t const* offset) { - uint32_t bx = blockIdx.x; + uint32_t const bx = blockIdx.x; int32_t const tid = static_cast(threadIdx.x); constexpr uint32_t kVecSize = 4; if (bx >= batchSize) @@ -1432,16 +1025,16 @@ __global__ void verifyDynamicTreeRejectionKernel(int64_t* acceptIndex, int64_t* __shared__ int32_t sWinnerIndex; __shared__ int32_t sLastValidIndex; __shared__ int64_t sSampledToken; - __shared__ float sLogitsMax; - __shared__ float sLogitsSum; - __shared__ int32_t sLogitsArgmax; - - // The first sibling that passes the rejection test at the current depth. __shared__ int32_t sAccSibIdx; __shared__ int64_t sAccSibTok; - __shared__ int32_t sNumAccSiblings; + __shared__ bool sAccepted; - uint32_t batchOffset = bx * numDraftTokens; + // Rejected siblings at the current depth (for correction sampling). + __shared__ int32_t sTriedTokenIds[kMaxTriedPerLevel]; + __shared__ int32_t sNumTriedTokens; + __shared__ float sProbResidual; // 1.0 - cumulative target prob of tried siblings + + uint32_t const batchOffset = bx * numDraftTokens; curandStatePhilox4_32_10_t state; if (tid == 0) @@ -1450,12 +1043,8 @@ __global__ void verifyDynamicTreeRejectionKernel(int64_t* acceptIndex, int64_t* static_cast(seed[0]), static_cast(bx), static_cast(offset[0]), &state); } __syncthreads(); - bool const hasCompactTargetSupport = targetSupportIndices != nullptr && targetSupportLengths != nullptr; - bool const isGreedyRequest - = USE_LOGITS && temperatures != nullptr && temperatures[bx] <= static_cast(kGreedyTempThreshold); - float const* draftProbs = draftInputs; - float const* targetProbs = targetInputs; - uint32_t const draftRowStride = USE_LOGITS ? draftVocabSize : vocabSize; + + // --- Helper lambdas --- auto canVectorizeLoad = [&](float const* probs, uint32_t rowSize) -> bool { @@ -1486,84 +1075,7 @@ __global__ void verifyDynamicTreeRejectionKernel(int64_t* acceptIndex, int64_t* return probVec; }; - auto computeLogitsStats = [&](float const* logitsRow, uint32_t rowSize) -> SoftmaxStats - { - float threadMax = -FLT_MAX; - for (uint32_t v = static_cast(tid); v < rowSize; v += BLOCK_SIZE) - { - threadMax = fmaxf(threadMax, logitsRow[v]); - } - - float const blockMax = BlockReduce(tempStorage.reduce).Reduce(threadMax, MaxFloatOp{}); - if (tid == 0) - { - sLogitsMax = blockMax; - } - __syncthreads(); - - float threadSum = 0.0f; - int32_t localArgmax = static_cast(rowSize); - for (uint32_t v = static_cast(tid); v < rowSize; v += BLOCK_SIZE) - { - float const logit = logitsRow[v]; - threadSum += __expf(logit - sLogitsMax); - if (logit == sLogitsMax && localArgmax == static_cast(rowSize)) - { - localArgmax = static_cast(v); - } - } - - float const blockSum = BlockReduce(tempStorage.reduce).Sum(threadSum); - __syncthreads(); - int32_t const blockArgmax = BlockReduceInt(tempStorage.reduceInt).Reduce(localArgmax, MinInt32Op{}); - if (tid == 0) - { - sLogitsSum = blockSum; - sLogitsArgmax = blockArgmax; - } - __syncthreads(); - - return SoftmaxStats{sLogitsMax, sLogitsSum, sLogitsArgmax}; - }; - - auto probFromLogits = [&](float const* logitsRow, uint32_t tokenId, uint32_t rowSize, SoftmaxStats stats) -> float - { - if (tokenId >= rowSize) - { - return 0.0f; - } - if (isGreedyRequest) - { - return tokenId == static_cast(stats.argmax) ? 1.0f : 0.0f; - } - constexpr float kFloatSoftmaxEpsilon = 1e-6f; - return __expf(logitsRow[tokenId] - stats.maxVal) / (stats.sumVal + kFloatSoftmaxEpsilon); - }; - - auto targetTokenToDraftToken = [&](uint32_t targetTokenId) -> int32_t - { - if (targetTokenId >= vocabSize) - { - return -1; - } - if (targetToDraft != nullptr) - { - return targetToDraft[targetTokenId]; - } - return targetTokenId < draftVocabSize ? static_cast(targetTokenId) : -1; - }; - - auto draftProbFromTargetToken - = [&](float const* draftLogitsRow, uint32_t targetTokenId, SoftmaxStats stats) -> float - { - int32_t const draftTokenId = targetTokenToDraftToken(targetTokenId); - if (draftTokenId < 0) - { - return 0.0f; - } - return probFromLogits(draftLogitsRow, static_cast(draftTokenId), draftVocabSize, stats); - }; - + // Block-parallel tile sampling: accumulate values, scan, find first >= sTargetMass. auto sampleProbTile = [&](float(&value)[kVecSize], uint32_t base) -> bool { float const tileSum = BlockReduce(tempStorage.reduce).template Sum(value); @@ -1627,6 +1139,7 @@ __global__ void verifyDynamicTreeRejectionKernel(int64_t* acceptIndex, int64_t* return false; }; + // Sample from full target distribution at tProbs. auto sampleTargetFullVocab = [&](float const* tProbs) { bool const useVectorizedLoads = canVectorizeLoad(tProbs, vocabSize); @@ -1652,7 +1165,6 @@ __global__ void verifyDynamicTreeRejectionKernel(int64_t* acceptIndex, int64_t* { value[j] = qVec[j]; } - if (sampleProbTile(value, base)) { break; @@ -1666,238 +1178,64 @@ __global__ void verifyDynamicTreeRejectionKernel(int64_t* acceptIndex, int64_t* __syncthreads(); }; - auto sampleResidualFullVocab = [&](float const* tProbs, float const* dProbs) + // Sample correction token from target excluding tried siblings. + // Correction prob at token v = target_prob[v] if v was not tried, else 0. + // probResidual = 1.0 - sum of target probs of tried siblings (pre-computed). + auto sampleResidualWithTriedTokens = [&](float const* tProbs, float probResidual) { - bool const useVectorizedTargetLoads = canVectorizeLoad(tProbs, vocabSize); - bool const useVectorizedDraftLoads = canVectorizeLoad(dProbs, vocabSize); + bool const useVectorizedLoads = canVectorizeLoad(tProbs, vocabSize); uint32_t const numIters = (vocabSize + BLOCK_SIZE * kVecSize - 1) / (BLOCK_SIZE * kVecSize); - float totalSum = 0.0f; -#pragma unroll 2 - for (uint32_t i = 0; i < numIters; ++i) - { - uint32_t const base = (i * BLOCK_SIZE + static_cast(tid)) * kVecSize; - auto const qVec = loadProbVec(tProbs, base, useVectorizedTargetLoads, vocabSize); - auto const pVec = loadProbVec(dProbs, base, useVectorizedDraftLoads, vocabSize); - float value[kVecSize]; -#pragma unroll - for (uint32_t j = 0; j < kVecSize; ++j) - { - value[j] = fmaxf(qVec[j] - pVec[j], 0.0f); - } - totalSum += BlockReduce(tempStorage.reduce).template Sum(value); - __syncthreads(); - } if (tid == 0) { - sDiffSum = totalSum; sPrefixBase = 0.0f; sWinnerIndex = static_cast(vocabSize); sLastValidIndex = -1; sSampledToken = static_cast(vocabSize) - 1; - sTargetMass = totalSum > 1e-10f ? curand_uniform_open_right(state) * totalSum : 0.0f; - if (totalSum <= 1e-10f) + if (probResidual > 1e-10f) { - sSampledToken = sampleFromDistribution(state, tProbs, vocabSize); + sTargetMass = curand_uniform_open_right(state) * probResidual; } - } - __syncthreads(); - - if (sDiffSum <= 1e-10f) - { - return; - } - -#pragma unroll 2 - for (uint32_t i = 0; i < numIters; ++i) - { - uint32_t const base = (i * BLOCK_SIZE + static_cast(tid)) * kVecSize; - auto const qVec = loadProbVec(tProbs, base, useVectorizedTargetLoads, vocabSize); - auto const pVec = loadProbVec(dProbs, base, useVectorizedDraftLoads, vocabSize); - float value[kVecSize]; -#pragma unroll - for (uint32_t j = 0; j < kVecSize; ++j) - { - value[j] = fmaxf(qVec[j] - pVec[j], 0.0f); - } - - if (sampleProbTile(value, base)) - { - break; - } - } - - if (tid == 0 && sWinnerIndex >= static_cast(vocabSize) && sLastValidIndex >= 0) - { - sSampledToken = static_cast(sLastValidIndex); - } - __syncthreads(); - }; - - auto sampleTargetLogitsFullVocabWithStats = [&](float const* tLogits, SoftmaxStats targetStats) - { - if (tid == 0) - { - constexpr float kFloatSoftmaxEpsilon = 1e-6f; - sPrefixBase = 0.0f; - sWinnerIndex = static_cast(vocabSize); - sLastValidIndex = -1; - sSampledToken = static_cast(vocabSize) - 1; - sTargetMass = isGreedyRequest - ? 0.0f - : curand_uniform_open_right(state) * (targetStats.sumVal + kFloatSoftmaxEpsilon); - if (isGreedyRequest) - { - sSampledToken = static_cast(targetStats.argmax); - } - } - __syncthreads(); - - if (isGreedyRequest) - { - return; - } - - bool const useVectorizedLoads = canVectorizeLoad(tLogits, vocabSize); - uint32_t const numIters = (vocabSize + BLOCK_SIZE * kVecSize - 1) / (BLOCK_SIZE * kVecSize); -#pragma unroll 2 - for (uint32_t i = 0; i < numIters; ++i) - { - uint32_t const base = (i * BLOCK_SIZE + static_cast(tid)) * kVecSize; - auto const qVec = loadProbVec(tLogits, base, useVectorizedLoads, vocabSize); - float value[kVecSize]; -#pragma unroll - for (uint32_t j = 0; j < kVecSize; ++j) - { - uint32_t const v = base + j; - value[j] = v < vocabSize ? __expf(qVec[j] - targetStats.maxVal) : 0.0f; - } - - if (sampleProbTile(value, base)) + else { - break; + // Nearly zero residual: fall back to argmax over full target. + sSampledToken = sampleFromDistribution(state, tProbs, vocabSize); } } - - if (tid == 0 && sWinnerIndex >= static_cast(vocabSize) && sLastValidIndex >= 0) - { - sSampledToken = static_cast(sLastValidIndex); - } __syncthreads(); - }; - auto sampleTargetLogitsFullVocab = [&](float const* tLogits) - { - auto const targetStats = computeLogitsStats(tLogits, vocabSize); - sampleTargetLogitsFullVocabWithStats(tLogits, targetStats); - }; - - auto sampleResidualLogitsFullVocab - = [&](float const* tLogits, float const* dLogits, SoftmaxStats targetStats, SoftmaxStats draftStats) - { - if (isGreedyRequest) + if (probResidual <= 1e-10f) { - if (tid == 0) - { - sSampledToken = static_cast(targetStats.argmax); - } - __syncthreads(); return; } - bool const useVectorizedTargetLoads = canVectorizeLoad(tLogits, vocabSize); - bool const useVectorizedDraftLoads = targetToDraft == nullptr && canVectorizeLoad(dLogits, draftVocabSize); - uint32_t const numIters = (vocabSize + BLOCK_SIZE * kVecSize - 1) / (BLOCK_SIZE * kVecSize); - constexpr float kFloatSoftmaxEpsilon = 1e-6f; - float totalSum = 0.0f; #pragma unroll 2 for (uint32_t i = 0; i < numIters; ++i) { uint32_t const base = (i * BLOCK_SIZE + static_cast(tid)) * kVecSize; - auto const qVec = loadProbVec(tLogits, base, useVectorizedTargetLoads, vocabSize); - flashinfer::vec_t pVec; - pVec.fill(0.0f); - if (targetToDraft == nullptr) - { - pVec = loadProbVec(dLogits, base, useVectorizedDraftLoads, draftVocabSize); - } + auto const qVec = loadProbVec(tProbs, base, useVectorizedLoads, vocabSize); float value[kVecSize]; #pragma unroll for (uint32_t j = 0; j < kVecSize; ++j) { uint32_t const v = base + j; - if (v < vocabSize) - { - float const q = __expf(qVec[j] - targetStats.maxVal) / (targetStats.sumVal + kFloatSoftmaxEpsilon); - int32_t const draftTokenId = targetTokenToDraftToken(v); - float p = 0.0f; - if (draftTokenId >= 0 && draftTokenId < static_cast(draftVocabSize)) - { - float const draftLogit = targetToDraft == nullptr ? pVec[j] : dLogits[draftTokenId]; - p = __expf(draftLogit - draftStats.maxVal) / (draftStats.sumVal + kFloatSoftmaxEpsilon); - } - value[j] = fmaxf(q - p, 0.0f); - } - else + if (v >= vocabSize) { value[j] = 0.0f; + continue; } - } - totalSum += BlockReduce(tempStorage.reduce).template Sum(value); - __syncthreads(); - } - - if (tid == 0) - { - sDiffSum = totalSum; - sPrefixBase = 0.0f; - sWinnerIndex = static_cast(vocabSize); - sLastValidIndex = -1; - sSampledToken = static_cast(vocabSize) - 1; - sTargetMass = totalSum > 1e-10f ? curand_uniform_open_right(state) * totalSum : 0.0f; - } - __syncthreads(); - - if (sDiffSum <= 1e-10f) - { - sampleTargetLogitsFullVocabWithStats(tLogits, targetStats); - return; - } - -#pragma unroll 2 - for (uint32_t i = 0; i < numIters; ++i) - { - uint32_t const base = (i * BLOCK_SIZE + static_cast(tid)) * kVecSize; - auto const qVec = loadProbVec(tLogits, base, useVectorizedTargetLoads, vocabSize); - flashinfer::vec_t pVec; - pVec.fill(0.0f); - if (targetToDraft == nullptr) - { - pVec = loadProbVec(dLogits, base, useVectorizedDraftLoads, draftVocabSize); - } - float value[kVecSize]; -#pragma unroll - for (uint32_t j = 0; j < kVecSize; ++j) - { - uint32_t const v = base + j; - if (v < vocabSize) + // Zero out tried siblings. + bool inTried = false; + for (int32_t k = 0; k < sNumTriedTokens; ++k) { - float const q = __expf(qVec[j] - targetStats.maxVal) / (targetStats.sumVal + kFloatSoftmaxEpsilon); - int32_t const draftTokenId = targetTokenToDraftToken(v); - float p = 0.0f; - if (draftTokenId >= 0 && draftTokenId < static_cast(draftVocabSize)) + if (sTriedTokenIds[k] == static_cast(v)) { - float const draftLogit = targetToDraft == nullptr ? pVec[j] : dLogits[draftTokenId]; - p = __expf(draftLogit - draftStats.maxVal) / (draftStats.sumVal + kFloatSoftmaxEpsilon); + inTried = true; + break; } - value[j] = fmaxf(q - p, 0.0f); - } - else - { - value[j] = 0.0f; } + value[j] = inTried ? 0.0f : qVec[j]; } - if (sampleProbTile(value, base)) { break; @@ -1911,32 +1249,11 @@ __global__ void verifyDynamicTreeRejectionKernel(int64_t* acceptIndex, int64_t* __syncthreads(); }; - // First-gen or dummy request: no valid tree exists yet. Sample directly - // from the target distribution at the root and skip tree traversal. + // --- First-gen / dummy request: no valid tree, sample from target root. --- if (treeValid != nullptr && !treeValid[bx]) { float const* tProbs = targetProbs + static_cast(bx) * numDraftTokens * vocabSize; - if (hasCompactTargetSupport) - { - if (tid == 0) - { - uint32_t const supportOffset = static_cast(bx) * numDraftTokens * maxTargetSupportSize; - uint32_t const supportSize = static_cast(targetSupportLengths[batchOffset]); - sSampledToken = sampleFromIndexedDistribution( - state, tProbs, targetSupportIndices + supportOffset, supportSize, vocabSize); - } - } - else - { - if constexpr (USE_LOGITS) - { - sampleTargetLogitsFullVocab(tProbs); - } - else - { - sampleTargetFullVocab(tProbs); - } - } + sampleTargetFullVocab(tProbs); if (tid == 0) { acceptIndex[bx * numSpeculativeTokens] = 0; @@ -1946,32 +1263,7 @@ __global__ void verifyDynamicTreeRejectionKernel(int64_t* acceptIndex, int64_t* return; } - // Root (depth 0): initialize path state at tree position 0. - // - // Example tree used in code review discussions: - // root: E - // children of E: F1, F2, F3 - // children of F1: G1, G2 - // children of F2: G3 - // - // In that example the per-request inputs are conceptually: - // candidates = [E, F1, F2, F3, G1, G2, G3] - // draftProbs = [p(.|E), p(.|F1), p(.|F2)] - // draftProbIndices = [0, 0, 0, 0, 1, 1, 2] - // targetProbs = [q(.|E), q(.|F1), q(.|F2), q(.|F3), q(.|G1), q(.|G2), q(.|G3)] - // - // draftProbs stores one row per unique parent context, so siblings that share - // the same parent also share the same draftProbs row via draftProbIndices. - // targetProbs remains aligned to all tree positions, including the root at slot 0. - // - // Output convention: - // - acceptIndex stores the accepted draft path as tree positions, with slot 0 - // reserved for the root position. - // - acceptToken stores the emitted token sequence, matching the greedy kernel: - // slot 0 = first emitted token - // slot numAcceptedTokens = final bonus/correction token - // - acceptTokenNum stores the number of accepted draft tokens only. The caller - // adds 1 to obtain the total number of emitted tokens. + // --- Root initialization --- if (tid == 0) { sLastAcceptedLocalIdx = 0; @@ -1980,44 +1272,22 @@ __global__ void verifyDynamicTreeRejectionKernel(int64_t* acceptIndex, int64_t* sHasTerminalToken = false; } __syncthreads(); + + // --- Main tree traversal --- for (uint32_t j = 1; j < numSpeculativeTokens; ++j) { - // Get first child of the last accepted node. if (tid == 0) { sFirstChild = retrieveNextToken[batchOffset + sLastAcceptedLocalIdx]; } __syncthreads(); - // Leaf node: no children at this depth. - // Emit bonus token from the target distribution at the last accepted position. + // Leaf node: emit bonus token from target at last accepted position. if (sFirstChild == -1) { float const* tProbs = targetProbs + (static_cast(bx) * numDraftTokens + sLastAcceptedLocalIdx) * vocabSize; - if (hasCompactTargetSupport) - { - if (tid == 0) - { - uint32_t const supportOffset - = (static_cast(bx) * numDraftTokens + sLastAcceptedLocalIdx) * maxTargetSupportSize; - uint32_t const supportSize - = static_cast(targetSupportLengths[batchOffset + sLastAcceptedLocalIdx]); - sSampledToken = sampleFromIndexedDistribution( - state, tProbs, targetSupportIndices + supportOffset, supportSize, vocabSize); - } - } - else - { - if constexpr (USE_LOGITS) - { - sampleTargetLogitsFullVocab(tProbs); - } - else - { - sampleTargetFullVocab(tProbs); - } - } + sampleTargetFullVocab(tProbs); if (tid == 0) { acceptToken[bx * numSpeculativeTokens + sNumAcceptedTokens] = sSampledToken; @@ -2027,183 +1297,87 @@ __global__ void verifyDynamicTreeRejectionKernel(int64_t* acceptIndex, int64_t* break; } - // Test siblings in linked-list order. Once a sibling passes the - // Bernoulli rejection test, accept it immediately and skip the rest. - int32_t const firstDraftProbRow = draftProbIndices[batchOffset + sFirstChild]; - float const* siblingTargetRow - = targetProbs + (static_cast(bx) * numDraftTokens + sLastAcceptedLocalIdx) * vocabSize; - float const* siblingDraftRow - = draftProbs + (static_cast(bx) * numDraftProbRows + firstDraftProbRow) * draftRowStride; - SoftmaxStats siblingTargetStats{}; - SoftmaxStats siblingDraftStats{}; - if constexpr (USE_LOGITS) - { - siblingTargetStats = computeLogitsStats(siblingTargetRow, vocabSize); - siblingDraftStats = computeLogitsStats(siblingDraftRow, draftVocabSize); - } - + // Try siblings using cumulative target probability. + // Accept the first sibling whose cumulative prob exceeds the coin. if (tid == 0) { - sNumAccSiblings = 0; + sNumTriedTokens = 0; + sAccepted = false; + float probAcc = 0.0f; + float const coin = curand_uniform_open_right(state); + + float const* parentProbs + = targetProbs + (static_cast(bx) * numDraftTokens + sLastAcceptedLocalIdx) * vocabSize; int32_t childIdx = sFirstChild; while (childIdx != -1) { - int64_t const draftTokenId = candidates[batchOffset + childIdx]; - int32_t const draftProbRow = draftProbIndices[batchOffset + childIdx]; - uint32_t const tokenId = static_cast(draftTokenId); - float const pDraft = USE_LOGITS - ? draftProbFromTargetToken(siblingDraftRow, tokenId, siblingDraftStats) - : draftProbs[(static_cast(bx) * numDraftProbRows + draftProbRow) * vocabSize - + draftTokenId]; - float const pTarget = USE_LOGITS - ? probFromLogits(siblingTargetRow, tokenId, vocabSize, siblingTargetStats) - : targetProbs[(static_cast(bx) * numDraftTokens + sLastAcceptedLocalIdx) * vocabSize - + draftTokenId]; - - float const acceptProb = fminf(1.0f, pTarget / (pDraft + 1e-10f)); - float const u = curand_uniform_open_right(state); - - if (u < acceptProb) + int64_t const tokenId = draftTokens[bx * (numDraftTokens - 1) + (childIdx - 1)]; + float const tProb = parentProbs[static_cast(tokenId)]; + probAcc += tProb; + + if (coin <= probAcc) { sAccSibIdx = childIdx; - sAccSibTok = draftTokenId; - sNumAccSiblings = 1; + sAccSibTok = tokenId; + sAccepted = true; break; } - childIdx = retrieveNextSibling[batchOffset + childIdx]; + else + { + // Clamp the counter together with the write: sTriedTokenIds is + // sized [kMaxTriedPerLevel], and sNumTriedTokens bounds the + // read loop above. Incrementing it past the array size (when a + // node has more than kMaxTriedPerLevel siblings) would make that + // loop read out-of-bounds shared memory. Only count tokens we + // actually recorded; probAcc below still accumulates every + // sibling, so the residual normalization stays correct. + if (sNumTriedTokens < kMaxTriedPerLevel) + { + sTriedTokenIds[sNumTriedTokens] = static_cast(tokenId); + sNumTriedTokens = sNumTriedTokens + 1; + } + childIdx = retrieveNextSibling[batchOffset + childIdx]; + } } + sProbResidual = 1.0f - probAcc; } __syncthreads(); - // Select the first accepted sibling or emit correction when all siblings reject. - if (sNumAccSiblings > 0) + if (sAccepted) { if (tid == 0) { - int32_t const childIdx = sAccSibIdx; acceptToken[bx * numSpeculativeTokens + sNumAcceptedTokens] = sAccSibTok; ++sNumAcceptedTokens; - acceptIndex[bx * numSpeculativeTokens + sNumAcceptedTokens] = childIdx; - sLastAcceptedLocalIdx = childIdx; + acceptIndex[bx * numSpeculativeTokens + sNumAcceptedTokens] = sAccSibIdx; + sLastAcceptedLocalIdx = sAccSibIdx; } __syncthreads(); } else { - // All siblings rejected -> sample correction token from relu(q - p). - { - int32_t const draftProbRow = draftProbIndices[batchOffset + sFirstChild]; - float const* tProbs - = targetProbs + (static_cast(bx) * numDraftTokens + sLastAcceptedLocalIdx) * vocabSize; - float const* dProbs - = draftProbs + (static_cast(bx) * numDraftProbRows + draftProbRow) * draftRowStride; - int32_t const* tProbIndices = nullptr; - uint32_t targetSupportSize = vocabSize; - if (hasCompactTargetSupport) - { - uint32_t const supportOffset - = (static_cast(bx) * numDraftTokens + sLastAcceptedLocalIdx) * maxTargetSupportSize; - tProbIndices = targetSupportIndices + supportOffset; - targetSupportSize - = static_cast(targetSupportLengths[batchOffset + sLastAcceptedLocalIdx]); - } - - if (hasCompactTargetSupport) - { - if (tid == 0) - { - float diffSum = 0.0f; - for (uint32_t i = 0; i < targetSupportSize; ++i) - { - uint32_t const v = static_cast(tProbIndices[i]); - float const diff = tProbs[v] - dProbs[v]; - if (diff > 0.0f) - { - diffSum += diff; - } - } - - int64_t corrTok = static_cast(vocabSize) - 1; - bool const useDiff = (diffSum > 1e-10f); - - if (useDiff) - { - float const r = curand_uniform_open_right(state); - float cumsum = 0.0f; - for (uint32_t i = 0; i < targetSupportSize; ++i) - { - uint32_t const v = static_cast(tProbIndices[i]); - float const diff = tProbs[v] - dProbs[v]; - float const prob = (diff > 0.0f) ? diff / diffSum : 0.0f; - cumsum += prob; - if (r <= cumsum) - { - corrTok = static_cast(v); - break; - } - } - } - else - { - corrTok = sampleFromIndexedDistribution( - state, tProbs, tProbIndices, targetSupportSize, vocabSize); - } - acceptToken[bx * numSpeculativeTokens + sNumAcceptedTokens] = corrTok; - sHasTerminalToken = true; - } - } - else - { - if constexpr (USE_LOGITS) - { - sampleResidualLogitsFullVocab(tProbs, dProbs, siblingTargetStats, siblingDraftStats); - } - else - { - sampleResidualFullVocab(tProbs, dProbs); - } + // All siblings rejected: sample correction from residual target mass. + float const* tProbs + = targetProbs + (static_cast(bx) * numDraftTokens + sLastAcceptedLocalIdx) * vocabSize; - if (tid == 0) - { - acceptToken[bx * numSpeculativeTokens + sNumAcceptedTokens] = sSampledToken; - sHasTerminalToken = true; - } - } + // Full-vocab parallel correction, excluding tried tokens. + sampleResidualWithTriedTokens(tProbs, sProbResidual); + if (tid == 0) + { + acceptToken[bx * numSpeculativeTokens + sNumAcceptedTokens] = sSampledToken; + sHasTerminalToken = true; } __syncthreads(); break; } } + // Reached max speculative depth: emit bonus token from last accepted position. if (!sHasTerminalToken) { - // Reached max speculative depth while continuing to accept the draft path. - // Emit the final bonus token from the last accepted position. float const* tProbs = targetProbs + (static_cast(bx) * numDraftTokens + sLastAcceptedLocalIdx) * vocabSize; - if (hasCompactTargetSupport) - { - if (tid == 0) - { - uint32_t const supportOffset - = (static_cast(bx) * numDraftTokens + sLastAcceptedLocalIdx) * maxTargetSupportSize; - uint32_t const supportSize - = static_cast(targetSupportLengths[batchOffset + sLastAcceptedLocalIdx]); - sSampledToken = sampleFromIndexedDistribution( - state, tProbs, targetSupportIndices + supportOffset, supportSize, vocabSize); - } - } - else - { - if constexpr (USE_LOGITS) - { - sampleTargetLogitsFullVocab(tProbs); - } - else - { - sampleTargetFullVocab(tProbs); - } - } + sampleTargetFullVocab(tProbs); if (tid == 0) { acceptToken[bx * numSpeculativeTokens + sNumAcceptedTokens] = sSampledToken; @@ -2217,21 +1391,18 @@ __global__ void verifyDynamicTreeRejectionKernel(int64_t* acceptIndex, int64_t* } void invokeVerifyDynamicTreeRejection(int64_t* acceptIndex, int64_t* acceptTokenNum, int64_t* acceptToken, - int64_t const* candidates, float const* draftProbs, float const* targetProbs, int32_t const* targetSupportIndices, - int32_t const* targetSupportLengths, int32_t const* draftProbIndices, int32_t const* retrieveNextToken, - int32_t const* retrieveNextSibling, bool const* treeValid, SizeType32 batchSize, SizeType32 numDraftProbRows, - SizeType32 maxTargetSupportSize, SizeType32 numDraftTokens, SizeType32 numSpecStep, SizeType32 vocabSize, - int64_t const* seed, int64_t const* offset, cudaStream_t stream) + int64_t const* draftTokens, float const* targetProbs, int32_t const* retrieveNextToken, + int32_t const* retrieveNextSibling, bool const* treeValid, SizeType32 batchSize, SizeType32 numDraftTokens, + SizeType32 numSpecStep, SizeType32 vocabSize, int64_t const* seed, int64_t const* offset, cudaStream_t stream) { - constexpr int32_t kVerifyDynamicTreeRejectionBlockSize = 1024; + constexpr int32_t kBlockSize = 256; dim3 grid(batchSize); - dim3 block(kVerifyDynamicTreeRejectionBlockSize); + dim3 block(kBlockSize); - verifyDynamicTreeRejectionKernel<<>>( - acceptIndex, acceptTokenNum, acceptToken, candidates, draftProbs, targetProbs, targetSupportIndices, - targetSupportLengths, draftProbIndices, retrieveNextToken, retrieveNextSibling, treeValid, batchSize, - numDraftProbRows, maxTargetSupportSize, numSpecStep, numDraftTokens, vocabSize, vocabSize, - /*targetToDraft=*/nullptr, seed, offset, nullptr); + verifyDynamicTreeRejectionKernel<<>>(acceptIndex, acceptTokenNum, acceptToken, + draftTokens, targetProbs, retrieveNextToken, retrieveNextSibling, treeValid, static_cast(batchSize), + static_cast(numSpecStep), static_cast(numDraftTokens), static_cast(vocabSize), + seed, offset); sync_check_cuda_error(stream); } diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h b/cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h index dfee15977326..eda52e7907c8 100644 --- a/cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h +++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h @@ -61,18 +61,6 @@ void invokeBuildDynamicTree(int64_t const* parentList, int64_t const* selectedIn runtime::SizeType32 topK, runtime::SizeType32 depth, runtime::SizeType32 numDraftTokens, TreeMaskMode treeMaskMode, cudaStream_t stream, runtime::SizeType32 numInt32PerRow); -//! \brief Build tree-position -> unique draft-prob row mapping for rejection sampling. -//! \param topkScoreIndices [batchSize, numDraftTokens], on GPU. int64. -//! History-buffer indices selected for each final tree position. -//! \param draftProbIndices output [batchSize, numDraftTokens + 1], on GPU. int32. -//! Column 0 is reserved for the root and set to 0. -//! \param batchSize runtime::SizeType32. Batch size. -//! \param topK runtime::SizeType32. Tree top-K branching factor. -//! \param numDraftTokens runtime::SizeType32. Total number of non-root draft positions. -//! \param stream cuda stream. -void invokeBuildDraftProbIndices(int64_t const* topkScoreIndices, int32_t* draftProbIndices, - runtime::SizeType32 batchSize, runtime::SizeType32 topK, runtime::SizeType32 numDraftTokens, cudaStream_t stream); - //! \brief Verify dynamic tree using greedy strategy with packed retrieve buffers. //! \param acceptIndex output buffer [batchSize, numSpecStep], on GPU. int32. //! \param acceptTokenNum output buffer [batchSize], on GPU. int32. @@ -91,29 +79,21 @@ void invokeVerifyDynamicTreeGreedyPacked(int32_t* acceptIndex, int32_t* acceptTo cudaStream_t stream); //! \brief Verify dynamic tree using rejection sampling. -//! For each request, traverses the tree depth-by-depth. At each depth, siblings are tried -//! in order; the first sibling accepted by rejection sampling (p_target/p_draft) continues -//! the path. If all siblings are rejected, a correction token is sampled from (target-draft)_+. +//! Accepts draft tokens by accumulating cumulative target probability across siblings. +//! The first sibling whose cumulative target prob exceeds the random coin is accepted. +//! When all siblings are rejected, a correction token is sampled from the residual +//! target distribution (target prob for tokens not tried as siblings). +//! No draft probabilities are needed. Always uses full-vocab path. +//! //! \param acceptIndex output [batchSize, numSpecStep] int64 — tree positions of accepted tokens. //! \param acceptTokenNum output [batchSize] int64 — # accepted draft tokens (excl. root). //! \param acceptToken output [batchSize, numSpecStep] int64 — accepted/correction token ids. -//! \param candidates [batchSize, numDraftTokens] int64; col 0 = root (target sample). -//! \param draftProbs [batchSize, numDraftProbRows, vocabSize] float32. -//! Unique draft probability rows per request; tree positions map into this -//! tensor via draftProbIndices. +//! \param draftTokens [batchSize, numDraftTokens-1] int64; draft token ids (excluding root). //! \param targetProbs [batchSize, numDraftTokens, vocabSize] float32; index 0 = root. -//! \param targetSupportIndices [batchSize, numDraftTokens, maxTargetSupportSize] int32; compact token ids that -//! survive top-k/top-p filtering for each row, padded with -1. May be empty when no filtering is active. -//! \param targetSupportLengths [batchSize, numDraftTokens] int32; valid support length per row. May be empty when -//! no filtering is active. -//! \param draftProbIndices [batchSize, numDraftTokens] int32; maps each tree position to the -//! corresponding row in draftProbs. Root is unused. //! \param retrieveNextToken [batchSize, numDraftTokens] int32 first-child pointer, -1=none. //! \param retrieveNextSibling [batchSize, numDraftTokens] int32 next-sibling pointer, -1=none. //! \param treeValid [batchSize] bool; false means no valid tree exists for this request. //! \param batchSize runtime::SizeType32. -//! \param numDraftProbRows runtime::SizeType32. Number of unique draft-prob rows per request. -//! \param maxTargetSupportSize runtime::SizeType32. Third dim of targetSupportIndices. Can be zero. //! \param numDraftTokens runtime::SizeType32. Total tree nodes per request (including root). //! \param numSpecStep runtime::SizeType32. Second dim of acceptIndex/acceptToken. //! \param vocabSize runtime::SizeType32. Vocabulary size. @@ -121,45 +101,10 @@ void invokeVerifyDynamicTreeGreedyPacked(int32_t* acceptIndex, int32_t* acceptTo //! \param offset [1] int64 on GPU. Philox RNG offset. //! \param stream cudaStream_t. void invokeVerifyDynamicTreeRejection(int64_t* acceptIndex, int64_t* acceptTokenNum, int64_t* acceptToken, - int64_t const* candidates, float const* draftProbs, float const* targetProbs, int32_t const* targetSupportIndices, - int32_t const* targetSupportLengths, int32_t const* draftProbIndices, int32_t const* retrieveNextToken, + int64_t const* draftTokens, float const* targetProbs, int32_t const* retrieveNextToken, int32_t const* retrieveNextSibling, bool const* treeValid, runtime::SizeType32 batchSize, - runtime::SizeType32 numDraftProbRows, runtime::SizeType32 maxTargetSupportSize, runtime::SizeType32 numDraftTokens, - runtime::SizeType32 numSpecStep, runtime::SizeType32 vocabSize, int64_t const* seed, int64_t const* offset, - cudaStream_t stream); - -//! \brief Compute draft probabilities for dynamic-tree rejection sampling from logits. -//! \param draftLogits [batchSize * numDraftProbRows, draftVocabSize], on GPU. -//! \param temperatures [batchSize], on GPU. -//! \param numDraftProbRows runtime::SizeType32. Unique draft-prob rows per request. -//! \param topK Optional [batchSize], on GPU. -//! \param topP Optional [batchSize], on GPU. -//! \param targetVocabSize runtime::SizeType32. Output vocabulary size after optional d2t expansion. -//! \param d2t Optional [draftVocabSize], on GPU. -//! \return [batchSize, numDraftProbRows, targetVocabSize] float32 probabilities. -torch::Tensor computeDraftProbsForDynamicTreeRejection(torch::Tensor const& draftLogits, - torch::Tensor const& temperatures, runtime::SizeType32 numDraftProbRows, torch::optional const& topK, - torch::optional const& topP, runtime::SizeType32 targetVocabSize, bool skipTemperature, - torch::optional const& d2t, runtime::SizeType32 kMax = 0, bool skipAllSamplingParams = false); - -//! \brief Compute target probabilities for dynamic-tree rejection sampling from logits. -//! \param targetLogits [batchSize * numDraftTokens, targetVocabSize], on GPU. -//! \param temperatures [batchSize], on GPU. -//! \param numDraftTokens runtime::SizeType32. Total tree nodes per request (including root). -//! \param topK Optional [batchSize], on GPU. -//! \param topP Optional [batchSize], on GPU. -//! \param kMax runtime::SizeType32. Max top-K value across the batch; enables the fast topk -//! path when > 0. Must be computed on CPU (e.g. topK.max().item()). Default 0 = fallback -//! to full sort. -//! \return Tuple of: -//! 1. [batchSize, numDraftTokens, targetVocabSize] float32 probabilities. -//! 2. [batchSize, numDraftTokens, maxTargetSupportSize] int32 compact token ids that survive -//! top-k/top-p filtering for each row, padded with -1. Empty when no filtering is active. -//! 3. [batchSize, numDraftTokens] int32 support lengths. Empty when no filtering is active. -std::tuple computeTargetProbsForDynamicTreeRejection( - torch::Tensor const& targetLogits, torch::Tensor const& temperatures, runtime::SizeType32 numDraftTokens, - torch::optional const& topK, torch::optional const& topP, bool skipTemperature, - runtime::SizeType32 kMax = 0, bool skipAllSamplingParams = false); + runtime::SizeType32 numDraftTokens, runtime::SizeType32 numSpecStep, runtime::SizeType32 vocabSize, + int64_t const* seed, int64_t const* offset, cudaStream_t stream); } // namespace kernels::speculative_decoding diff --git a/cpp/tensorrt_llm/thop/dynamicTreeOp.cpp b/cpp/tensorrt_llm/thop/dynamicTreeOp.cpp index 572536db8b25..0b862108a4e6 100644 --- a/cpp/tensorrt_llm/thop/dynamicTreeOp.cpp +++ b/cpp/tensorrt_llm/thop/dynamicTreeOp.cpp @@ -30,40 +30,11 @@ namespace kernels::speculative_decoding th::Tensor computeProbsFromLogits(th::Tensor const& logits, th::Tensor const& temperatures, th::optional const& topK, th::optional const& topP, bool skipTemperature, runtime::SizeType32 kMax); -void invokeBuildDraftProbIndices(int64_t const* topkScoreIndices, int32_t* draftProbIndices, - runtime::SizeType32 batchSize, runtime::SizeType32 topK, runtime::SizeType32 numDraftTokens, cudaStream_t stream); -th::Tensor computeDraftProbsForDynamicTreeRejection(th::Tensor const& draftLogits, th::Tensor const& temperatures, - runtime::SizeType32 numDraftProbRows, th::optional const& topK, th::optional const& topP, - runtime::SizeType32 targetVocabSize, bool skipTemperature, th::optional const& d2t, - runtime::SizeType32 kMax, bool skipAllSamplingParams); -std::tuple computeTargetProbsForDynamicTreeRejection(th::Tensor const& targetLogits, - th::Tensor const& temperatures, runtime::SizeType32 numDraftTokens, th::optional const& topK, - th::optional const& topP, bool skipTemperature, runtime::SizeType32 kMax, bool skipAllSamplingParams); } // namespace kernels::speculative_decoding namespace torch_ext { -void build_draft_prob_indices_out_op( - th::Tensor& topkScoreIndices, th::Tensor& draftProbIndices, int64_t topK, int64_t numDraftTokens) -{ - TORCH_CHECK(topkScoreIndices.is_cuda(), "topkScoreIndices must be a CUDA tensor"); - TORCH_CHECK(draftProbIndices.is_cuda(), "draftProbIndices must be a CUDA tensor"); - TORCH_CHECK(topkScoreIndices.dim() == 2, "topkScoreIndices must be a 2D tensor"); - TORCH_CHECK(draftProbIndices.dim() == 2, "draftProbIndices must be a 2D tensor"); - TORCH_CHECK(topkScoreIndices.scalar_type() == torch::kInt64, "topkScoreIndices must be int64 tensor"); - TORCH_CHECK(draftProbIndices.scalar_type() == torch::kInt32, "draftProbIndices must be int32 tensor"); - TORCH_CHECK(topkScoreIndices.size(1) == numDraftTokens, "topkScoreIndices size mismatch"); - TORCH_CHECK(draftProbIndices.size(0) == topkScoreIndices.size(0), "Batch size mismatch"); - TORCH_CHECK(draftProbIndices.size(1) == numDraftTokens + 1, "draftProbIndices size mismatch"); - TORCH_CHECK(topK > 0, "topK must be positive"); - TORCH_CHECK(numDraftTokens + 1 <= 1024, "numDraftTokens + 1 exceeds CUDA block size limit of 1024"); - - auto stream = at::cuda::getCurrentCUDAStream(topkScoreIndices.device().index()); - tk::invokeBuildDraftProbIndices(topkScoreIndices.data_ptr(), draftProbIndices.data_ptr(), - topkScoreIndices.size(0), topK, numDraftTokens, stream); -} - //! \brief Build dynamic tree structure (in-place, writes to pre-allocated output buffers) //! All index tensors use int64 to match PyTorch's default integer dtype. void build_dynamic_tree_op(th::Tensor& parentList, th::Tensor& selectedIndex, th::Tensor& treeMask, @@ -155,35 +126,41 @@ void verify_dynamic_tree_greedy_out_packed_op(th::Tensor& candidates, th::Tensor targetPredict.data_ptr(), treeValid.data_ptr(), batchSize, numDraftTokens, numSpecStep, stream); } -//! \brief In-place tree rejection sampling verify op. -//! Accepts draft tokens by rejection sampling at each depth using pre-computed probabilities. -void verify_dynamic_tree_rejection_out_op(th::Tensor& candidates, th::Tensor& draftProbs, th::Tensor& targetProbs, - th::Tensor& targetSupportIndices, th::Tensor& targetSupportLengths, th::Tensor& draftProbIndices, +th::Tensor compute_probs_from_logits_op(th::Tensor logits, th::Tensor temperatures, th::optional topK, + th::optional topP, bool skipTemperature) +{ + TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor"); + TORCH_CHECK(temperatures.is_cuda(), "temperatures must be a CUDA tensor"); + TORCH_CHECK(logits.dim() == 2, "logits must be a 2D tensor"); + TORCH_CHECK(temperatures.dim() == 1, "temperatures must be a 1D tensor"); + TORCH_CHECK(logits.size(0) == temperatures.size(0), "logits and temperatures size mismatch"); + if (topK.has_value() && topK->defined()) + { + TORCH_CHECK(topK->is_cuda(), "top_k must be a CUDA tensor"); + } + if (topP.has_value() && topP->defined()) + { + TORCH_CHECK(topP->is_cuda(), "top_p must be a CUDA tensor"); + } + + return tk::computeProbsFromLogits(logits, temperatures, topK, topP, skipTemperature, /*kMax=*/0); +} + +//! \brief Target-only rejection sampling verify op (no draft probabilities needed). +void verify_dynamic_tree_rejection_out_op(th::Tensor& draftTokens, th::Tensor& targetProbs, th::Tensor& retrieveNextToken, th::Tensor& retrieveNextSibling, th::Tensor& treeValid, th::Tensor& acceptIndex, th::Tensor& acceptTokenNum, th::Tensor& acceptToken, int64_t numSpecStep, th::Tensor& seed, th::Tensor& offset) { - TORCH_CHECK(candidates.dim() == 2, "candidates must be 2D tensor"); - TORCH_CHECK(draftProbs.dim() == 3, "draftProbs must be 3D tensor"); + TORCH_CHECK(draftTokens.dim() == 2, "draftTokens must be 2D tensor"); TORCH_CHECK(targetProbs.dim() == 3, "targetProbs must be 3D tensor"); - TORCH_CHECK(targetSupportIndices.dim() == 1 || targetSupportIndices.dim() == 3, - "targetSupportIndices must be 1D or 3D tensor"); - TORCH_CHECK(targetSupportLengths.dim() == 1 || targetSupportLengths.dim() == 2, - "targetSupportLengths must be 1D or 2D tensor"); - TORCH_CHECK(draftProbIndices.dim() == 2, "draftProbIndices must be 2D tensor"); TORCH_CHECK(retrieveNextToken.dim() == 2, "retrieveNextToken must be 2D tensor"); TORCH_CHECK(retrieveNextSibling.dim() == 2, "retrieveNextSibling must be 2D tensor"); TORCH_CHECK(treeValid.dim() == 1, "treeValid must be 1D tensor"); - TORCH_CHECK(candidates.scalar_type() == torch::kInt64, "candidates must be int64 tensor"); - TORCH_CHECK(draftProbs.scalar_type() == torch::kFloat32, "draftProbs must be float32 tensor"); + TORCH_CHECK(draftTokens.scalar_type() == torch::kInt64, "draftTokens must be int64 tensor"); TORCH_CHECK(targetProbs.scalar_type() == torch::kFloat32, "targetProbs must be float32 tensor"); - TORCH_CHECK(targetSupportIndices.scalar_type() == torch::kInt32, "targetSupportIndices must be int32 tensor"); - TORCH_CHECK(targetSupportLengths.scalar_type() == torch::kInt32, "targetSupportLengths must be int32 tensor"); - TORCH_CHECK(draftProbIndices.scalar_type() == torch::kInt32, "draftProbIndices must be int32 tensor"); TORCH_CHECK(treeValid.scalar_type() == torch::kBool, "treeValid must be bool tensor"); - TORCH_CHECK(candidates.is_cuda(), "candidates must be a CUDA tensor"); - TORCH_CHECK(draftProbs.is_cuda(), "draftProbs must be a CUDA tensor"); + TORCH_CHECK(draftTokens.is_cuda(), "draftTokens must be a CUDA tensor"); TORCH_CHECK(targetProbs.is_cuda(), "targetProbs must be a CUDA tensor"); - TORCH_CHECK(draftProbIndices.is_cuda(), "draftProbIndices must be a CUDA tensor"); TORCH_CHECK(retrieveNextToken.is_cuda(), "retrieveNextToken must be a CUDA tensor"); TORCH_CHECK(retrieveNextSibling.is_cuda(), "retrieveNextSibling must be a CUDA tensor"); TORCH_CHECK(treeValid.is_cuda(), "treeValid must be a CUDA tensor"); @@ -191,122 +168,70 @@ void verify_dynamic_tree_rejection_out_op(th::Tensor& candidates, th::Tensor& dr TORCH_CHECK(acceptTokenNum.is_cuda(), "acceptTokenNum must be a CUDA tensor"); TORCH_CHECK(acceptToken.is_cuda(), "acceptToken must be a CUDA tensor"); - int64_t batchSize = candidates.size(0); - int64_t numDraftProbRows = draftProbs.size(1); - int64_t numDraftTokens = candidates.size(1); - int64_t vocabSize = targetProbs.size(2); - int64_t maxTargetSupportSize = targetSupportIndices.dim() == 3 ? targetSupportIndices.size(2) : 0; + // The kernel reads/writes raw data_ptr buffers with dense row-major indexing and + // runs on the stream of draftTokens.device(), so every tensor must live on that + // same device and be contiguous; strided or cross-device tensors would silently + // read/write wrong memory or fault. + auto const device = draftTokens.device(); + TORCH_CHECK(targetProbs.device() == device, "targetProbs must be on the same device as draftTokens"); + TORCH_CHECK(retrieveNextToken.device() == device, "retrieveNextToken must be on the same device as draftTokens"); + TORCH_CHECK( + retrieveNextSibling.device() == device, "retrieveNextSibling must be on the same device as draftTokens"); + TORCH_CHECK(treeValid.device() == device, "treeValid must be on the same device as draftTokens"); + TORCH_CHECK(acceptIndex.device() == device, "acceptIndex must be on the same device as draftTokens"); + TORCH_CHECK(acceptTokenNum.device() == device, "acceptTokenNum must be on the same device as draftTokens"); + TORCH_CHECK(acceptToken.device() == device, "acceptToken must be on the same device as draftTokens"); + TORCH_CHECK(seed.device() == device, "seed must be on the same device as draftTokens"); + TORCH_CHECK(offset.device() == device, "offset must be on the same device as draftTokens"); + + TORCH_CHECK(draftTokens.is_contiguous(), "draftTokens must be contiguous"); + TORCH_CHECK(targetProbs.is_contiguous(), "targetProbs must be contiguous"); + TORCH_CHECK(retrieveNextToken.is_contiguous(), "retrieveNextToken must be contiguous"); + TORCH_CHECK(retrieveNextSibling.is_contiguous(), "retrieveNextSibling must be contiguous"); + TORCH_CHECK(treeValid.is_contiguous(), "treeValid must be contiguous"); + TORCH_CHECK(acceptIndex.is_contiguous(), "acceptIndex must be contiguous"); + TORCH_CHECK(acceptTokenNum.is_contiguous(), "acceptTokenNum must be contiguous"); + TORCH_CHECK(acceptToken.is_contiguous(), "acceptToken must be contiguous"); + TORCH_CHECK(seed.is_contiguous(), "seed must be contiguous"); + TORCH_CHECK(offset.is_contiguous(), "offset must be contiguous"); + + int64_t const batchSize = draftTokens.size(0); + // draftTokens has shape [batchSize, N-1]; numDraftTokens is the total tree nodes N (including root). + int64_t const numDraftTokens = draftTokens.size(1) + 1; + int64_t const vocabSize = targetProbs.size(2); - TORCH_CHECK(draftProbs.size(0) == batchSize, "draftProbs batch size mismatch"); - TORCH_CHECK(draftProbs.size(2) == vocabSize, "draftProbs vocabSize mismatch"); TORCH_CHECK(targetProbs.size(0) == batchSize, "targetProbs batch size mismatch"); TORCH_CHECK(targetProbs.size(1) == numDraftTokens, "targetProbs numDraftTokens mismatch"); - if (targetSupportIndices.numel() > 0) - { - TORCH_CHECK(targetSupportIndices.dim() == 3, "targetSupportIndices must be 3D when non-empty"); - TORCH_CHECK(targetSupportIndices.size(0) == batchSize, "targetSupportIndices batch size mismatch"); - TORCH_CHECK(targetSupportIndices.size(1) == numDraftTokens, "targetSupportIndices numDraftTokens mismatch"); - TORCH_CHECK(targetSupportIndices.is_cuda(), "targetSupportIndices must be a CUDA tensor when non-empty"); - TORCH_CHECK(targetSupportIndices.device() == candidates.device(), - "targetSupportIndices must be on the same device as candidates"); - } - if (targetSupportLengths.numel() > 0) - { - TORCH_CHECK(targetSupportLengths.dim() == 2, "targetSupportLengths must be 2D when non-empty"); - TORCH_CHECK(targetSupportLengths.size(0) == batchSize, "targetSupportLengths batch size mismatch"); - TORCH_CHECK(targetSupportLengths.size(1) == numDraftTokens, "targetSupportLengths numDraftTokens mismatch"); - TORCH_CHECK(targetSupportLengths.is_cuda(), "targetSupportLengths must be a CUDA tensor when non-empty"); - TORCH_CHECK(targetSupportLengths.device() == candidates.device(), - "targetSupportLengths must be on the same device as candidates"); - } - TORCH_CHECK((targetSupportIndices.numel() == 0) == (targetSupportLengths.numel() == 0), - "targetSupportIndices and targetSupportLengths must both be empty or both be non-empty"); - TORCH_CHECK(draftProbIndices.size(0) == batchSize, "draftProbIndices batch size mismatch"); - TORCH_CHECK(draftProbIndices.size(1) == numDraftTokens, "draftProbIndices size mismatch"); TORCH_CHECK(retrieveNextToken.size(0) == batchSize, "retrieveNextToken batch size mismatch"); TORCH_CHECK(retrieveNextToken.size(1) == numDraftTokens, "retrieveNextToken size mismatch"); TORCH_CHECK(retrieveNextSibling.size(0) == batchSize, "retrieveNextSibling batch size mismatch"); TORCH_CHECK(retrieveNextSibling.size(1) == numDraftTokens, "retrieveNextSibling size mismatch"); TORCH_CHECK(treeValid.size(0) >= batchSize, "treeValid buffer too small"); - TORCH_CHECK(draftProbs.device() == candidates.device(), "draftProbs must be on the same device as candidates"); - TORCH_CHECK(targetProbs.device() == candidates.device(), "targetProbs must be on the same device as candidates"); - TORCH_CHECK( - draftProbIndices.device() == candidates.device(), "draftProbIndices must be on the same device as candidates"); - TORCH_CHECK(retrieveNextToken.device() == candidates.device(), - "retrieveNextToken must be on the same device as candidates"); - TORCH_CHECK(retrieveNextSibling.device() == candidates.device(), - "retrieveNextSibling must be on the same device as candidates"); - TORCH_CHECK(treeValid.device() == candidates.device(), "treeValid must be on the same device as candidates"); TORCH_CHECK(acceptIndex.scalar_type() == torch::kInt64, "acceptIndex must be int64 tensor"); TORCH_CHECK(acceptTokenNum.scalar_type() == torch::kInt64, "acceptTokenNum must be int64 tensor"); TORCH_CHECK(acceptToken.scalar_type() == torch::kInt64, "acceptToken must be int64 tensor"); + // Guard before it reaches the kernel: numSpecStep is cast to uint32_t and used as a loop + // bound / index stride there, so a non-positive value would wrap to a huge bound and fault. + TORCH_CHECK(numSpecStep > 0, "numSpecStep must be > 0"); TORCH_CHECK(acceptIndex.size(0) >= batchSize && acceptIndex.size(1) >= numSpecStep, "acceptIndex buffer too small"); TORCH_CHECK(acceptTokenNum.size(0) >= batchSize, "acceptTokenNum buffer too small"); TORCH_CHECK(acceptToken.size(0) >= batchSize && acceptToken.size(1) >= numSpecStep, "acceptToken buffer too small"); - TORCH_CHECK(acceptIndex.device() == candidates.device(), "acceptIndex must be on the same device as candidates"); - TORCH_CHECK( - acceptTokenNum.device() == candidates.device(), "acceptTokenNum must be on the same device as candidates"); - TORCH_CHECK(acceptToken.device() == candidates.device(), "acceptToken must be on the same device as candidates"); - TORCH_CHECK(seed.scalar_type() == torch::kInt64, "seed must be int64 tensor"); - TORCH_CHECK(offset.scalar_type() == torch::kInt64, "offset must be int64 tensor"); - TORCH_CHECK(seed.numel() >= 1, "seed tensor must have at least one element"); - TORCH_CHECK(offset.numel() >= 1, "offset tensor must have at least one element"); - TORCH_CHECK(seed.is_cuda(), "seed must be CUDA tensor"); - TORCH_CHECK(offset.is_cuda(), "offset must be CUDA tensor"); - TORCH_CHECK(seed.device() == candidates.device(), "seed must be on the same device as candidates"); - TORCH_CHECK(offset.device() == candidates.device(), "offset must be on the same device as candidates"); + TORCH_CHECK(seed.scalar_type() == torch::kInt64 && seed.numel() >= 1 && seed.is_cuda(), + "seed must be int64 CUDA tensor with >=1 element"); + TORCH_CHECK(offset.scalar_type() == torch::kInt64 && offset.numel() >= 1 && offset.is_cuda(), + "offset must be int64 CUDA tensor with >=1 element"); - auto stream = at::cuda::getCurrentCUDAStream(candidates.device().index()); + auto stream = at::cuda::getCurrentCUDAStream(draftTokens.device().index()); acceptIndex.zero_(); acceptTokenNum.zero_(); acceptToken.zero_(); tk::invokeVerifyDynamicTreeRejection(acceptIndex.data_ptr(), acceptTokenNum.data_ptr(), - acceptToken.data_ptr(), candidates.data_ptr(), draftProbs.data_ptr(), - targetProbs.data_ptr(), - targetSupportIndices.numel() > 0 ? targetSupportIndices.data_ptr() : nullptr, - targetSupportLengths.numel() > 0 ? targetSupportLengths.data_ptr() : nullptr, - draftProbIndices.data_ptr(), retrieveNextToken.data_ptr(), - retrieveNextSibling.data_ptr(), treeValid.data_ptr(), batchSize, numDraftProbRows, - maxTargetSupportSize, numDraftTokens, numSpecStep, vocabSize, seed.data_ptr(), - offset.data_ptr(), stream); -} - -th::Tensor compute_draft_probs_for_dynamic_tree_rejection_op(th::Tensor draftLogits, th::Tensor temperatures, - int64_t numDraftProbRows, int64_t targetVocabSize, th::optional topK, th::optional topP, - bool skipTemperature, th::optional d2t, int64_t topKMax, bool skipAllSamplingParams) -{ - return tk::computeDraftProbsForDynamicTreeRejection(draftLogits, temperatures, numDraftProbRows, topK, topP, - targetVocabSize, skipTemperature, d2t, topKMax, skipAllSamplingParams); -} - -std::tuple compute_target_probs_for_dynamic_tree_rejection_op( - th::Tensor targetLogits, th::Tensor temperatures, int64_t numDraftTokens, th::optional topK, - th::optional topP, bool skipTemperature, int64_t topKMax, bool skipAllSamplingParams) -{ - return tk::computeTargetProbsForDynamicTreeRejection( - targetLogits, temperatures, numDraftTokens, topK, topP, skipTemperature, topKMax, skipAllSamplingParams); -} - -th::Tensor compute_probs_from_logits_op(th::Tensor logits, th::Tensor temperatures, th::optional topK, - th::optional topP, bool skipTemperature) -{ - TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor"); - TORCH_CHECK(temperatures.is_cuda(), "temperatures must be a CUDA tensor"); - TORCH_CHECK(logits.dim() == 2, "logits must be a 2D tensor"); - TORCH_CHECK(temperatures.dim() == 1, "temperatures must be a 1D tensor"); - TORCH_CHECK(logits.size(0) == temperatures.size(0), "logits and temperatures size mismatch"); - if (topK.has_value() && topK->defined()) - { - TORCH_CHECK(topK->is_cuda(), "top_k must be a CUDA tensor"); - } - if (topP.has_value() && topP->defined()) - { - TORCH_CHECK(topP->is_cuda(), "top_p must be a CUDA tensor"); - } - - return tk::computeProbsFromLogits(logits, temperatures, topK, topP, skipTemperature, /*kMax=*/0); + acceptToken.data_ptr(), draftTokens.data_ptr(), targetProbs.data_ptr(), + retrieveNextToken.data_ptr(), retrieveNextSibling.data_ptr(), treeValid.data_ptr(), + batchSize, numDraftTokens, numSpecStep, vocabSize, seed.data_ptr(), offset.data_ptr(), + stream); } } // namespace torch_ext @@ -315,18 +240,6 @@ TRTLLM_NAMESPACE_END //////////////////////////////////////////////////////////////////////////////////////////////////////////// -TORCH_LIBRARY_FRAGMENT(trtllm, m) -{ - m.def( - "build_draft_prob_indices_out_op(Tensor topkScoreIndices, Tensor(a!) draftProbIndices, " - "int topK, int numDraftTokens) -> ()"); -} - -TORCH_LIBRARY_IMPL(trtllm, CUDA, m) -{ - m.impl("build_draft_prob_indices_out_op", &tensorrt_llm::torch_ext::build_draft_prob_indices_out_op); -} - TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( @@ -363,8 +276,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "verify_dynamic_tree_rejection_out_op(" - "Tensor candidates, Tensor draftProbs, Tensor targetProbs, Tensor targetSupportIndices, " - "Tensor targetSupportLengths, Tensor draftProbIndices, " + "Tensor draftTokens, Tensor targetProbs, " "Tensor retrieveNextToken, Tensor retrieveNextSibling, Tensor treeValid, " "Tensor(a!) acceptIndex, Tensor(b!) acceptTokenNum, Tensor(c!) acceptToken, " "int numSpecStep, Tensor seed, Tensor offset) -> ()"); @@ -375,36 +287,6 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m) m.impl("verify_dynamic_tree_rejection_out_op", &tensorrt_llm::torch_ext::verify_dynamic_tree_rejection_out_op); } -TORCH_LIBRARY_FRAGMENT(trtllm, m) -{ - m.def( - "compute_draft_probs_for_dynamic_tree_rejection_op(" - "Tensor draftLogits, Tensor temperatures, int numDraftProbRows, int targetVocabSize, " - "Tensor? top_k=None, Tensor? top_p=None, bool skip_temperature=False, Tensor? d2t=None, " - "int top_k_max=0, bool skip_all_sampling_params=False) -> Tensor"); -} - -TORCH_LIBRARY_IMPL(trtllm, CUDA, m) -{ - m.impl("compute_draft_probs_for_dynamic_tree_rejection_op", - &tensorrt_llm::torch_ext::compute_draft_probs_for_dynamic_tree_rejection_op); -} - -TORCH_LIBRARY_FRAGMENT(trtllm, m) -{ - m.def( - "compute_target_probs_for_dynamic_tree_rejection_op(" - "Tensor targetLogits, Tensor temperatures, int numDraftTokens, " - "Tensor? top_k=None, Tensor? top_p=None, bool skip_temperature=False, int top_k_max=0, " - "bool skip_all_sampling_params=False) -> (Tensor, Tensor, Tensor)"); -} - -TORCH_LIBRARY_IMPL(trtllm, CUDA, m) -{ - m.impl("compute_target_probs_for_dynamic_tree_rejection_op", - &tensorrt_llm::torch_ext::compute_target_probs_for_dynamic_tree_rejection_op); -} - TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( diff --git a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py index 345d6fac7e29..2eea3900d266 100644 --- a/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py @@ -1247,63 +1247,10 @@ def _(logits: torch.Tensor, skip_temperature: bool = False) -> torch.Tensor: return logits.new_empty(list(logits.shape), dtype=torch.float32) - @torch.library.register_fake("trtllm::build_draft_prob_indices_out_op") - def _(topkScoreIndices: torch.Tensor, draftProbIndices: torch.Tensor, - topK: int, numDraftTokens: int) -> None: - return None - @torch.library.register_fake("trtllm::verify_dynamic_tree_rejection_out_op") - def _(candidates: torch.Tensor, draftProbs: torch.Tensor, - targetProbs: torch.Tensor, targetSupportIndices: torch.Tensor, - targetSupportLengths: torch.Tensor, draftProbIndices: torch.Tensor, + def _(draftTokens: torch.Tensor, targetProbs: torch.Tensor, retrieveNextToken: torch.Tensor, retrieveNextSibling: torch.Tensor, treeValid: torch.Tensor, acceptIndex: torch.Tensor, acceptTokenNum: torch.Tensor, acceptToken: torch.Tensor, numSpecStep: int, seed: torch.Tensor, offset: torch.Tensor) -> None: return None - - @torch.library.register_fake( - "trtllm::compute_draft_probs_for_dynamic_tree_rejection_op") - def _(draftLogits: torch.Tensor, - temperatures: torch.Tensor, - numDraftProbRows: int, - targetVocabSize: int, - top_k: Optional[torch.Tensor] = None, - top_p: Optional[torch.Tensor] = None, - skip_temperature: bool = False, - d2t: Optional[torch.Tensor] = None, - top_k_max: int = 0, - skip_all_sampling_params: bool = False) -> torch.Tensor: - batch_size = temperatures.shape[0] - return draftLogits.new_empty( - (batch_size, numDraftProbRows, targetVocabSize), - dtype=torch.float32) - - @torch.library.register_fake( - "trtllm::compute_target_probs_for_dynamic_tree_rejection_op") - def _( - targetLogits: torch.Tensor, - temperatures: torch.Tensor, - numDraftTokens: int, - top_k: Optional[torch.Tensor] = None, - top_p: Optional[torch.Tensor] = None, - skip_temperature: bool = False, - top_k_max: int = 0, - skip_all_sampling_params: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - batch_size = temperatures.shape[0] - target_vocab_size = targetLogits.shape[-1] - target_probs = targetLogits.new_empty( - (batch_size, numDraftTokens, target_vocab_size), - dtype=torch.float32) - has_filtering = (top_k is not None) or (top_p is not None) - if skip_all_sampling_params or not has_filtering: - support_indices = targetLogits.new_empty((0, ), dtype=torch.int32) - support_lengths = targetLogits.new_empty((0, ), dtype=torch.int32) - else: - support_dim = top_k_max if top_k_max > 0 else target_vocab_size - support_indices = targetLogits.new_empty( - (batch_size, numDraftTokens, support_dim), dtype=torch.int32) - support_lengths = targetLogits.new_empty( - (batch_size, numDraftTokens), dtype=torch.int32) - return target_probs, support_indices, support_lengths diff --git a/tensorrt_llm/_torch/speculative/dynamic_tree_ops.py b/tensorrt_llm/_torch/speculative/dynamic_tree_ops.py index bfb520f2d733..35a0bd992557 100644 --- a/tensorrt_llm/_torch/speculative/dynamic_tree_ops.py +++ b/tensorrt_llm/_torch/speculative/dynamic_tree_ops.py @@ -27,6 +27,8 @@ import torch +from tensorrt_llm._torch.speculative.one_model_sampler import compute_probs_from_logits + class DynamicTreeOpsConverter: """ @@ -218,101 +220,67 @@ def verify_dynamic_tree_greedy_out_packed( return accept_index, accept_token_num, accept_token - def verify_dynamic_tree_rejection_from_logits_out( + def verify_dynamic_tree_rejection_out( self, - candidates: torch.Tensor, - draft_logits_tree: torch.Tensor, + draft_tokens: torch.Tensor, target_logits_tree: torch.Tensor, - draft_prob_indices: torch.Tensor, retrieve_next_token: torch.Tensor, retrieve_next_sibling: torch.Tensor, tree_valid: torch.Tensor, temperatures: torch.Tensor, top_k: torch.Tensor | None, top_p: torch.Tensor | None, - skip_temperature: bool, num_gens: int, num_spec_step: int, seed: int | torch.Tensor = 0, offset: int | torch.Tensor = 0, - d2t: torch.Tensor | None = None, - skip_all_sampling_params: bool = False, - top_k_max: int | None = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Tree-aware rejection sampling from logits (three CUDA ops). + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Dynamic tree rejection sampling. - This path keeps draft/target logits as inputs, computes unique draft - and target probabilities with separate CUDA ops, then runs the tree - rejection kernel as a third CUDA op. `draft_prob_indices` maps each - tree position to its shared draft-prob row. `tree_valid` guards - first-gen and dummy requests that do not have a usable tree yet. + Computes target probabilities from logits, then runs the target-only + rejection kernel. No draft probabilities are needed. + `tree_valid` guards first-gen and dummy requests without a valid tree. """ accept_index = self._rej_accept_index_buf[:num_gens] accept_token = self._rej_accept_token_buf[:num_gens] accept_tok_num = self._rej_accept_token_num_buf[:num_gens] seed_tensor = self._get_rejection_rng_tensor(seed, self._rej_seed_buf, "seed") offset_tensor = self._get_rejection_rng_tensor(offset, self._rej_offset_buf, "offset") - num_draft_tokens = candidates.shape[1] if num_gens <= 0: raise ValueError(f"num_gens must be positive, got {num_gens}") - if draft_logits_tree.shape[0] % num_gens != 0: + if target_logits_tree.shape[0] % num_gens != 0: raise ValueError( - f"draft_logits_tree rows ({draft_logits_tree.shape[0]}) must be divisible by " - f"num_gens ({num_gens})" + "target_logits_tree rows must be divisible by num_gens, got " + f"{target_logits_tree.shape[0]} and {num_gens}" ) - num_draft_prob_rows = draft_logits_tree.shape[0] // num_gens - target_vocab_size = target_logits_tree.shape[-1] + # draft_tokens has shape [num_gens, N-1]; derive total tree nodes N from target_logits_tree. + num_draft_tokens = target_logits_tree.shape[0] // num_gens if tree_valid is None: - tree_valid = torch.ones(num_gens, dtype=torch.bool, device=candidates.device) + tree_valid = torch.ones(num_gens, dtype=torch.bool, device=draft_tokens.device) tree_valid = tree_valid.contiguous() - if top_k_max is not None: - # Pre-computed CPU-side (CUDA-graph-safe): use as-is. - pass - elif top_k is None: - top_k_max = 0 - else: - # Fallback path (non-CUDA-graph contexts): compute from tensor. - enabled_top_k = top_k[(top_k > 0) & (top_k < target_vocab_size)] - top_k_max = int(enabled_top_k.max().item()) if enabled_top_k.numel() > 0 else 0 - - try: - draft_probs_tree = torch.ops.trtllm.compute_draft_probs_for_dynamic_tree_rejection_op( - draft_logits_tree, - temperatures, - num_draft_prob_rows, - target_vocab_size, - top_k, - top_p, - skip_temperature, - d2t=d2t, - top_k_max=top_k_max, - skip_all_sampling_params=skip_all_sampling_params, - ) + # Expand per-request sampling params to per-tree-position (num_gens * N rows). + temps_exp = temperatures.repeat_interleave(num_draft_tokens) + top_k_exp = top_k.repeat_interleave(num_draft_tokens) if top_k is not None else None + top_p_exp = top_p.repeat_interleave(num_draft_tokens) if top_p is not None else None - ( - target_probs_tree, - target_support_indices, - target_support_lengths, - ) = torch.ops.trtllm.compute_target_probs_for_dynamic_tree_rejection_op( - target_logits_tree, - temperatures, - num_draft_tokens, - top_k, - top_p, - skip_temperature, - top_k_max=top_k_max, - skip_all_sampling_params=skip_all_sampling_params, - ) + # Compute target probs using the shared linear-path interface (FlashInfer fast + # path when available, sort-based fallback otherwise). Returns dense full-vocab + # probs [num_gens * N, vocab_size]; no sparse support indices needed. + target_probs_flat = compute_probs_from_logits( + target_logits_tree, + temps_exp, + top_k_exp, + top_p_exp, + ) + vocab_size = target_probs_flat.shape[-1] + target_probs_tree = target_probs_flat.reshape(num_gens, num_draft_tokens, vocab_size) + try: torch.ops.trtllm.verify_dynamic_tree_rejection_out_op( - candidates, - draft_probs_tree, + draft_tokens, target_probs_tree, - target_support_indices, - target_support_lengths, - draft_prob_indices, retrieve_next_token, retrieve_next_sibling, tree_valid, @@ -325,10 +293,9 @@ def verify_dynamic_tree_rejection_from_logits_out( ) except Exception as e: raise RuntimeError( - f"dynamic tree rejection op chain failed: {e}\n" - f"Inputs: num_gens={num_gens}, N={candidates.shape[1]}, " - f"draft_vocab={draft_logits_tree.shape[-1]}, " + f"dynamic tree rejection target-only op chain failed: {e}\n" + f"Inputs: num_gens={num_gens}, N={draft_tokens.shape[1] + 1}, " f"target_vocab={target_logits_tree.shape[-1]}, num_spec_step={num_spec_step}" ) from e - return target_support_indices, accept_index, accept_tok_num, accept_token + return accept_index, accept_tok_num, accept_token diff --git a/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py b/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py index e7a6746d7c70..7c80e534d72d 100644 --- a/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py +++ b/tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py @@ -15,7 +15,7 @@ """Eagle3 one-model dynamic tree speculative decoding.""" import math -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import torch import triton @@ -25,6 +25,7 @@ from ..attention_backend import AttentionMetadata from .eagle3 import Eagle3OneModelWorker +from .one_model_sampler import sampling_batch_spec_dec_one_model if TYPE_CHECKING: from ...llmapi.llm_args import EagleDecodingConfig @@ -293,9 +294,6 @@ def __init__( self._candidates_buf = torch.zeros( max_batch_size, tokens_per_gen_step, dtype=torch.int32, device="cuda" ) - self._rejection_candidates_buf = torch.zeros( - max_batch_size, tokens_per_gen_step, dtype=torch.int64, device="cuda" - ) self._target_predict_buf = torch.zeros( max_batch_size, tokens_per_gen_step, dtype=torch.int32, device="cuda" ) @@ -319,25 +317,8 @@ def __init__( sm = get_sm_version() self._needs_mask_repack = sm < 100 or sm in (120, 121) - # Rejection sampling buffers - # Unique draft logits per request, one row per distinct parent context. - # Shape: [max_batch_size, 1 + (max_draft_len - 1) * K, vocab_size] - # row 0 -> p(.|root) - # rows [1 : 1 + K] -> p(.|depth-0 parent_k) - # rows [1 + K : 1 + 2K] -> p(.|depth-1 parent_k) - # This avoids materializing repeated per-tree-position logits such as - # [p(.|E), p(.|E), p(.|E), p(.|F1), p(.|F1), p(.|F2)]. - self._draft_depth_logits_cat: Optional[torch.Tensor] = None - # topk_score_indices from resampling_final_draft_tokens for path tracing. - # Shape: [max_batch_size, max_total_draft_tokens] - self._topk_score_indices_buf = torch.zeros( - max_batch_size, self.max_total_draft_tokens, dtype=torch.int64, device="cuda" - ) - # Tree position -> unique draft-prob row mapping for rejection sampling. - # Shape: [max_batch_size, max_total_draft_tokens + 1], root row is unused and kept at 0. - self._draft_prob_indices_buf = torch.zeros( - max_batch_size, self.max_total_draft_tokens + 1, dtype=torch.int32, device="cuda" - ) + # No draft-side buffers needed: target-only rejection sampling does not + # require unique draft logits, topk score indices, or draft prob indices. def _repack_mask_padded_to_packed(self, mask_buf, n_req, n_tok): """XQA indexes mask flat via cuQSeqLens with per-request stride @@ -476,9 +457,14 @@ def sample_and_accept_draft_tokens(self, input_ids, logits, attn_metadata, spec_ num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts - accepted_tokens, num_accepted_tokens = self._sample_and_accept_dynamic_tree( - logits, attn_metadata, spec_metadata, batch_size, num_contexts, num_gens - ) + if self._can_use_rejection_sampling(spec_metadata): + accepted_tokens, num_accepted_tokens = self._sample_and_accept_dynamic_tree_rejection( + logits, attn_metadata, spec_metadata, batch_size, num_contexts, num_gens + ) + else: + accepted_tokens, num_accepted_tokens = self._sample_and_accept_dynamic_tree( + logits, attn_metadata, spec_metadata, batch_size, num_contexts, num_gens + ) if num_gens > 0: self._relocate_kv_eagerly(attn_metadata, batch_size) return accepted_tokens, num_accepted_tokens @@ -650,11 +636,6 @@ def _forward_draft_loop( hidden_states[gather_ids], draft_model.lm_head, attn_metadata, True ) - # Capture unique draft logits for the root parent context. - if spec_metadata.use_rejection_sampling and num_gens > 0: - self._lazy_alloc_draft_logits_buf(logits.shape[-1], logits.dtype, logits.device) - self._draft_depth_logits_cat[:num_gens, 0].copy_(logits[num_contexts:]) - new_draft_tokens, new_draft_scores = self.sample( logits, self.K, draft_model=draft_model ) @@ -717,13 +698,6 @@ def _forward_draft_loop( selected_hs, draft_model.lm_head, attn_metadata, True ) - # Capture unique draft logits for the K parent contexts at this depth. - if spec_metadata.use_rejection_sampling and num_gens > 0: - row_start = 1 + (layer_idx - 1) * self.K - self._draft_depth_logits_cat[:num_gens, row_start : row_start + self.K].copy_( - logits[num_contexts * self.K :].reshape(num_gens, self.K, -1) - ) - new_draft_tokens, new_draft_scores = self.sample( logits, self.K, draft_model=draft_model ) @@ -753,15 +727,6 @@ def _forward_draft_loop( # Resample final tokens and build tree real_draft_tokens, topk_score_indices = self.resampling_final_draft_tokens(batch_size) - # Save topk_score_indices for rejection sampling path tracing. - # Rejection sampling needs to map each final draft token back to its history - # buffer index to retrieve the corresponding draft logits. Greedy verification - # doesn't need this mapping since it only compares token IDs. - # Only save the gen rows (skip context rows) so that - # _build_draft_prob_indices can read them back at [:num_gens]. - if spec_metadata.use_rejection_sampling: - self._topk_score_indices_buf[:num_gens].copy_(topk_score_indices[num_contexts:]) - if spec_tree_manager is not None: # Build into contiguous work buffers indexed by bid (0..num_gens-1). # Cannot use slotIds because padded dummies share dummy_slot_id. @@ -787,7 +752,7 @@ def _forward_draft_loop( def _sample_and_accept_dynamic_tree( self, logits, attn_metadata, spec_metadata, batch_size, num_contexts, num_gens ): - """Dynamic tree verification using CUDA kernel.""" + """Dynamic tree verification: greedy and non-greedy-without-rejection paths.""" if num_gens > self._max_batch_size: raise RuntimeError( f"Dynamic tree batch size {num_gens} exceeds configured " @@ -804,8 +769,29 @@ def _sample_and_accept_dynamic_tree( self._accepted_draft_indices_tensor[:batch_size].fill_(-1) num_flat_tokens = logits.shape[0] - # torch.argmax writes LongTensor indices; token storage below remains int32. - torch.argmax(logits, dim=-1, out=self._target_tokens_buf[:num_flat_tokens]) + if not spec_metadata.is_all_greedy_sample: + # Non-greedy: sample target tokens with per-request temperature/top_k/top_p. + # Lazily initialize RNG tensors for CUDA graph compatibility. + if self.seed is None: + self.seed = torch.tensor([0], dtype=torch.int64, device=logits.device) + self.offset = torch.tensor([0], dtype=torch.int64, device=logits.device) + self.seed.add_(1).remainder_(2**31) + top_ks = spec_metadata.top_ks[:num_flat_tokens] + if self.use_flashinfer: + top_ks = top_ks.clamp(min=1, max=logits.shape[-1] - 1) + sampled = sampling_batch_spec_dec_one_model( + logits, + spec_metadata.temperatures[:num_flat_tokens], + top_ks, + spec_metadata.top_ps[:num_flat_tokens], + use_flashinfer=self.use_flashinfer, + seed=self.seed, + offset=self.offset, + ) + self._target_tokens_buf[:num_flat_tokens].copy_(sampled) + else: + # Greedy fast path (CUDA graph key: is_all_greedy_sample=True). + torch.argmax(logits, dim=-1, out=self._target_tokens_buf[:num_flat_tokens]) target_tokens = self._target_tokens_buf[:num_flat_tokens] # Context requests: accept sampled token @@ -834,115 +820,145 @@ def _sample_and_accept_dynamic_tree( tree_valid = slot_storage.has_tree[gen_slot_ids] retrieve_packed = slot_storage.pack_retrieve_from_slots(gen_slot_ids, num_gens) - if self._can_use_rejection_sampling(spec_metadata): - vocab_size = logits.shape[-1] - num_ctx_tokens = logits.shape[0] - num_gens * N - device = logits.device - - draft_logits_tree = self._get_unique_draft_logits(num_gens) - draft_prob_indices = self._build_draft_prob_indices(num_gens) - target_logits_tree = logits[num_ctx_tokens:].reshape(-1, vocab_size) - gen_slice = slice(num_contexts, num_contexts + num_gens) - skip_top_k = getattr(spec_metadata, "skip_top_k", False) - skip_top_p = getattr(spec_metadata, "skip_top_p", False) - skip_temperature = getattr(spec_metadata, "skip_temperature", False) - - if spec_metadata.request_temperatures is None: - temps = torch.ones(num_gens, dtype=torch.float32, device=device) - skip_temperature = True - else: - temps = spec_metadata.request_temperatures[gen_slice] - - top_ks = None - if not skip_top_k and spec_metadata.request_top_ks is not None: - top_ks = spec_metadata.request_top_ks[gen_slice] - - top_ps = None - if not skip_top_p and spec_metadata.request_top_ps is not None: - top_ps = spec_metadata.request_top_ps[gen_slice] - - skip_all_sampling_params = ( - skip_temperature - and skip_top_k - and skip_top_p - and not getattr(spec_metadata, "has_greedy_requests", False) + accept_index, accept_token_num, accept_token = ( + self.tree_ops_converter.verify_dynamic_tree_greedy_out_packed( + candidates, + retrieve_packed, + target_predict, + num_gens, + self._max_path_len, + tree_valid=tree_valid, ) + ) - # Lazily initialize seed/offset tensors on correct device - if self.seed is None: - self.seed = torch.tensor([0], dtype=torch.int64, device=device) - self.offset = torch.tensor([0], dtype=torch.int64, device=device) - # Use in-place operations for CUDA graph compatibility - self.seed.add_(1).remainder_(2**31) - rejection_candidates = self._rejection_candidates_buf[:num_gens] - rejection_candidates[:, 1:].copy_( - spec_metadata.draft_tokens.reshape(num_gens, N - 1) - ) - rejection_candidates[:, 0].copy_( - target_tokens[num_contexts:].reshape(num_gens, N)[:, 0] - ) - retrieve_next_token, retrieve_next_sibling = slot_storage.next_links_from_slots( - gen_slot_ids, num_gens - ) + self._finalize_dynamic_tree_verify_outputs( + accept_index=accept_index, + accept_token_num=accept_token_num, + accept_token=accept_token, + accepted_tokens=accepted_tokens, + num_accepted_tokens=num_accepted_tokens, + num_contexts=num_contexts, + batch_size=batch_size, + num_gens=num_gens, + max_path_len=max_path_len, + ) - _, accept_index, accept_token_num, accept_token = ( - self.tree_ops_converter.verify_dynamic_tree_rejection_from_logits_out( - rejection_candidates, - draft_logits_tree, - target_logits_tree, - draft_prob_indices, - retrieve_next_token, - retrieve_next_sibling, - tree_valid, - temps, - top_ks, - top_ps, - skip_temperature, - num_gens, - self._max_path_len, - seed=self.seed, - offset=self.offset, - d2t=self._d2t, - skip_all_sampling_params=skip_all_sampling_params, - # During CUDA graph capture bake top_k_max=0 so the - # full-sort (always-correct) path is captured. Outside - # capture, pass the pre-computed value for the fast - # topk(kMax) path. - top_k_max=( - 0 - if torch.cuda.is_current_stream_capturing() - else getattr(spec_metadata, "top_k_max", None) - ), - ) - ) + num_accepted_tokens = self._apply_force_accepted_tokens( + num_accepted_tokens, num_contexts, self.max_draft_len + ) - self._finalize_dynamic_tree_verify_outputs( - accept_index=accept_index, - accept_token_num=accept_token_num, - accept_token=accept_token, - accepted_tokens=accepted_tokens, - num_accepted_tokens=num_accepted_tokens, - num_contexts=num_contexts, - batch_size=batch_size, - num_gens=num_gens, - max_path_len=max_path_len, - ) - num_accepted_tokens = self._apply_force_accepted_tokens( - num_accepted_tokens, num_contexts, self.max_draft_len + return accepted_tokens, num_accepted_tokens + + def _sample_and_accept_dynamic_tree_rejection( + self, logits, attn_metadata, spec_metadata, batch_size, num_contexts, num_gens + ): + """Dynamic tree rejection sampling path (non-greedy only). + + Context tokens are sampled directly (they bypass the rejection kernel). + Gen tokens are handled entirely by the rejection kernel, which samples + from the target distribution internally — so only the raw logits (not + pre-sampled token IDs) are passed to the kernel. + """ + if num_gens > self._max_batch_size: + raise RuntimeError( + f"Dynamic tree batch size {num_gens} exceeds configured " + f"max_batch_size {self._max_batch_size}" + ) + N = self.tokens_per_gen_step + max_path_len = self._max_path_len + vocab_size = logits.shape[-1] + device = logits.device + + # Reset output buffers + self._accepted_tokens_buf[:batch_size].zero_() + accepted_tokens = self._accepted_tokens_buf[:batch_size, :max_path_len] + self._num_accepted_tokens_buf[:batch_size].fill_(1) + num_accepted_tokens = self._num_accepted_tokens_buf[:batch_size] + self._accepted_draft_indices_tensor[:batch_size].fill_(-1) + + # Lazily initialize RNG tensors (needed by rejection kernel). + if self.seed is None: + self.seed = torch.tensor([0], dtype=torch.int64, device=device) + self.offset = torch.tensor([0], dtype=torch.int64, device=device) + self.seed.add_(1).remainder_(2**31) + + # Context tokens bypass the rejection kernel — sample them directly. + if num_contexts > 0: + top_ks_ctx = spec_metadata.top_ks[:num_contexts] + if self.use_flashinfer: + top_ks_ctx = top_ks_ctx.clamp(min=1, max=vocab_size - 1) + sampled_ctx = sampling_batch_spec_dec_one_model( + logits[:num_contexts], + spec_metadata.temperatures[:num_contexts], + top_ks_ctx, + spec_metadata.top_ps[:num_contexts], + use_flashinfer=self.use_flashinfer, + seed=self.seed, + offset=self.offset, + ) + accepted_tokens[:num_contexts, 0].copy_(sampled_ctx) + + if num_gens > 0: + spec_tree_manager = self.spec_tree_manager + + if spec_tree_manager is None: + # CUDA graph warmup: accept only the root token per request. + # Use argmax of root-position logits as the accepted token. + gen_root_tokens = torch.argmax( + logits[num_contexts:].reshape(num_gens, N, vocab_size)[:, 0, :], dim=-1 ) + num_accepted_tokens[num_contexts:batch_size] = 1 + accepted_tokens[num_contexts:batch_size, 0] = gen_root_tokens + self._accepted_draft_indices_tensor[num_contexts:batch_size] = -1 return accepted_tokens, num_accepted_tokens + target_logits_tree = logits[-num_gens * N :].reshape(-1, vocab_size) + gen_slice = slice(num_contexts, num_contexts + num_gens) + temps = spec_metadata.request_temperatures[gen_slice] + top_ks_rej = spec_metadata.request_top_ks[gen_slice] + top_ps_rej = spec_metadata.request_top_ps[gen_slice] + + slot_storage = spec_tree_manager.slot_storage + gen_slot_ids = slot_storage.all_ids_buf[num_contexts : num_contexts + num_gens] + tree_valid = slot_storage.has_tree[gen_slot_ids] + + retrieve_next_token, retrieve_next_sibling = slot_storage.next_links_from_slots( + gen_slot_ids, num_gens + ) + accept_index, accept_token_num, accept_token = ( - self.tree_ops_converter.verify_dynamic_tree_greedy_out_packed( - candidates, - retrieve_packed, - target_predict, + self.tree_ops_converter.verify_dynamic_tree_rejection_out( + spec_metadata.draft_tokens.reshape(num_gens, N - 1).long(), + target_logits_tree, + retrieve_next_token, + retrieve_next_sibling, + tree_valid, + temps, + top_ks_rej, + top_ps_rej, num_gens, self._max_path_len, - tree_valid=tree_valid, + seed=self.seed, + offset=self.offset, ) ) + if self.force_num_accepted_tokens != 0.0: + # Fill accept_token positions 1..max_path_len-1 with draft tokens so + # that when _apply_force_accepted_tokens inflates num_accepted_tokens + # the decoder reads valid tokens instead of zeros. Finalize copies + # accept_token into accepted_tokens. + # accept_token shape: [num_gens, max_path_len] (max_path_len = max_draft_len+1) + # draft_tokens shape: [num_gens * (N-1)] where N-1 = max_total_draft_tokens + # We fill at most max_path_len-1 positions, taking the first ones from draft. + # All shapes are static Python ints: CUDA-graph-safe. + n_fill = accept_token.shape[1] - 1 # max_path_len - 1 + accept_token[:, 1 : n_fill + 1].copy_( + spec_metadata.draft_tokens.reshape(num_gens, N - 1)[:, :n_fill].to( + accept_token.dtype + ) + ) + self._finalize_dynamic_tree_verify_outputs( accept_index=accept_index, accept_token_num=accept_token_num, @@ -964,29 +980,17 @@ def _sample_and_accept_dynamic_tree( def _can_use_rejection_sampling(self, spec_metadata) -> bool: """Check if rejection sampling can be used for dynamic tree verification. - Dynamic tree uses its own compact unique-logit buffer - (_draft_depth_logits_cat) instead of spec_metadata.draft_logits, so we - check that buffer's allocation status rather than draft_logits_valid. - The buffer is lazily allocated and populated during the first forward - pass with generation requests. + Target-only rejection sampling only requires target logits, which are + always available during verification. We skip it only when the whole + batch is greedy (argmax is equivalent and avoids the rejection kernel cost). Args: spec_metadata: Speculative decoding metadata Returns: - True if rejection sampling is enabled and the draft logit buffer is allocated + True if rejection sampling is enabled for this batch """ - # Skip rejection sampling when the whole batch is greedy: argmax is - # equivalent and avoids the rejection kernel cost. - # Also skip during CUDA graph capture/replay: the rejection ops use - # dynamic memory allocation (full-sort fallback) which is incompatible - # with stream capture. - return ( - spec_metadata.use_rejection_sampling - and self._draft_depth_logits_cat is not None - and not spec_metadata.is_all_greedy_sample - and not spec_metadata.is_cuda_graph - ) + return spec_metadata.use_rejection_sampling and not spec_metadata.is_all_greedy_sample def _finalize_dynamic_tree_verify_outputs( self, @@ -1010,52 +1014,6 @@ def _finalize_dynamic_tree_verify_outputs( accept_index[:num_gens, 1:max_path_len] - 1 ).to(torch.int32) - def _lazy_alloc_draft_logits_buf(self, vocab_size: int, dtype, device): - """Lazily allocate unique draft-logit capture buffer.""" - if self._draft_depth_logits_cat is None: - rows = 1 + (self.max_draft_len - 1) * self.K - self._draft_depth_logits_cat = torch.empty( - self._max_batch_size, rows, vocab_size, dtype=dtype, device=device - ) - - def _get_unique_draft_logits( - self, - num_gens: int, - ) -> torch.Tensor: - """Return compact unique draft logits for rejection sampling. - - The returned rows are unique per parent context, not duplicated per - final tree position. - - Args: - num_gens: Number of generation requests. - - Returns: - draft_logits: [num_gens * (1 + (max_draft_len - 1) * K), draft_vocab_size] - """ - return self._draft_depth_logits_cat[:num_gens].reshape( - -1, self._draft_depth_logits_cat.shape[-1] - ) - - def _build_draft_prob_indices( - self, - num_gens: int, - ) -> torch.Tensor: - """Build tree-position -> unique draft-prob row mapping. - - For a final tree position: - - depth 0 children map to row 0, which stores p(.|root) - - deeper nodes map to row 1 + depth_bucket * K + parent_k - """ - draft_prob_indices = self._draft_prob_indices_buf[:num_gens] - torch.ops.trtllm.build_draft_prob_indices_out_op( - self._topk_score_indices_buf[:num_gens], - draft_prob_indices, - self.K, - self.max_total_draft_tokens, - ) - return draft_prob_indices - @nvtx_range("eagle3_dyn.sample") def sample( self, logits: torch.Tensor, max_top_k: int, draft_model=None diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index a6232866cc36..36e3b0485421 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -1324,6 +1324,14 @@ def _sample_and_accept_draft_tokens_rejection( offset=self.offset, ) + if self.force_num_accepted_tokens != 0.0: + # Fill gen_accepted positions 1..runtime_draft_len with all draft tokens + # so that when _apply_force_accepted_tokens inflates num_accepted_tokens + # the decoder reads valid draft tokens instead of zeros. + # Slice bounds are Python ints (static at CUDA-graph capture time). + gen_accepted[:, + 1:runtime_draft_len + 1].copy_(full_draft_tokens) + accepted_tokens[num_contexts:] = gen_accepted num_accepted_tokens[num_contexts:] = gen_num_accepted diff --git a/tensorrt_llm/_torch/speculative/one_model_sampler.py b/tensorrt_llm/_torch/speculative/one_model_sampler.py index 6734b5e9f79a..a61f6d1f4b08 100644 --- a/tensorrt_llm/_torch/speculative/one_model_sampler.py +++ b/tensorrt_llm/_torch/speculative/one_model_sampler.py @@ -117,6 +117,7 @@ def sampling_batch_spec_dec_one_model( return random_sampled +@torch.compile(options={"max-autotune": True}) def compute_probs_from_logits( logits: torch.Tensor, temperatures: torch.Tensor,