[TRTLLM-12958][feat] Enable speculative decoding for dis-agg gen-only requests#14546
Conversation
📝 WalkthroughWalkthroughThis PR enables uniform draft token handling across disaggregated CTX and GEN peers by relaxing peer compatibility checks, padding empty draft token lists, and pre-filling missing context draft tokens in the executor to ensure consistent speculative decoding input shapes. ChangesUniform draft token handling in disaggregated generation
🎯 2 (Simple) | ⏱️ ~8 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/disaggregation/native/peer.py`:
- Around line 100-107: The current unconditional allowance for differing layer
counts in PeerRegistrar can register incompatible peers; change the logic so you
compute the actual transferable overlap (e.g., derive overlapping layer
names/ids via set(self_layers) ∩ set(peer_layers) or by using the existing
pool_mapping logic to compute transferable_layers) and only proceed with the
partial-transfer path when that overlap is non-empty or the difference matches
an explicitly allowed MTP-only delta; otherwise log an error and raise/fail
compatibility instead of silently allowing registration (update the code around
PeerRegistrar and the pool_mapping check and the logger call that currently
emits "layer count differs ... allowing partial layer transfer").
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 05dfc83a-4bcf-4d42-9a00-bb3a8e3859b2
📒 Files selected for processing (3)
tensorrt_llm/_torch/disaggregation/native/peer.pytensorrt_llm/_torch/disaggregation/native/transfer.pytensorrt_llm/_torch/pyexecutor/py_executor.py
|
/bot run --add-multi-gpu-test --disable-fail-fast |
1 similar comment
|
/bot run --add-multi-gpu-test --disable-fail-fast |
|
PR_Github #50415 [ run ] triggered by Bot. Commit: |
|
PR_Github #50415 [ run ] completed with state
|
|
/bot run --add-multi-gpu-test --disable-fail-fast |
|
PR_Github #51474 [ run ] triggered by Bot. Commit: |
|
PR_Github #51474 [ run ] completed with state
|
|
/bot run --add-multi-gpu-test --disable-fail-fast |
|
PR_Github #51609 [ run ] triggered by Bot. Commit: |
|
PR_Github #51609 [ run ] completed with state
|
|
/bot run --add-multi-gpu-test --disable-fail-fast |
|
PR_Github #51714 [ run ] triggered by Bot. Commit: |
|
PR_Github #51714 [ run ] completed with state
|
|
/bot run --add-multi-gpu-test --disable-fail-fast |
|
PR_Github #51783 [ run ] triggered by Bot. Commit: |
|
PR_Github #51783 [ run ] completed with state
|
|
/bot run |
|
PR_Github #51823 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #53138 [ run ] triggered by Bot. Commit: |
|
PR_Github #53138 [ run ] completed with state
|
|
/bot run |
|
PR_Github #53149 [ run ] triggered by Bot. Commit: |
|
PR_Github #53149 [ run ] completed with state
|
|
/bot run |
|
PR_Github #53162 [ run ] triggered by Bot. Commit: |
Signed-off-by: Bo Deng <deemod@nvidia.com>
Signed-off-by: Bo Deng <deemod@nvidia.com>
Signed-off-by: Bo Deng <deemod@nvidia.com>
|
/bot run --add-multi-gpu-test --disable-fail-fast |
|
PR_Github #53197 [ run ] triggered by Bot. Commit: |
|
PR_Github #53162 [ run ] completed with state |
| # Limit to prompt_len blocks, matching C++ cacheFormatter behavior. | ||
| # Extra blocks from num_extra_kv_tokens (speculative decoding) have | ||
| # uninitialized KV data and must not be transferred. | ||
| total_blocks = (req.prompt_len + tpb - 1) // tpb |
There was a problem hiding this comment.
Here, along with token_range, the requirement is to only transfer blocks for prompt_len. However, in practice, prompt + num_extra_kv_tokens blocks are allocated.
If MTP is enabled for both the context phase and the generation phase, then the current modification will only transfer prompt_len
blocks, and the extra block that may will not be transferred. The questions are:
- Will the KV cache for num_extra_kv_tokens be written to during the context phase?
- Will the KV cache written during the context phase be used by the generation phase?
- When both context and generation have MTP enabled, do we need to transfer prompt_len + num_extra_kv_tokens KV blocks?
cc @lfr-0531
There was a problem hiding this comment.
@chuangz0 these changes fix the accuracy issue with py-transceiver + eagle3. Without them,accuracy drops even if enable eagle3 for both ctx and gen.
|
PR_Github #53197 [ run ] completed with state
|
|
/bot run --add-multi-gpu-test --disable-fail-fast |
|
PR_Github #53286 [ run ] triggered by Bot. Commit: |
|
PR_Github #53286 [ run ] completed with state |
ADEngine subclasses the abstract ModelEngine and does not run PyTorchModelEngine.__init__, so it never set `enable_spec_decode`. After NVIDIA#14546 added an unguarded `self.model_engine.enable_spec_decode` read in `_prepare_disagg_gen_transmission_complete` (the disagg generation handoff path that ADEngine traverses via NVIDIA#14057 AutoDeploy Basic Disagg Support), AutoDeploy disaggregated runs crash with: AttributeError: 'ADEngine' object has no attribute 'enable_spec_decode' NVIDIA#14546 and NVIDIA#14057 each passed CI independently but conflict semantically once both are on main. Set `is_spec_decode`/`enable_spec_decode` in ADEngine.__init__, mirroring PyTorchModelEngine (enable_spec_decode == spec_config is not None), so ADEngine satisfies the ModelEngine attribute contract that shared PyExecutor code relies on. Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
ADEngine subclasses the abstract ModelEngine and does not run PyTorchModelEngine.__init__, so it never set `enable_spec_decode`. After NVIDIA#14546 added an unguarded `self.model_engine.enable_spec_decode` read in `_prepare_disagg_gen_transmission_complete` (the disagg generation handoff path that ADEngine traverses via NVIDIA#14057 AutoDeploy Basic Disagg Support), AutoDeploy disaggregated runs crash with: AttributeError: 'ADEngine' object has no attribute 'enable_spec_decode' NVIDIA#14546 and NVIDIA#14057 each passed CI independently but conflict semantically once both are on main. Set `is_spec_decode`/`enable_spec_decode` in ADEngine.__init__, mirroring PyTorchModelEngine (enable_spec_decode == spec_config is not None), so ADEngine satisfies the ModelEngine attribute contract that shared PyExecutor code relies on. Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
Summary by CodeRabbit
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.