Skip to content

Commit 03a1899

Browse files
cjluo-nvcoderabbitai[bot]realAsmaCopilot
authored
Support force tokens to % of total experts during calibration (#910)
## What does this PR do? **Type of change:** New feature **Overview:** Adds a configurable `moe_calib_experts_ratio` parameter that controls the percentage of experts to calibrate during the forward pass in MoE (Mixture of Experts) models. Previously, the calibration forward always routed tokens to **all** experts, which is expensive. This PR allows the user to specify a ratio (default: still all experts so no behavior change) to improve expert calibration coverage without the cost of a full-expert forward. The token counting for the expert coverage table now tracks the calibration routing and runs on CUDA for efficiency. **Changes include:** - New `moe_calib_experts_ratio` field in `QuantizeAlgorithmConfig` (`config.py`) - Propagation of the ratio from the algorithm config to MoE modules during calibration (`mode.py`) - Updated `_QuantSparseMoe.forward` to use the configurable ratio instead of hard-coding all experts (`huggingface.py`) - New `--moe_calib_experts_ratio` CLI flag in `hf_ptq.py` (default `0.25`) - Moved `expert_token_count` tensor to CUDA and updated the HTML table title in `moe_utils.py` ## Usage Via hf_ptq.py CLI — calibrate 50% of experts during MoE calibration python hf_ptq.py --model <model> --qformat int4_awq --moe_calib_experts_ratio 0.5 Via Python API — pass the ratio through the algorithm config import modelopt.torch.quantization as mtq quant_cfg = { "quant_cfg": { ... }, "algorithm": { "method": "awq_lite", "moe_calib_experts_ratio": 0.25, # calibrate 1/4 of experts }, } mtq.quantize(model, quant_cfg, forward_loop=calib_loop) ## Testing Test with Qwen3 30B A3B calibration and check the tokens per expert. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added support for configurable expert calibration during Mixture of Experts (MOE) model quantization. Users can now specify the percentage of experts to include during calibration, enabling better expert coverage and improved quantization accuracy for MOE models. Default: 25% of all experts. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chenjie Luo <chenjiel@nvidia.com> Signed-off-by: Chenjie Luo <108829653+cjluo-nv@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: realAsma <86726418+realAsma@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 75b5da9 commit 03a1899

File tree

7 files changed

+80
-14
lines changed

7 files changed

+80
-14
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ NVIDIA Model Optimizer Changelog (Linux)
88

99
- User does not need to manually register MOE modules to cover experts calibration coverage in PTQ workflow.
1010
- ``hf_ptq.py`` now saves the quantization summary and moe expert token count table to the export directory.
11+
- Add ``--moe_calib_experts_ratio`` flag in ``hf_ptq.py`` to specify the ratio of experts to calibrate during forward pass to improve expert coverage during calibration. Default to all the experts.
1112
- Add sparse attention optimization for transformer models (``modelopt.torch.sparsity.attention_sparsity``). This reduces computational cost by skipping attention computation. Supports calibration for threshold selection on HuggingFace models. See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
1213
- Add support for rotating the input before quantization for RHT.
1314

examples/llm_ptq/example_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def build_quant_cfg(
201201
model_type,
202202
quant_cfg_choices,
203203
kv_quant_cfg_choices,
204+
moe_calib_experts_ratio: float | None = None,
204205
) -> dict[str, Any]:
205206
quant_cfg = {}
206207
assert qformat in quant_cfg_choices, (
@@ -232,6 +233,20 @@ def build_quant_cfg(
232233
getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"],
233234
)
234235

236+
if moe_calib_experts_ratio:
237+
assert 0 < moe_calib_experts_ratio <= 1, "moe_calib_experts_ratio must be between 0 and 1"
238+
if isinstance(quant_cfg["algorithm"], str):
239+
quant_cfg["algorithm"] = {
240+
"method": quant_cfg["algorithm"],
241+
"moe_calib_experts_ratio": moe_calib_experts_ratio,
242+
}
243+
elif isinstance(quant_cfg["algorithm"], dict):
244+
quant_cfg["algorithm"]["moe_calib_experts_ratio"] = moe_calib_experts_ratio
245+
else:
246+
warnings.warn(
247+
f"Quantization algorithm: {quant_cfg['algorithm']} does not support setting moe_calib_experts_ratio"
248+
)
249+
235250
# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
236251
if model_type == "gemma" and "int8_sq" in qformat:
237252
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}

examples/llm_ptq/hf_ptq.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,7 @@ def quantize_main(
906906
model_type,
907907
QUANT_CFG_CHOICES,
908908
KV_QUANT_CFG_CHOICES,
909+
args.moe_calib_experts_ratio,
909910
)
910911

911912
# Exclude MTP layers from quantization if detected (e.g., GLM-4.7's layer 92)
@@ -1126,8 +1127,21 @@ def parse_args() -> argparse.Namespace:
11261127
"(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified."
11271128
),
11281129
)
1130+
parser.add_argument(
1131+
"--moe_calib_experts_ratio",
1132+
type=float,
1133+
default=1.0,
1134+
help=(
1135+
"Fraction of experts to calibrate during forward pass (ratio in (0.0, 1.0]). "
1136+
"Only used for MOE models; used to reduce the number of experts calibrated during the forward pass."
1137+
"Does not impact non-MOE models."
1138+
),
1139+
)
11291140

1130-
return parser.parse_args()
1141+
args = parser.parse_args()
1142+
if not (0.0 < args.moe_calib_experts_ratio <= 1.0):
1143+
parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].")
1144+
return args
11311145

11321146

11331147
def main(args: argparse.Namespace):

modelopt/torch/export/moe_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | Non
4848
"th, td { border: 1px solid #ccc; padding: 4px 8px; text-align: right; }",
4949
"th { background: #f0f0f0; }",
5050
"</style></head><body>",
51-
"<h2>Expert Token Counts (per MoE layer)</h2>",
51+
"<h2>Expert Calib Token Counts (per MoE layer)</h2>",
5252
"<table><tr><th>Layer/Expert</th>",
5353
]
5454
html_parts.extend(f"<th>{i}</th>" for i in range(num_experts))

modelopt/torch/quantization/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,6 +1097,16 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
10971097
title="This field specifies the name of the calibration algorithm. If None, no calibration is performed.",
10981098
)
10991099

1100+
moe_calib_experts_ratio: float | None = ModeloptField(
1101+
default=None,
1102+
title="% of experts to calibrate during forward pass.",
1103+
description=(
1104+
"If specified, we force forward tokens to % of experts during the calibration"
1105+
" pass. This forward is for calibration purpose only and will not affect the"
1106+
" actual inference."
1107+
),
1108+
)
1109+
11001110

11011111
class MaxCalibConfig(QuantizeAlgorithmConfig):
11021112
"""The config for max calibration algorithm.

modelopt/torch/quantization/mode.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,15 @@ def wrapped_calib_func(
225225
# For backward compatibility
226226
kwargs["algorithm"] = method
227227

228+
moe_calib_experts_ratio = kwargs.pop("moe_calib_experts_ratio", None)
229+
if moe_calib_experts_ratio is not None:
230+
assert (
231+
isinstance(moe_calib_experts_ratio, (int, float)) and 0 < moe_calib_experts_ratio <= 1
232+
), f"Invalid moe_calib_experts_ratio {moe_calib_experts_ratio!r}"
233+
for module in model.modules():
234+
if hasattr(module, "_moe_calib_experts_ratio"):
235+
module._moe_calib_experts_ratio = moe_calib_experts_ratio
236+
228237
if func is not None:
229238
# Call the function with forward_loop as a separate argument
230239
func(model, forward_loop=forward_loop, **kwargs)

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,13 @@ def _setup(self):
458458
elif hasattr(self, "experts") and hasattr(self.experts, "num_experts"):
459459
num_experts = self.experts.num_experts
460460

461-
self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cpu")
461+
self.register_buffer(
462+
"expert_token_count",
463+
torch.zeros(num_experts, dtype=torch.long, device=next(self.parameters()).device),
464+
persistent=False,
465+
)
462466
self._count_expert_tokens = False
467+
self._moe_calib_experts_ratio = None
463468

464469
if num_experts == 0:
465470
warnings.warn(
@@ -483,36 +488,48 @@ def _gate_forward_hook(self, module, input, output):
483488
logits = output if not isinstance(output, tuple) else output[0]
484489
top_k = self.gate.top_k if hasattr(self.gate, "top_k") else self.top_k
485490
_, indices = torch.topk(logits.float(), top_k, dim=-1)
486-
counts = torch.bincount(
487-
indices.reshape(-1).cpu(), minlength=len(self.expert_token_count)
488-
)
489-
self.expert_token_count += counts
491+
counts = torch.bincount(indices.reshape(-1), minlength=self.expert_token_count.shape[0])
492+
self.expert_token_count += counts.to(self.expert_token_count.device)
490493

491494
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
492495
is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules())
493-
if is_calib:
494-
# If any of the experts are in calibration mode, we will forward all tokens to all experts
496+
self._count_expert_tokens = is_calib
497+
if is_calib and self._moe_calib_experts_ratio:
498+
self._count_expert_tokens = True
499+
assert 0 < self._moe_calib_experts_ratio <= 1, (
500+
"moe_calib_experts_ratio must be between 0 and 1"
501+
)
502+
# If any of the experts are in calibration mode, we will forward all tokens to
503+
# self._moe_calib_experts_ratio % of the experts to improve the calibration coverage.
495504
# This is used only for calibration, we need to re-calculate the actual outputs again using
496505
# the original top_k
497506
if TRANSFORMERS_VERSION_GE_5_0:
498507
assert hasattr(self, "gate") and hasattr(self.gate, "top_k")
499508
original_top_k = self.gate.top_k
500-
self.gate.top_k = self.gate.num_experts
509+
self.gate.top_k = max(
510+
original_top_k, round(self.gate.num_experts * self._moe_calib_experts_ratio)
511+
)
501512
super().forward(hidden_states)
502513
self.gate.top_k = original_top_k
503514
else:
504515
# Path for transformers < 5.0
505516
original_top_k = self.top_k
506517
if hasattr(self, "num_experts"):
507-
self.top_k = self.num_experts
518+
self.top_k = max(
519+
original_top_k, round(self.num_experts * self._moe_calib_experts_ratio)
520+
)
508521
elif hasattr(self, "experts"):
509-
self.top_k = self.experts.num_experts
522+
self.top_k = max(
523+
original_top_k,
524+
round(self.experts.num_experts * self._moe_calib_experts_ratio),
525+
)
510526
else:
511527
raise ValueError(f"Could not find num_experts in module {self}")
512528
super().forward(hidden_states)
513529
self.top_k = original_top_k
514-
# Enable counting only for the real-routing forward during calibration
515-
self._count_expert_tokens = is_calib
530+
self._count_expert_tokens = False
531+
else:
532+
self._count_expert_tokens = True
516533
output = super().forward(hidden_states)
517534
self._count_expert_tokens = False
518535
return output

0 commit comments

Comments
 (0)