Skip to content

Commit da41d79

Browse files
shengliangxuEdwardf0t1
authored andcommitted
use typed quantize config instead of a raw dict (#1249)
### What does this PR do? But fix: Use typed QuantizeConfig instead using raw dict for formal typed ModelOpt configs. The dict typing was accidental. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Quantization recipe configuration is now implemented with a strongly-typed, structured schema that enforces type safety and provides enhanced validation with comprehensive error detection capabilities. * **Tests** * Updated recipe loading tests to correctly validate quantization configurations when recipes are loaded from directories, fully supporting the new structured object-based configuration format. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Shengliang Xu <shengliangx@nvidia.com>
1 parent a3be686 commit da41d79

3 files changed

Lines changed: 6 additions & 5 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1059,7 +1059,7 @@ def quantize_main(
10591059
assert isinstance(recipe, ModelOptPTQRecipe), (
10601060
f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}"
10611061
)
1062-
quant_cfg = recipe.quantize
1062+
quant_cfg = recipe.quantize.model_dump()
10631063

10641064
else:
10651065
assert len(args.qformat.split(",")) == 1, (

modelopt/recipe/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
from __future__ import annotations
1919

2020
from enum import Enum
21-
from typing import Any
2221

2322
from pydantic import field_validator
2423

2524
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
25+
from modelopt.torch.quantization.config import QuantizeConfig
2626

2727

2828
class RecipeType(str, Enum):
@@ -66,8 +66,8 @@ def validate_recipe_type(cls, v):
6666
class ModelOptPTQRecipe(ModelOptRecipeBase):
6767
"""Our config class for PTQ recipes."""
6868

69-
quantize: dict[str, Any] = ModeloptField(
70-
default={},
69+
quantize: QuantizeConfig = ModeloptField(
70+
default=QuantizeConfig(),
7171
title="PTQ config",
7272
description="PTQ config containing quant_cfg and algorithm.",
7373
validate_default=True,

tests/unit/recipe/test_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ def test_load_recipe_dir(tmp_path):
170170
recipe = load_recipe(tmp_path)
171171
assert recipe.recipe_type == RecipeType.PTQ
172172
assert recipe.description == "Dir test."
173-
assert recipe.quantize == {"algorithm": "max", "quant_cfg": []}
173+
assert recipe.quantize.algorithm == "max"
174+
assert recipe.quantize.quant_cfg == []
174175

175176

176177
def test_load_recipe_dir_missing_recipe_raises(tmp_path):

0 commit comments

Comments
 (0)