Skip to content

inplement mix hidden_states for eagle3; deprecate eagle1#946

Merged
yeyu-nvidia merged 9 commits intomainfrom
yeyu/mix_hidden_states
Mar 9, 2026
Merged

inplement mix hidden_states for eagle3; deprecate eagle1#946
yeyu-nvidia merged 9 commits intomainfrom
yeyu/mix_hidden_states

Conversation

@yeyu-nvidia
Copy link
Copy Markdown
Contributor

@yeyu-nvidia yeyu-nvidia commented Feb 27, 2026

What does this PR do?

new feature

Overview:
Enable mix hidden_states in eagle3 training. Deprecate eagle1

Usage

Add --mix_hidden_states True to launch_train.sh

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Added --mix_hidden_states option to enable optional hidden-state mixing during training.
    • Added eagle_ttt_steps setting to control speculative multi-step iterations.
  • Chores

    • Consolidated speculative decoding to EAGLE3 only; legacy Medusa/EAGLE1 paths removed.
    • Unified configuration handling so models and plugins accept a single config object.
  • Tests

    • Updated and expanded tests for hidden-state mixing and EAGLE3-only scenarios.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Feb 27, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Feb 27, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds eagle_ttt_steps and eagle_mix_hidden_states to Eagle config; refactors many modify() APIs to accept a single config; implements conditional hidden-state mixing, mask/cache adjustments, and gated aux-FC usage; removes Medusa/Eagle1 code paths; and exposes a new --mix_hidden_states CLI flag propagated through examples and tests.

Changes

Cohort / File(s) Summary
Config & Core Model
modelopt/torch/speculative/config.py, modelopt/torch/speculative/eagle/eagle_model.py
Add eagle_ttt_steps: int and eagle_mix_hidden_states: bool to EagleConfig; remove EAGLE1_DEFAULT_CFG; change EagleModel.modify() to accept a single config and read new fields.
Conversion Layer
modelopt/torch/speculative/eagle/conversion.py
Pass a single config object into EagleModel.modify(config) instead of expanding individual eagle_* keyword arguments.
Megatron Plugin
modelopt/torch/speculative/plugins/megatron_eagle.py
Gate fc creation on config.use_aux_hidden_state; change modify() to accept config; update forward/TTT loops to use eagle_ttt_steps and eagle_mix_hidden_states; adjust attention mask, kv-cache/inference_context, and per-step hidden-state mixing logic.
Transformers Plugin
modelopt/torch/speculative/plugins/transformers.py
Remove Eagle-1/Medusa branches, standardize EAGLE-3 flow; change plugin modify() signatures to accept config; wire eagle_ttt_steps/eagle_mix_hidden_states into generation/TTT loops and adjust input-embedding/aux-state handling.
Examples & Launch
examples/speculative_decoding/main.py, examples/speculative_decoding/launch_train.sh
Add --mix_hidden_states CLI argument and propagate its value through the launch script into the training invocation; add mix_hidden_states field to example args.
Tests — examples, gpu, unit
tests/examples/speculative_decoding/test_eagle.py, tests/gpu_megatron/.../test_speculative_megatron_modules.py, tests/unit/torch/speculative/plugins/test_hf_speculative.py
Expand example test parameterization to include mix_hidden_states; remove Medusa/Eagle1 test branches and imports; simplify tests to target Eagle3 only and adapt signatures/params accordingly.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant CLI as "examples/speculative_decoding/main.py"
    participant Converter as "modelopt/.../eagle/conversion.py"
    participant EagleModel as "EagleModel.modify(config)"
    participant Plugin as "megatron/transformers plugin"
    participant Trainer as "Training loop / GPU"

    User->>CLI: launch with --mix_hidden_states
    CLI->>Converter: build conversion config (includes eagle_mix_hidden_states, eagle_ttt_steps)
    Converter->>EagleModel: modify(config)
    EagleModel->>Plugin: expose eagle_config, eagle_ttt_steps, eagle_mix_hidden_states
    Plugin->>Trainer: run TTT loop (uses eagle_ttt_steps)
    alt eagle_mix_hidden_states == true
        Trainer->>Plugin: apply hidden-state mixing (random replacement subset)
        Plugin->>Trainer: propagate mixed hidden states (no kv-cache inference_context)
    else eagle_mix_hidden_states == false
        Trainer->>Plugin: propagate generated hidden states normally (kv-cache preserved)
    end
    Trainer->>Plugin: return outputs / updated cache
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 65.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title contains a typo ('inplement' instead of 'implement') and uses abbreviations that may be unclear to developers unfamiliar with the codebase. However, it does accurately describe the main changes: adding mix_hidden_states support to Eagle3 and deprecating Eagle1.
Security Anti-Patterns ✅ Passed PR does not introduce security anti-patterns from SECURITY.md. Changes limited to configuration parameters, method refactoring, and removing deprecated code.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yeyu/mix_hidden_states

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link
Copy Markdown

codecov Bot commented Feb 27, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 72.01%. Comparing base (e8f9687) to head (e2b7e0e).
⚠️ Report is 13 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #946      +/-   ##
==========================================
- Coverage   72.12%   72.01%   -0.11%     
==========================================
  Files         209      209              
  Lines       23628    23631       +3     
==========================================
- Hits        17042    17019      -23     
- Misses       6586     6612      +26     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yeyu-nvidia yeyu-nvidia marked this pull request as ready for review February 27, 2026 22:11
@yeyu-nvidia yeyu-nvidia requested a review from a team as a code owner February 27, 2026 22:11
@yeyu-nvidia yeyu-nvidia requested a review from h-guo18 February 27, 2026 22:11
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
examples/speculative_decoding/main.py (1)

228-239: ⚠️ Potential issue | 🔴 Critical

data_module can be undefined for mode="medusa" and crash trainer construction.

Line 228 only initializes data_module for "eagle3", but Line 238 always expands **data_module. With mode="medusa" this raises UnboundLocalError.

💡 Proposed fix (fail fast if only eagle3 is supported now)
-    if training_args.mode == "eagle3":
-        data_module = make_eagle_supervised_data_module(
-            tokenizer, data_args, train_len=training_args.training_seq_len
-        )
+    if training_args.mode != "eagle3":
+        raise ValueError(f"mode={training_args.mode} is not supported in this training entrypoint.")
+    data_module = make_eagle_supervised_data_module(
+        tokenizer, data_args, train_len=training_args.training_seq_len
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 228 - 239, The code only
sets data_module when training_args.mode == "eagle3" but always expands
**data_module into EagleTrainerWithAccLog, causing UnboundLocalError for other
modes (e.g., "medusa"); fix by ensuring data_module is always defined before
trainer creation: either call make_eagle_supervised_data_module when mode ==
"eagle3" and otherwise set data_module to a safe default (e.g., an empty dict)
or explicitly raise a clear error (ValueError) if only "eagle3" is supported;
perform this check/assignment just before constructing EagleTrainerWithAccLog so
the trainer call can safely use **data_module.
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

1208-1233: ⚠️ Potential issue | 🔴 Critical

Critical: In-place mixed-state writes on sharded tensors.

When sequence_parallel is enabled and eagle_mix_hidden_states is True, eagle_module_input_hidden_states remains sharded in the sequence dimension (shape [s/TP, b, h]) when the mixing block executes at lines 1208–1227. However, the mixing logic computes rand_indices based on eagle_module_output_hidden_states.shape[0] (which is the full gathered sequence length) and performs in-place indexed updates on the sharded tensor. This causes index misalignment—indices valid for the full sequence are applied to a shard, resulting in out-of-bounds access or incorrect position updates. Additionally, in-place mutation breaks distributed tensor semantics and can corrupt autograd state across TTT steps.

The scatter at line 1231 occurs after the corrupt updates, amplifying the risk.

💡 Proposed fix
             if self.eagle_mix_hidden_states:
+                if self.config.sequence_parallel:
+                    eagle_module_input_hidden_states = gather_from_sequence_parallel_region(
+                        eagle_module_input_hidden_states
+                    )
+                updated_hidden_states = eagle_module_input_hidden_states.clone()
                 seq_len_s, batch_size, _ = eagle_module_output_hidden_states.shape
                 num_to_replace = max(1, seq_len_s // (2**ttt_step + 1))

                 # Randomly select positions for each batch to replace
                 rand_indices = torch.stack(
                     [
                         torch.randperm(seq_len_s, device=eagle_module_output_hidden_states.device)[
                             :num_to_replace
                         ]
                         for _ in range(batch_size)
                     ],
                     dim=0,
                 )

-                for batch_idx in range(batch_size):
-                    eagle_module_input_hidden_states[rand_indices[batch_idx], batch_idx, :] = (
-                        eagle_module_output_hidden_states[rand_indices[batch_idx], batch_idx, :]
-                    )
+                batch_indices = torch.arange(
+                    batch_size, device=updated_hidden_states.device
+                ).unsqueeze(1)
+                updated_hidden_states[rand_indices, batch_indices, :] = (
+                    eagle_module_output_hidden_states[rand_indices, batch_indices, :]
+                )
+                eagle_module_input_hidden_states = updated_hidden_states
             else:
                 eagle_module_input_hidden_states = eagle_module_output_hidden_states
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/megatron_eagle.py` around lines 1208 -
1233, When self.eagle_mix_hidden_states is True and
self.config.sequence_parallel is enabled, you must avoid in-place indexed writes
on a sharded sequence tensor; instead gather the full sequence, perform the
random-index mixing on the gathered eagle_module_input_hidden_states and
eagle_module_output_hidden_states (using gather_from_sequence_parallel_region),
then scatter the mixed result back with scatter_to_sequence_parallel_region;
reference and update the logic around eagle_module_input_hidden_states,
eagle_module_output_hidden_states, eagle_mix_hidden_states,
gather_from_sequence_parallel_region and scatter_to_sequence_parallel_region so
indices are computed against the full sequence and mutations occur on the
gathered tensor before re-sharding.
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/transformers.py (1)

938-956: Vectorize hidden-state replacement in the TTT loop to reduce Python overhead.

Line 953–956 does per-batch indexed assignment in Python for every TTT step; this is on the training hot path and can be batched with advanced indexing.

💡 Proposed refactor
-                for batch_idx in range(batch_size):
-                    eagle_input_hiddens[batch_idx, rand_indices[batch_idx], :] = (
-                        eagle_output_hiddens[batch_idx, rand_indices[batch_idx], :]
-                    )
+                batch_idx = torch.arange(batch_size, device=eagle_input_hiddens.device).unsqueeze(1)
+                eagle_input_hiddens[batch_idx, rand_indices, :] = eagle_output_hiddens[
+                    batch_idx, rand_indices, :
+                ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` around lines 938 - 956,
The per-batch Python loop that replaces hidden states should be replaced with a
single advanced-indexed assignment: keep computing rand_indices (ensure it's a
LongTensor) of shape (batch_size, num_to_replace) and create batch_idx =
torch.arange(batch_size, device=eagle_input_hiddens.device).unsqueeze(1); then
perform eagle_input_hiddens[batch_idx, rand_indices, :] =
eagle_output_hiddens[batch_idx, rand_indices, :] to vectorize the replacement in
the TTT loop when eagle_mix_hidden_states is true (references:
eagle_mix_hidden_states, eagle_input_hiddens, eagle_output_hiddens,
rand_indices, ttt_step).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/config.py`:
- Around line 103-105: eagle_ttt_steps currently allows non-positive integers
which can skip the TTT loop; update its ModeloptField declaration to enforce a
lower bound of 1 (e.g., set a validation constraint such as ge=1 or min=1
depending on ModeloptField API) so only positive integers are accepted; if
ModeloptField doesn't support a direct constraint, add a validation step for the
eagle_ttt_steps attribute (a pydantic/field validator in the enclosing config
class) that raises a clear error when eagle_ttt_steps < 1.

In `@tests/examples/speculative_decoding/test_eagle.py`:
- Around line 64-69: The parameterized test test_llama_eagle3 allows runs with
the same cp_size but different mix_hidden_states to write to the same output
directory (eagle_output_dir/output_dir), causing checkpoint collisions; update
the test to create a unique output_dir per parameter combination (use cp_size
and mix_hidden_states in the directory name or use tmp_path.joinpath/f-string
with cp_size and mix_hidden_states) wherever output_dir is constructed/used so
each param case writes to an isolated directory and cannot overwrite another
case’s checkpoints.

In
`@tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py`:
- Around line 101-109: The pytest parametrize names ("algo", "num_layers",
"activation_func", "normalization") do not match the test function signature:
rename the function parameter num_medusa_heads_or_eagle_layers to num_layers in
test_speculative_gpt_model and update all uses (including the partial call that
currently references num_medusa_heads_or_eagle_layers) to use num_layers so
pytest parametrize can bind correctly.

---

Outside diff comments:
In `@examples/speculative_decoding/main.py`:
- Around line 228-239: The code only sets data_module when training_args.mode ==
"eagle3" but always expands **data_module into EagleTrainerWithAccLog, causing
UnboundLocalError for other modes (e.g., "medusa"); fix by ensuring data_module
is always defined before trainer creation: either call
make_eagle_supervised_data_module when mode == "eagle3" and otherwise set
data_module to a safe default (e.g., an empty dict) or explicitly raise a clear
error (ValueError) if only "eagle3" is supported; perform this check/assignment
just before constructing EagleTrainerWithAccLog so the trainer call can safely
use **data_module.

In `@modelopt/torch/speculative/plugins/megatron_eagle.py`:
- Around line 1208-1233: When self.eagle_mix_hidden_states is True and
self.config.sequence_parallel is enabled, you must avoid in-place indexed writes
on a sharded sequence tensor; instead gather the full sequence, perform the
random-index mixing on the gathered eagle_module_input_hidden_states and
eagle_module_output_hidden_states (using gather_from_sequence_parallel_region),
then scatter the mixed result back with scatter_to_sequence_parallel_region;
reference and update the logic around eagle_module_input_hidden_states,
eagle_module_output_hidden_states, eagle_mix_hidden_states,
gather_from_sequence_parallel_region and scatter_to_sequence_parallel_region so
indices are computed against the full sequence and mutations occur on the
gathered tensor before re-sharding.

---

Nitpick comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 938-956: The per-batch Python loop that replaces hidden states
should be replaced with a single advanced-indexed assignment: keep computing
rand_indices (ensure it's a LongTensor) of shape (batch_size, num_to_replace)
and create batch_idx = torch.arange(batch_size,
device=eagle_input_hiddens.device).unsqueeze(1); then perform
eagle_input_hiddens[batch_idx, rand_indices, :] =
eagle_output_hiddens[batch_idx, rand_indices, :] to vectorize the replacement in
the TTT loop when eagle_mix_hidden_states is true (references:
eagle_mix_hidden_states, eagle_input_hiddens, eagle_output_hiddens,
rand_indices, ttt_step).

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 35e6099 and 802689ffb0b9384b3dc071b268a66f4c497d02d9.

📒 Files selected for processing (9)
  • examples/speculative_decoding/main.py
  • modelopt/torch/speculative/config.py
  • modelopt/torch/speculative/eagle/conversion.py
  • modelopt/torch/speculative/eagle/eagle_model.py
  • modelopt/torch/speculative/plugins/megatron_eagle.py
  • modelopt/torch/speculative/plugins/transformers.py
  • tests/examples/speculative_decoding/test_eagle.py
  • tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py
  • tests/unit/torch/speculative/plugins/test_hf_speculative.py

Comment on lines +103 to +105
eagle_ttt_steps: int = ModeloptField(
default=4, description=("The number of train-time-test steps in training.")
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add a lower-bound validation for eagle_ttt_steps.

At Line 103, eagle_ttt_steps accepts non-positive values. That can bypass the TTT loop and trigger downstream training-time loss assertions. Constrain it at config level.

💡 Proposed fix
     eagle_ttt_steps: int = ModeloptField(
-        default=4, description=("The number of train-time-test steps in training.")
+        default=4,
+        ge=1,
+        description=("The number of train-time-test steps in training."),
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
eagle_ttt_steps: int = ModeloptField(
default=4, description=("The number of train-time-test steps in training.")
)
eagle_ttt_steps: int = ModeloptField(
default=4,
ge=1,
description=("The number of train-time-test steps in training."),
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/config.py` around lines 103 - 105, eagle_ttt_steps
currently allows non-positive integers which can skip the TTT loop; update its
ModeloptField declaration to enforce a lower bound of 1 (e.g., set a validation
constraint such as ge=1 or min=1 depending on ModeloptField API) so only
positive integers are accepted; if ModeloptField doesn't support a direct
constraint, add a validation step for the eagle_ttt_steps attribute (a
pydantic/field validator in the enclosing config class) that raises a clear
error when eagle_ttt_steps < 1.

Comment on lines +64 to +69
@pytest.mark.parametrize(("cp_size", "mix_hidden_states"), [(1, "false"), (2, "false"), (1, "true"), (2, "true")])
def test_llama_eagle3(tiny_llama_path,
tiny_daring_anteater_path,
tmp_path, eagle_output_dir,
cp_size,
mix_hidden_states):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Parameterized cases can overwrite each other’s checkpoints.

After adding mix_hidden_states (Line 64), both variants for the same cp_size still write to the same output_dir (Line 100). That risks resume/contamination between cases.

💡 Proposed fix
-    run_example_command(
+    mix_tag = "mix" if mix_hidden_states == "true" else "nomix"
+    run_example_command(
         [
             "./launch_train.sh",
@@
-            "--output_dir", eagle_output_dir / f"eagle-tinyllama-cp{cp_size}",
+            "--output_dir", eagle_output_dir / f"eagle-tinyllama-cp{cp_size}-{mix_tag}",
             "--training_seq_len", "128", # Match max_position_embeddings
             "--cp_size", str(cp_size),
             "--mix_hidden_states", mix_hidden_states,
         ],

Also applies to: 100-104

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/examples/speculative_decoding/test_eagle.py` around lines 64 - 69, The
parameterized test test_llama_eagle3 allows runs with the same cp_size but
different mix_hidden_states to write to the same output directory
(eagle_output_dir/output_dir), causing checkpoint collisions; update the test to
create a unique output_dir per parameter combination (use cp_size and
mix_hidden_states in the directory name or use tmp_path.joinpath/f-string with
cp_size and mix_hidden_states) wherever output_dir is constructed/used so each
param case writes to an isolated directory and cannot overwrite another case’s
checkpoints.

Comment thread tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py Outdated
Comment thread modelopt/torch/speculative/eagle/eagle_model.py Outdated
Comment thread modelopt/torch/speculative/plugins/transformers.py
@@ -58,6 +58,8 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu
eagle_loss_decay_factor=config.eagle_loss_decay_factor,
eagle_architecture_config=config.eagle_architecture_config,
eagle_decoder_type=config.eagle_decoder_type,
eagle_ttt_steps=config.eagle_ttt_steps,
eagle_mix_hidden_states=config.eagle_mix_hidden_states,
Copy link
Copy Markdown
Contributor Author

@yeyu-nvidia yeyu-nvidia Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, mix hidden states still use multiple steps of training, like ttt. it's just no more hidden_states concat, but replaced with mixing.

@yeyu-nvidia yeyu-nvidia requested review from a team and ChenhanYu and removed request for a team February 27, 2026 22:50
@yeyu-nvidia yeyu-nvidia force-pushed the yeyu/mix_hidden_states branch 2 times, most recently from 4d77152 to 0c19663 Compare March 2, 2026 19:07
Comment thread modelopt/torch/speculative/plugins/transformers.py Outdated
Comment thread modelopt/torch/speculative/plugins/transformers.py Outdated
Comment thread modelopt/torch/speculative/plugins/transformers.py Outdated
@yeyu-nvidia yeyu-nvidia force-pushed the yeyu/mix_hidden_states branch from 0c19663 to ca56d25 Compare March 3, 2026 20:57
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
examples/speculative_decoding/main.py (2)

222-222: ⚠️ Potential issue | 🔴 Critical

Use safe torch.load behavior with explicit weights_only=True parameter.

Line 222 loads a draft vocab cache without the weights_only=True parameter. Even though the cache is internally generated (per the calibration script), the security coding guidelines require explicit safe loading semantics.

Proposed fix
-                model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache)
+                model.eagle_module.d2t = torch.load(
+                    data_args.draft_vocab_cache,
+                    weights_only=True,  # cache tensor is generated internally by calibration script
+                )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` at line 222, The call that assigns
model.eagle_module.d2t from torch.load(data_args.draft_vocab_cache) must use
safe loading semantics; update the call to pass weights_only=True to torch.load
(e.g., torch.load(data_args.draft_vocab_cache, weights_only=True)) so the draft
vocab cache is loaded with explicit safe behavior; locate the assignment to
model.eagle_module.d2t and replace the plain torch.load call accordingly
(optionally preserve any map_location usage if present).

171-175: ⚠️ Potential issue | 🔴 Critical

Remove hardcoded trust_remote_code=True and make it user-configurable (default False).

Hardcoding remote code execution violates security policy and weakens the threat model. The repository requires trust_remote_code to be exposed as a caller-configurable parameter defaulting to False (per SECURITY.md).

Add a trust_remote_code field to ModelArguments and propagate it to all four model/tokenizer loading calls at lines 173, 175, 184, and 194:

Proposed fix
 `@dataclass`
 class ModelArguments:
     model_name_or_path: str | None = field(default="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
+    trust_remote_code: bool = field(
+        default=False,
+        metadata={"help": "Allow loading custom remote model/tokenizer code."},
+    )
@@
             _, model = load_vlm_or_llm_with_kwargs(
-                checkpoint, torch_dtype="auto", trust_remote_code=True
+                checkpoint,
+                torch_dtype="auto",
+                trust_remote_code=model_args.trust_remote_code,
             )
-        tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
+        tokenizer = transformers.AutoTokenizer.from_pretrained(
+            checkpoint, trust_remote_code=model_args.trust_remote_code
+        )
@@
         model_config, model = load_vlm_or_llm_with_kwargs(
             model_args.model_name_or_path,
             torch_dtype="auto",
             device_map="cpu",
-            trust_remote_code=True,
+            trust_remote_code=model_args.trust_remote_code,
             **offline_kwargs,
         )
@@
         tokenizer = transformers.AutoTokenizer.from_pretrained(
             model_args.model_name_or_path,
             model_max_length=training_args.training_seq_len,
-            trust_remote_code=True,
+            trust_remote_code=model_args.trust_remote_code,
         )

Also, line 222 uses torch.load() without weights_only=True. Unsafe deserialization of untrusted checkpoints can lead to arbitrary code execution. Add weights_only=True or document why it is safe if the file source is internally trusted.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 171 - 175, Add a new
boolean field trust_remote_code (default False) to ModelArguments and thread it
through the model/tokenizer loading calls instead of hardcoding
trust_remote_code=True: pass model_args.trust_remote_code into the calls that
use patch_transformers5_params_loading/load_vlm_or_llm_with_kwargs and
transformers.AutoTokenizer.from_pretrained (the four locations currently forcing
True), and update any helper wrappers that call these functions to accept and
forward the new flag; additionally, change the unsafe torch.load(...) call to
use torch.load(..., weights_only=True) (or document/validate the source) to
avoid arbitrary-code deserialization.
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

1167-1211: ⚠️ Potential issue | 🔴 Critical

Sequence-parallel and in-place mixing will cause incorrect indexing or crashes when both are enabled together.

When sequence_parallel=True and eagle_mix_hidden_states=True in subsequent loop iterations, eagle_module_input_hidden_states is sharded (from the scatter at line 1207-1209 in the previous iteration) while eagle_module_output_hidden_states is gathered. Computing seq_len_s and rand_indices from the gathered output then applying them to the sharded input creates a shape mismatch—indices will exceed the sharded tensor's first dimension.

Additionally, the in-place mutation at lines 1201-1204 can invalidate autograd's tracking if the tensor flows through backward computation.

Fix: Gather input before mixing and clone to avoid in-place mutation
            if self.config.sequence_parallel:
                eagle_module_output_hidden_states = gather_from_sequence_parallel_region(
                    eagle_module_output_hidden_states
                )
+                if self.eagle_mix_hidden_states:
+                    eagle_module_input_hidden_states = gather_from_sequence_parallel_region(
+                        eagle_module_input_hidden_states
+                    )
+
             eagle_module_output_hidden_states = torch.cat(
                 (
                     torch.zeros(
@@ -1187,7 +1195,8 @@
             )

             if self.eagle_mix_hidden_states:
+                mixed_hidden_states = eagle_module_input_hidden_states.clone()
                 seq_len_s, batch_size, _ = eagle_module_output_hidden_states.shape
                 num_to_replace = max(1, seq_len_s // (2**ttt_step + 1))

@@ -1199,10 +1208,11 @@
                 )

                 for batch_idx in range(batch_size):
-                    eagle_module_input_hidden_states[rand_indices[batch_idx], batch_idx, :] = (
+                    mixed_hidden_states[rand_indices[batch_idx], batch_idx, :] = (
                         eagle_module_output_hidden_states[rand_indices[batch_idx], batch_idx, :]
                     )
+                eagle_module_input_hidden_states = mixed_hidden_states
             else:
                 eagle_module_input_hidden_states = eagle_module_output_hidden_states
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/megatron_eagle.py` around lines 1167 -
1211, When sequence_parallel and eagle_mix_hidden_states are both enabled, the
code must gather the (possibly sharded) eagle_module_input_hidden_states to the
full tensor before computing seq_len_s/rand_indices and performing replacements,
and must avoid in-place mutation to preserve autograd; specifically, call
gather_from_sequence_parallel_region(eagle_module_input_hidden_states) (matching
how eagle_module_output_hidden_states is gathered), clone the gathered tensor
(or create a new tensor and use index_copy_ or advanced indexing to write
replacements) when applying rand_indices, then if self.config.sequence_parallel
scatter the modified tensor back with scatter_to_sequence_parallel_region;
update references to eagle_module_input_hidden_states,
eagle_module_output_hidden_states, eagle_mix_hidden_states, seq_len_s,
rand_indices, and the gather/scatter helpers accordingly.
♻️ Duplicate comments (1)
tests/examples/speculative_decoding/test_eagle.py (1)

64-104: ⚠️ Potential issue | 🟠 Major

Isolate checkpoints per (cp_size, mix_hidden_states) case.

The new matrix still writes both mix variants to the same output_dir for a given cp_size, so one case can overwrite another and contaminate downstream resume/export checks.

💡 Proposed fix
+    mix_tag = "mix" if mix_hidden_states == "true" else "nomix"
     run_example_command(
         [
             "./launch_train.sh",
@@
-            "--output_dir", eagle_output_dir / f"eagle-tinyllama-cp{cp_size}",
+            "--output_dir", eagle_output_dir / f"eagle-tinyllama-cp{cp_size}-{mix_tag}",
             "--training_seq_len", "128", # Match max_position_embeddings
             "--cp_size", str(cp_size),
             "--mix_hidden_states", mix_hidden_states,
         ],
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/examples/speculative_decoding/test_eagle.py` around lines 64 - 104, The
test test_llama_eagle3 currently writes both mix_hidden_states variants to the
same eagle_output_dir for a given cp_size, causing overwrites; update the
run_example_command call so the "--output_dir" path includes mix_hidden_states
(or both cp_size and mix_hidden_states) to create a unique directory per
(cp_size, mix_hidden_states) case (refer to test_llama_eagle3,
run_example_command, eagle_output_dir, cp_size, mix_hidden_states) so each
matrix cell writes to its own checkpoint folder and prevents contamination.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/speculative_decoding/main.py`:
- Around line 228-231: The code only initializes data_module inside the if
training_args.mode == "eagle3" branch (via make_eagle_supervised_data_module),
leaving data_module undefined for other modes and causing a crash at Trainer
construction; initialize data_module before the branch (e.g., data_module =
None) and either (A) add an else branch that constructs the appropriate data
module for other modes (e.g., a make_medusa_supervised_data_module call) or (B)
ensure the Trainer is only passed data_module when it is not None (guard the
trainer construction or pass a fallback) so training_args.mode and
make_eagle_supervised_data_module/data_module usage are consistent and never
leave data_module undefined.

In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 544-550: The code currently leaves decoder_cls unset when
self.eagle_decoder_type is not "llama" or "kimik2"; add an explicit fail-fast
check after those branches that raises a clear ValueError (or TypeError)
including the offending self.eagle_decoder_type value and supported options so
the error surfaces immediately; reference the eagle_decoder_type variable and
decoder_cls assignment flow (LlamaDecoderLayer and _setup_kimi_k2_decoder) and
raise before using decoder_cls later.
- Around line 916-934: The in-place assignments to eagle_input_hiddens when
eagle_mix_hidden_states is true (inside the loop using ttt_step and
rand_indices) mutate a tensor that’s part of the autograd graph; to fix, avoid
in-place updates by making a detached/independent copy before modifying (e.g.,
replace usage of eagle_input_hiddens with a cloned tensor like
eagle_input_hiddens = eagle_input_hiddens.clone() or a shallow copy only when
you will assign into it) then perform the per-batch indexed replacements from
eagle_output_hiddens into that cloned tensor; ensure subsequent forward calls
reference the cloned tensor (and not the original) so autograd is not broken
while preserving the same semantics for eagle_mix_hidden_states,
eagle_output_hiddens, rand_indices, and ttt_step.

In
`@tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py`:
- Around line 49-53: ALGO_TO_CONFIG[algo] is being mutated in-place for the
"eagle3" branch; instead clone the config before changing nested fields so tests
don't leak state—create a deep copy of ALGO_TO_CONFIG[algo] into mtsp_config
(e.g., using copy.deepcopy or similar) and then set
mtsp_config["config"]["eagle_architecture_config"]["num_hidden_layers"] =
num_layers and mtsp_config["config"]["eagle_architecture_config"]["hidden_size"]
= model.config.hidden_size, leaving ALGO_TO_CONFIG unchanged.

---

Outside diff comments:
In `@examples/speculative_decoding/main.py`:
- Line 222: The call that assigns model.eagle_module.d2t from
torch.load(data_args.draft_vocab_cache) must use safe loading semantics; update
the call to pass weights_only=True to torch.load (e.g.,
torch.load(data_args.draft_vocab_cache, weights_only=True)) so the draft vocab
cache is loaded with explicit safe behavior; locate the assignment to
model.eagle_module.d2t and replace the plain torch.load call accordingly
(optionally preserve any map_location usage if present).
- Around line 171-175: Add a new boolean field trust_remote_code (default False)
to ModelArguments and thread it through the model/tokenizer loading calls
instead of hardcoding trust_remote_code=True: pass model_args.trust_remote_code
into the calls that use
patch_transformers5_params_loading/load_vlm_or_llm_with_kwargs and
transformers.AutoTokenizer.from_pretrained (the four locations currently forcing
True), and update any helper wrappers that call these functions to accept and
forward the new flag; additionally, change the unsafe torch.load(...) call to
use torch.load(..., weights_only=True) (or document/validate the source) to
avoid arbitrary-code deserialization.

In `@modelopt/torch/speculative/plugins/megatron_eagle.py`:
- Around line 1167-1211: When sequence_parallel and eagle_mix_hidden_states are
both enabled, the code must gather the (possibly sharded)
eagle_module_input_hidden_states to the full tensor before computing
seq_len_s/rand_indices and performing replacements, and must avoid in-place
mutation to preserve autograd; specifically, call
gather_from_sequence_parallel_region(eagle_module_input_hidden_states) (matching
how eagle_module_output_hidden_states is gathered), clone the gathered tensor
(or create a new tensor and use index_copy_ or advanced indexing to write
replacements) when applying rand_indices, then if self.config.sequence_parallel
scatter the modified tensor back with scatter_to_sequence_parallel_region;
update references to eagle_module_input_hidden_states,
eagle_module_output_hidden_states, eagle_mix_hidden_states, seq_len_s,
rand_indices, and the gather/scatter helpers accordingly.

---

Duplicate comments:
In `@tests/examples/speculative_decoding/test_eagle.py`:
- Around line 64-104: The test test_llama_eagle3 currently writes both
mix_hidden_states variants to the same eagle_output_dir for a given cp_size,
causing overwrites; update the run_example_command call so the "--output_dir"
path includes mix_hidden_states (or both cp_size and mix_hidden_states) to
create a unique directory per (cp_size, mix_hidden_states) case (refer to
test_llama_eagle3, run_example_command, eagle_output_dir, cp_size,
mix_hidden_states) so each matrix cell writes to its own checkpoint folder and
prevents contamination.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 802689ffb0b9384b3dc071b268a66f4c497d02d9 and ca56d2505ba4d8c64ae3d70571a60227c196e5d2.

📒 Files selected for processing (10)
  • examples/speculative_decoding/launch_train.sh
  • examples/speculative_decoding/main.py
  • modelopt/torch/speculative/config.py
  • modelopt/torch/speculative/eagle/conversion.py
  • modelopt/torch/speculative/eagle/eagle_model.py
  • modelopt/torch/speculative/plugins/megatron_eagle.py
  • modelopt/torch/speculative/plugins/transformers.py
  • tests/examples/speculative_decoding/test_eagle.py
  • tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py
  • tests/unit/torch/speculative/plugins/test_hf_speculative.py
✅ Files skipped from review due to trivial changes (1)
  • tests/unit/torch/speculative/plugins/test_hf_speculative.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/speculative/config.py

Comment on lines +228 to 231
if training_args.mode == "eagle3":
data_module = make_eagle_supervised_data_module(
tokenizer, data_args, train_len=training_args.training_seq_len
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

data_module can be undefined for non-eagle3 modes.

TrainingArguments.mode still allows "medusa", but only the eagle3 branch initializes data_module, which can crash at trainer construction.

💡 Proposed fix
-    if training_args.mode == "eagle3":
-        data_module = make_eagle_supervised_data_module(
-            tokenizer, data_args, train_len=training_args.training_seq_len
-        )
+    if training_args.mode != "eagle3":
+        raise ValueError(f"{training_args.mode} is not supported!")
+    data_module = make_eagle_supervised_data_module(
+        tokenizer, data_args, train_len=training_args.training_seq_len
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if training_args.mode == "eagle3":
data_module = make_eagle_supervised_data_module(
tokenizer, data_args, train_len=training_args.training_seq_len
)
if training_args.mode != "eagle3":
raise ValueError(f"{training_args.mode} is not supported!")
data_module = make_eagle_supervised_data_module(
tokenizer, data_args, train_len=training_args.training_seq_len
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 228 - 231, The code only
initializes data_module inside the if training_args.mode == "eagle3" branch (via
make_eagle_supervised_data_module), leaving data_module undefined for other
modes and causing a crash at Trainer construction; initialize data_module before
the branch (e.g., data_module = None) and either (A) add an else branch that
constructs the appropriate data module for other modes (e.g., a
make_medusa_supervised_data_module call) or (B) ensure the Trainer is only
passed data_module when it is not None (guard the trainer construction or pass a
fallback) so training_args.mode and
make_eagle_supervised_data_module/data_module usage are consistent and never
leave data_module undefined.

Comment on lines +544 to +550
if self.eagle_decoder_type == "llama":
# Use default eagle config
decoder_cls = LlamaDecoderLayer
elif eagle_decoder_type == "kimik2":
elif self.eagle_decoder_type == "kimik2":
decoder_cls = _setup_kimi_k2_decoder()

self.eagle_config = PretrainedConfig.from_dict(eagle_architecture_config)
self.eagle_config.eagle_decoder_type = eagle_decoder_type
self.eagle_config = PretrainedConfig.from_dict(config.eagle_architecture_config)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast on unsupported eagle_decoder_type.

If the value is neither "llama" nor "kimik2", decoder_cls stays unset and fails later with a less actionable error.

💡 Proposed fix
         if self.eagle_decoder_type == "llama":
             # Use default eagle config
             decoder_cls = LlamaDecoderLayer
         elif self.eagle_decoder_type == "kimik2":
             decoder_cls = _setup_kimi_k2_decoder()
+        else:
+            raise ValueError(f"Unsupported eagle_decoder_type: {self.eagle_decoder_type}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.eagle_decoder_type == "llama":
# Use default eagle config
decoder_cls = LlamaDecoderLayer
elif eagle_decoder_type == "kimik2":
elif self.eagle_decoder_type == "kimik2":
decoder_cls = _setup_kimi_k2_decoder()
self.eagle_config = PretrainedConfig.from_dict(eagle_architecture_config)
self.eagle_config.eagle_decoder_type = eagle_decoder_type
self.eagle_config = PretrainedConfig.from_dict(config.eagle_architecture_config)
if self.eagle_decoder_type == "llama":
# Use default eagle config
decoder_cls = LlamaDecoderLayer
elif self.eagle_decoder_type == "kimik2":
decoder_cls = _setup_kimi_k2_decoder()
else:
raise ValueError(f"Unsupported eagle_decoder_type: {self.eagle_decoder_type}")
self.eagle_config = PretrainedConfig.from_dict(config.eagle_architecture_config)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` around lines 544 - 550,
The code currently leaves decoder_cls unset when self.eagle_decoder_type is not
"llama" or "kimik2"; add an explicit fail-fast check after those branches that
raises a clear ValueError (or TypeError) including the offending
self.eagle_decoder_type value and supported options so the error surfaces
immediately; reference the eagle_decoder_type variable and decoder_cls
assignment flow (LlamaDecoderLayer and _setup_kimi_k2_decoder) and raise before
using decoder_cls later.

Comment on lines +916 to +934
if self.eagle_mix_hidden_states:
batch_size, seq_len_s, _ = eagle_input_hiddens.shape
num_to_replace = max(1, seq_len_s // (2**ttt_step + 1))

# Randomly select positions for each batch to replace
rand_indices = torch.stack(
[
torch.randperm(seq_len_s, device=eagle_input_hiddens.device)[
:num_to_replace
]
for _ in range(batch_size)
],
dim=0,
)
eagle_input_hiddens = eagle_input_hiddens.roll(1, 1)

for batch_idx in range(batch_size):
eagle_input_hiddens[batch_idx, rand_indices[batch_idx], :] = (
eagle_output_hiddens[batch_idx, rand_indices[batch_idx], :]
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, locate and examine the file
fd -t f transformers.py | grep -i speculative | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 116


🏁 Script executed:

# Get the file and check its size
wc -l modelopt/torch/speculative/plugins/transformers.py

Repository: NVIDIA/Model-Optimizer

Length of output: 121


🏁 Script executed:

# Read the code context around lines 916-934
sed -n '880,960p' modelopt/torch/speculative/plugins/transformers.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 4252


🏁 Script executed:

# Look at the _prepare_eagle_inputs method to understand tensor creation
sed -n '1,100p' modelopt/torch/speculative/plugins/transformers.py | grep -A 30 "_prepare_eagle_inputs" | head -40

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Search for _prepare_eagle_inputs method definition
rg -n "_prepare_eagle_inputs" modelopt/torch/speculative/plugins/transformers.py | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 144


🏁 Script executed:

# Read the _prepare_eagle_inputs method
sed -n '643,750p' modelopt/torch/speculative/plugins/transformers.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 5062


🏁 Script executed:

# Check if eagle_input_hiddens is modified or used in training context
sed -n '850,880p' modelopt/torch/speculative/plugins/transformers.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 1587


🏁 Script executed:

# Verify the training context - check where this method is called and if training is involved
rg -B 5 "eagle_ttt_steps" modelopt/torch/speculative/plugins/transformers.py | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 252


🏁 Script executed:

# Check what base_outputs.out_hiddens is - trace back to understand if requires_grad
sed -n '730,820p' modelopt/torch/speculative/plugins/transformers.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 4122


🏁 Script executed:

# Check the broader forward method context to see if training mode is involved
sed -n '800,850p' modelopt/torch/speculative/plugins/transformers.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 2412


🏁 Script executed:

# Verify the condition for in-place modification - check if eagle_input_hiddens has requires_grad
sed -n '920,940p' modelopt/torch/speculative/plugins/transformers.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 1117


🏁 Script executed:

# Check the full forward method signature and training condition
sed -n '820,900p' modelopt/torch/speculative/plugins/transformers.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 3863


🏁 Script executed:

# Check if there are any torch.no_grad() contexts around the problematic code
sed -n '910,960p' modelopt/torch/speculative/plugins/transformers.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 2858


🏁 Script executed:

# Verify if eagle_input_hiddens would have requires_grad during training
# Check if fc layer or forward output preserves gradients
rg -A 5 "eagle_module.fc" modelopt/torch/speculative/plugins/transformers.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 705


🏁 Script executed:

# Double-check the flow: verify that eagle_input_hiddens is used in forward before mutation
# and that it would have gradients during training
sed -n '900,945p' modelopt/torch/speculative/plugins/transformers.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 2543


🏁 Script executed:

# Check if there's any detach() or torch.no_grad() between the forward and the in-place mutation
sed -n '900,935p' modelopt/torch/speculative/plugins/transformers.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 1874


Avoid in-place mutation of eagle_input_hiddens in the mix path.

The slice assignments at lines 933-935 modify a tensor that is in the active computation graph, causing autograd "modified by an inplace operation" errors during training. The tensor is used in the forward pass (line 928) and then immediately modified in-place, and the mutated version is reused in the next loop iteration's forward pass. Clone the tensor before mutation to preserve the computation graph:

Proposed fix
             if self.eagle_mix_hidden_states:
                 batch_size, seq_len_s, _ = eagle_input_hiddens.shape
                 num_to_replace = max(1, seq_len_s // (2**ttt_step + 1))
+                mixed_hiddens = eagle_input_hiddens.clone()
 
                 # Randomly select positions for each batch to replace
                 rand_indices = torch.stack(
@@ -933,10 +935,12 @@ class EagleModel(PreTrainedModel):
                 for batch_idx in range(batch_size):
-                    eagle_input_hiddens[batch_idx, rand_indices[batch_idx], :] = (
+                    mixed_hiddens[batch_idx, rand_indices[batch_idx], :] = (
                         eagle_output_hiddens[batch_idx, rand_indices[batch_idx], :]
                     )
+                eagle_input_hiddens = mixed_hiddens
             else:
                 eagle_input_hiddens = eagle_output_hiddens
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` around lines 916 - 934,
The in-place assignments to eagle_input_hiddens when eagle_mix_hidden_states is
true (inside the loop using ttt_step and rand_indices) mutate a tensor that’s
part of the autograd graph; to fix, avoid in-place updates by making a
detached/independent copy before modifying (e.g., replace usage of
eagle_input_hiddens with a cloned tensor like eagle_input_hiddens =
eagle_input_hiddens.clone() or a shallow copy only when you will assign into it)
then perform the per-batch indexed replacements from eagle_output_hiddens into
that cloned tensor; ensure subsequent forward calls reference the cloned tensor
(and not the original) so autograd is not broken while preserving the same
semantics for eagle_mix_hidden_states, eagle_output_hiddens, rand_indices, and
ttt_step.

Comment on lines +49 to 53
if algo == "eagle3":
mtsp_config = ALGO_TO_CONFIG[algo]

mtsp_config["config"]["eagle_architecture_config"]["num_hidden_layers"] = (
num_medusa_heads_or_eagle_layers
)
mtsp_config["config"]["eagle_architecture_config"]["num_hidden_layers"] = num_layers
mtsp_config["config"]["eagle_architecture_config"]["hidden_size"] = model.config.hidden_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid mutating shared default config in-place.

ALGO_TO_CONFIG[algo] is shared; mutating nested fields here can leak state between test cases. Copy before mutation.

💡 Proposed fix
+import copy
@@
-        mtsp_config = ALGO_TO_CONFIG[algo]
+        mtsp_config = copy.deepcopy(ALGO_TO_CONFIG[algo])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py`
around lines 49 - 53, ALGO_TO_CONFIG[algo] is being mutated in-place for the
"eagle3" branch; instead clone the config before changing nested fields so tests
don't leak state—create a deep copy of ALGO_TO_CONFIG[algo] into mtsp_config
(e.g., using copy.deepcopy or similar) and then set
mtsp_config["config"]["eagle_architecture_config"]["num_hidden_layers"] =
num_layers and mtsp_config["config"]["eagle_architecture_config"]["hidden_size"]
= model.config.hidden_size, leaving ALGO_TO_CONFIG unchanged.

Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
@yeyu-nvidia yeyu-nvidia force-pushed the yeyu/mix_hidden_states branch from 122f1b9 to ef832c4 Compare March 4, 2026 21:21
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/transformers.py (1)

913-927: ⚠️ Potential issue | 🟠 Major

In-place mutation of eagle_input_hiddens may break autograd during training.

The tensor eagle_input_hiddens is used in the forward pass at line 904, and then modified in-place at lines 923-925. During training, this can cause "modified by an inplace operation" errors because the tensor is part of the computation graph.

Clone before mutation to preserve the gradient flow:

Proposed fix
             if self.eagle_mix_hidden_states:
                 batch_size, seq_len_s, _ = eagle_input_hiddens.shape
                 num_to_replace = max(1, seq_len_s // (2**ttt_step + 1))
+                # Clone to avoid in-place modification of tensor in computation graph
+                eagle_input_hiddens = eagle_input_hiddens.clone()

                 # Randomly select positions for each batch to replace
                 rand_indices = torch.rand(
                     batch_size, seq_len_s, device=eagle_input_hiddens.device
                 ).argsort(dim=1)[:, :num_to_replace]

                 batch_indices = torch.arange(batch_size)[:, None]
                 eagle_input_hiddens[batch_indices, rand_indices] = eagle_output_hiddens[
                     batch_indices, rand_indices
                 ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` around lines 913 - 927,
eagle_input_hiddens is being mutated in-place when eagle_mix_hidden_states is
true which can break autograd; before performing the indexed replacement (the
block using rand_indices, batch_indices and assigning into eagle_input_hiddens
from eagle_output_hiddens) make a detached or plain clone of eagle_input_hiddens
(e.g., eagle_input_hiddens = eagle_input_hiddens.clone()) and then perform the
assignment on that clone so the original computation graph is preserved; ensure
this change happens inside the same conditional branch that checks
eagle_mix_hidden_states and keep the else branch assigning eagle_input_hiddens =
eagle_output_hiddens unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 922-925: batch_indices is created on CPU which can cause indexing
device-mismatch with eagle_input_hiddens; recreate batch_indices on the same
device as eagle_input_hiddens (e.g., use torch.arange(batch_size,
device=eagle_input_hiddens.device)[:, None] or call
.to(eagle_input_hiddens.device]) so batch_indices, rand_indices and
eagle_input_hiddens are all on the same device before using
eagle_input_hiddens[batch_indices, rand_indices]; ensure dtype is appropriate
(long) for indexing.
- Around line 364-367: Guard the assignment to self._input_embeds with the same
use_aux_hidden_state condition used when registering the hook so the attribute
is only set when it's consumed; specifically, in the method where
self._input_embeds = self.layers[0].input_layernorm(inputs_embeds) appears, wrap
that assignment in an if use_aux_hidden_state: block (matching the hook
registration code that references self._input_embeds), ensuring only the path
that registers the hook (use_aux_hidden_state True) sets the attribute.

---

Duplicate comments:
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 913-927: eagle_input_hiddens is being mutated in-place when
eagle_mix_hidden_states is true which can break autograd; before performing the
indexed replacement (the block using rand_indices, batch_indices and assigning
into eagle_input_hiddens from eagle_output_hiddens) make a detached or plain
clone of eagle_input_hiddens (e.g., eagle_input_hiddens =
eagle_input_hiddens.clone()) and then perform the assignment on that clone so
the original computation graph is preserved; ensure this change happens inside
the same conditional branch that checks eagle_mix_hidden_states and keep the
else branch assigning eagle_input_hiddens = eagle_output_hiddens unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b5aff685-7ea7-4925-a7d6-2a38cd62012e

📥 Commits

Reviewing files that changed from the base of the PR and between ca56d2505ba4d8c64ae3d70571a60227c196e5d2 and 122f1b9ae09d7f079ef078081d700c925356af54.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/transformers.py

Comment on lines +364 to +367
# In EAGLE-3, we save input embeddings to attribute, and use it in first decoder layer by hook function
# Also, we normalize input embeddings and hidden states before concatenating them.
# The default input norm in first layer attn will be disabled.
self._input_embeds = self.layers[0].input_layernorm(inputs_embeds)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find all uses of _input_embeds in the file
rg -n "_input_embeds" modelopt/torch/speculative/plugins/transformers.py

Repository: NVIDIA/Model-Optimizer

Length of output: 919


🏁 Script executed:

#!/bin/bash
# Read context around lines 322-326 and the surrounding structure
sed -n '315,340p' modelopt/torch/speculative/plugins/transformers.py

Repository: NVIDIA/Model-Optimizer

Length of output: 998


🏁 Script executed:

#!/bin/bash
# Read context around lines 270-280 to see hook registration
sed -n '265,285p' modelopt/torch/speculative/plugins/transformers.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1032


🏁 Script executed:

#!/bin/bash
# Read context around lines 360-375 to see if line 367 is conditional
sed -n '355,375p' modelopt/torch/speculative/plugins/transformers.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1180


🏁 Script executed:

#!/bin/bash
# Check if line 367 is inside a conditional block
# Look at the forward method structure from line 335 onwards
sed -n '335,380p' modelopt/torch/speculative/plugins/transformers.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2209


🏁 Script executed:

#!/bin/bash
# Check the full __init__ to understand when use_aux_hidden_state condition applies
# and verify if forward should also be guarded
sed -n '250,280p' modelopt/torch/speculative/plugins/transformers.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1494


🏁 Script executed:

#!/bin/bash
# Verify the full class structure - check what config.use_aux_hidden_state controls
# and confirm line 367 is unconditional
sed -n '200,220p' modelopt/torch/speculative/plugins/transformers.py

Repository: NVIDIA/Model-Optimizer

Length of output: 737


🏁 Script executed:

#!/bin/bash
# Get the full class definition to see if there are other uses of _input_embeds
# outside of the hook and if it serves any other purpose
sed -n '200,300p' modelopt/torch/speculative/plugins/transformers.py | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 1954


🏁 Script executed:

#!/bin/bash
# Check if the entire forward method or just line 367 should be guarded
# Look at the complete forward to see what happens
wc -l modelopt/torch/speculative/plugins/transformers.py

Repository: NVIDIA/Model-Optimizer

Length of output: 121


🏁 Script executed:

#!/bin/bash
# Read the complete forward method to check if line 367 is guarded
sed -n '333,430p' modelopt/torch/speculative/plugins/transformers.py

Repository: NVIDIA/Model-Optimizer

Length of output: 3912


🏁 Script executed:

#!/bin/bash
# Verify that input_layernorm exists by default in LlamaDecoderLayer
# Search for where decoder_layer_cls is used and what it is
rg -n "decoder_layer_cls" modelopt/torch/speculative/plugins/transformers.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 236


🏁 Script executed:

#!/bin/bash
# Check if _input_embeds is initialized in __init__
rg -n "self._input_embeds" modelopt/torch/speculative/plugins/transformers.py

Repository: NVIDIA/Model-Optimizer

Length of output: 334


🏁 Script executed:

#!/bin/bash
# Look for __init__ of EagleModule to see if _input_embeds is initialized
sed -n '230,290p' modelopt/torch/speculative/plugins/transformers.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2867


🏁 Script executed:

#!/bin/bash
# Verify that input_layernorm exists in the base LlamaDecoderLayer
# to confirm it can be called without error when use_aux_hidden_state=False
rg -n "input_layernorm" modelopt/torch/speculative/plugins/transformers.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 303


🏁 Script executed:

#!/bin/bash
# Check the imports and class usage to understand where decoder_layer_cls comes from
head -50 modelopt/torch/speculative/plugins/transformers.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2052


🏁 Script executed:

#!/bin/bash
# Search for where the forward method is called to understand the flow
# and verify if _input_embeds should only be set when use_aux_hidden_state=True
rg -B5 -A5 "def forward" modelopt/torch/speculative/plugins/transformers.py | grep -A10 "class EagleModule"

Repository: NVIDIA/Model-Optimizer

Length of output: 48


Line 367 should be guarded by use_aux_hidden_state condition to avoid setting unused _input_embeds attribute.

When use_aux_hidden_state=False, the code unconditionally sets self._input_embeds at line 367, but the hook that consumes it (lines 322-326) is only registered when use_aux_hidden_state=True (lines 272-274). This results in the attribute being set but never used in the non-aux path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` around lines 364 - 367,
Guard the assignment to self._input_embeds with the same use_aux_hidden_state
condition used when registering the hook so the attribute is only set when it's
consumed; specifically, in the method where self._input_embeds =
self.layers[0].input_layernorm(inputs_embeds) appears, wrap that assignment in
an if use_aux_hidden_state: block (matching the hook registration code that
references self._input_embeds), ensuring only the path that registers the hook
(use_aux_hidden_state True) sets the attribute.

Comment thread modelopt/torch/speculative/plugins/transformers.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/speculative_decoding/main.py (1)

173-175: ⚠️ Potential issue | 🔴 Critical

Remove hardcoded trust_remote_code=True and expose as a caller-configurable parameter.

Lines 173, 175, 184, and 194 hardcode trust_remote_code=True in model and tokenizer loading. This bypasses caller control over remote code execution. Expose as a configuration parameter with default value False, allowing callers to opt-in explicitly.

Proposed fix
 `@dataclass`
 class ModelArguments:
     model_name_or_path: str | None = field(default="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
+    trust_remote_code: bool = field(
+        default=False,
+        metadata={"help": "Allow loading custom code from model repos."},
+    )
@@
             _, model = load_vlm_or_llm_with_kwargs(
-                checkpoint, torch_dtype="auto", trust_remote_code=True
+                checkpoint,
+                torch_dtype="auto",
+                trust_remote_code=model_args.trust_remote_code,
             )
-        tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
+        tokenizer = transformers.AutoTokenizer.from_pretrained(
+            checkpoint, trust_remote_code=model_args.trust_remote_code
+        )
@@
         model_config, model = load_vlm_or_llm_with_kwargs(
             model_args.model_name_or_path,
             torch_dtype="auto",
             device_map="cpu",
-            trust_remote_code=True,
+            trust_remote_code=model_args.trust_remote_code,
             **offline_kwargs,
         )
@@
         tokenizer = transformers.AutoTokenizer.from_pretrained(
             model_args.model_name_or_path,
             model_max_length=training_args.training_seq_len,
-            trust_remote_code=True,
+            trust_remote_code=model_args.trust_remote_code,
         )

Per coding guidelines: "Do not hardcode trust_remote_code=True when loading transformers models. Let the caller decide and default to False."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 173 - 175, The code
currently hardcodes trust_remote_code=True when loading the model and tokenizer
(e.g., the from_pretrained calls that use checkpoint and torch_dtype="auto"), so
change the public function or script configuration to accept a trust_remote_code
boolean parameter (default False) and pass that variable into all
transformers.AutoModel*.from_pretrained(...) and
transformers.AutoTokenizer.from_pretrained(...) calls instead of the literal
True; update any callers to use the new parameter when they intentionally opt-in
to remote code execution.
♻️ Duplicate comments (5)
modelopt/torch/speculative/plugins/transformers.py (2)

561-565: ⚠️ Potential issue | 🟠 Major

Fail fast on unsupported eagle_decoder_type.

If the value is not "llama" or "kimik2", decoder_cls remains unset and fails later with a less actionable error.

💡 Proposed fix
         if self.eagle_decoder_type == "llama":
             # Use default eagle config
             decoder_cls = LlamaDecoderLayer
         elif self.eagle_decoder_type == "kimik2":
             decoder_cls = _setup_kimi_k2_decoder()
+        else:
+            raise ValueError(
+                f"Unsupported eagle_decoder_type={self.eagle_decoder_type}. "
+                "Supported: ['llama', 'kimik2']."
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` around lines 561 - 565,
The current branch sets decoder_cls only for eagle_decoder_type values "llama"
or "kimik2" (using LlamaDecoderLayer and _setup_kimi_k2_decoder()), leaving
decoder_cls undefined for other values; add an else branch that fails fast by
raising a clear ValueError (or TypeError) referencing eagle_decoder_type and
listing the supported options ("llama", "kimik2") so callers get an actionable
error instead of a later undefined-variable failure.

930-942: ⚠️ Potential issue | 🔴 Critical

Fix mixed-device indexing and in-place update in hidden-state mixing.

Line 939 creates CPU indices; Line 940 writes in-place into eagle_input_hiddens. This can fail on CUDA and can also destabilize autograd across TTT steps.

💡 Proposed fix
             if self.eagle_mix_hidden_states:
                 batch_size, seq_len_s, _ = eagle_input_hiddens.shape
                 num_to_replace = max(1, seq_len_s // (2**ttt_step + 1))
+                mixed_hiddens = eagle_input_hiddens.clone()
@@
-                batch_indices = torch.arange(batch_size)[:, None]
-                eagle_input_hiddens[batch_indices, rand_indices] = eagle_output_hiddens[
+                batch_indices = torch.arange(
+                    batch_size, device=eagle_input_hiddens.device
+                )[:, None]
+                mixed_hiddens[batch_indices, rand_indices] = eagle_output_hiddens[
                     batch_indices, rand_indices
                 ]
+                eagle_input_hiddens = mixed_hiddens
             else:
                 eagle_input_hiddens = eagle_output_hiddens
#!/bin/bash
set -euo pipefail
FILE="modelopt/torch/speculative/plugins/transformers.py"

# 1) Verify CPU index creation in mix path.
rg -n -C2 'batch_indices = torch\.arange\(batch_size\)\[:, None\]' "$FILE"

# 2) Verify in-place assignment on eagle_input_hiddens in same block.
rg -n -C2 'eagle_input_hiddens\[batch_indices, rand_indices\]\s*=' "$FILE"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/transformers.py` around lines 930 - 942,
The hidden-state mixing block must avoid CPU/GPU index mismatch and in-place
writes; when eagle_mix_hidden_states is true, create batch_indices on the same
device/dtype as eagle_input_hiddens (e.g., batch_indices =
torch.arange(batch_size, device=eagle_input_hiddens.device)[:, None]) and do the
replacement without mutating the original tensor in-place (e.g., make a copy:
mixed = eagle_input_hiddens.clone(); mixed[batch_indices, rand_indices] =
eagle_output_hiddens[batch_indices, rand_indices]; then use mixed further or
assign it back), referencing eagle_mix_hidden_states, eagle_input_hiddens,
eagle_output_hiddens, rand_indices and batch_indices.
tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py (1)

49-53: ⚠️ Potential issue | 🟡 Minor

Avoid mutating shared default config in-place.

Line 50 references a shared config dict; Lines 52-55 mutate it directly, which can leak state between tests.

💡 Proposed fix
+import copy
@@
-        mtsp_config = ALGO_TO_CONFIG[algo]
+        mtsp_config = copy.deepcopy(ALGO_TO_CONFIG[algo])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py`
around lines 49 - 53, The test mutates the shared ALGO_TO_CONFIG in-place via
mtsp_config = ALGO_TO_CONFIG[algo] and then changing
mtsp_config["config"]["eagle_architecture_config"]["num_hidden_layers"] and
["hidden_size"]; instead, create an independent copy before modifying (e.g.,
deepcopy ALGO_TO_CONFIG[algo] into mtsp_config or copy the nested
"eagle_architecture_config") so changes to mtsp_config
(num_hidden_layers/hidden_size) do not leak into ALGO_TO_CONFIG or other tests;
update the code that references mtsp_config accordingly.
modelopt/torch/speculative/config.py (1)

103-105: ⚠️ Potential issue | 🟠 Major

Enforce positive validation for eagle_ttt_steps.

Line 103 currently accepts 0/negative values, which can silently bypass TTT behavior. This should be constrained at config level.

💡 Proposed fix
     eagle_ttt_steps: int = ModeloptField(
-        default=4, description=("The number of train-time-test steps in training.")
+        default=4,
+        ge=1,
+        description=("The number of train-time-test steps in training."),
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/config.py` around lines 103 - 105, The field
eagle_ttt_steps on the config uses ModeloptField but currently allows
zero/negative values; add validation to enforce a positive integer (>=1). Modify
the ModeloptField declaration for eagle_ttt_steps to include a minimum
constraint (e.g., min/ge=1 if ModeloptField forwards pydantic constraints) or
add a class-level/field validator (e.g., validate_eagle_ttt_steps) that raises a
clear ValueError when eagle_ttt_steps <= 0 so invalid configs fail fast.
examples/speculative_decoding/main.py (1)

228-231: ⚠️ Potential issue | 🟠 Major

data_module is undefined for non-eagle3 modes.

Line 228 initializes data_module only for eagle3, but mode still permits "medusa", which can crash at trainer construction.

💡 Proposed fix
-    if training_args.mode == "eagle3":
-        data_module = make_eagle_supervised_data_module(
-            tokenizer, data_args, train_len=training_args.training_seq_len
-        )
+    if training_args.mode != "eagle3":
+        raise ValueError(f"{training_args.mode} is not supported!")
+    data_module = make_eagle_supervised_data_module(
+        tokenizer, data_args, train_len=training_args.training_seq_len
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 228 - 231, The code only
sets data_module when training_args.mode == "eagle3", causing data_module to be
undefined for other modes (e.g., "medusa"); update the branch around
training_args.mode to either (A) construct the correct data module for other
supported modes (e.g., call a make_medusa_data_module or appropriate factory) or
(B) explicitly set data_module = None and ensure the downstream trainer
construction/Trainer init handles None; modify the block containing
make_eagle_supervised_data_module and training_args.mode to provide a defined
data_module for all allowed modes so trainer construction won't crash.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/plugins/megatron_eagle.py`:
- Around line 1186-1204: The in-place assignment to
eagle_module_input_hidden_states inside the TTT mix loop causes autograd
versioning errors; instead, create a fresh tensor for the mixed inputs (e.g.,
modified_input = eagle_module_input_hidden_states.clone() or torch.where-based
copy) and perform the per-batch replacement on that copy using the same
rand_indices and eagle_module_output_hidden_states, then pass the modified_input
into the rest of the forward/TTT iteration; ensure you never write directly into
eagle_module_input_hidden_states so repeated iterations and backpropagation use
non-mutated tensors.

---

Outside diff comments:
In `@examples/speculative_decoding/main.py`:
- Around line 173-175: The code currently hardcodes trust_remote_code=True when
loading the model and tokenizer (e.g., the from_pretrained calls that use
checkpoint and torch_dtype="auto"), so change the public function or script
configuration to accept a trust_remote_code boolean parameter (default False)
and pass that variable into all transformers.AutoModel*.from_pretrained(...) and
transformers.AutoTokenizer.from_pretrained(...) calls instead of the literal
True; update any callers to use the new parameter when they intentionally opt-in
to remote code execution.

---

Duplicate comments:
In `@examples/speculative_decoding/main.py`:
- Around line 228-231: The code only sets data_module when training_args.mode ==
"eagle3", causing data_module to be undefined for other modes (e.g., "medusa");
update the branch around training_args.mode to either (A) construct the correct
data module for other supported modes (e.g., call a make_medusa_data_module or
appropriate factory) or (B) explicitly set data_module = None and ensure the
downstream trainer construction/Trainer init handles None; modify the block
containing make_eagle_supervised_data_module and training_args.mode to provide a
defined data_module for all allowed modes so trainer construction won't crash.

In `@modelopt/torch/speculative/config.py`:
- Around line 103-105: The field eagle_ttt_steps on the config uses
ModeloptField but currently allows zero/negative values; add validation to
enforce a positive integer (>=1). Modify the ModeloptField declaration for
eagle_ttt_steps to include a minimum constraint (e.g., min/ge=1 if ModeloptField
forwards pydantic constraints) or add a class-level/field validator (e.g.,
validate_eagle_ttt_steps) that raises a clear ValueError when eagle_ttt_steps <=
0 so invalid configs fail fast.

In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 561-565: The current branch sets decoder_cls only for
eagle_decoder_type values "llama" or "kimik2" (using LlamaDecoderLayer and
_setup_kimi_k2_decoder()), leaving decoder_cls undefined for other values; add
an else branch that fails fast by raising a clear ValueError (or TypeError)
referencing eagle_decoder_type and listing the supported options ("llama",
"kimik2") so callers get an actionable error instead of a later
undefined-variable failure.
- Around line 930-942: The hidden-state mixing block must avoid CPU/GPU index
mismatch and in-place writes; when eagle_mix_hidden_states is true, create
batch_indices on the same device/dtype as eagle_input_hiddens (e.g.,
batch_indices = torch.arange(batch_size, device=eagle_input_hiddens.device)[:,
None]) and do the replacement without mutating the original tensor in-place
(e.g., make a copy: mixed = eagle_input_hiddens.clone(); mixed[batch_indices,
rand_indices] = eagle_output_hiddens[batch_indices, rand_indices]; then use
mixed further or assign it back), referencing eagle_mix_hidden_states,
eagle_input_hiddens, eagle_output_hiddens, rand_indices and batch_indices.

In
`@tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py`:
- Around line 49-53: The test mutates the shared ALGO_TO_CONFIG in-place via
mtsp_config = ALGO_TO_CONFIG[algo] and then changing
mtsp_config["config"]["eagle_architecture_config"]["num_hidden_layers"] and
["hidden_size"]; instead, create an independent copy before modifying (e.g.,
deepcopy ALGO_TO_CONFIG[algo] into mtsp_config or copy the nested
"eagle_architecture_config") so changes to mtsp_config
(num_hidden_layers/hidden_size) do not leak into ALGO_TO_CONFIG or other tests;
update the code that references mtsp_config accordingly.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 9d2871e3-9ad5-40ed-989f-73c3cd9fb1ed

📥 Commits

Reviewing files that changed from the base of the PR and between 122f1b9ae09d7f079ef078081d700c925356af54 and ef832c4.

📒 Files selected for processing (10)
  • examples/speculative_decoding/launch_train.sh
  • examples/speculative_decoding/main.py
  • modelopt/torch/speculative/config.py
  • modelopt/torch/speculative/eagle/conversion.py
  • modelopt/torch/speculative/eagle/eagle_model.py
  • modelopt/torch/speculative/plugins/megatron_eagle.py
  • modelopt/torch/speculative/plugins/transformers.py
  • tests/examples/speculative_decoding/test_eagle.py
  • tests/gpu_megatron/torch/speculative/plugins/test_speculative_megatron_modules.py
  • tests/unit/torch/speculative/plugins/test_hf_speculative.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt/torch/speculative/eagle/conversion.py
  • tests/examples/speculative_decoding/test_eagle.py

Comment thread modelopt/torch/speculative/plugins/megatron_eagle.py
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

1189-1209: ⚠️ Potential issue | 🟠 Major

In-place hidden-state mutation may break autograd during backpropagation.

The in-place assignment at lines 1204-1207 modifies eagle_module_input_hidden_states after it has been used in the forward pass. This can cause PyTorch autograd version tracking errors during backpropagation since the tensor is part of the computational graph.

Proposed fix: clone before mutation
         if self.eagle_mix_hidden_states:
             seq_len_s, batch_size, _ = eagle_module_output_hidden_states.shape
             num_to_replace = max(1, seq_len_s // (2**ttt_step + 1))

             # Randomly select positions for each batch to replace
             rand_indices = torch.stack(
                 [
                     torch.randperm(seq_len_s, device=eagle_module_output_hidden_states.device)[
                         :num_to_replace
                     ]
                     for _ in range(batch_size)
                 ],
                 dim=0,
             )

+            mixed_hidden_states = eagle_module_input_hidden_states.clone()
             for batch_idx in range(batch_size):
-                eagle_module_input_hidden_states[rand_indices[batch_idx], batch_idx, :] = (
+                mixed_hidden_states[rand_indices[batch_idx], batch_idx, :] = (
                     eagle_module_output_hidden_states[rand_indices[batch_idx], batch_idx, :]
                 )
+            eagle_module_input_hidden_states = mixed_hidden_states
         else:
             eagle_module_input_hidden_states = eagle_module_output_hidden_states
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/megatron_eagle.py` around lines 1189 -
1209, The in-place assignment to eagle_module_input_hidden_states inside the
eagle_mix_hidden_states branch can break autograd; to fix, make a copy of
eagle_module_input_hidden_states (e.g., via .clone() or .detach().clone()
depending on desired grad behavior) before mutating it, perform the indexed
replacements from eagle_module_output_hidden_states using the rand_indices loop
(respecting ttt_step and num_to_replace), and then use the cloned-and-modified
tensor for subsequent computation (replace the in-place writes at the
rand_indices loop with writes to the clone and ensure
eagle_module_input_hidden_states is reassigned to that clone).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/speculative/plugins/megatron_eagle.py`:
- Around line 1189-1209: The in-place assignment to
eagle_module_input_hidden_states inside the eagle_mix_hidden_states branch can
break autograd; to fix, make a copy of eagle_module_input_hidden_states (e.g.,
via .clone() or .detach().clone() depending on desired grad behavior) before
mutating it, perform the indexed replacements from
eagle_module_output_hidden_states using the rand_indices loop (respecting
ttt_step and num_to_replace), and then use the cloned-and-modified tensor for
subsequent computation (replace the in-place writes at the rand_indices loop
with writes to the clone and ensure eagle_module_input_hidden_states is
reassigned to that clone).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 65edc277-c840-4019-88d3-94ce78799829

📥 Commits

Reviewing files that changed from the base of the PR and between ef832c4 and 486001b.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/megatron_eagle.py

Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
@yeyu-nvidia yeyu-nvidia merged commit 5d0e012 into main Mar 9, 2026
52 of 54 checks passed
@yeyu-nvidia yeyu-nvidia deleted the yeyu/mix_hidden_states branch March 9, 2026 18:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants