[TRTLLM-13319][fix] fix output distribution correctness for Eagle3 dynamic-tree rejection sampling#15098
Conversation
d0e09ad to
6103ec1
Compare
📝 WalkthroughWalkthroughThe PR replaces the draft-aware dynamic-tree rejection-sampling kernel with a target-only approach across the full stack. The new CUDA kernel accumulates cumulative target probabilities across siblings, accepts the first sibling exceeding a random coin, and samples a correction token from residual mass when all siblings are rejected. All draft-prob computation functions, support-index tensors, and related C++/Python op registrations are removed, and the Eagle3 worker integrates the new path. ChangesTarget-only dynamic-tree rejection sampling refactor
Sequence Diagram(s)sequenceDiagram
participant Eagle3 as Eagle3OneModelDynamicTreeWorker
participant Converter as DynamicTreeOpsConverter
participant TorchOp as trtllm::verify_dynamic_tree_rejection_out_op
participant Kernel as verifyDynamicTreeRejectionKernel
Eagle3->>Eagle3: _can_use_rejection_sampling(spec_metadata)?
Eagle3->>Eagle3: sample context tokens directly
Eagle3->>Converter: verify_dynamic_tree_rejection_out(draft_tokens, target_logits_tree, ...)
Converter->>Converter: compute_probs_from_logits(target_logits_tree, temperatures, top_k, top_p)
Converter->>Converter: reshape target_probs → [num_gens, num_draft_tokens, vocab]
Converter->>TorchOp: verify_dynamic_tree_rejection_out_op(draft_tokens, target_probs, ...)
TorchOp->>Kernel: invokeVerifyDynamicTreeRejection(draftTokens, targetProbs, seed, offset, ...)
loop each depth level
Kernel->>Kernel: accumulate probAcc over siblings
alt accepted
Kernel->>Kernel: record accepted sibling token
else all rejected
Kernel->>Kernel: sampleResidualWithTriedTokens(probResidual)
end
end
Kernel-->>TorchOp: fill acceptIndex, acceptTokenNum, acceptToken in-place
TorchOp-->>Converter: (void)
Converter-->>Eagle3: (accept_index, accept_tok_num, accept_token)
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Warning Review ran into problems🔥 ProblemsStopped waiting for pipeline failures after 30000ms. One of your pipelines takes longer than our 30000ms fetch window to run, so review may not consider pipeline-failure results for inline comments if any failures occurred after the fetch window. Increase the timeout if you want to wait longer or run a Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (1)
tensorrt_llm/_torch/speculative/one_model_sampler.py (1)
120-121: 💤 Low valueVerify
torch.compileinteraction with data-dependent branches and custom ops.The function contains runtime-dependent branches (
logits.is_cuda) and calls to custom C++ ops (torch.ops.trtllm.compute_probs_from_logits_op) and flashinfer functions. Whiletorch.compilehandles these via graph breaks and opaque op handling, this may reduce compilation benefits for certain code paths.Given the PR reports measured ~15% throughput improvement, this is likely working well in practice for the hot paths. Consider adding a brief inline comment noting that the flashinfer path is the primary optimization target if future maintainers question the compilation strategy.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/speculative/one_model_sampler.py` around lines 120 - 121, Add an inline comment near the torch.compile decorator or at the start of the compute_probs_from_logits function to document that the flashinfer execution path is the primary optimization target for the torch.compile strategy. This clarifies to future maintainers why torch.compile is used despite the presence of data-dependent branches (logits.is_cuda check) and custom C++ ops (torch.ops.trtllm.compute_probs_from_logits_op) that may cause graph breaks, acknowledging that the current approach yields measured performance improvements for the hot paths.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu`:
- Around line 1227-1237: The loop iterating over sTriedTokenIds array can read
out-of-bounds shared memory because sNumTriedTokens is incremented without
bounds checking at line 1331, allowing it to exceed the kMaxTriedPerLevel size
limit of the sTriedTokenIds array. Fix this by clamping the loop bound to the
minimum of sNumTriedTokens and kMaxTriedPerLevel, or alternatively clamp
sNumTriedTokens itself at the increment location (line 1331) to never exceed
kMaxTriedPerLevel. Either approach prevents the undefined shared memory read
when the number of siblings exceeds the array capacity.
In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h`:
- Line 91: The documentation comment at line 91 incorrectly names the parameter
as candidates and specifies its shape as [batchSize, numDraftTokens], but the
actual parameter is named draftTokens with shape [batchSize, numDraftTokens-1]
because the root token is stored separately and not included in draftTokens.
Update the documentation comment to reflect the correct parameter name
draftTokens and correct shape [batchSize, numDraftTokens-1], and clarify that
the root token at position 0 is not included in this parameter, consistent with
the upstream implementation in dynamicTreeOp.cpp.
In `@cpp/tensorrt_llm/thop/dynamicTreeOp.cpp`:
- Around line 186-188: The `numSpecStep` parameter is not validated to be
positive before being passed to `invokeVerifyDynamicTreeRejection`, where
negative values convert to very large unsigned loop bounds causing invalid
memory access. Add a TORCH_CHECK validation to ensure `numSpecStep` is greater
than 0, placing this check before the existing buffer size validations for
acceptIndex, acceptTokenNum, and acceptToken to prevent non-positive values from
flowing into the kernel launch.
- Around line 154-170: The code verifies that input and output tensors are on
CUDA devices but does not verify they are contiguous in memory. Since the
operation uses raw data_ptr buffer access with dense indexing and uses a stream
from draftTokens.device(), all tensors must be contiguous to prevent incorrect
memory access. Add TORCH_CHECK calls to verify contiguity for all tensors
involved: draftTokens, targetProbs, retrieveNextToken, retrieveNextSibling,
treeValid, acceptIndex, acceptTokenNum, and acceptToken using the
.is_contiguous() method on each tensor. Apply the same contiguity checks in all
other sections that use raw data_ptr access with dense indexing patterns as
mentioned in the comment scope.
In `@tensorrt_llm/_torch/speculative/dynamic_tree_ops.py`:
- Around line 249-252: Move the validation check for num_gens (the if statement
checking if num_gens <= 0) to execute before the line computing
num_draft_tokens. Currently, the division operation num_draft_tokens =
target_logits_tree.shape[0] // num_gens happens first, which causes a
ZeroDivisionError when num_gens is 0, preventing the intended ValueError from
being raised. Reorder these statements so that num_gens is validated as positive
before it is used in any arithmetic operation.
- Around line 249-274: After computing num_draft_tokens from
target_logits_tree.shape[0] // num_gens, add an explicit divisibility check to
ensure that target_logits_tree.shape[0] is perfectly divisible by num_gens
without remainder. If divisibility fails, raise a ValueError with a descriptive
message indicating the mismatch between the target logits shape and the number
of generations, since this mismatch will later cause failures in the reshape
operation that reshapes target_probs_flat into target_probs_tree dimensions.
---
Nitpick comments:
In `@tensorrt_llm/_torch/speculative/one_model_sampler.py`:
- Around line 120-121: Add an inline comment near the torch.compile decorator or
at the start of the compute_probs_from_logits function to document that the
flashinfer execution path is the primary optimization target for the
torch.compile strategy. This clarifies to future maintainers why torch.compile
is used despite the presence of data-dependent branches (logits.is_cuda check)
and custom C++ ops (torch.ops.trtllm.compute_probs_from_logits_op) that may
cause graph breaks, acknowledging that the current approach yields measured
performance improvements for the hot paths.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: ed733216-742c-4041-a372-970fdbbc75b3
📒 Files selected for processing (8)
cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cucpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.hcpp/tensorrt_llm/thop/dynamicTreeOp.cpptensorrt_llm/_torch/custom_ops/cpp_custom_ops.pytensorrt_llm/_torch/speculative/dynamic_tree_ops.pytensorrt_llm/_torch/speculative/eagle3_dynamic_tree.pytensorrt_llm/_torch/speculative/interface.pytensorrt_llm/_torch/speculative/one_model_sampler.py
6103ec1 to
5c71b95
Compare
…ee rejection sampling Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…rejection kernel The verifyDynamicTreeRejectionKernel traversal loop starts at childIdx>=1 and never reads candidates[:,0] (the root position). Remove the _rejection_candidates_buf intermediate buffer and the buffer-fill boilerplate; pass draft_tokens[num_gens, N-1] directly instead. Also add .long() cast at the call site since draft_tokens_buffer is int32 but the kernel requires int64 (previously the int64 buffer provided an implicit cast). Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
When force_num_accepted_tokens != 0, pre-fill accept_token positions 1..max_path_len-1 with draft tokens before _finalize so the decoder reads valid tokens when num_accepted_tokens is inflated; slices are bounded by max_path_len-1 for shape correctness. Same force_num draft fill added to the linear rejection path in interface.py. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…_from_logits Profiling on H200 shows +15% rejection sampling throughput (1135 → 1304 tok/s) at bs=16 with Qwen3-8B Eagle3 dynamic tree. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
5c71b95 to
56cc46a
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #54836 [ run ] triggered by Bot. Commit: |
|
PR_Github #54836 [ run ] completed with state
|
Summary by CodeRabbit
Release Notes
Performance Improvements
Refactor
Summary
This PR fixes the probability distribution used during verification in Eagle3 one-model dynamic tree rejection sampling, and includes a refactor and a new feature.
Changes
draftTokens[num_gens, N-1]directly to rejection kernel, removing redundant reshapingTLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENSfor both dynamic tree and linear rejection sampling paths. Previously unsupported because rejection sampling processes tokens sequentially; when the forced length exceeds actual accepted length, unfilled slots contain out-of-vocabulary garbage. Fix: pre-fill those positions with draft tokens when force mode is active.Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.