Skip to content

Commit 070f215

Browse files
committed
cleaner code
Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parent e3c9e50 commit 070f215

2 files changed

Lines changed: 45 additions & 26 deletions

File tree

modelopt/recipe/loader.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,10 @@ def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBas
8888
plus a ``quant_cfg`` mapping and an optional ``algorithm`` for PTQ recipes.
8989
"""
9090
raw = _load_raw_config(recipe_file)
91-
assert isinstance(raw, dict), f"Recipe file {recipe_file} must be a YAML mapping."
91+
if not isinstance(raw, dict):
92+
raise ValueError(
93+
f"Recipe file {recipe_file} must be a YAML mapping, got {type(raw).__name__}."
94+
)
9295
data = _resolve_imports(raw)
9396

9497
metadata = data.get("metadata", {})
@@ -121,7 +124,10 @@ def _load_recipe_from_dir(recipe_dir: Path | Traversable) -> ModelOptRecipeBase:
121124
)
122125

123126
recipe_data = _load_raw_config(recipe_file)
124-
assert isinstance(recipe_data, dict), f"Recipe file {recipe_file} must be a YAML mapping."
127+
if not isinstance(recipe_data, dict):
128+
raise ValueError(
129+
f"Recipe file {recipe_file} must be a YAML mapping, got {type(recipe_data).__name__}."
130+
)
125131
metadata = recipe_data.get("metadata", {})
126132
recipe_type = metadata.get("recipe_type")
127133
if recipe_type is None:
@@ -138,14 +144,21 @@ def _load_recipe_from_dir(recipe_dir: Path | Traversable) -> ModelOptRecipeBase:
138144
raise ValueError(
139145
f"Cannot find quantize in {recipe_dir}. Looked for: quantize.yml, quantize.yaml"
140146
)
141-
# Resolve imports: imports are in recipe.yml, quantize data is separate
147+
# Resolve imports from both recipe.yaml and quantize.yaml
142148
quantize_data = _load_raw_config(quantize_file)
143-
assert isinstance(quantize_data, dict), f"{quantize_file} must be a YAML mapping."
149+
if not isinstance(quantize_data, dict):
150+
raise ValueError(
151+
f"{quantize_file} must be a YAML mapping, got {type(quantize_data).__name__}."
152+
)
153+
# Resolve quantize.yaml's own imports first (if any)
154+
if "imports" in quantize_data:
155+
quantize_data = _resolve_imports(quantize_data)
156+
# Then resolve recipe.yaml's imports applied to the quantize data
144157
combined: dict[str, Any] = {"quantize": quantize_data}
145158
imports = recipe_data.get("imports")
146159
if imports:
147160
combined["imports"] = imports
148-
combined = _resolve_imports(combined)
161+
combined = _resolve_imports(combined)
149162
return ModelOptPTQRecipe(
150163
recipe_type=RecipeType.PTQ,
151164
description=metadata.get("description", "PTQ recipe."),

modelopt/torch/opt/config_loader.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,28 @@ def _lookup(ref_name: str, context: str) -> Any:
198198
)
199199
return import_map[ref_name]
200200

201+
def _resolve_dict_value(d: dict[str, Any], key: str) -> None:
202+
"""Resolve ``$import`` in a dict value: ``key: {$import: name, ...inline}``."""
203+
val = d[key]
204+
if not isinstance(val, dict) or _IMPORT_KEY not in val:
205+
return
206+
ref = val.pop(_IMPORT_KEY)
207+
inline_keys = dict(val)
208+
ref_names = ref if isinstance(ref, list) else [ref]
209+
210+
merged: dict[str, Any] = {}
211+
for rname in ref_names:
212+
snippet = _lookup(rname, f"{key} of {d}")
213+
if not isinstance(snippet, dict):
214+
raise ValueError(
215+
f"$import {rname!r} in {key} must resolve to a dict, "
216+
f"got {type(snippet).__name__}."
217+
)
218+
merged.update(snippet)
219+
220+
merged.update(inline_keys)
221+
d[key] = merged
222+
201223
def _resolve_list(entries: list[Any]) -> list[Any]:
202224
"""Resolve $import markers in a list of entries."""
203225
resolved: list[Any] = []
@@ -215,27 +237,11 @@ def _resolve_list(entries: list[Any]) -> list[Any]:
215237
f"list, got {type(imported).__name__}."
216238
)
217239
resolved.extend(imported)
218-
elif (
219-
isinstance(entry, dict)
220-
and isinstance(entry.get("cfg"), dict)
221-
and _IMPORT_KEY in entry["cfg"]
222-
):
223-
ref = entry["cfg"].pop(_IMPORT_KEY)
224-
inline_keys = dict(entry["cfg"])
225-
ref_names = ref if isinstance(ref, list) else [ref]
226-
227-
merged: dict[str, Any] = {}
228-
for rname in ref_names:
229-
snippet = _lookup(rname, f"cfg of {entry}")
230-
if not isinstance(snippet, dict):
231-
raise ValueError(
232-
f"$import {rname!r} in cfg must resolve to a dict, "
233-
f"got {type(snippet).__name__}."
234-
)
235-
merged.update(snippet)
236-
237-
merged.update(inline_keys)
238-
entry["cfg"] = merged
240+
elif isinstance(entry, dict):
241+
# Resolve $import in any dict value within the entry
242+
for key in list(entry):
243+
if isinstance(entry.get(key), dict) and _IMPORT_KEY in entry[key]:
244+
_resolve_dict_value(entry, key)
239245
resolved.append(entry)
240246
else:
241247
resolved.append(entry)

0 commit comments

Comments
 (0)