|
49 | 49 |
|
50 | 50 | import modelopt.torch.opt as mto |
51 | 51 | import modelopt.torch.speculative as mtsp |
| 52 | +from modelopt.recipe import load_config |
52 | 53 | from modelopt.torch.speculative.config import DFlashConfig, EagleConfig |
53 | 54 | from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading |
54 | 55 | from modelopt.torch.utils import print_rank_0 |
@@ -167,10 +168,14 @@ def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dic |
167 | 168 | eagle_cfg: Eagle section dict (EagleConfig fields), passed directly to mtsp.convert() |
168 | 169 | dflash_cfg: DFlash section dict (DFlashConfig fields), passed directly to mtsp.convert() |
169 | 170 | """ |
170 | | - merged = OmegaConf.load(config_path) |
| 171 | + # Resolve $import / imports: via modelopt's loader, then layer OmegaConf |
| 172 | + # dotlist overrides on top. |
| 173 | + cfg = load_config(config_path) |
| 174 | + assert isinstance(cfg, dict), f"Top-level recipe must be a YAML mapping: {config_path}" |
171 | 175 | if overrides: |
172 | | - merged = OmegaConf.merge(merged, OmegaConf.from_dotlist(list(overrides))) |
173 | | - cfg = OmegaConf.to_container(merged, resolve=True) |
| 176 | + merged = OmegaConf.merge(OmegaConf.create(cfg), OmegaConf.from_dotlist(list(overrides))) |
| 177 | + cfg = OmegaConf.to_container(merged, resolve=True) |
| 178 | + assert isinstance(cfg, dict) |
174 | 179 |
|
175 | 180 | # Eagle/DFlash sections map directly to config fields — no field enumeration needed. |
176 | 181 | eagle_cfg = cfg.get("eagle", {}) |
@@ -318,8 +323,15 @@ def train(): |
318 | 323 | model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True) |
319 | 324 | print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.") |
320 | 325 | elif training_args.mode == "dflash": |
| 326 | + # Mask-token resolution: recipe value wins; otherwise fall back to the |
| 327 | + # tokenizer's built-in mask_token_id. DFlashConfig still raises if neither |
| 328 | + # source provides one. |
| 329 | + if dflash_cfg.get("dflash_mask_token_id") is None: |
| 330 | + tok_mask_id = getattr(tokenizer, "mask_token_id", None) |
| 331 | + if tok_mask_id is not None: |
| 332 | + dflash_cfg["dflash_mask_token_id"] = tok_mask_id |
321 | 333 | dflash_cfg = DFlashConfig.model_validate( |
322 | | - dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args} |
| 334 | + dflash_cfg, context={"data_args": data_args} |
323 | 335 | ).model_dump() |
324 | 336 | mtsp.convert(model, [("dflash", dflash_cfg)]) |
325 | 337 | else: |
|
0 commit comments