diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index f19e82c5d4..327605406c 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -1059,7 +1059,7 @@ def quantize_main( assert isinstance(recipe, ModelOptPTQRecipe), ( f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}" ) - quant_cfg = recipe.quantize + quant_cfg = recipe.quantize.model_dump() else: assert len(args.qformat.split(",")) == 1, ( diff --git a/modelopt/recipe/config.py b/modelopt/recipe/config.py index 22a17e6452..cc9276c0ff 100644 --- a/modelopt/recipe/config.py +++ b/modelopt/recipe/config.py @@ -18,11 +18,11 @@ from __future__ import annotations from enum import Enum -from typing import Any from pydantic import field_validator from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField +from modelopt.torch.quantization.config import QuantizeConfig class RecipeType(str, Enum): @@ -66,8 +66,8 @@ def validate_recipe_type(cls, v): class ModelOptPTQRecipe(ModelOptRecipeBase): """Our config class for PTQ recipes.""" - quantize: dict[str, Any] = ModeloptField( - default={}, + quantize: QuantizeConfig = ModeloptField( + default=QuantizeConfig(), title="PTQ config", description="PTQ config containing quant_cfg and algorithm.", validate_default=True, diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index 58fd3ddabe..d1277d40c0 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -170,7 +170,8 @@ def test_load_recipe_dir(tmp_path): recipe = load_recipe(tmp_path) assert recipe.recipe_type == RecipeType.PTQ assert recipe.description == "Dir test." - assert recipe.quantize == {"algorithm": "max", "quant_cfg": []} + assert recipe.quantize.algorithm == "max" + assert recipe.quantize.quant_cfg == [] def test_load_recipe_dir_missing_recipe_raises(tmp_path):