Skip to content

Commit d6e1973

Browse files
realAsmajenchen13
authored andcommitted
fix: layerwise calibration backward-compat, recipe split, batch-size guard (#1310)
## Summary Follow-up to #1251 (which renamed `use_sequential` → `layerwise`). Three related fixes bundled: 1. **Backward-compatible config loading.** PTQ checkpoints saved before #1251 store the legacy `use_sequential` key in the calibration-algorithm config, so loading them now raises `ValidationError: Extra inputs are not permitted (use_sequential)` because `QuantizeAlgorithmConfig` uses `extra='forbid'`. Accept `use_sequential` as an alias for `layerwise` via `AliasChoices`. The field still serializes as `layerwise`, so round-trips through the current schema are clean. 2. **Recipe split.** `nvfp4_experts_only-fp8_kv` previously enabled layerwise calibration by default, which changes the calibration flow materially. Split into two recipes: - `nvfp4_experts_only-fp8_kv.yaml` — default (no layerwise) - `nvfp4_experts_only-fp8_kv_layerwise.yaml` — layerwise variant 3. **`hf_ptq` batch-size guard.** Auto batch-size detection is not supported together with layerwise calibration. Default to `batch_size=1` when layerwise is enabled and the user hasn't set a batch size explicitly. Originally reported by Jenny Chen while resuming a PTQ checkpoint via `restore_sharded_modelopt_state`: ``` pydantic_core._pydantic_core.ValidationError: 1 validation error for MaxCalibConfig use_sequential Extra inputs are not permitted [type=extra_forbidden, input_value=False, input_type=bool] ``` ## Test plan - [x] `tests/unit/torch/quantization/test_config_validation.py` — legacy alias accepted, current name accepted, dump serializes under current name, `extra='forbid'` still rejects unknown keys. - [x] `pre-commit run` — clean. ### Before your PR is *Ready for review* - Is this change backward compatible?: ✅ (restores compatibility for pre-#1251 checkpoints) - New PIP dependency: N/A - New necessary tests: ✅ - Changelog update: N/A (bug fix) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added new PTQ recipe for efficient layerwise calibration of large models. * Automatic batch size optimization for layerwise calibration recipes. * Backward compatibility support for legacy input naming conventions. * **Documentation** * Updated recipe guides and changelog with new layerwise calibration recipe. * **Tests** * Added validation tests for configuration compatibility. [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/NVIDIA/Model-Optimizer/pull/1310) <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent c3b1f5a commit d6e1973

8 files changed

Lines changed: 111 additions & 9 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Changelog
4141
- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.
4242
- [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution.
4343
- [Early Testing] Polish Claude Code evaluation skill (``.claude/skills/evaluation/``) for agent-assisted LLM accuracy benchmarking via NeMo Evaluator Launcher. Adds two companion skills vendored verbatim from `NVIDIA-NeMo/Evaluator <https://github.com/NVIDIA-NeMo/Evaluator>`_: ``launching-evals`` (run/check/debug/analyze NEL evaluations) and ``accessing-mlflow`` (query MLflow runs, compare metrics, fetch artifacts). Re-sync at a pinned upstream SHA via ``.claude/scripts/sync-upstream-skills.sh``. Also adds a shared ``skills/common/credentials.md`` covering HF / NGC / Docker token setup referenced by multiple skills. This feature is in early testing — use with caution.
44-
- Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8.yaml>`_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml>`_ for usage.
44+
- Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_layerwise.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8_layerwise.yaml>`_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_default-kv_none-gptq.yaml>`_ for usage.
4545
- Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.quantization.src.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning.
4646
- Add FP8 MHA quantization support for vision transformers. Adds an attention-aware ONNX post-processing pass (scale Mul / K-transpose move before Q, Q→DQ insertion on softmax output) in :class:`FP8QuantExporter <modelopt.onnx.export.fp8_exporter.FP8QuantExporter>`, per-instance nested-attention-wrapper skipping in the HF plugin, and ``nn.LayerNorm`` registration in ``QuantModuleRegistry`` so BMM input quantizers and LayerNorm output quantizers defined in FP8_DEFAULT_CFG are honored end-to-end. See `examples/torch_onnx/torch_quant_to_onnx.py <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/torch_onnx/torch_quant_to_onnx.py>`_ for the general timm-model quantize→ONNX workflow.
4747

docs/source/guides/10_recipes.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,8 @@ General PTQ recipes are model-agnostic and apply to any supported architecture:
495495
- NVFP4 for MLP layers only, FP8 KV cache
496496
* - ``general/ptq/nvfp4_experts_only-kv_fp8``
497497
- NVFP4 for MoE expert layers only, FP8 KV cache
498+
* - ``general/ptq/nvfp4_experts_only-kv_fp8_layerwise``
499+
- NVFP4 for MoE expert layers only, FP8 KV cache, layerwise calibration
498500
* - ``general/ptq/nvfp4_omlp_only-kv_fp8``
499501
- NVFP4 for output projection + MLP layers, FP8 KV cache
500502

@@ -657,6 +659,7 @@ The ``modelopt_recipes/`` package is organized as follows:
657659
| +-- nvfp4_default-kv_nvfp4_cast.yaml
658660
| +-- nvfp4_mlp_only-kv_fp8.yaml
659661
| +-- nvfp4_experts_only-kv_fp8.yaml
662+
| +-- nvfp4_experts_only-kv_fp8_layerwise.yaml
660663
| +-- nvfp4_omlp_only-kv_fp8.yaml
661664
+-- models/ # Model-specific recipes
662665
| +-- Step3.5-Flash/

examples/llm_ptq/hf_ptq.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,25 @@ def quantize_main(
988988
default_pad_token,
989989
device: torch.device,
990990
):
991+
# Load the recipe up front so we can detect layerwise calibration before batch-size probing.
992+
recipe = None
993+
if args.recipe is not None and not args.auto_quantize_bits:
994+
print(f"Use recipe {args.recipe} for quantization")
995+
recipe = load_recipe(args.recipe)
996+
if not isinstance(recipe, ModelOptPTQRecipe):
997+
raise TypeError(
998+
f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}"
999+
)
1000+
1001+
def _is_layerwise(obj):
1002+
if isinstance(obj, ModelOptPTQRecipe):
1003+
return _is_layerwise(obj.quantize.algorithm)
1004+
if isinstance(obj, list):
1005+
return any(_is_layerwise(a) for a in obj)
1006+
return bool(getattr(obj, "layerwise", False))
1007+
1008+
is_layerwise = _is_layerwise(recipe)
1009+
9911010
if args.batch_size == 0:
9921011
# For VL models with image-text calibration, skip automatic batch size detection
9931012
# since get_max_batch_size can't handle multimodal inputs
@@ -1001,6 +1020,11 @@ def quantize_main(
10011020
"Offline speculative decoding calibration enabled. Using default batch_size=1 for calibration."
10021021
)
10031022
args.batch_size = 1
1023+
# Layerwise calibration processes one layer at a time; auto batch-size probing runs a
1024+
# full-model forward which defeats the point and can OOM on very large models.
1025+
elif is_layerwise:
1026+
print("Layerwise calibration enabled. Using default batch_size=1 for calibration.")
1027+
args.batch_size = 1
10041028
else:
10051029
# Calibration/sparsification will actually take much more memory than regular inference
10061030
# due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio
@@ -1064,12 +1088,7 @@ def quantize_main(
10641088
else:
10651089
# mono quantization
10661090

1067-
if args.recipe is not None:
1068-
print(f"Use recipe {args.recipe} for quantization")
1069-
recipe = load_recipe(args.recipe)
1070-
assert isinstance(recipe, ModelOptPTQRecipe), (
1071-
f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}"
1072-
)
1091+
if recipe is not None:
10731092
quant_cfg = recipe.quantize.model_dump()
10741093

10751094
else:

modelopt/torch/quantization/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@
154154
import warnings
155155
from typing import Any, Literal, cast
156156

157-
from pydantic import ValidationInfo, field_validator, model_validator
157+
from pydantic import AliasChoices, ValidationInfo, field_validator, model_validator
158158
from typing_extensions import Required, TypedDict
159159

160160
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
@@ -588,6 +588,7 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
588588

589589
layerwise: bool = ModeloptField(
590590
default=False,
591+
validation_alias=AliasChoices("layerwise", "use_sequential"),
591592
title="Enable layerwise (layer-by-layer) calibration.",
592593
description=(
593594
"If True, the calibration algorithm is applied layer by layer. "

modelopt_recipes/general/ptq/nvfp4_experts_only-kv_fp8.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ imports:
2121

2222
metadata:
2323
recipe_type: ptq
24-
description: NVFP4 static weight and dynamic activation for expert layers only (W4A4), FP8 KV cache, max layerwise calibration.
24+
description: NVFP4 static weight and dynamic activation for expert layers only (W4A4), FP8 KV cache, max calibration.
2525
quantize:
2626
algorithm:
2727
method: max
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
imports:
17+
base_disable_all: configs/ptq/units/base_disable_all
18+
default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers
19+
nvfp4: configs/numerics/nvfp4
20+
kv_fp8: configs/ptq/units/kv_fp8
21+
22+
metadata:
23+
recipe_type: ptq
24+
description: NVFP4 static weight and dynamic activation for expert layers only (W4A4), FP8 KV cache, max layerwise calibration.
25+
quantize:
26+
algorithm:
27+
method: max
28+
# Max calibration is fast and does not typically need checkpointing.
29+
layerwise: true
30+
quant_cfg:
31+
- $import: base_disable_all
32+
- quantizer_name: '*mlp.experts*weight_quantizer'
33+
cfg:
34+
$import: nvfp4
35+
- quantizer_name: '*mlp.experts*input_quantizer'
36+
cfg:
37+
$import: nvfp4
38+
- quantizer_name: '*block_sparse_moe*weight_quantizer'
39+
cfg:
40+
$import: nvfp4
41+
- quantizer_name: '*block_sparse_moe*input_quantizer'
42+
cfg:
43+
$import: nvfp4
44+
- $import: kv_fp8
45+
- $import: default_disabled_quantizers

tests/unit/recipe/test_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def test_load_recipe_builtin_description():
136136
"general/ptq/nvfp4_default-kv_nvfp4_cast",
137137
"general/ptq/nvfp4_default-kv_none-gptq",
138138
"general/ptq/nvfp4_experts_only-kv_fp8",
139+
"general/ptq/nvfp4_experts_only-kv_fp8_layerwise",
139140
"general/ptq/nvfp4_mlp_only-kv_fp8",
140141
"general/ptq/nvfp4_omlp_only-kv_fp8",
141142
]

tests/unit/torch/quantization/test_config_validation.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
INT4_AWQ_CFG,
2626
NVFP4_DEFAULT_CFG,
2727
W4A8_AWQ_BETA_CFG,
28+
MaxCalibConfig,
2829
QuantizeConfig,
2930
find_quant_cfg_entry_by_path,
3031
need_calibration,
@@ -525,3 +526,35 @@ def test_validate_quant_cfg_entries_accepts_valid_cfg(self):
525526
algorithm="max",
526527
)
527528
assert len(cfg.quant_cfg) == 2
529+
530+
531+
class TestLayerwiseUseSequentialAlias:
532+
"""`layerwise` accepts the legacy `use_sequential` name via validation_alias.
533+
534+
Old PTQ checkpoints serialized the field as `use_sequential` before #1251 renamed
535+
it to `layerwise`. AliasChoices lets those checkpoints load without a migration
536+
validator while still serializing under the current name.
537+
"""
538+
539+
def test_use_sequential_true_sets_layerwise(self):
540+
cfg = MaxCalibConfig(use_sequential=True)
541+
assert cfg.layerwise is True
542+
543+
def test_use_sequential_false_sets_layerwise(self):
544+
cfg = MaxCalibConfig(use_sequential=False)
545+
assert cfg.layerwise is False
546+
547+
def test_layerwise_name_still_accepted(self):
548+
cfg = MaxCalibConfig(layerwise=True)
549+
assert cfg.layerwise is True
550+
551+
def test_serializes_under_current_name(self):
552+
"""Dump must use `layerwise`, not the legacy alias."""
553+
dumped = MaxCalibConfig(use_sequential=True).model_dump()
554+
assert dumped["layerwise"] is True
555+
assert "use_sequential" not in dumped
556+
557+
def test_unknown_field_still_rejected(self):
558+
"""extra='forbid' must still reject unrelated unknown fields."""
559+
with pytest.raises(ValidationError):
560+
MaxCalibConfig(not_a_real_field=True)

0 commit comments

Comments
 (0)