Commit 16d562a
authored
Refactor local_hessian onto shared MSE flow + fused-MoE expert support (#1578)
### What does this PR do?
Type of change: Bug fix + new feature (fused-MoE coverage) <!-- Use one
of the following: Bug fix, new feature, new example, new tests,
documentation. -->
**Dusted off and refactored `local_hessian_calibrate`** to align with
the new MSE calibration flow, fix latent
drift, extend coverage to fused-MoE experts, and decouple
module-specific handling behind a clean extension point.
Core changes (`modelopt/torch/quantization/model_calib.py`):
- Extracted a shared `_mse_calibrate_weights` helper now used by both
`mse_calibrate` and `local_hessian_calibrate`
- Replaced the monolithic `LocalHessianHelper` + bespoke per-weight loop
with a small `_LocalHessianAccumulator` (lazy fp32 buffer, freed after
building the error func), removing ~200 lines of duplicated scale-search
logic and all manual `cuda.synchronize`/`empty_cache` bookkeeping. The
`XᵀX` GEMM accumulates in fp32 to avoid bf16/fp16 precision loss.
- Removed dead NVFP4-static promotion (now handled inside
`max_calibrate`).
- **Fused-MoE expert support:** per-expert Hessians captured from each
expert's routed activations, keyed by
`id(weight_quantizer)` so dense and per-expert paths share one
calibration loop. Never-routed experts / non-eager kernels / `cin` not
divisible by `block_size` / registered backends fall back to plain MSE
(with an eager, module-named warning).
Decoupling (zero module-type-specific code in `model_calib.py`):
- Added `QuantModule.register_calibration_input_hooks(callback)` — the
activation-side counterpart to `iter_weights_for_calibration`. Base
default is a no-op; `QuantLinearConvBase` pairs the weight quantizer
with the
forward input (linear only), and `_QuantFusedExperts` (in
`plugins/huggingface.py`) owns the per-expert pairing via
`_current_expert_idx`. Any future module type gains local-Hessian
support by implementing this one method.
### Usage
```python
# Add a code snippet demonstrating how to use this
```
### Testing
- new `tests/unit/torch/quantization/test_local_hessian.py` (accumulator
math/shape/dtype, dense end-to-end, backend-skip, block-size guard) and
a per-expert MoE test in `test_fused_experts.py`.
- Behavior-preserving check: refactored branch produces
**bit-identical** dense weight scales to `main` (216/216 tensors) on
Qwen3-8B with the fp32 accumulation neutralized to isolate the
structural change.
### Before your PR is "*Ready for review*"
Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)
and your commits are signed (`git commit -s -S`).
Make sure you read and follow the [Security Best
Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors)
(e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(...,
weights_only=False)`, `pickle`, etc.).
- Is this change backward compatible?: ✅ <!--- If ❌, explain why. -->
- If you copied code from any other sources or added a new PIP
dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ <!---
Mandatory -->
- Did you write any new necessary tests?: ✅ <!--- Mandatory for new
features or examples. -->
- Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?:
✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes
or backward incompatible changes. -->
- Did you get Claude approval on this PR?: ✅ / ❌ / N/A <!--- Run
`/claude review`. NVIDIA org members can self-trigger for complex
changes; orthogonal to CodeRabbit. -->
### Additional Information
<!-- E.g. related issue. -->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Introduced Hessian-weighted calibration method for enhanced weight
quantization refinement in mixed-precision models.
* Enhanced MSE calibration to support custom per-quantizer error
functions for specialized calibration workflows.
* **Tests**
* Added comprehensive unit tests for local Hessian calibration
validation across dense models and fused-expert architectures.
* Included tests for custom error function integration and fallback
behavior.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>1 parent b98a595 commit 16d562a
3 files changed
Lines changed: 561 additions & 216 deletions
File tree
- modelopt/torch/quantization
- tests/unit/torch/quantization
- plugins
0 commit comments