Skip to content

Commit 1127f32

Browse files
committed
A new test
Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parent 070f215 commit 1127f32

1 file changed

Lines changed: 47 additions & 0 deletions

File tree

tests/unit/recipe/test_loader.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,53 @@ def test_import_cfg_inline_overrides_import(tmp_path):
557557
assert cfg["axis"] is None
558558

559559

560+
def test_import_in_non_cfg_dict_value(tmp_path):
561+
"""$import resolves in any dict value, not just cfg."""
562+
(tmp_path / "bias_cfg.yml").write_text("enable: true\ntype: static\naxis: -1\n")
563+
recipe_file = tmp_path / "recipe.yml"
564+
recipe_file.write_text(
565+
f"imports:\n"
566+
f" bias_cfg: {tmp_path / 'bias_cfg.yml'}\n"
567+
f"metadata:\n"
568+
f" recipe_type: ptq\n"
569+
f"quantize:\n"
570+
f" algorithm: max\n"
571+
f" quant_cfg:\n"
572+
f" - quantizer_name: '*weight_quantizer'\n"
573+
f" bias:\n"
574+
f" $import: bias_cfg\n"
575+
)
576+
recipe = load_recipe(recipe_file)
577+
entry = recipe.quantize["quant_cfg"][0]
578+
assert entry["bias"] == {"enable": True, "type": "static", "axis": -1}
579+
580+
581+
def test_import_in_multiple_dict_values(tmp_path):
582+
"""$import resolves independently in multiple dict values of the same entry."""
583+
(tmp_path / "fp8.yml").write_text("num_bits: e4m3\n")
584+
(tmp_path / "bias_cfg.yml").write_text("enable: true\ntype: dynamic\n")
585+
recipe_file = tmp_path / "recipe.yml"
586+
recipe_file.write_text(
587+
f"imports:\n"
588+
f" fp8: {tmp_path / 'fp8.yml'}\n"
589+
f" bias_cfg: {tmp_path / 'bias_cfg.yml'}\n"
590+
f"metadata:\n"
591+
f" recipe_type: ptq\n"
592+
f"quantize:\n"
593+
f" algorithm: max\n"
594+
f" quant_cfg:\n"
595+
f" - quantizer_name: '*weight_quantizer'\n"
596+
f" cfg:\n"
597+
f" $import: fp8\n"
598+
f" bias:\n"
599+
f" $import: bias_cfg\n"
600+
)
601+
recipe = load_recipe(recipe_file)
602+
entry = recipe.quantize["quant_cfg"][0]
603+
assert entry["cfg"] == {"num_bits": (4, 3)}
604+
assert entry["bias"] == {"enable": True, "type": "dynamic"}
605+
606+
560607
def test_import_cfg_multi_import(tmp_path):
561608
"""$import with a list of names merges non-overlapping snippets."""
562609
(tmp_path / "bits.yml").write_text("num_bits: e4m3\n")

0 commit comments

Comments
 (0)