Skip to content

Commit 73871da

Browse files
committed
support import for recipe snippets
Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parent 613f3d1 commit 73871da

9 files changed

Lines changed: 149 additions & 68 deletions

File tree

modelopt/recipe/_config_loader.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,39 @@ def load_config(config_file: str | Path | Traversable) -> dict[str, Any] | list[
103103
f"Cannot find config file of {config_file}, paths checked: {paths_to_check}"
104104
)
105105

106-
_raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
107-
if _raw is None:
106+
text = config_path.read_text(encoding="utf-8")
107+
docs = list(yaml.safe_load_all(text))
108+
109+
if len(docs) == 0 or docs[0] is None:
108110
return {}
111+
if len(docs) == 1:
112+
_raw = docs[0]
113+
elif len(docs) == 2:
114+
# Multi-document: first doc is imports/metadata, second is content.
115+
# Merge the imports into the content for downstream resolution.
116+
header, content = docs[0], docs[1]
117+
if not isinstance(header, dict):
118+
raise ValueError(
119+
f"Config file {config_path}: first YAML document must be a mapping, "
120+
f"got {type(header).__name__}"
121+
)
122+
if content is None:
123+
content = {}
124+
if isinstance(content, dict):
125+
_raw = {**header, **content}
126+
elif isinstance(content, list):
127+
# List content with a header dict — attach imports via wrapper
128+
_raw = {**header, "_list_content": content}
129+
else:
130+
raise ValueError(
131+
f"Config file {config_path}: second YAML document must be a mapping or list, "
132+
f"got {type(content).__name__}"
133+
)
134+
else:
135+
raise ValueError(
136+
f"Config file {config_path}: expected 1 or 2 YAML documents, got {len(docs)}"
137+
)
138+
109139
if not isinstance(_raw, (dict, list)):
110140
raise ValueError(
111141
f"Config file {config_path} must contain a YAML mapping or list, got {type(_raw).__name__}"

modelopt/recipe/loader.py

Lines changed: 57 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def _resolve_imports(
7979
snippet = load_config(config_path)
8080
if isinstance(snippet, dict) and "imports" in snippet:
8181
snippet = _resolve_imports(snippet, _loading | {config_path})
82+
# Unwrap _list_content (multi-document YAML: imports + list content)
83+
if isinstance(snippet, dict) and "_list_content" in snippet:
84+
snippet = snippet["_list_content"]
8285
import_map[name] = snippet
8386

8487
def _lookup(ref_name: str, context: str) -> Any:
@@ -89,58 +92,65 @@ def _lookup(ref_name: str, context: str) -> Any:
8992
)
9093
return import_map[ref_name]
9194

95+
def _resolve_list(entries: list[Any]) -> list[Any]:
96+
"""Resolve $import markers in a list of quant_cfg-style entries."""
97+
resolved: list[Any] = []
98+
for entry in entries:
99+
if isinstance(entry, dict) and _IMPORT_KEY in entry:
100+
# {$import: name} → splice imported list
101+
if len(entry) > 1:
102+
raise ValueError(
103+
f"$import must be the only key in the dict, got extra keys: "
104+
f"{sorted(k for k in entry if k != _IMPORT_KEY)}"
105+
)
106+
imported = _lookup(entry[_IMPORT_KEY], "list entry")
107+
if not isinstance(imported, list):
108+
raise ValueError(
109+
f"$import {entry[_IMPORT_KEY]!r} in list must resolve to a "
110+
f"list, got {type(imported).__name__}."
111+
)
112+
resolved.extend(imported)
113+
elif (
114+
isinstance(entry, dict)
115+
and isinstance(entry.get("cfg"), dict)
116+
and _IMPORT_KEY in entry["cfg"]
117+
):
118+
# cfg: {$import: name_or_list, ...inline} → import then override
119+
#
120+
# Precedence (lowest → highest):
121+
# 1. Imports in list order (later imports override earlier)
122+
# 2. Inline keys (override all imports)
123+
ref = entry["cfg"].pop(_IMPORT_KEY)
124+
inline_keys = dict(entry["cfg"])
125+
ref_names = ref if isinstance(ref, list) else [ref]
126+
127+
merged: dict[str, Any] = {}
128+
for name in ref_names:
129+
snippet = _lookup(name, f"cfg of {entry}")
130+
if not isinstance(snippet, dict):
131+
raise ValueError(
132+
f"$import {name!r} in cfg must resolve to a dict, "
133+
f"got {type(snippet).__name__}."
134+
)
135+
merged.update(snippet)
136+
137+
merged.update(inline_keys)
138+
entry["cfg"] = merged
139+
resolved.append(entry)
140+
else:
141+
resolved.append(entry)
142+
return resolved
143+
92144
# Resolve $import references in quant_cfg entries
93145
quantize = data.get("quantize")
94146
if isinstance(quantize, dict):
95147
quant_cfg = quantize.get("quant_cfg")
96148
if isinstance(quant_cfg, list):
97-
resolved_cfg: list[Any] = []
98-
for entry in quant_cfg:
99-
if isinstance(entry, dict) and _IMPORT_KEY in entry:
100-
# {$import: name} → splice imported list into quant_cfg
101-
if len(entry) > 1:
102-
raise ValueError(
103-
f"$import must be the only key in the dict, got extra keys: "
104-
f"{sorted(k for k in entry if k != _IMPORT_KEY)}"
105-
)
106-
imported = _lookup(entry[_IMPORT_KEY], "quant_cfg entry")
107-
if not isinstance(imported, list):
108-
raise ValueError(
109-
f"$import {entry[_IMPORT_KEY]!r} in quant_cfg must resolve to a "
110-
f"list, got {type(imported).__name__}. Config snippets used as "
111-
f"quant_cfg entries must be YAML lists."
112-
)
113-
resolved_cfg.extend(imported)
114-
elif (
115-
isinstance(entry, dict)
116-
and isinstance(entry.get("cfg"), dict)
117-
and _IMPORT_KEY in entry["cfg"]
118-
):
119-
# cfg: {$import: name_or_list, ...inline} → import then override
120-
#
121-
# Precedence (lowest → highest):
122-
# 1. Imports in list order (later imports override earlier)
123-
# 2. Inline keys (override all imports)
124-
ref = entry["cfg"].pop(_IMPORT_KEY)
125-
inline_keys = dict(entry["cfg"]) # remaining inline keys
126-
ref_names = ref if isinstance(ref, list) else [ref]
127-
128-
merged: dict[str, Any] = {}
129-
for name in ref_names:
130-
snippet = _lookup(name, f"cfg of {entry}")
131-
if not isinstance(snippet, dict):
132-
raise ValueError(
133-
f"$import {name!r} in cfg must resolve to a dict, "
134-
f"got {type(snippet).__name__}."
135-
)
136-
merged.update(snippet)
137-
138-
merged.update(inline_keys)
139-
entry["cfg"] = merged
140-
resolved_cfg.append(entry)
141-
else:
142-
resolved_cfg.append(entry)
143-
quantize["quant_cfg"] = resolved_cfg
149+
quantize["quant_cfg"] = _resolve_list(quant_cfg)
150+
151+
# Resolve $import references in _list_content (multi-document snippets)
152+
if "_list_content" in data:
153+
data["_list_content"] = _resolve_list(data["_list_content"])
144154

145155
return data
146156

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# FP8 E4M3 KV cache quantization.
2+
imports:
3+
fp8: configs/numerics/fp8
4+
---
5+
- quantizer_name: '*[kv]_bmm_quantizer'
6+
cfg:
7+
$import: fp8

modelopt_recipes/general/ptq/fp8_default-fp8_kv.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ imports:
1717
base_disable_all: configs/ptq/base_disable_all
1818
default_disabled: configs/ptq/default_disabled_quantizers
1919
fp8: configs/numerics/fp8
20+
fp8_kv: configs/ptq/fp8_kv
2021

2122
metadata:
2223
recipe_type: ptq
@@ -31,7 +32,5 @@ quantize:
3132
- quantizer_name: '*weight_quantizer'
3233
cfg:
3334
$import: fp8
34-
- quantizer_name: '*[kv]_bmm_quantizer'
35-
cfg:
36-
$import: fp8
35+
- $import: fp8_kv
3736
- $import: default_disabled

modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ imports:
1717
base_disable_all: configs/ptq/base_disable_all
1818
default_disabled: configs/ptq/default_disabled_quantizers
1919
nvfp4: configs/numerics/nvfp4_dynamic
20-
fp8: configs/numerics/fp8
20+
fp8_kv: configs/ptq/fp8_kv
2121

2222
metadata:
2323
recipe_type: ptq
@@ -32,7 +32,5 @@ quantize:
3232
- quantizer_name: '*input_quantizer'
3333
cfg:
3434
$import: nvfp4
35-
- quantizer_name: '*[kv]_bmm_quantizer'
36-
cfg:
37-
$import: fp8
35+
- $import: fp8_kv
3836
- $import: default_disabled

modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ imports:
1717
base_disable_all: configs/ptq/base_disable_all
1818
default_disabled: configs/ptq/default_disabled_quantizers
1919
nvfp4: configs/numerics/nvfp4_dynamic
20-
fp8: configs/numerics/fp8
20+
fp8_kv: configs/ptq/fp8_kv
2121

2222
metadata:
2323
recipe_type: ptq
@@ -38,7 +38,5 @@ quantize:
3838
- quantizer_name: '*block_sparse_moe*input_quantizer'
3939
cfg:
4040
$import: nvfp4
41-
- quantizer_name: '*[kv]_bmm_quantizer'
42-
cfg:
43-
$import: fp8
41+
- $import: fp8_kv
4442
- $import: default_disabled

modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ imports:
1717
base_disable_all: configs/ptq/base_disable_all
1818
default_disabled: configs/ptq/default_disabled_quantizers
1919
nvfp4: configs/numerics/nvfp4_dynamic
20-
fp8: configs/numerics/fp8
20+
fp8_kv: configs/ptq/fp8_kv
2121

2222
metadata:
2323
recipe_type: ptq
@@ -38,7 +38,5 @@ quantize:
3838
- quantizer_name: '*block_sparse_moe*input_quantizer'
3939
cfg:
4040
$import: nvfp4
41-
- quantizer_name: '*[kv]_bmm_quantizer'
42-
cfg:
43-
$import: fp8
41+
- $import: fp8_kv
4442
- $import: default_disabled

modelopt_recipes/general/ptq/nvfp4_omlp_only-fp8_kv.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ imports:
1717
base_disable_all: configs/ptq/base_disable_all
1818
default_disabled: configs/ptq/default_disabled_quantizers
1919
nvfp4: configs/numerics/nvfp4_dynamic
20-
fp8: configs/numerics/fp8
20+
fp8_kv: configs/ptq/fp8_kv
2121

2222
metadata:
2323
recipe_type: ptq
@@ -44,7 +44,5 @@ quantize:
4444
- quantizer_name: '*o_proj*input_quantizer'
4545
cfg:
4646
$import: nvfp4
47-
- quantizer_name: '*[kv]_bmm_quantizer'
48-
cfg:
49-
$import: fp8
47+
- $import: fp8_kv
5048
- $import: default_disabled

tests/unit/recipe/test_loader.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,49 @@ def test_import_dir_format(tmp_path):
650650
assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3), "axis": None}
651651

652652

653+
# ---------------------------------------------------------------------------
654+
# imports — multi-document snippets
655+
# ---------------------------------------------------------------------------
656+
657+
658+
def test_import_multi_document_list_snippet(tmp_path):
659+
"""List snippet using multi-document YAML (imports --- content) resolves $import."""
660+
(tmp_path / "fp8.yml").write_text("num_bits: e4m3\n")
661+
(tmp_path / "kv.yaml").write_text(
662+
f"imports:\n"
663+
f" fp8: {tmp_path / 'fp8.yml'}\n"
664+
f"---\n"
665+
f"- quantizer_name: '*[kv]_bmm_quantizer'\n"
666+
f" cfg:\n"
667+
f" $import: fp8\n"
668+
)
669+
recipe_file = tmp_path / "recipe.yml"
670+
recipe_file.write_text(
671+
f"imports:\n"
672+
f" kv: {tmp_path / 'kv.yaml'}\n"
673+
f"metadata:\n"
674+
f" recipe_type: ptq\n"
675+
f"quantize:\n"
676+
f" algorithm: max\n"
677+
f" quant_cfg:\n"
678+
f" - $import: kv\n"
679+
)
680+
recipe = load_recipe(recipe_file)
681+
assert len(recipe.quantize["quant_cfg"]) == 1
682+
assert recipe.quantize["quant_cfg"][0]["quantizer_name"] == "*[kv]_bmm_quantizer"
683+
assert recipe.quantize["quant_cfg"][0]["cfg"] == {"num_bits": (4, 3)}
684+
685+
686+
def test_import_builtin_fp8_kv_snippet():
687+
"""Built-in fp8_kv snippet uses multi-document format and resolves correctly."""
688+
recipe = load_recipe("general/ptq/fp8_default-fp8_kv")
689+
kv_entries = [
690+
e for e in recipe.quantize["quant_cfg"] if e.get("quantizer_name") == "*[kv]_bmm_quantizer"
691+
]
692+
assert len(kv_entries) == 1
693+
assert kv_entries[0]["cfg"]["num_bits"] == (4, 3)
694+
695+
653696
# ---------------------------------------------------------------------------
654697
# imports — recursive resolution and cycle detection
655698
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)