[1/3][Refactor]: File reorg; deprecate ParallelDraft#1296
[1/3][Refactor]: File reorg; deprecate ParallelDraft#1296
Conversation
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughThis PR splits the prior unified Transformers speculative plugin into separate Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1296 +/- ##
==========================================
- Coverage 75.87% 71.51% -4.37%
==========================================
Files 462 465 +3
Lines 49747 49731 -16
==========================================
- Hits 37745 35564 -2181
- Misses 12002 14167 +2165
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
|
[AI Review] Generated by Claude — reference only, not a substitute for human review. Verdict: pure file-reorg + rename + the documented Verified:
Please double-check:
|
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/hf_medusa.py (1)
147-156: Consider avoiding in-place mutation oflabelsparameter.The loop mutates
labelson each iteration (labels = labels[..., 1:].contiguous()), which shadows the input parameter and could be confusing. Using a separate variable would improve clarity.♻️ Suggested improvement
if labels is not None: loss = 0 loss_fct = CrossEntropyLoss() # Base model loss if not freeze_base_model: loss_logits = logits.view(-1, logits.shape[-1]) loss_labels = labels.view(-1) base_model_loss = loss_fct(loss_logits, loss_labels) loss += base_model_loss # Medusa loss + shifted_labels = labels for i in range(self.medusa_num_heads): - labels = labels[..., 1:].contiguous() + shifted_labels = shifted_labels[..., 1:].contiguous() loss_logits = medusa_logits[i][:, : -(1 + i)].contiguous() loss_logits = loss_logits.view(-1, loss_logits.shape[-1]) - loss_labels = labels.view(-1) + loss_labels = shifted_labels.view(-1) loss += ( loss_fct(loss_logits, loss_labels) * medusa_decay_coefficient**i * medusa_heads_coefficient )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_medusa.py` around lines 147 - 156, The loop in function(s) using medusa_num_heads currently mutates the input parameter labels with labels = labels[..., 1:].contiguous() each iteration; replace this in-place mutation by computing a fresh shifted slice per head (e.g. shifted_labels = labels[..., 1 + i:].contiguous() or compute start = i+1 and use labels[..., start:]) and then use shifted_labels to form loss_labels; keep the rest of the computation (loss_logits from medusa_logits[i], view/reshape, loss_fct, medusa_decay_coefficient**i, medusa_heads_coefficient) unchanged so the original labels argument is not shadowed or modified.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/hf_medusa.py`:
- Line 122: Replace the incorrect keyword argument name rcache_position with the
correct cache_position where the model forward call is constructed so the
HuggingFace PreTrainedModel.forward() receives the cache position correctly;
locate the call using the rcache_position symbol in hf_medusa plugin code and
rename that argument to cache_position (and ensure any downstream uses expect
cache_position, not rcache_position).
In `@modelopt/torch/speculative/plugins/modeling_dflash.py`:
- Around line 25-31: Top-level hard imports of transformers/Qwen3 symbols
(ALL_ATTENTION_FUNCTIONS, Qwen3MLP/_MLP_CLS, Qwen3RMSNorm/_NORM_CLS,
Qwen3RotaryEmbedding/_ROTARY_CLS, _rotate_half) must be removed from
modeling_dflash.py and instead acquired via the plugin lazy loader; replace
those module-level imports with calls to the plugin system (import_plugin()) and
perform the transformers/Qwen3 imports inside the plugin initialization or
inside the functions that need them so they only run when the hf_dflash plugin
is loaded, ensuring modeling_dflash.py can be imported in environments without
the optional transformers integration.
- Around line 36-47: The function build_target_layer_ids is producing duplicate
layer indices for shallow target models (e.g., build_target_layer_ids(4,2) ->
[1,1]); update it to detect when the computed interior window is too small to
yield unique indices and either (a) fall back to using the full layer span (0
through num_target_layers-1) to select uniformly spaced indices, or (b) raise a
ValueError for unsupported configs; implement the check inside
build_target_layer_ids (use the computed start/end/span and compare span to
num_draft_layers - 1) and then either compute indices across the full range or
raise, so hf_dflash.py will not receive duplicate hidden-state indices.
In `@modelopt/torch/speculative/plugins/modeling_eagle.py`:
- Line 49: Fix the typo in the comment that reads "Their values depend on
specific tokenzier and calibrate dataset, and should be set in training script."
by changing "tokenzier" to "tokenizer" so the comment reads "Their values depend
on specific tokenizer and calibrate dataset, and should be set in training
script."; update the text in modeling_eagle.py wherever that exact comment
appears.
---
Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_medusa.py`:
- Around line 147-156: The loop in function(s) using medusa_num_heads currently
mutates the input parameter labels with labels = labels[..., 1:].contiguous()
each iteration; replace this in-place mutation by computing a fresh shifted
slice per head (e.g. shifted_labels = labels[..., 1 + i:].contiguous() or
compute start = i+1 and use labels[..., start:]) and then use shifted_labels to
form loss_labels; keep the rest of the computation (loss_logits from
medusa_logits[i], view/reshape, loss_fct, medusa_decay_coefficient**i,
medusa_heads_coefficient) unchanged so the original labels argument is not
shadowed or modified.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 67d25ab7-428e-47d9-b671-09260bc72cf7
📒 Files selected for processing (13)
.pre-commit-config.yamlexamples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/scripts/ar_validate.pymodelopt/torch/export/plugins/hf_spec_export.pymodelopt/torch/speculative/eagle/default_config.pymodelopt/torch/speculative/plugins/__init__.pymodelopt/torch/speculative/plugins/hf_dflash.pymodelopt/torch/speculative/plugins/hf_eagle.pymodelopt/torch/speculative/plugins/hf_medusa.pymodelopt/torch/speculative/plugins/modeling_dflash.pymodelopt/torch/speculative/plugins/modeling_eagle.pymodelopt/torch/speculative/utils.pytests/unit/torch/export/test_hf_spec_rope_export.py
💤 Files with no reviewable changes (1)
- modelopt/torch/speculative/eagle/default_config.py
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
ChenhanYu
left a comment
There was a problem hiding this comment.
Review
Clean refactor — file renames and code extraction are straightforward. Two items need fixing before merge:
1. parallel_draft_step removal breaks Megatron Eagle
default_config.py removes parallel_draft_step and parallel_draft_heads_num_layers, but megatron_eagle.py:526 still does:
if self.config.parallel_draft_step > 1:This is the direct cause of the test_unified_export_megatron CI failure (TypeError: '>' not supported between instances of 'NoneType' and 'int'). The HF side was cleaned up but the Megatron side was missed.
Fix: Guard with getattr(self.config, "parallel_draft_step", 1) > 1 or clean up the Megatron Eagle references in this PR.
2. CUDAGraph tensor overwrite in modeling_eagle.py
The test_llama_eagle3[1-False] failure:
RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run.
at modeling_eagle.py:166. This suggests the EagleModule extraction changed tensor lifecycle — possibly _input_embeds stashing retains a tensor from a previous CUDAGraph capture. Worth diffing the old transformers.py EagleModule.forward() against the new modeling_eagle.py version to verify no accidental changes during extraction.
3. No backward-compat shim for transformers.py → hf_eagle.py
The old import path from modelopt.torch.speculative.plugins.transformers import ... will break silently. This was specifically called out in the original #1271 review. A 3-line transformers.py stub with DeprecationWarning + re-export would prevent downstream breakage.
What does this PR do?
Type of change: refactoring
Part 1 of a 3-PR series splitting #1271:
ParallelDraftHFSpecDecMixinChanges:
transformers.py→hf_eagle.py; extractHFMedusaModel→hf_medusa.py; extractEagleModule/EagleBaseModelOutput→modeling_eagle.py; extractDFlashModule/DFlashAttention/DFlashDecoderLayer/build_target_layer_ids/apply_rotary_pos_emb→modeling_dflash.py.ParallelDraft: removeparallel_draft_step,parallel_draft_heads_num_layers, and theParallelDraftmodule from HF Eagle; remove theEagleMedusaExporterbranch fromHFEagleModel.get_exporter()(theEagleMedusaExporterclass itself still lives inhf_spec_export.pyfor Megatron parity)._draft_model_config→eagle_configin export plugin.examples/speculative_decoding/andmodelopt/torch/speculative/utils.pyto follow the module rename.Testing
Validated with existing Eagle and DFlash training scripts (re-run after
9ae5302729 revert behavior change).Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).modelopt.torch.speculative.plugins.transformers→.hf_eagle; removesparallel_draft_step/parallel_draft_heads_num_layersfrom Eagle config; renames_draft_model_config→eagle_configin export plugin.CONTRIBUTING.md: N/Atest_hf_spec_rope_export.pyassertions were also corrected to reflect the actual production path (the old assertions were masked byMagicMocknot invoking the_draft_model_config@property).Additional Information
Breaking changes:
modelopt.torch.speculative.plugins.transformers→.hf_eagleparallel_draft_step/parallel_draft_heads_num_layersremoved from Eagle config_draft_model_config→eagle_configin export pluginSummary by CodeRabbit
Refactoring
New Features
Chores
Tests