Skip to content

Commit e0118b2

Browse files
committed
Update
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent f12c16f commit e0118b2

5 files changed

Lines changed: 15 additions & 5 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ NVIDIA Model Optimizer Changelog (Linux)
88

99
- User does not need to manually register MOE modules to cover experts calibration coverage in PTQ workflow.
1010
- ``hf_ptq.py`` now saves the quantization summary and moe expert token count table to the export directory.
11-
- Add ``--moe_calib_experts_ratio`` flag in ``hf_ptq.py`` to specify the ratio of experts to calibrate during forward pass to improve expert coverage during calibration. Default to 1/4 of all the experts.
11+
- Add ``--moe_calib_experts_ratio`` flag in ``hf_ptq.py`` to specify the ratio of experts to calibrate during forward pass to improve expert coverage during calibration. Default to all the experts.
1212
- Add sparse attention optimization for transformer models (``modelopt.torch.sparsity.attention_sparsity``). This reduces computational cost by skipping attention computation. Supports calibration for threshold selection on HuggingFace models. See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
1313

1414
0.42 (2026-02-xx)

examples/llm_ptq/example_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def build_quant_cfg(
201201
model_type,
202202
quant_cfg_choices,
203203
kv_quant_cfg_choices,
204-
moe_calib_experts_ratio,
204+
moe_calib_experts_ratio: float | None = None,
205205
) -> dict[str, Any]:
206206
quant_cfg = {}
207207
assert qformat in quant_cfg_choices, (
@@ -234,13 +234,18 @@ def build_quant_cfg(
234234
)
235235

236236
if moe_calib_experts_ratio:
237+
assert 0 < moe_calib_experts_ratio <= 1, "moe_calib_experts_ratio must be between 0 and 1"
237238
if isinstance(quant_cfg["algorithm"], str):
238239
quant_cfg["algorithm"] = {
239240
"method": quant_cfg["algorithm"],
240241
"moe_calib_experts_ratio": moe_calib_experts_ratio,
241242
}
242-
else:
243+
elif isinstance(quant_cfg["algorithm"], dict):
243244
quant_cfg["algorithm"]["moe_calib_experts_ratio"] = moe_calib_experts_ratio
245+
else:
246+
warnings.warn(
247+
f"Quantization algorithm: {quant_cfg['algorithm']} does not support setting moe_calib_experts_ratio"
248+
)
244249

245250
# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
246251
if model_type == "gemma" and "int8_sq" in qformat:

examples/llm_ptq/hf_ptq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1130,10 +1130,11 @@ def parse_args() -> argparse.Namespace:
11301130
parser.add_argument(
11311131
"--moe_calib_experts_ratio",
11321132
type=float,
1133-
default=1.0 / 4,
1133+
default=1.0,
11341134
help=(
11351135
"Fraction of experts to calibrate during forward pass (ratio in (0.0, 1.0]). "
11361136
"Only used for MOE models; used to reduce the number of experts calibrated during the forward pass."
1137+
"Does not impact non-MOE models."
11371138
),
11381139
)
11391140

modelopt/torch/quantization/mode.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ def wrapped_calib_func(
227227

228228
moe_calib_experts_ratio = kwargs.pop("moe_calib_experts_ratio", None)
229229
if moe_calib_experts_ratio is not None:
230+
assert (
231+
isinstance(moe_calib_experts_ratio, (int, float)) and 0 < moe_calib_experts_ratio <= 1
232+
), f"Invalid moe_calib_experts_ratio {moe_calib_experts_ratio!r}"
230233
for module in model.modules():
231234
if hasattr(module, "_moe_calib_experts_ratio"):
232235
module._moe_calib_experts_ratio = moe_calib_experts_ratio

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
499499
assert 0 < self._moe_calib_experts_ratio <= 1, (
500500
"moe_calib_experts_ratio must be between 0 and 1"
501501
)
502-
# If any of the experts are in calibration mode, we will forward all tokens to all experts
502+
# If any of the experts are in calibration mode, we will forward all tokens to
503+
# self._moe_calib_experts_ratio % of the experts to improve the calibration coverage.
503504
# This is used only for calibration, we need to re-calculate the actual outputs again using
504505
# the original top_k
505506
if TRANSFORMERS_VERSION_GE_5_0:

0 commit comments

Comments
 (0)