Skip to content

Commit a5d46ff

Browse files
realAsmaclaude
andauthored
Auto Quantize improvements and bug fixes for large sparse MoEs (#953)
## What does this PR do? **Type of change:** New feature + Bug fixes **Overview:** Enable AutoQuantize for NemotronH and large SparseMoE models, and update the FP8 workflow split between `mtq.auto_quantize` and `mtq.quantize`. `mtq.auto_quantize` is now positioned as the lightweight search phase (lite calibration + scoring), while `mtq.quantize` is used for heavier/final calibration workflows (longer calibration passes, force-all-token style MoE calibration, and advanced recipes such as GPTQ, MSE, etc.). ### Algorithm & feature changes - **NemotronH / SparseMoE support**: Updated `quant_module` and `score_module` rules (should eventually move to the proposed modeling lib). In future, this should be the only change needed to support new models — the bug fixes below were unearthed while enabling NemotronH - **Config generation**: Added `mtq.get_auto_quantize_config(search_state, constraints=None, verbose=False)` to re-solve from `search_state` and produce plain-dict configs (no redundant `output_quantizer`), with optional verbose summary - **FP8 workflow split**: Use lite calibration in `mtq.auto_quantize`, then run longer/final calibration with `mtq.quantize` using the generated config - **Performance**: Pass `name_to_module` to `enable_weight_access_and_writeback` to avoid O(N^2) overhead on large MoE models - **Calibration caching in checkpoint**: Save/restore quantizer calibration states (metadata + state_dict) per recipe in the AutoQuantize checkpoint, so resuming a search skips redundant calibration - **Per-rank distributed checkpointing**: When `torch.distributed` is initialized, each rank saves/loads its own checkpoint file (`search_state{rank}.pt`), with backward-compatible fallback to the single-file path ### API updates - **Config API naming**: Use `mtq.get_auto_quantize_config(...)` for exporting the searched recipe into a quantize-ready config - **Recommended usage pattern**: ```python # 1) Lightweight search + lite calibration model, search_state = mtq.auto_quantize( model, constraints={"effective_bits": 6.0}, quantization_formats=[mtq.NVFP4_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG], data_loader=data_loader, forward_step=forward_step, loss_func=loss_func, num_calib_steps=64, # lite calibration during search num_score_steps=128, ) # 2) Export searched config (optionally re-solve constraints) auto_quantize_config = mtq.get_auto_quantize_config( search_state, constraints={"effective_bits": 6.0}, verbose=True, ) # 3) Final / longer calibration pass with quantize model = mtq.quantize( model, config=auto_quantize_config, forward_loop=long_calibration_loop, # e.g. force-all-token style MoE calibration ) ``` ### Bug fixes - Fixed `disabled_layers` handling so fused kernels (e.g. Mamba blocks) are properly skipped - Fixed gradient checkpointing to keep all modules except the checkpointed modules in eval - Fixed FP8 fake quant NaN/inf when `amax ≈ 0` - Fixed `SequentialQuantizer.convert_to_single_quantizer` to operate on `module` instead of `model`, avoiding O(N^2) CPU iteration on SparseMoE models with 1000s of submodules - Switched to proper `F.kl_div` for KL divergence scoring ### Not yet exposed to `llm_ptq` `mtq.get_auto_quantize_config` is not yet wired into `llm_ptq`. The plain config records per-expert quantization settings for all MoE experts, resulting in large JSON files. For my experiments I used a quick workaround. A follow-up PR will add a better config representation and expose it to `llm_ptq`. ## Testing - Tested on NemotronH-tiny and Nemotron-Super-RL models - Verified auto_quantize scoring + config generation end-to-end - Unit test for checkpoint resume verifies calibration cache correctness (metadata + tensor values) - Existing unit tests pass ## Before your PR is "*Ready for review*" - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: AutoQuantize end-to-end requires GPU + large MoE models; verified manually on NemotronH-tiny and Nemotron-Super-RL. Unit test coverage to follow. - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes ## Additional Information Follow-up planned: expose `mtq.get_auto_quantize_config` to `llm_ptq` with a compact config format for MoE models. AWQ support in AutoQuantize can also be removed in a future PR to keep it lightweight. --------- Signed-off-by: realAsma <akuriparambi@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent fff65b0 commit a5d46ff

12 files changed

Lines changed: 450 additions & 97 deletions

File tree

CHANGELOG.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ NVIDIA Model Optimizer Changelog
2121
- ``pass_through_bwd`` in the quantization config is now default to True. Please set it to False if you want to use STE with zeroed outlier gradients for potentially better QAT accuracy.
2222
- Add :meth:`compute_quantization_mse <modelopt.torch.quantization.model_quant.compute_quantization_mse>` API to measure per-quantizer mean-squared quantization error, with flexible wildcard and callable filtering.
2323
- **AutoQDQ**: New tool for automated Q/DQ (Quantize/Dequantize) placement optimization for ONNX models. Uses TensorRT latency measurements to choose insertion schemes that minimize inference time. Discovers regions automatically, groups them by structural pattern, and tests multiple Q/DQ schemes per pattern. Supports INT8 and FP8 quantization, pattern cache for warm-start on similar models, checkpoint/resume, and importing patterns from an existing QDQ baseline. CLI: ``python -m modelopt.onnx.quantization.autotune``. See the AutoQDQ guide in the documentation.
24+
- Add ``get_auto_quantize_config`` API to extract a flat quantization config from ``auto_quantize`` search results, enabling re-quantization at different effective bit targets without re-running calibration.
25+
- Improve ``auto_quantize`` checkpoint/resume: calibration state is now saved and restored across runs, avoiding redundant calibration when resuming a search.
26+
- Add NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules.
2427

2528
**Misc**
2629

modelopt/torch/opt/searcher.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -236,33 +236,54 @@ def state_dict(self) -> SearchStateDict:
236236
"""The state dictionary that can be stored/loaded."""
237237
return {key: getattr(self, key) for key in self.default_state_dict}
238238

239-
def load_search_checkpoint(self) -> bool:
239+
def _get_checkpoint_path(self) -> str | None:
240+
"""Get per-rank checkpoint path when distributed, otherwise the original path."""
241+
checkpoint = self.config["checkpoint"]
242+
if checkpoint is None:
243+
return None
244+
if dist.is_initialized():
245+
dirname, basename = os.path.split(checkpoint)
246+
name, ext = os.path.splitext(basename)
247+
return os.path.join(dirname, f"{name}{dist.rank()}{ext}")
248+
return checkpoint
249+
250+
def load_search_checkpoint(self, strict=True) -> bool:
240251
"""Load function for search checkpoint returning indicator whether checkpoint was loaded."""
241-
# check if checkpoint exists
242-
checkpoint: str | None = self.config["checkpoint"]
252+
checkpoint = self._get_checkpoint_path()
243253
if checkpoint is None:
244254
return False
255+
# Backward compat: fall back to the original single-file path
256+
if not os.path.exists(checkpoint):
257+
warn_rank_0(
258+
f"Per-rank checkpoint {checkpoint} not found, falling back to "
259+
f"{self.config['checkpoint']}. Ensure world size matches the original run."
260+
)
261+
checkpoint = self.config["checkpoint"]
245262
if not os.path.exists(checkpoint):
246263
warn_rank_0(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.")
247264
return False
248265

249-
# iterate through state dict and load keys
250266
print_rank_0(f"Loading searcher state from {checkpoint}...")
251267
# Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input
252268
state_dict = torch.load(checkpoint, weights_only=False)
253-
assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!"
254-
for key, state in state_dict.items():
255-
setattr(self, key, state)
269+
if strict:
270+
assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!"
271+
for key, default_val in self.default_state_dict.items():
272+
setattr(self, key, state_dict.get(key, default_val))
256273
return True
257274

258275
def save_search_checkpoint(self, verbose=False) -> None:
259276
"""Save function for search checkpoint."""
260-
# check if save requirements are satisfied
261-
checkpoint: str | None = self.config["checkpoint"]
262-
if checkpoint is None or not dist.is_master():
277+
checkpoint = self._get_checkpoint_path()
278+
if checkpoint is None:
263279
return
264280

265-
# save state dict
281+
if dist.is_initialized():
282+
warn_rank_0(
283+
"torch.distributed is initialized. Please maintain the same parallelism "
284+
"configuration (world size, TP, EP, etc.) across search save and restore sessions."
285+
)
286+
266287
if verbose:
267288
print(f"Saving searcher state to {checkpoint}...")
268289
save_dirname, _ = os.path.split(checkpoint)

0 commit comments

Comments
 (0)