Skip to content

Commit 6403389

Browse files
authored
Feat: Configurable Eagle ROPE scaling during export (#1238)
### What does this PR do? JIRA ticket: https://jirasw.nvidia.com/browse/OMNIML-3469 Type of change: New feature Decouple EAGLE training rope configuration from export rope configuration, enabling separate YaRN rope scaling injection at export time for long-context inference. #### Changes **Configurable export rope scaling (`EagleConfig`)** - Add `eagle_export_rope_scaling` field to `EagleConfig` with default YaRN config (`factor=32.0`, `original_max_position_embeddings=2048`) - Set to `{}` to disable rope scaling injection at export **Simplified training defaults (`default_config.py`)** - Change default training rope from `llama3` (theta=500k) to `default` (theta=10k) — models now train with simple positional embeddings; rope scaling is applied only at export - Add `rope_theta` inside `rope_scaling` dict for transformers 5.x cross-version compatibility **Move config validation/rewriting into `EagleConfig` (`config.py`)** - `_derive_eagle_offline`: derives `eagle_offline` from `data_args.offline_data_path` via validation context, removing manual assignment in `main.py` - `_check_rope_scaling_consistency`: rejects configs where `eagle_export_rope_scaling` is set but training `rope_type` is not `"default"` - `_warn_rope_vs_training_seq_len`: warns when `original_max_position_embeddings` differs from `training_seq_len` **Export rope injection (`hf_spec_export.py`)** - Inject `eagle_export_rope_scaling` into the exported HF config when training rope_type is `"default"` - Fall back `rope_theta` from `rope_scaling` dict for transformers 5.x compatibility **Fix Megatron RotaryEmbedding crash (`megatron_eagle.py`)** - `dict_to_config()` set `rope_scaling=True` whenever the `rope_scaling` key existed, even without a `"factor"` — causing `RotaryEmbedding` to divide by `None` - Now only enables `rope_scaling` when the dict actually contains a `"factor"` key ### Usage Configure in YAML config (or use defaults from `eagle3.yaml`): ```yaml eagle: eagle_export_rope_scaling: rope_type: yarn factor: 32.0 original_max_position_embeddings: 2048 ``` Set to empty dict to disable export rope injection: ```yaml eagle: eagle_export_rope_scaling: {} ``` ### Testing - New unit tests: `tests/unit/torch/speculative/test_eagle_config.py` — rope consistency validator, seq_len warning, context-derived `eagle_offline` - New unit tests: `tests/unit/torch/export/test_hf_spec_rope_export.py` — export rope injection, fallback, and empty-config cases ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ (new field has sensible default; existing configs work unchanged) - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: ✅ - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ❌ (should be added if merging as a feature) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Add export-time rope-scaling configuration for EAGLE models. * **Improvements** * Stronger validation and context-aware reconciliation between training and export configs. * Export now injects rope-scaling and rope-theta when appropriate. * Default rope-scaling values updated for EAGLE variants. * Model instances now expose export rope-scaling for downstream use. * **Tests** * Added unit tests covering rope-scaling export behavior and configuration validators. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 14b78ae commit 6403389

9 files changed

Lines changed: 289 additions & 17 deletions

File tree

examples/speculative_decoding/main.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
import modelopt.torch.opt as mto
5050
import modelopt.torch.speculative as mtsp
51+
from modelopt.torch.speculative.config import EagleConfig
5152
from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading
5253
from modelopt.torch.utils import print_rank_0
5354

@@ -266,8 +267,11 @@ def train():
266267
}
267268
mtsp.convert(model, [("medusa", config)])
268269
elif training_args.mode == "eagle3":
269-
# eagle_cfg maps directly to EagleConfig fields; eagle_offline is derived here.
270-
eagle_cfg["eagle_offline"] = use_offline_training
270+
# Validate and rewrite eagle config fields
271+
eagle_cfg = EagleConfig.model_validate(
272+
eagle_cfg,
273+
context={"training_args": training_args, "data_args": data_args},
274+
).model_dump()
271275
mtsp.convert(model, [("eagle", eagle_cfg)])
272276

273277
# Load draft vocab cache if the draft model uses a compressed vocabulary

modelopt/torch/export/plugins/hf_spec_export.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,18 @@ def _get_config_from_draft_or_base(key: str, model: nn.Module):
187187
new_value = str(new_value).replace("torch.", "")
188188
template_config[key] = new_value
189189

190+
# Inject export rope scaling override when training rope_type is "default".
191+
rope_cfg = self.model.eagle_config.rope_scaling or {}
192+
training_rope_type = rope_cfg.get("rope_type") or rope_cfg.get("type")
193+
eagle_export_rope_scaling = getattr(self.model, "eagle_export_rope_scaling", None)
194+
if eagle_export_rope_scaling and training_rope_type == "default":
195+
template_config["rope_scaling"] = eagle_export_rope_scaling
196+
197+
# In transformers 5.x, rope_theta is under rope_scaling, not the main config.
198+
# Always source from the training rope config (rope_theta is not in export overrides).
199+
if template_config.get("rope_theta") is None and rope_cfg:
200+
template_config["rope_theta"] = rope_cfg.get("rope_theta")
201+
190202
return template_config
191203

192204
def export(

modelopt/torch/speculative/config.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515

1616
"""Configurations for speculative decoding modes."""
1717

18+
import warnings
1819
from copy import deepcopy
20+
from typing import Any
21+
22+
from pydantic import ValidationInfo, model_validator
1923

2024
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
2125

@@ -120,3 +124,51 @@ class EagleConfig(ModeloptBaseConfig):
120124
default=False,
121125
description="Whether to enable NVTX ranges for profiling eagle forward/loss methods.",
122126
)
127+
128+
eagle_export_rope_scaling: dict = ModeloptField(
129+
default={"rope_type": "yarn", "factor": 32.0, "original_max_position_embeddings": 2048},
130+
description=(
131+
"The rope_scaling config to inject into the exported HuggingFace model config. "
132+
"Applied when the training rope_type is 'default' (no scaling). "
133+
"Set to empty dict {} to disable rope scaling injection at export."
134+
),
135+
)
136+
137+
@model_validator(mode="before")
138+
@classmethod
139+
def _derive_eagle_offline(cls, data: Any, info: ValidationInfo) -> Any:
140+
"""Derive ``eagle_offline`` from ``data_args.offline_data_path`` when provided in context."""
141+
ctx = info.context if info.context else {}
142+
data_args = ctx.get("data_args")
143+
if data_args is not None and isinstance(data, dict):
144+
data["eagle_offline"] = data_args.offline_data_path is not None
145+
return data
146+
147+
@model_validator(mode="after")
148+
def _check_rope_scaling_consistency(self) -> "EagleConfig":
149+
if not self.eagle_export_rope_scaling:
150+
return self
151+
rope_cfg = self.eagle_architecture_config.get("rope_scaling", {}) or {}
152+
rope_type = rope_cfg.get("rope_type") or rope_cfg.get("type")
153+
if rope_type is not None and rope_type != "default":
154+
raise ValueError(
155+
f"eagle_export_rope_scaling is set but eagle_architecture_config has "
156+
f"rope_type='{rope_type}'. Export rope overwrite is only valid when the "
157+
f"training rope_type is 'default' (no scaling)."
158+
)
159+
return self
160+
161+
@model_validator(mode="after")
162+
def _warn_rope_vs_training_seq_len(self, info: ValidationInfo) -> "EagleConfig":
163+
ctx = info.context if info.context else {}
164+
training_args = ctx.get("training_args")
165+
if training_args is None:
166+
return self
167+
orig_max_pos = self.eagle_export_rope_scaling.get("original_max_position_embeddings")
168+
if orig_max_pos is not None and orig_max_pos != training_args.training_seq_len:
169+
warnings.warn(
170+
f"eagle_export_rope_scaling.original_max_position_embeddings ({orig_max_pos}) "
171+
f"differs from training_seq_len ({training_args.training_seq_len}). "
172+
f"This may affect long-context inference quality."
173+
)
174+
return self

modelopt/torch/speculative/eagle/default_config.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,10 @@
1919
"hidden_act": "silu",
2020
"torch_dtype": "bfloat16",
2121
"position_embedding_type": "rope",
22-
"rope_scaling": {
23-
"factor": 8.0,
24-
"low_freq_factor": 1.0,
25-
"high_freq_factor": 4.0,
26-
"original_max_position_embeddings": 8192,
27-
"rope_type": "llama3",
28-
"rope_theta": 500000.0,
29-
},
30-
"rope_theta": 500000.0,
22+
# rope_theta is set both inside rope_scaling and at the top level for cross-version
23+
# compatibility: transformers 5.x reads it from rope_scaling, while 4.x reads it top-level.
24+
"rope_scaling": {"rope_type": "default", "rope_theta": 10000},
25+
"rope_theta": 10000,
3126
"num_hidden_layers": 1,
3227
"intermediate_size": 14336,
3328
"num_attention_heads": 32,
@@ -83,6 +78,8 @@
8378
"qk_nope_head_dim": 128,
8479
"qk_rope_head_dim": 64,
8580
"rms_norm_eps": 0.00001,
81+
# rope_theta is set both inside rope_scaling and at the top level for cross-version
82+
# compatibility: transformers 5.x reads it from rope_scaling, while 4.x reads it top-level.
8683
"rope_scaling": {
8784
"beta_fast": 1.0,
8885
"beta_slow": 1.0,
@@ -91,6 +88,7 @@
9188
"mscale_all_dim": 1.0,
9289
"original_max_position_embeddings": 4096,
9390
"type": "yarn",
91+
"rope_theta": 50000.0,
9492
},
9593
"rope_theta": 50000.0,
9694
"routed_scaling_factor": 2.827,

modelopt/torch/speculative/eagle/eagle_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,4 @@ def modify(
4141
self.eagle_mix_hidden_states = config.eagle_mix_hidden_states
4242
self.eagle_use_torch_compile = config.eagle_use_torch_compile
4343
self.eagle_enable_nvtx = config.eagle_enable_nvtx
44+
self.eagle_export_rope_scaling = config.eagle_export_rope_scaling

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,9 @@ def dict_to_config(
107107
config.position_embedding_type = architecture_config.get("position_embedding_type")
108108
config.rotary_percent = 1.0
109109
config.rotary_base = architecture_config.get("rope_theta")
110-
config.rope_scaling = "rope_scaling" in architecture_config
111-
config.rope_scaling_factor = (
112-
architecture_config.get("rope_scaling").get("factor")
113-
if "rope_scaling" in architecture_config
114-
else None
115-
)
110+
_rope_scaling_dict = architecture_config.get("rope_scaling", {})
111+
config.rope_scaling = isinstance(_rope_scaling_dict, dict) and "factor" in _rope_scaling_dict
112+
config.rope_scaling_factor = _rope_scaling_dict.get("factor") if config.rope_scaling else None
116113

117114
config.draft_vocab_size = architecture_config.get("draft_vocab_size")
118115
config.use_input_layernorm_in_first_layer = architecture_config.get(

modelopt_recipes/general/speculative_decoding/eagle3.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,5 +55,11 @@ eagle:
5555
eagle_reuse_base_decoder: false
5656
eagle_report_acc: true
5757
eagle_enable_nvtx: false
58+
# Rope scaling: disable during training (default_config.py uses rope_type=default),
59+
# inject YaRN during export for long-context inference.
60+
eagle_export_rope_scaling:
61+
rope_type: yarn
62+
factor: 32.0
63+
original_max_position_embeddings: 2048
5864
# overwrite to modelopt/torch/speculative/eagle/default_config.py
5965
eagle_architecture_config: {}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Unit tests for EAGLE export rope scaling logic in hf_spec_export.py."""
17+
18+
from unittest.mock import MagicMock
19+
20+
from modelopt.torch.export.plugins.hf_spec_export import EagleExporter
21+
22+
DEFAULT_ROPE_SCALING = {
23+
"rope_type": "yarn",
24+
"factor": 32.0,
25+
"original_max_position_embeddings": 2048,
26+
}
27+
28+
29+
def _make_exporter(
30+
rope_type="default",
31+
rope_theta=10000,
32+
eagle_export_rope_scaling=None,
33+
):
34+
if eagle_export_rope_scaling is None:
35+
eagle_export_rope_scaling = DEFAULT_ROPE_SCALING
36+
37+
model = MagicMock()
38+
model.eagle_config.eagle_decoder_type = "llama"
39+
model.eagle_config.rope_scaling = {"rope_type": rope_type, "rope_theta": rope_theta}
40+
model.eagle_export_rope_scaling = eagle_export_rope_scaling
41+
model._draft_model_config = None
42+
model.config.rope_scaling = None
43+
model.config.rope_theta = None
44+
45+
exporter = EagleExporter.__new__(EagleExporter)
46+
exporter.model = model
47+
exporter.eagle_decoder_type = "llama"
48+
exporter.num_hidden_layers = 1
49+
return exporter
50+
51+
52+
def test_yarn_rope_injected_with_correct_config():
53+
"""YaRN rope_scaling is injected as-is when training rope_type is 'default'."""
54+
config = _make_exporter(rope_type="default")._export_config()
55+
assert config["rope_scaling"] == DEFAULT_ROPE_SCALING
56+
57+
58+
def test_rope_not_injected_when_non_default_training_rope():
59+
"""rope_scaling is not overridden when training rope_type is not 'default'."""
60+
config = _make_exporter(rope_type="llama3")._export_config()
61+
assert config.get("rope_scaling") is None
62+
63+
64+
def test_rope_not_injected_when_eagle_export_rope_scaling_is_empty():
65+
"""rope_scaling is not injected when eagle_export_rope_scaling is empty dict."""
66+
config = _make_exporter(eagle_export_rope_scaling={})._export_config()
67+
assert config.get("rope_scaling") is None
68+
69+
70+
def test_rope_theta_fallback_from_rope_scaling():
71+
"""rope_theta is populated from rope_scaling when not available as top-level attr."""
72+
config = _make_exporter(rope_type="default", rope_theta=500000)._export_config()
73+
assert config["rope_theta"] == 500000
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for EagleConfig model validators."""
17+
18+
import types
19+
import warnings
20+
21+
import pytest
22+
from pydantic import ValidationError
23+
24+
from modelopt.torch.speculative.config import EagleConfig
25+
26+
# --- rope scaling consistency validator tests ---
27+
28+
29+
def test_rope_consistency_error_non_default_rope_type():
30+
"""Error when eagle_export_rope_scaling is set but training rope_type is not 'default'."""
31+
cfg = {
32+
"eagle_export_rope_scaling": {"rope_type": "yarn", "factor": 32.0},
33+
"eagle_architecture_config": {"rope_scaling": {"rope_type": "llama3"}},
34+
}
35+
with pytest.raises(ValidationError, match="rope_type='llama3'"):
36+
EagleConfig.model_validate(cfg)
37+
38+
39+
def test_rope_consistency_error_non_default_rope_type_alt_key():
40+
"""Error when rope_scaling uses 'type' key instead of 'rope_type' (kimik2-style)."""
41+
cfg = {
42+
"eagle_export_rope_scaling": {"rope_type": "yarn", "factor": 32.0},
43+
"eagle_architecture_config": {"rope_scaling": {"type": "yarn"}},
44+
}
45+
with pytest.raises(ValidationError, match="rope_type='yarn'"):
46+
EagleConfig.model_validate(cfg)
47+
48+
49+
def test_rope_consistency_ok_default_rope_type():
50+
"""No error when training rope_type is 'default'."""
51+
cfg = {
52+
"eagle_export_rope_scaling": {"rope_type": "yarn", "factor": 32.0},
53+
"eagle_architecture_config": {"rope_scaling": {"rope_type": "default"}},
54+
}
55+
EagleConfig.model_validate(cfg)
56+
57+
58+
def test_rope_consistency_ok_no_rope_scaling_in_arch():
59+
"""No error when eagle_architecture_config has no rope_scaling (defaults to 'default')."""
60+
cfg = {
61+
"eagle_export_rope_scaling": {"rope_type": "yarn", "factor": 32.0},
62+
"eagle_architecture_config": {},
63+
}
64+
EagleConfig.model_validate(cfg)
65+
66+
67+
def test_rope_consistency_ok_empty_export_rope():
68+
"""No error when eagle_export_rope_scaling is empty (disabled)."""
69+
cfg = {
70+
"eagle_export_rope_scaling": {},
71+
"eagle_architecture_config": {"rope_scaling": {"rope_type": "llama3"}},
72+
}
73+
EagleConfig.model_validate(cfg)
74+
75+
76+
# --- rope vs training_seq_len warning tests ---
77+
78+
79+
def _make_training_args(training_seq_len: int):
80+
return types.SimpleNamespace(training_seq_len=training_seq_len)
81+
82+
83+
def test_warn_rope_mismatch():
84+
"""Warning should fire when original_max_position_embeddings != training_seq_len."""
85+
cfg = {
86+
"eagle_export_rope_scaling": {
87+
"rope_type": "yarn",
88+
"factor": 32.0,
89+
"original_max_position_embeddings": 2048,
90+
},
91+
}
92+
with pytest.warns(UserWarning, match="differs from training_seq_len"):
93+
EagleConfig.model_validate(cfg, context={"training_args": _make_training_args(4096)})
94+
95+
96+
def test_no_warn_rope_match():
97+
"""No warning when original_max_position_embeddings == training_seq_len."""
98+
cfg = {
99+
"eagle_export_rope_scaling": {
100+
"rope_type": "yarn",
101+
"factor": 32.0,
102+
"original_max_position_embeddings": 2048,
103+
},
104+
}
105+
with warnings.catch_warnings():
106+
warnings.simplefilter("error")
107+
EagleConfig.model_validate(cfg, context={"training_args": _make_training_args(2048)})
108+
109+
110+
def test_no_warn_without_context():
111+
"""No warning when context is not provided (e.g. inside convert chain)."""
112+
with warnings.catch_warnings():
113+
warnings.simplefilter("error")
114+
EagleConfig.model_validate({})
115+
116+
117+
def test_no_warn_missing_orig_max_pos():
118+
"""No warning when original_max_position_embeddings is absent from rope scaling config."""
119+
cfg = {"eagle_export_rope_scaling": {}}
120+
with warnings.catch_warnings():
121+
warnings.simplefilter("error")
122+
EagleConfig.model_validate(cfg, context={"training_args": _make_training_args(4096)})
123+
124+
125+
def test_no_warn_empty_context():
126+
"""No warning when context dict has no training_args key."""
127+
with warnings.catch_warnings():
128+
warnings.simplefilter("error")
129+
EagleConfig.model_validate({}, context={})

0 commit comments

Comments
 (0)