Skip to content

Commit cda4150

Browse files
committed
wip: autoquant recipe schema + hf_ptq dispatch
Signed-off-by: Juhi Mittal <juhim@nvidia.com>
1 parent b02e888 commit cda4150

4 files changed

Lines changed: 376 additions & 68 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 138 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,12 @@
5555
import modelopt.torch.opt as mto
5656
import modelopt.torch.quantization as mtq
5757
import modelopt.torch.sparsity as mts
58-
from modelopt.recipe import ModelOptPTQRecipe, load_recipe
58+
from modelopt.recipe import (
59+
ModelOptAutoQuantizeRecipe,
60+
ModelOptPTQRecipe,
61+
ModelOptRecipeBase,
62+
load_recipe,
63+
)
5964
from modelopt.torch.export import (
6065
export_hf_checkpoint,
6166
export_hf_vllm_fq_checkpoint,
@@ -208,6 +213,7 @@ def make_calib_dataloader(
208213
tokenizer: PreTrainedTokenizerBase | None,
209214
device: torch.device,
210215
model_type: str | None,
216+
recipe: ModelOptRecipeBase | None = None,
211217
) -> tuple[DataLoader | _DeviceDataLoader, str | None]:
212218
calib_dataloader = None
213219
first_text_speech_dataset = None
@@ -271,8 +277,12 @@ def make_calib_dataloader(
271277
assert tokenizer is not None and isinstance(
272278
tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)
273279
), "The PreTrainedTokenizer must be set"
274-
# Labels are only needed for gradient-based auto_quantize
275-
include_labels = (
280+
# Labels are only needed for gradient-based auto_quantize (CLI or recipe path).
281+
is_autoquant_recipe_gradient = (
282+
isinstance(recipe, ModelOptAutoQuantizeRecipe)
283+
and recipe.auto_quantize.method == "gradient"
284+
)
285+
include_labels = is_autoquant_recipe_gradient or (
276286
args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient"
277287
)
278288

@@ -292,48 +302,32 @@ def auto_quantize(
292302
args: argparse.Namespace,
293303
language_model: torch.nn.Module,
294304
calib_dataloader: DataLoader,
295-
auto_quantize_method="gradient",
296-
auto_quantize_score_size=128,
297-
auto_quantize_checkpoint=None,
298305
full_model: torch.nn.Module | None = None,
306+
*,
307+
auto_quantize_method: str,
308+
auto_quantize_score_size: int,
309+
auto_quantize_checkpoint: str | None,
310+
constraints: dict,
311+
quantization_formats: list[dict],
312+
disabled_layers: list[str],
313+
kv_cache_qformat: str,
299314
):
300-
"""Auto search quantization of multiple formats."""
315+
"""Pure orchestrator: build forward_step/loss_func, call mtq.auto_quantize,
316+
run KV cache post-step. All knobs are explicit keyword-only args; the
317+
caller (dispatch site in ``quantize_main``) is responsible for resolving
318+
them from either CLI args or a recipe before invoking this function.
319+
"""
301320

302321
if args.calib_with_images:
303322
raise NotImplementedError(
304323
"AutoQuantize with image-text calibration is not supported yet. "
305324
"Please run plain PTQ (e.g., --qformat nvfp4) with --calib_with_images."
306325
)
307326

308-
assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), (
327+
assert args.inference_pipeline_parallel <= 1, (
309328
"Auto Quantization is not supported for pipeline parallel size > 1"
310329
)
311330

312-
qformat_list = args.qformat.split(",")
313-
assert qformat_list, "No quantization formats provided"
314-
# Check if all provided quantization formats are supported
315-
assert all(
316-
qformat
317-
in [
318-
"fp8",
319-
"int8_sq",
320-
"int8_wo",
321-
"int4_awq",
322-
"nvfp4",
323-
"nvfp4_awq",
324-
"nvfp4_mse",
325-
"w4a8_awq",
326-
"fp8_pb_wo",
327-
"w4a8_mxfp4_fp8",
328-
"nvfp4_mlp_only",
329-
"nvfp4_experts_only",
330-
"nvfp4_omlp_only",
331-
"nvfp4_local_hessian",
332-
"mxfp8",
333-
]
334-
for qformat in qformat_list
335-
), "One or more quantization formats provided are not supported for unified checkpoint export"
336-
337331
# When language_model is a base text model without lm_head (e.g. Gemma4TextModel),
338332
# use full_model's lm_head to compute logits/loss from hidden states.
339333
is_base_model = (
@@ -384,45 +378,39 @@ def forward_step(model, batch):
384378

385379
language_model, _ = mtq.auto_quantize(
386380
language_model,
387-
constraints={"effective_bits": args.auto_quantize_bits},
381+
constraints=constraints,
388382
data_loader=calib_dataloader,
389383
forward_step=forward_step,
390384
loss_func=loss_func, # Only used for gradient-based method
391385
# TRTLLM only support one quantization format or None (do not quantize, internally supported)
392-
quantization_formats=[QUANT_CFG_CHOICES[format] for format in qformat_list],
386+
quantization_formats=quantization_formats, # type: ignore[arg-type]
393387
num_calib_steps=len(calib_dataloader),
394388
# AutoQuantize scoring is the costly phase; allow smaller sample counts than calibration.
395389
num_score_steps=min(
396390
len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1)
397391
),
398392
verbose=True,
399-
# Disable all default disabled layers such as lm_head, mlp.gate, router etc.
400-
disabled_layers=[
401-
entry["quantizer_name"]
402-
for entry in _default_disabled_quantizer_cfg
403-
if "parent_class" not in entry
404-
],
393+
disabled_layers=disabled_layers,
405394
method=auto_quantize_method,
406395
checkpoint=auto_quantize_checkpoint,
407396
)
408397

409398
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
410-
# We need to explicitly set up KV cache quantization after auto_quantize
411-
enable_quant_kv_cache = args.kv_cache_qformat != "none"
399+
enable_quant_kv_cache = kv_cache_qformat != "none"
412400
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
413401
if enable_quant_kv_cache:
414402
kv_cache_quant_cfg = copy.deepcopy(
415-
getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"]
403+
getattr(mtq, KV_QUANT_CFG_CHOICES[kv_cache_qformat])["quant_cfg"]
416404
)
417405
kv_cache_quant_cfg = [
418406
e for e in kv_cache_quant_cfg if e["quantizer_name"] != "*"
419407
] # keep other quantizers from auto_quantize
420408

421-
if args.kv_cache_qformat in _KV_CAST_FORMATS:
409+
if kv_cache_qformat in _KV_CAST_FORMATS:
422410
_set_kv_cache_constant_amax(kv_cache_quant_cfg)
423411

424412
mtq.set_quantizer_by_cfg(language_model, quant_cfg=kv_cache_quant_cfg)
425-
if args.kv_cache_qformat not in _KV_CAST_FORMATS:
413+
if kv_cache_qformat not in _KV_CAST_FORMATS:
426414
# Calibrate only the KV cache quantizers; disable all others.
427415
with mtq.set_quantizer_by_cfg_context(
428416
language_model,
@@ -1003,12 +991,20 @@ def quantize_main(
1003991
):
1004992
# Load the recipe up front so we can detect layerwise calibration before batch-size probing.
1005993
recipe = None
1006-
if args.recipe is not None and not args.auto_quantize_bits:
994+
if args.recipe is not None:
1007995
print(f"Use recipe {args.recipe} for quantization")
1008996
recipe = load_recipe(args.recipe)
1009-
if not isinstance(recipe, ModelOptPTQRecipe):
997+
if not isinstance(recipe, (ModelOptPTQRecipe, ModelOptAutoQuantizeRecipe)):
1010998
raise TypeError(
1011-
f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}"
999+
f"Expected PTQ or AutoQuantize recipe, but got {type(recipe).__name__} "
1000+
f"from {args.recipe}"
1001+
)
1002+
# Fail-fast on conflicting budget sources: a recipe carries its own
1003+
# effective_bits, so silently honoring one over the other would be a
1004+
# reproducibility hazard.
1005+
if args.auto_quantize_bits is not None:
1006+
raise ValueError(
1007+
"Cannot combine --auto_quantize_bits with --recipe; the recipe owns the budget."
10121008
)
10131009

10141010
def _is_layerwise(obj):
@@ -1059,7 +1055,9 @@ def _is_layerwise(obj):
10591055
else:
10601056
sample_input_single_batch = None
10611057

1062-
run_auto_quant = args.auto_quantize_bits is not None
1058+
run_auto_quant = args.auto_quantize_bits is not None or isinstance(
1059+
recipe, ModelOptAutoQuantizeRecipe
1060+
)
10631061

10641062
args.batch_size = get_max_batch_size(
10651063
language_model,
@@ -1073,7 +1071,7 @@ def _is_layerwise(obj):
10731071
print(f"Use calib batch_size {args.batch_size}")
10741072

10751073
calib_dataloader, first_text_speech_dataset = make_calib_dataloader(
1076-
args, language_model, processor, tokenizer, device, model_type
1074+
args, language_model, processor, tokenizer, device, model_type, recipe=recipe
10771075
)
10781076

10791077
# Detect if this is a Nemotron VL model using architecture-based detection
@@ -1083,20 +1081,91 @@ def _is_layerwise(obj):
10831081
args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model
10841082
)
10851083

1086-
if args.auto_quantize_bits:
1087-
assert len(args.qformat.split(",")) > 1, (
1088-
"Auto quantization needs multiple quantization format."
1089-
)
1084+
# All auto_quantize() knobs are resolved here before calling the helper.
1085+
# Helper is a leaf orchestrator — it does not know whether inputs came from
1086+
# CLI args or a recipe.
1087+
if isinstance(recipe, ModelOptAutoQuantizeRecipe) or args.auto_quantize_bits:
1088+
default_disabled_layers = [
1089+
entry["quantizer_name"]
1090+
for entry in _default_disabled_quantizer_cfg
1091+
if "parent_class" not in entry
1092+
]
10901093

1091-
auto_quantize(
1092-
args,
1093-
language_model,
1094-
calib_dataloader,
1095-
auto_quantize_method=args.auto_quantize_method,
1096-
auto_quantize_score_size=args.auto_quantize_score_size,
1097-
auto_quantize_checkpoint=args.auto_quantize_checkpoint,
1098-
full_model=full_model,
1099-
)
1094+
if isinstance(recipe, ModelOptAutoQuantizeRecipe):
1095+
aq = recipe.auto_quantize
1096+
1097+
# mtq.auto_quantize labels candidates by upstream identity: dicts that ARE
1098+
# an mtq.X_CFG object get the constant's name in logs (e.g. NVFP4_DEFAULT_CFG);
1099+
# all other dicts get "CUSTOM_N" plus a "results may not be optimal" warning.
1100+
# Recipe candidates come from .model_dump() — equal by value but not identity,
1101+
# so we'd lose the friendly names. Substitute the canonical object back when
1102+
# the dump matches a known preset, so logs and the warning line up with CLI.
1103+
# The match check uses exclude_unset=True so it compares against the
1104+
# preset YAML's natural shape (mtq.X_CFG dicts don't carry Pydantic-filled
1105+
# defaults). The payload still passes the full dump to upstream.
1106+
def _candidate_for_mtq(fmt):
1107+
strict = fmt.model_dump(exclude_unset=True)
1108+
for cfg in QUANT_CFG_CHOICES.values():
1109+
if cfg == strict:
1110+
return cfg
1111+
return fmt.model_dump()
1112+
1113+
auto_quantize(
1114+
args,
1115+
language_model,
1116+
calib_dataloader,
1117+
full_model=full_model,
1118+
auto_quantize_method=aq.method,
1119+
auto_quantize_score_size=aq.num_score_steps,
1120+
auto_quantize_checkpoint=aq.score_checkpoint,
1121+
constraints=aq.constraints.model_dump(exclude_none=True),
1122+
quantization_formats=[_candidate_for_mtq(fmt) for fmt in aq.candidate_formats],
1123+
disabled_layers=aq.disabled_layers or default_disabled_layers,
1124+
kv_cache_qformat=(
1125+
aq.kv_cache.qformat
1126+
if (aq.kv_cache and aq.kv_cache.qformat)
1127+
else args.kv_cache_qformat
1128+
),
1129+
)
1130+
else:
1131+
qformat_list = args.qformat.split(",")
1132+
assert len(qformat_list) > 1, "Auto quantization needs multiple quantization format."
1133+
assert all(
1134+
qformat
1135+
in [
1136+
"fp8",
1137+
"int8_sq",
1138+
"int8_wo",
1139+
"int4_awq",
1140+
"nvfp4",
1141+
"nvfp4_awq",
1142+
"nvfp4_mse",
1143+
"w4a8_awq",
1144+
"fp8_pb_wo",
1145+
"w4a8_mxfp4_fp8",
1146+
"nvfp4_mlp_only",
1147+
"nvfp4_experts_only",
1148+
"nvfp4_omlp_only",
1149+
"nvfp4_local_hessian",
1150+
"mxfp8",
1151+
]
1152+
for qformat in qformat_list
1153+
), (
1154+
"One or more quantization formats provided are not supported for unified checkpoint export"
1155+
)
1156+
auto_quantize(
1157+
args,
1158+
language_model,
1159+
calib_dataloader,
1160+
full_model=full_model,
1161+
auto_quantize_method=args.auto_quantize_method,
1162+
auto_quantize_score_size=args.auto_quantize_score_size,
1163+
auto_quantize_checkpoint=args.auto_quantize_checkpoint,
1164+
constraints={"effective_bits": args.auto_quantize_bits},
1165+
quantization_formats=[QUANT_CFG_CHOICES[fmt] for fmt in qformat_list],
1166+
disabled_layers=default_disabled_layers,
1167+
kv_cache_qformat=args.kv_cache_qformat,
1168+
)
11001169

11011170
else:
11021171
# mono quantization
@@ -1214,9 +1283,11 @@ def parse_args() -> argparse.Namespace:
12141283
parser.add_argument(
12151284
"--recipe",
12161285
help=(
1217-
"PTQ recipe YAML file or name without suffix (e.g. general/ptq/fp8_default-kv_fp8_cast, "
1218-
"general/ptq/nvfp4_default-kv_fp8_cast, general/ptq/nvfp4_default-kv_nvfp4_cast). "
1219-
"When set, --kv_cache_qformat is ignored; the recipe fully determines KV cache config."
1286+
"PTQ or AutoQuantize recipe YAML file or name without suffix "
1287+
"(e.g. general/ptq/nvfp4_default-kv_fp8_cast, "
1288+
"general/auto_quantize/nvfp4_fp8_at_4p8bits-kv_fp8_cast). "
1289+
"PTQ recipes fully own quant config; AutoQuantize recipes own search config "
1290+
"and may optionally override --kv_cache_qformat via their kv_cache field."
12201291
),
12211292
default=None,
12221293
)

0 commit comments

Comments
 (0)