Skip to content

[TRTLLM-13319][fix] fix output distribution correctness for Eagle3 dynamic-tree rejection sampling#15098

Open
zhaoyangwang-nvidia wants to merge 5 commits into
NVIDIA:mainfrom
zhaoyangwang-nvidia:eagle3-rej-target-only
Open

[TRTLLM-13319][fix] fix output distribution correctness for Eagle3 dynamic-tree rejection sampling#15098
zhaoyangwang-nvidia wants to merge 5 commits into
NVIDIA:mainfrom
zhaoyangwang-nvidia:eagle3-rej-target-only

Conversation

@zhaoyangwang-nvidia

@zhaoyangwang-nvidia zhaoyangwang-nvidia commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Summary by CodeRabbit

Release Notes

  • Performance Improvements

    • Optimized speculative decoding verification path with simplified computation logic, reducing unnecessary intermediate steps and memory overhead.
    • Added compilation optimization for probability calculations to improve inference throughput.
  • Refactor

    • Streamlined internal verification interface for speculative decoding, removing intermediate computation stages while maintaining accuracy.

Summary

This PR fixes the probability distribution used during verification in Eagle3 one-model dynamic tree rejection sampling, and includes a refactor and a new feature.

Changes

  • Fix: Correct probability distribution in dynamic tree rejection sampling — previously using incorrect logits slice for verification
  • Refactor: Pass draftTokens[num_gens, N-1] directly to rejection kernel, removing redundant reshaping
  • Feat: Support TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS for both dynamic tree and linear rejection sampling paths. Previously unsupported because rejection sampling processes tokens sequentially; when the forced length exceeds actual accepted length, unfilled slots contain out-of-vocabulary garbage. Fix: pre-fill those positions with draft tokens when force mode is active.

Test Coverage

TRT-LLM numbers reflect batch-full periods only (all concurrency slots occupied), so they appear higher than system-level throughput.

Model Framework greedy tok/s rejection tok/s AL (greedy) AL (rejection) AR (greedy) AR (rejection)
Qwen3-8B SGLang 2060 1998 3.23 3.18 - -
Qwen3-8B TRT-LLM 2977.2 2806.6 3.327 3.330 0.776 0.777
Qwen3-235B-FP8 SGLang 977 1022 2.33 2.34 - -
Qwen3-235B-A22B TRT-LLM 1344.7 1332.3 2.713 2.709 0.571 0.570
Llama-3.3-70B SGLang 1381 1354 2.33 2.31 - -
Llama-3.3-70B TRT-LLM 1168.3 1118.6 3.747 3.708 0.458 0.451

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • If PR introduces API changes, an appropriate PR label is added - either api-compatible or api-breaking. For api-breaking, include BREAKING in the PR title.

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@zhaoyangwang-nvidia zhaoyangwang-nvidia changed the title [TRTLLM-12669][fix] fix output distribution correctness for Eagle3 dynamic-tree rejection sampling [TRTLLM-13319][fix] fix output distribution correctness for Eagle3 dynamic-tree rejection sampling Jun 9, 2026
@zhaoyangwang-nvidia zhaoyangwang-nvidia force-pushed the eagle3-rej-target-only branch 3 times, most recently from d0e09ad to 6103ec1 Compare June 17, 2026 08:35
@zhaoyangwang-nvidia zhaoyangwang-nvidia marked this pull request as ready for review June 17, 2026 08:37
@zhaoyangwang-nvidia zhaoyangwang-nvidia requested review from a team as code owners June 17, 2026 08:37
@coderabbitai

coderabbitai Bot commented Jun 17, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

The PR replaces the draft-aware dynamic-tree rejection-sampling kernel with a target-only approach across the full stack. The new CUDA kernel accumulates cumulative target probabilities across siblings, accepts the first sibling exceeding a random coin, and samples a correction token from residual mass when all siblings are rejected. All draft-prob computation functions, support-index tensors, and related C++/Python op registrations are removed, and the Eagle3 worker integrates the new path.

Changes

Target-only dynamic-tree rejection sampling refactor

Layer / File(s) Summary
CUDA kernel contract and header update
cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h, dynamicTreeKernels.cu
Removes invokeBuildDraftProbIndices declaration; updates invokeVerifyDynamicTreeRejection signature to accept draftTokens + targetProbs instead of draft-prob/support tensors; deletes computeDraftProbsSkipAllForDynamicTreeRejection, computeDraftProbsForDynamicTreeRejection, computeTargetProbsForDynamicTreeRejection, and invokeBuildDraftProbIndices implementations.
Target-only rejection sampling CUDA kernel implementation
cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu
Introduces verifyDynamicTreeRejectionKernel with kMaxTriedPerLevel cap, new shared-memory state for tried tokens and residual mass, tile-based sampleTargetFullVocab and sampleResidualWithTriedTokens helpers; rewrites sibling traversal to accumulate cumulative target probabilities and apply residual correction sampling.
C++ PyTorch op bindings update
cpp/tensorrt_llm/thop/dynamicTreeOp.cpp
Adds compute_probs_from_logits_op wrapper; replaces verify_dynamic_tree_rejection_out_op with target-only interface; removes build_draft_prob_indices_out_op, compute_draft_probs_for_dynamic_tree_rejection_op, and compute_target_probs_for_dynamic_tree_rejection_op registrations.
Python op fake registration and DynamicTreeOpsConverter
tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py, tensorrt_llm/_torch/speculative/dynamic_tree_ops.py
Updates fake op signature for trtllm::verify_dynamic_tree_rejection_out_op; adds verify_dynamic_tree_rejection_out method to DynamicTreeOpsConverter; removes old verify_dynamic_tree_rejection_from_logits_out.
Eagle3 worker integration
tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py
Removes draft-logit-capture buffers and helpers; adds _sample_and_accept_dynamic_tree_rejection for non-greedy batches; updates _sample_and_accept_dynamic_tree with non-greedy sampling_batch_spec_dec_one_model branch; simplifies _can_use_rejection_sampling.
Force-accept prefill fix and torch.compile decoration
tensorrt_llm/_torch/speculative/interface.py, tensorrt_llm/_torch/speculative/one_model_sampler.py
Pre-fills gen_accepted draft-token slots from full_draft_tokens when force_num_accepted_tokens is set; decorates compute_probs_from_logits with @torch.compile(options={"max-autotune": True}).

Sequence Diagram(s)

sequenceDiagram
  participant Eagle3 as Eagle3OneModelDynamicTreeWorker
  participant Converter as DynamicTreeOpsConverter
  participant TorchOp as trtllm::verify_dynamic_tree_rejection_out_op
  participant Kernel as verifyDynamicTreeRejectionKernel

  Eagle3->>Eagle3: _can_use_rejection_sampling(spec_metadata)?
  Eagle3->>Eagle3: sample context tokens directly
  Eagle3->>Converter: verify_dynamic_tree_rejection_out(draft_tokens, target_logits_tree, ...)
  Converter->>Converter: compute_probs_from_logits(target_logits_tree, temperatures, top_k, top_p)
  Converter->>Converter: reshape target_probs → [num_gens, num_draft_tokens, vocab]
  Converter->>TorchOp: verify_dynamic_tree_rejection_out_op(draft_tokens, target_probs, ...)
  TorchOp->>Kernel: invokeVerifyDynamicTreeRejection(draftTokens, targetProbs, seed, offset, ...)
  loop each depth level
    Kernel->>Kernel: accumulate probAcc over siblings
    alt accepted
      Kernel->>Kernel: record accepted sibling token
    else all rejected
      Kernel->>Kernel: sampleResidualWithTriedTokens(probResidual)
    end
  end
  Kernel-->>TorchOp: fill acceptIndex, acceptTokenNum, acceptToken in-place
  TorchOp-->>Converter: (void)
  Converter-->>Eagle3: (accept_index, accept_tok_num, accept_token)
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

  • NVIDIA/TensorRT-LLM#12588: Added the original invokeVerifyDynamicTreeRejection rejection-sampling kernel and draft/target probability computation plumbing in dynamicTreeKernels.cu/.h that this PR directly replaces with the target-only interface.
  • NVIDIA/TensorRT-LLM#14745: Also modifies Eagle3 dynamic-tree rejection-sampling eligibility via _can_use_rejection_sampling / is_all_greedy_sample and the verification op interface, directly overlapping with the greedy-eligibility simplification in this PR.

Suggested reviewers

  • lancelly
  • laikhtewari
  • liji-nv
  • ziyixiong-nv
  • nv-guomingz
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 61.76% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly summarizes the main fix: correcting output distribution for Eagle3 dynamic-tree rejection sampling, which aligns with the primary objective stated in the PR description and objectives.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed PR description comprehensively covers the fix, refactor, and new feature with test coverage data and checklist completion.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Warning

Review ran into problems

🔥 Problems

Stopped waiting for pipeline failures after 30000ms. One of your pipelines takes longer than our 30000ms fetch window to run, so review may not consider pipeline-failure results for inline comments if any failures occurred after the fetch window. Increase the timeout if you want to wait longer or run a @coderabbit review after the pipeline has finished.


Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (1)
tensorrt_llm/_torch/speculative/one_model_sampler.py (1)

120-121: 💤 Low value

Verify torch.compile interaction with data-dependent branches and custom ops.

The function contains runtime-dependent branches (logits.is_cuda) and calls to custom C++ ops (torch.ops.trtllm.compute_probs_from_logits_op) and flashinfer functions. While torch.compile handles these via graph breaks and opaque op handling, this may reduce compilation benefits for certain code paths.

Given the PR reports measured ~15% throughput improvement, this is likely working well in practice for the hot paths. Consider adding a brief inline comment noting that the flashinfer path is the primary optimization target if future maintainers question the compilation strategy.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tensorrt_llm/_torch/speculative/one_model_sampler.py` around lines 120 - 121,
Add an inline comment near the torch.compile decorator or at the start of the
compute_probs_from_logits function to document that the flashinfer execution
path is the primary optimization target for the torch.compile strategy. This
clarifies to future maintainers why torch.compile is used despite the presence
of data-dependent branches (logits.is_cuda check) and custom C++ ops
(torch.ops.trtllm.compute_probs_from_logits_op) that may cause graph breaks,
acknowledging that the current approach yields measured performance improvements
for the hot paths.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu`:
- Around line 1227-1237: The loop iterating over sTriedTokenIds array can read
out-of-bounds shared memory because sNumTriedTokens is incremented without
bounds checking at line 1331, allowing it to exceed the kMaxTriedPerLevel size
limit of the sTriedTokenIds array. Fix this by clamping the loop bound to the
minimum of sNumTriedTokens and kMaxTriedPerLevel, or alternatively clamp
sNumTriedTokens itself at the increment location (line 1331) to never exceed
kMaxTriedPerLevel. Either approach prevents the undefined shared memory read
when the number of siblings exceeds the array capacity.

In `@cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h`:
- Line 91: The documentation comment at line 91 incorrectly names the parameter
as candidates and specifies its shape as [batchSize, numDraftTokens], but the
actual parameter is named draftTokens with shape [batchSize, numDraftTokens-1]
because the root token is stored separately and not included in draftTokens.
Update the documentation comment to reflect the correct parameter name
draftTokens and correct shape [batchSize, numDraftTokens-1], and clarify that
the root token at position 0 is not included in this parameter, consistent with
the upstream implementation in dynamicTreeOp.cpp.

In `@cpp/tensorrt_llm/thop/dynamicTreeOp.cpp`:
- Around line 186-188: The `numSpecStep` parameter is not validated to be
positive before being passed to `invokeVerifyDynamicTreeRejection`, where
negative values convert to very large unsigned loop bounds causing invalid
memory access. Add a TORCH_CHECK validation to ensure `numSpecStep` is greater
than 0, placing this check before the existing buffer size validations for
acceptIndex, acceptTokenNum, and acceptToken to prevent non-positive values from
flowing into the kernel launch.
- Around line 154-170: The code verifies that input and output tensors are on
CUDA devices but does not verify they are contiguous in memory. Since the
operation uses raw data_ptr buffer access with dense indexing and uses a stream
from draftTokens.device(), all tensors must be contiguous to prevent incorrect
memory access. Add TORCH_CHECK calls to verify contiguity for all tensors
involved: draftTokens, targetProbs, retrieveNextToken, retrieveNextSibling,
treeValid, acceptIndex, acceptTokenNum, and acceptToken using the
.is_contiguous() method on each tensor. Apply the same contiguity checks in all
other sections that use raw data_ptr access with dense indexing patterns as
mentioned in the comment scope.

In `@tensorrt_llm/_torch/speculative/dynamic_tree_ops.py`:
- Around line 249-252: Move the validation check for num_gens (the if statement
checking if num_gens <= 0) to execute before the line computing
num_draft_tokens. Currently, the division operation num_draft_tokens =
target_logits_tree.shape[0] // num_gens happens first, which causes a
ZeroDivisionError when num_gens is 0, preventing the intended ValueError from
being raised. Reorder these statements so that num_gens is validated as positive
before it is used in any arithmetic operation.
- Around line 249-274: After computing num_draft_tokens from
target_logits_tree.shape[0] // num_gens, add an explicit divisibility check to
ensure that target_logits_tree.shape[0] is perfectly divisible by num_gens
without remainder. If divisibility fails, raise a ValueError with a descriptive
message indicating the mismatch between the target logits shape and the number
of generations, since this mismatch will later cause failures in the reshape
operation that reshapes target_probs_flat into target_probs_tree dimensions.

---

Nitpick comments:
In `@tensorrt_llm/_torch/speculative/one_model_sampler.py`:
- Around line 120-121: Add an inline comment near the torch.compile decorator or
at the start of the compute_probs_from_logits function to document that the
flashinfer execution path is the primary optimization target for the
torch.compile strategy. This clarifies to future maintainers why torch.compile
is used despite the presence of data-dependent branches (logits.is_cuda check)
and custom C++ ops (torch.ops.trtllm.compute_probs_from_logits_op) that may
cause graph breaks, acknowledging that the current approach yields measured
performance improvements for the hot paths.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: ed733216-742c-4041-a372-970fdbbc75b3

📥 Commits

Reviewing files that changed from the base of the PR and between 52cbeee and 6103ec1.

📒 Files selected for processing (8)
  • cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu
  • cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h
  • cpp/tensorrt_llm/thop/dynamicTreeOp.cpp
  • tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py
  • tensorrt_llm/_torch/speculative/dynamic_tree_ops.py
  • tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py
  • tensorrt_llm/_torch/speculative/interface.py
  • tensorrt_llm/_torch/speculative/one_model_sampler.py

Comment thread cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.cu
Comment thread cpp/tensorrt_llm/kernels/speculativeDecoding/dynamicTreeKernels.h Outdated
Comment thread cpp/tensorrt_llm/thop/dynamicTreeOp.cpp
Comment thread cpp/tensorrt_llm/thop/dynamicTreeOp.cpp
Comment thread tensorrt_llm/_torch/speculative/dynamic_tree_ops.py Outdated
Comment thread tensorrt_llm/_torch/speculative/dynamic_tree_ops.py Outdated
…ee rejection sampling

Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…rejection kernel

The verifyDynamicTreeRejectionKernel traversal loop starts at childIdx>=1
and never reads candidates[:,0] (the root position). Remove the
_rejection_candidates_buf intermediate buffer and the buffer-fill
boilerplate; pass draft_tokens[num_gens, N-1] directly instead.

Also add .long() cast at the call site since draft_tokens_buffer is
int32 but the kernel requires int64 (previously the int64 buffer
provided an implicit cast).

Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
When force_num_accepted_tokens != 0, pre-fill accept_token positions
1..max_path_len-1 with draft tokens before _finalize so the decoder
reads valid tokens when num_accepted_tokens is inflated; slices are
bounded by max_path_len-1 for shape correctness.

Same force_num draft fill added to the linear rejection path in
interface.py.

Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…_from_logits

Profiling on H200 shows +15% rejection sampling throughput (1135 → 1304 tok/s)
at bs=16 with Qwen3-8B Eagle3 dynamic tree.

Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
@zhaoyangwang-nvidia

Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54836 [ run ] triggered by Bot. Commit: 56cc46a Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54836 [ run ] completed with state SUCCESS. Commit: 56cc46a
/LLM/main/L0_MergeRequest_PR pipeline #43848 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants