Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,089 changes: 130 additions & 959 deletions cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu

Large diffs are not rendered by default.

75 changes: 10 additions & 65 deletions cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,6 @@ void invokeBuildDynamicTree(int64_t const* parentList, int64_t const* selectedIn
runtime::SizeType32 topK, runtime::SizeType32 depth, runtime::SizeType32 numDraftTokens, TreeMaskMode treeMaskMode,
cudaStream_t stream, runtime::SizeType32 numInt32PerRow);

//! \brief Build tree-position -> unique draft-prob row mapping for rejection sampling.
//! \param topkScoreIndices [batchSize, numDraftTokens], on GPU. int64.
//! History-buffer indices selected for each final tree position.
//! \param draftProbIndices output [batchSize, numDraftTokens + 1], on GPU. int32.
//! Column 0 is reserved for the root and set to 0.
//! \param batchSize runtime::SizeType32. Batch size.
//! \param topK runtime::SizeType32. Tree top-K branching factor.
//! \param numDraftTokens runtime::SizeType32. Total number of non-root draft positions.
//! \param stream cuda stream.
void invokeBuildDraftProbIndices(int64_t const* topkScoreIndices, int32_t* draftProbIndices,
runtime::SizeType32 batchSize, runtime::SizeType32 topK, runtime::SizeType32 numDraftTokens, cudaStream_t stream);

//! \brief Verify dynamic tree using greedy strategy with packed retrieve buffers.
//! \param acceptIndex output buffer [batchSize, numSpecStep], on GPU. int32.
//! \param acceptTokenNum output buffer [batchSize], on GPU. int32.
Expand All @@ -91,75 +79,32 @@ void invokeVerifyDynamicTreeGreedyPacked(int32_t* acceptIndex, int32_t* acceptTo
cudaStream_t stream);

//! \brief Verify dynamic tree using rejection sampling.
//! For each request, traverses the tree depth-by-depth. At each depth, siblings are tried
//! in order; the first sibling accepted by rejection sampling (p_target/p_draft) continues
//! the path. If all siblings are rejected, a correction token is sampled from (target-draft)_+.
//! Accepts draft tokens by accumulating cumulative target probability across siblings.
//! The first sibling whose cumulative target prob exceeds the random coin is accepted.
//! When all siblings are rejected, a correction token is sampled from the residual
//! target distribution (target prob for tokens not tried as siblings).
//! No draft probabilities are needed. Always uses full-vocab path.
//!
//! \param acceptIndex output [batchSize, numSpecStep] int64 — tree positions of accepted tokens.
//! \param acceptTokenNum output [batchSize] int64 — # accepted draft tokens (excl. root).
//! \param acceptToken output [batchSize, numSpecStep] int64 — accepted/correction token ids.
//! \param candidates [batchSize, numDraftTokens] int64; col 0 = root (target sample).
//! \param draftProbs [batchSize, numDraftProbRows, vocabSize] float32.
//! Unique draft probability rows per request; tree positions map into this
//! tensor via draftProbIndices.
//! \param draftTokens [batchSize, numDraftTokens-1] int64; draft token ids (excluding root).
//! \param targetProbs [batchSize, numDraftTokens, vocabSize] float32; index 0 = root.
//! \param targetSupportIndices [batchSize, numDraftTokens, maxTargetSupportSize] int32; compact token ids that
//! survive top-k/top-p filtering for each row, padded with -1. May be empty when no filtering is active.
//! \param targetSupportLengths [batchSize, numDraftTokens] int32; valid support length per row. May be empty when
//! no filtering is active.
//! \param draftProbIndices [batchSize, numDraftTokens] int32; maps each tree position to the
//! corresponding row in draftProbs. Root is unused.
//! \param retrieveNextToken [batchSize, numDraftTokens] int32 first-child pointer, -1=none.
//! \param retrieveNextSibling [batchSize, numDraftTokens] int32 next-sibling pointer, -1=none.
//! \param treeValid [batchSize] bool; false means no valid tree exists for this request.
//! \param batchSize runtime::SizeType32.
//! \param numDraftProbRows runtime::SizeType32. Number of unique draft-prob rows per request.
//! \param maxTargetSupportSize runtime::SizeType32. Third dim of targetSupportIndices. Can be zero.
//! \param numDraftTokens runtime::SizeType32. Total tree nodes per request (including root).
//! \param numSpecStep runtime::SizeType32. Second dim of acceptIndex/acceptToken.
//! \param vocabSize runtime::SizeType32. Vocabulary size.
//! \param seed [1] int64 on GPU. Philox RNG seed.
//! \param offset [1] int64 on GPU. Philox RNG offset.
//! \param stream cudaStream_t.
void invokeVerifyDynamicTreeRejection(int64_t* acceptIndex, int64_t* acceptTokenNum, int64_t* acceptToken,
int64_t const* candidates, float const* draftProbs, float const* targetProbs, int32_t const* targetSupportIndices,
int32_t const* targetSupportLengths, int32_t const* draftProbIndices, int32_t const* retrieveNextToken,
int64_t const* draftTokens, float const* targetProbs, int32_t const* retrieveNextToken,
int32_t const* retrieveNextSibling, bool const* treeValid, runtime::SizeType32 batchSize,
runtime::SizeType32 numDraftProbRows, runtime::SizeType32 maxTargetSupportSize, runtime::SizeType32 numDraftTokens,
runtime::SizeType32 numSpecStep, runtime::SizeType32 vocabSize, int64_t const* seed, int64_t const* offset,
cudaStream_t stream);

//! \brief Compute draft probabilities for dynamic-tree rejection sampling from logits.
//! \param draftLogits [batchSize * numDraftProbRows, draftVocabSize], on GPU.
//! \param temperatures [batchSize], on GPU.
//! \param numDraftProbRows runtime::SizeType32. Unique draft-prob rows per request.
//! \param topK Optional [batchSize], on GPU.
//! \param topP Optional [batchSize], on GPU.
//! \param targetVocabSize runtime::SizeType32. Output vocabulary size after optional d2t expansion.
//! \param d2t Optional [draftVocabSize], on GPU.
//! \return [batchSize, numDraftProbRows, targetVocabSize] float32 probabilities.
torch::Tensor computeDraftProbsForDynamicTreeRejection(torch::Tensor const& draftLogits,
torch::Tensor const& temperatures, runtime::SizeType32 numDraftProbRows, torch::optional<torch::Tensor> const& topK,
torch::optional<torch::Tensor> const& topP, runtime::SizeType32 targetVocabSize, bool skipTemperature,
torch::optional<torch::Tensor> const& d2t, runtime::SizeType32 kMax = 0, bool skipAllSamplingParams = false);

//! \brief Compute target probabilities for dynamic-tree rejection sampling from logits.
//! \param targetLogits [batchSize * numDraftTokens, targetVocabSize], on GPU.
//! \param temperatures [batchSize], on GPU.
//! \param numDraftTokens runtime::SizeType32. Total tree nodes per request (including root).
//! \param topK Optional [batchSize], on GPU.
//! \param topP Optional [batchSize], on GPU.
//! \param kMax runtime::SizeType32. Max top-K value across the batch; enables the fast topk
//! path when > 0. Must be computed on CPU (e.g. topK.max().item()). Default 0 = fallback
//! to full sort.
//! \return Tuple of:
//! 1. [batchSize, numDraftTokens, targetVocabSize] float32 probabilities.
//! 2. [batchSize, numDraftTokens, maxTargetSupportSize] int32 compact token ids that survive
//! top-k/top-p filtering for each row, padded with -1. Empty when no filtering is active.
//! 3. [batchSize, numDraftTokens] int32 support lengths. Empty when no filtering is active.
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> computeTargetProbsForDynamicTreeRejection(
torch::Tensor const& targetLogits, torch::Tensor const& temperatures, runtime::SizeType32 numDraftTokens,
torch::optional<torch::Tensor> const& topK, torch::optional<torch::Tensor> const& topP, bool skipTemperature,
runtime::SizeType32 kMax = 0, bool skipAllSamplingParams = false);
runtime::SizeType32 numDraftTokens, runtime::SizeType32 numSpecStep, runtime::SizeType32 vocabSize,
int64_t const* seed, int64_t const* offset, cudaStream_t stream);

} // namespace kernels::speculative_decoding

Expand Down
Loading
Loading