Skip to content

Commit 31b3bca

Browse files
realAsmaclaude
andcommitted
Auto Quantize improvements and bug fixes for large sparse MoEs
- Add get_auto_quantize_config API to extract quant config from search results - Save/restore calibration state in auto_quantize checkpoints - Add NemotronH MoE expert support in auto_quantize grouping/scoring - Fix SequentialQuantizer scope, use F.kl_div for numerical stability - Fix mypy errors and clean up tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent 695c8e8 commit 31b3bca

12 files changed

Lines changed: 440 additions & 97 deletions

File tree

CHANGELOG.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ NVIDIA Model Optimizer Changelog
2121
- ``pass_through_bwd`` in the quantization config is now default to True. Please set it to False if you want to use STE with zeroed outlier gradients for potentially better QAT accuracy.
2222
- Add :meth:`compute_quantization_mse <modelopt.torch.quantization.model_quant.compute_quantization_mse>` API to measure per-quantizer mean-squared quantization error, with flexible wildcard and callable filtering.
2323
- **AutoQDQ**: New tool for automated Q/DQ (Quantize/Dequantize) placement optimization for ONNX models. Uses TensorRT latency measurements to choose insertion schemes that minimize inference time. Discovers regions automatically, groups them by structural pattern, and tests multiple Q/DQ schemes per pattern. Supports INT8 and FP8 quantization, pattern cache for warm-start on similar models, checkpoint/resume, and importing patterns from an existing QDQ baseline. CLI: ``python -m modelopt.onnx.quantization.autotune``. See the AutoQDQ guide in the documentation.
24+
- Add ``get_auto_quantize_config`` API to extract a flat quantization config from ``auto_quantize`` search results, enabling re-quantization at different effective bit targets without re-running calibration.
25+
- Improve ``auto_quantize`` checkpoint/resume: calibration state is now saved and restored across runs, avoiding redundant calibration when resuming a search.
26+
- Add NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules.
2427

2528
**Misc**
2629

modelopt/torch/opt/searcher.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -236,33 +236,50 @@ def state_dict(self) -> SearchStateDict:
236236
"""The state dictionary that can be stored/loaded."""
237237
return {key: getattr(self, key) for key in self.default_state_dict}
238238

239-
def load_search_checkpoint(self) -> bool:
239+
def _get_checkpoint_path(self) -> str | None:
240+
"""Get per-rank checkpoint path when distributed, otherwise the original path."""
241+
checkpoint = self.config["checkpoint"]
242+
if checkpoint is None:
243+
return None
244+
if dist.is_initialized():
245+
dirname, basename = os.path.split(checkpoint)
246+
name, ext = os.path.splitext(basename)
247+
return os.path.join(dirname, f"{name}{dist.rank()}{ext}")
248+
return checkpoint
249+
250+
def load_search_checkpoint(self, strict=True) -> bool:
240251
"""Load function for search checkpoint returning indicator whether checkpoint was loaded."""
241-
# check if checkpoint exists
242-
checkpoint: str | None = self.config["checkpoint"]
252+
checkpoint = self._get_checkpoint_path()
243253
if checkpoint is None:
244254
return False
255+
# Backward compat: fall back to the original single-file path
256+
if not os.path.exists(checkpoint):
257+
checkpoint = self.config["checkpoint"]
245258
if not os.path.exists(checkpoint):
246259
warn_rank_0(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.")
247260
return False
248261

249-
# iterate through state dict and load keys
250262
print_rank_0(f"Loading searcher state from {checkpoint}...")
251263
# Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input
252264
state_dict = torch.load(checkpoint, weights_only=False)
253-
assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!"
254-
for key, state in state_dict.items():
255-
setattr(self, key, state)
265+
if strict:
266+
assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!"
267+
for key, default_val in self.default_state_dict.items():
268+
setattr(self, key, state_dict.get(key, default_val))
256269
return True
257270

258271
def save_search_checkpoint(self, verbose=False) -> None:
259272
"""Save function for search checkpoint."""
260-
# check if save requirements are satisfied
261-
checkpoint: str | None = self.config["checkpoint"]
262-
if checkpoint is None or not dist.is_master():
273+
checkpoint = self._get_checkpoint_path()
274+
if checkpoint is None:
263275
return
264276

265-
# save state dict
277+
if dist.is_initialized():
278+
warn_rank_0(
279+
"torch.distributed is initialized. Please maintain the same parallelism "
280+
"configuration (world size, TP, EP, etc.) across search save and restore sessions."
281+
)
282+
266283
if verbose:
267284
print(f"Saving searcher state to {checkpoint}...")
268285
save_dirname, _ = os.path.split(checkpoint)

0 commit comments

Comments
 (0)