Skip to content

Commit 1420299

Browse files
committed
use recipe import system
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent ed1d6a9 commit 1420299

10 files changed

Lines changed: 225 additions & 101 deletions

File tree

examples/speculative_decoding/main.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
import modelopt.torch.opt as mto
5151
import modelopt.torch.speculative as mtsp
52+
from modelopt.recipe import load_config
5253
from modelopt.torch.speculative.config import DFlashConfig, EagleConfig
5354
from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading
5455
from modelopt.torch.utils import print_rank_0
@@ -167,10 +168,14 @@ def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dic
167168
eagle_cfg: Eagle section dict (EagleConfig fields), passed directly to mtsp.convert()
168169
dflash_cfg: DFlash section dict (DFlashConfig fields), passed directly to mtsp.convert()
169170
"""
170-
merged = OmegaConf.load(config_path)
171+
# Resolve $import / imports: via modelopt's loader, then layer OmegaConf
172+
# dotlist overrides on top.
173+
cfg = load_config(config_path)
174+
assert isinstance(cfg, dict), f"Top-level recipe must be a YAML mapping: {config_path}"
171175
if overrides:
172-
merged = OmegaConf.merge(merged, OmegaConf.from_dotlist(list(overrides)))
173-
cfg = OmegaConf.to_container(merged, resolve=True)
176+
merged = OmegaConf.merge(OmegaConf.create(cfg), OmegaConf.from_dotlist(list(overrides)))
177+
cfg = OmegaConf.to_container(merged, resolve=True)
178+
assert isinstance(cfg, dict)
174179

175180
# Eagle/DFlash sections map directly to config fields — no field enumeration needed.
176181
eagle_cfg = cfg.get("eagle", {})
@@ -318,8 +323,15 @@ def train():
318323
model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True)
319324
print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.")
320325
elif training_args.mode == "dflash":
326+
# Mask-token resolution: recipe value wins; otherwise fall back to the
327+
# tokenizer's built-in mask_token_id. DFlashConfig still raises if neither
328+
# source provides one.
329+
if dflash_cfg.get("dflash_mask_token_id") is None:
330+
tok_mask_id = getattr(tokenizer, "mask_token_id", None)
331+
if tok_mask_id is not None:
332+
dflash_cfg["dflash_mask_token_id"] = tok_mask_id
321333
dflash_cfg = DFlashConfig.model_validate(
322-
dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}
334+
dflash_cfg, context={"data_args": data_args}
323335
).model_dump()
324336
mtsp.convert(model, [("dflash", dflash_cfg)])
325337
else:

modelopt/torch/speculative/config.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@
2525

2626
from .eagle.default_config import default_eagle_config, default_kimik2_eagle_config
2727

28+
# Permissive schema for `model:` / `data:` / `training:` recipe snippets used
29+
# via $import in modelopt_recipes/configs/speculative_decoding/. Real field
30+
# validation happens downstream in transformers.HfArgumentParser.parse_dict()
31+
# (examples/speculative_decoding/main.py); this alias exists so snippets can
32+
# satisfy load_config()'s requirement that modelopt-schema paths resolve under
33+
# the modelopt.* namespace.
34+
SpeculativeArgsSnippet = dict
35+
2836
kimik2_eagle_default_config = deepcopy(default_kimik2_eagle_config)
2937

3038
eagle3_default_config = deepcopy(default_eagle_config)
@@ -132,18 +140,6 @@ def _derive_dflash_offline(cls, data: Any, info: ValidationInfo) -> Any:
132140
data["dflash_offline"] = getattr(data_args, "offline_data_path", None) is not None
133141
return data
134142

135-
@model_validator(mode="before")
136-
@classmethod
137-
def _resolve_mask_token_id(cls, data: Any, info: ValidationInfo) -> Any:
138-
"""Auto-detect ``dflash_mask_token_id`` from tokenizer when provided in context."""
139-
if not isinstance(data, dict) or data.get("dflash_mask_token_id") is not None:
140-
return data
141-
ctx = info.context if info.context else {}
142-
tokenizer = ctx.get("tokenizer")
143-
if tokenizer is not None and getattr(tokenizer, "mask_token_id", None) is not None:
144-
data["dflash_mask_token_id"] = tokenizer.mask_token_id
145-
return data
146-
147143
@model_validator(mode="after")
148144
def _check_mask_token_id(self) -> "DFlashConfig":
149145
"""Validate that mask_token_id is set after all resolution attempts."""
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Default DFlashConfig values for DFlash training. Imported into the `dflash:`
5+
# section of recipes. ``dflash_mask_token_id`` is intentionally omitted so the
6+
# snippet schema is the permissive ``SpeculativeArgsSnippet`` (DFlashConfig's
7+
# after-validator would otherwise raise during snippet load); per-model recipes
8+
# should provide ``dflash_mask_token_id`` explicitly, and main.py falls back to
9+
# ``tokenizer.mask_token_id`` when neither does.
10+
11+
# modelopt-schema: modelopt.torch.speculative.config.SpeculativeArgsSnippet
12+
dflash_block_size: 8
13+
dflash_num_anchors: 512
14+
dflash_use_torch_compile: false
15+
dflash_self_logit_distillation: true
16+
dflash_loss_decay_factor: 4.0
17+
dflash_architecture_config:
18+
num_hidden_layers: 5
19+
# mask_token_id: auto-detected from model vocab (override for specific models)
20+
# sliding_window and layer_types are inherited from base model config automatically
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Default `training:` section values for DFlash training. Imported into the
5+
# `training:` section of recipes. Real field validation is performed by
6+
# transformers.HfArgumentParser.parse_dict() in main.py.
7+
8+
# modelopt-schema: modelopt.torch.speculative.config.SpeculativeArgsSnippet
9+
10+
# --- commonly modified ---
11+
mode: dflash
12+
output_dir:
13+
num_train_epochs: 10
14+
per_device_train_batch_size: 1
15+
learning_rate: 6.0e-4
16+
warmup_steps: 100
17+
training_seq_len: 4096
18+
logging_steps: 100
19+
save_steps: 5000
20+
cp_size: 1
21+
dp_shard_size: 1
22+
disable_tqdm: true
23+
estimate_ar: false
24+
ar_validate_steps: 0
25+
answer_only_loss: true
26+
27+
# --- rarely modified ---
28+
do_eval: false
29+
lr_scheduler_type: linear
30+
save_strategy: steps
31+
weight_decay: 0.0
32+
dataloader_drop_last: true
33+
bf16: true
34+
tf32: true
35+
remove_unused_columns: false
36+
ddp_find_unused_parameters: true
37+
ddp_timeout: 1800
38+
report_to: tensorboard
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Default EagleConfig values for EAGLE3 training. Imported into the `eagle:` section of recipes.
17+
# eagle_offline is derived from data.offline_data_path; do not set here.
18+
19+
# modelopt-schema: modelopt.torch.speculative.config.EagleConfig
20+
eagle_decoder_type: llama
21+
eagle_ttt_steps: 3
22+
eagle_mix_hidden_states: false
23+
eagle_use_torch_compile: true
24+
eagle_self_logit_distillation: true
25+
eagle_freeze_base_model: true
26+
eagle_loss_decay_factor: 0.9
27+
eagle_hidden_state_distillation: false
28+
eagle_reuse_base_decoder: false
29+
eagle_report_acc: true
30+
eagle_enable_nvtx: false
31+
# Rope scaling: disable during training (default_config.py uses rope_type=default),
32+
# inject YaRN during export for long-context inference.
33+
eagle_export_rope_scaling:
34+
rope_type: yarn
35+
factor: 32.0
36+
original_max_position_embeddings: 2048
37+
# overwrite to modelopt/torch/speculative/eagle/default_config.py
38+
eagle_architecture_config: {}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Default `training:` section values for EAGLE3 training. Imported into the
5+
# `training:` section of recipes. Real field validation is performed by
6+
# transformers.HfArgumentParser.parse_dict() in main.py.
7+
8+
# modelopt-schema: modelopt.torch.speculative.config.SpeculativeArgsSnippet
9+
10+
# --- commonly modified ---
11+
mode: eagle3
12+
output_dir:
13+
num_train_epochs: 1
14+
per_device_train_batch_size: 1
15+
learning_rate: 1.0e-4
16+
warmup_steps: 1000
17+
training_seq_len: 2048
18+
logging_steps: 100
19+
save_steps: 8192
20+
cp_size: 1
21+
disable_tqdm: false
22+
estimate_ar: false
23+
ar_validate_steps: -1
24+
answer_only_loss: false
25+
26+
# --- rarely modified ---
27+
do_eval: false
28+
lr_scheduler_type: linear
29+
save_strategy: steps
30+
weight_decay: 0.0
31+
dataloader_drop_last: true
32+
bf16: true
33+
tf32: true
34+
remove_unused_columns: false
Lines changed: 11 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
# Base config for DFlash training. Override fields via OmegaConf dotlist on the CLI.
1+
# Base config for DFlash training. Override fields via OmegaConf dotlist on the CLI
2+
# or by importing this file from a per-model recipe in modelopt_recipes/models/.
3+
4+
imports:
5+
dflash_default: configs/speculative_decoding/dflash/default
6+
dflash_training_default: configs/speculative_decoding/dflash/training_default
27

38
# maps to ModelArguments (main.py)
49
model:
@@ -17,44 +22,11 @@ data:
1722

1823
# maps to TrainingArguments (main.py)
1924
training:
20-
# --- commonly modified ---
21-
mode: dflash
22-
output_dir:
23-
num_train_epochs: 10
24-
per_device_train_batch_size: 1
25-
learning_rate: 6.0e-4
26-
warmup_steps: 100
27-
training_seq_len: 4096
28-
logging_steps: 100
29-
save_steps: 5000
30-
cp_size: 1
31-
dp_shard_size: 1
32-
disable_tqdm: true
33-
estimate_ar: false
34-
ar_validate_steps: 0
35-
answer_only_loss: true
36-
37-
# --- rarely modified ---
38-
do_eval: false
39-
lr_scheduler_type: linear
40-
save_strategy: steps
41-
weight_decay: 0.0
42-
dataloader_drop_last: true
43-
bf16: true
44-
tf32: true
45-
remove_unused_columns: false
46-
ddp_find_unused_parameters: true
47-
ddp_timeout: 1800
48-
report_to: tensorboard
25+
$import: dflash_training_default
4926

5027
# maps to DFlashConfig (modelopt/torch/speculative/config.py).
28+
# Per-model recipes should also set ``dflash_mask_token_id``; otherwise main.py
29+
# falls back to ``tokenizer.mask_token_id``, and DFlashConfig raises if neither
30+
# source provides one.
5131
dflash:
52-
dflash_block_size: 8
53-
dflash_num_anchors: 512
54-
dflash_use_torch_compile: false
55-
dflash_self_logit_distillation: true
56-
dflash_loss_decay_factor: 4.0
57-
dflash_architecture_config:
58-
num_hidden_layers: 5
59-
# mask_token_id: auto-detected from model vocab (override for specific models)
60-
# sliding_window and layer_types are inherited from base model config automatically
32+
$import: dflash_default
Lines changed: 8 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
# Base config for EAGLE3 training. Override fields via OmegaConf dotlist on the CLI.
1+
# Base config for EAGLE3 training. Override fields via OmegaConf dotlist on the CLI
2+
# or by importing this file from a per-model recipe in modelopt_recipes/models/.
3+
4+
imports:
5+
eagle_default: configs/speculative_decoding/eagle/default
6+
eagle_training_default: configs/speculative_decoding/eagle/training_default
27

38
# maps to ModelArguments (main.py)
49
model:
@@ -16,51 +21,8 @@ data:
1621

1722
# maps to TrainingArguments (main.py)
1823
training:
19-
# --- commonly modified ---
20-
mode: eagle3
21-
output_dir:
22-
num_train_epochs: 1
23-
per_device_train_batch_size: 1
24-
learning_rate: 1.0e-4
25-
warmup_steps: 1000
26-
training_seq_len: 2048
27-
logging_steps: 100
28-
save_steps: 8192
29-
cp_size: 1
30-
disable_tqdm: false
31-
estimate_ar: false
32-
ar_validate_steps: -1
33-
answer_only_loss: false
34-
35-
# --- rarely modified ---
36-
do_eval: false
37-
lr_scheduler_type: linear
38-
save_strategy: steps
39-
weight_decay: 0.0
40-
dataloader_drop_last: true
41-
bf16: true
42-
tf32: true
43-
remove_unused_columns: false
24+
$import: eagle_training_default
4425

4526
# maps to EagleConfig (modelopt/torch/speculative/config.py).
4627
eagle:
47-
# eagle_offline is derived from data.offline_data_path; do not set here.
48-
eagle_decoder_type: llama
49-
eagle_ttt_steps: 3
50-
eagle_mix_hidden_states: false
51-
eagle_use_torch_compile: true
52-
eagle_self_logit_distillation: true
53-
eagle_freeze_base_model: true
54-
eagle_loss_decay_factor: 0.9
55-
eagle_hidden_state_distillation: false
56-
eagle_reuse_base_decoder: false
57-
eagle_report_acc: true
58-
eagle_enable_nvtx: false
59-
# Rope scaling: disable during training (default_config.py uses rope_type=default),
60-
# inject YaRN during export for long-context inference.
61-
eagle_export_rope_scaling:
62-
rope_type: yarn
63-
factor: 32.0
64-
original_max_position_embeddings: 2048
65-
# overwrite to modelopt/torch/speculative/eagle/default_config.py
66-
eagle_architecture_config: {}
28+
$import: eagle_default
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Per-model DFlash offline training recipe for Kimi-K2.5.
5+
6+
imports:
7+
dflash_default: configs/speculative_decoding/dflash/default
8+
dflash_training_default: configs/speculative_decoding/dflash/training_default
9+
10+
model:
11+
model_name_or_path: moonshotai/Kimi-K2.5
12+
trust_remote_code: true
13+
use_fake_base_for_offline: true
14+
15+
data:
16+
offline_data_path: <path to offline data>
17+
18+
training:
19+
$import: dflash_training_default
20+
output_dir: ckpts/kimi-k25-dflash
21+
22+
dflash:
23+
$import: dflash_default
24+
# If unset, main.py falls back to tokenizer.mask_token_id; DFlashConfig
25+
# raises if neither this field nor the tokenizer provides one.
26+
# dflash_mask_token_id:
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Per-model EAGLE3 offline training recipe for Kimi-K2.5.
5+
# Mirrors examples/speculative_decoding/scripts/train_kimi_k25_offline.sh.
6+
7+
imports:
8+
eagle_default: configs/speculative_decoding/eagle/default
9+
eagle_training_default: configs/speculative_decoding/eagle/training_default
10+
11+
model:
12+
model_name_or_path: moonshotai/Kimi-K2.5
13+
trust_remote_code: true
14+
use_fake_base_for_offline: true
15+
16+
data:
17+
offline_data_path: <path to offline data>
18+
19+
training:
20+
$import: eagle_training_default
21+
output_dir: ckpts/kimi-k25-eagle3
22+
training_seq_len: 4096
23+
24+
eagle:
25+
$import: eagle_default
26+
eagle_decoder_type: kimik2

0 commit comments

Comments
 (0)