Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ NVIDIA Model Optimizer Changelog
- ``pass_through_bwd`` in the quantization config is now default to True. Please set it to False if you want to use STE with zeroed outlier gradients for potentially better QAT accuracy.
- Add :meth:`compute_quantization_mse <modelopt.torch.quantization.model_quant.compute_quantization_mse>` API to measure per-quantizer mean-squared quantization error, with flexible wildcard and callable filtering.
- **AutoQDQ**: New tool for automated Q/DQ (Quantize/Dequantize) placement optimization for ONNX models. Uses TensorRT latency measurements to choose insertion schemes that minimize inference time. Discovers regions automatically, groups them by structural pattern, and tests multiple Q/DQ schemes per pattern. Supports INT8 and FP8 quantization, pattern cache for warm-start on similar models, checkpoint/resume, and importing patterns from an existing QDQ baseline. CLI: ``python -m modelopt.onnx.quantization.autotune``. See the AutoQDQ guide in the documentation.
- Add ``get_auto_quantize_config`` API to extract a flat quantization config from ``auto_quantize`` search results, enabling re-quantization at different effective bit targets without re-running calibration.
- Improve ``auto_quantize`` checkpoint/resume: calibration state is now saved and restored across runs, avoiding redundant calibration when resuming a search.
- Add NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules.

**Misc**

Expand Down
43 changes: 32 additions & 11 deletions modelopt/torch/opt/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,33 +236,54 @@ def state_dict(self) -> SearchStateDict:
"""The state dictionary that can be stored/loaded."""
return {key: getattr(self, key) for key in self.default_state_dict}

def load_search_checkpoint(self) -> bool:
def _get_checkpoint_path(self) -> str | None:
"""Get per-rank checkpoint path when distributed, otherwise the original path."""
checkpoint = self.config["checkpoint"]
if checkpoint is None:
return None
if dist.is_initialized():
dirname, basename = os.path.split(checkpoint)
name, ext = os.path.splitext(basename)
return os.path.join(dirname, f"{name}{dist.rank()}{ext}")
return checkpoint

def load_search_checkpoint(self, strict=True) -> bool:
"""Load function for search checkpoint returning indicator whether checkpoint was loaded."""
# check if checkpoint exists
checkpoint: str | None = self.config["checkpoint"]
checkpoint = self._get_checkpoint_path()
if checkpoint is None:
return False
# Backward compat: fall back to the original single-file path
if not os.path.exists(checkpoint):
warn_rank_0(
f"Per-rank checkpoint {checkpoint} not found, falling back to "
f"{self.config['checkpoint']}. Ensure world size matches the original run."
)
checkpoint = self.config["checkpoint"]
if not os.path.exists(checkpoint):
warn_rank_0(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.")
return False

# iterate through state dict and load keys
print_rank_0(f"Loading searcher state from {checkpoint}...")
# Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input
state_dict = torch.load(checkpoint, weights_only=False)
assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!"
for key, state in state_dict.items():
setattr(self, key, state)
if strict:
assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!"
for key, default_val in self.default_state_dict.items():
setattr(self, key, state_dict.get(key, default_val))
return True

def save_search_checkpoint(self, verbose=False) -> None:
"""Save function for search checkpoint."""
# check if save requirements are satisfied
checkpoint: str | None = self.config["checkpoint"]
if checkpoint is None or not dist.is_master():
checkpoint = self._get_checkpoint_path()
if checkpoint is None:
return

# save state dict
if dist.is_initialized():
warn_rank_0(
"torch.distributed is initialized. Please maintain the same parallelism "
"configuration (world size, TP, EP, etc.) across search save and restore sessions."
)

if verbose:
print(f"Saving searcher state to {checkpoint}...")
save_dirname, _ = os.path.split(checkpoint)
Expand Down
Loading
Loading