Skip to content

Commit ac7c985

Browse files
authored
[NVBUG: 5804406] Auto detect MOE layers (#900)
## What does this PR do? **Type of change:** New feature, new tests **Overview:** Replace hardcoded per-model MoE class registrations (Mixtral, Qwen2Moe, Qwen3Moe, Qwen3Next, Llama4TextMoe, Qwen3VLMoe, MiniMaxM2, etc.) with a single generic auto-detection mechanism (`register_sparse_moe_on_the_fly`) that walks the model tree and identifies MoE blocks by their structural attributes (`gate` + `experts` with `top_k`/`num_experts`). This makes MoE quantization forward-compatible with new HuggingFace MoE architectures without requiring explicit registration for each model family. Additionally, this PR: - Tracks per-expert token routing counts during calibration via a gate forward hook, enabling visibility into expert utilization. - Saves an HTML report of expert token counts during export (`save_expert_token_count_table`), highlighting under-utilized experts. - Fixes the `topk` -> `top_k` attribute name for transformers >= 5.0 compatibility. - Also move the ptq summary prints to a file in hf_ptq.py to reduce the prints ## Usage Auto-detection is transparent -- no user-facing API changes are needed. Any HuggingFace MoE model with the standard `gate`/`experts` pattern is automatically detected and quantized: import modelopt.torch.quantization as mtq # Any HuggingFace MoE model (Mixtral, Qwen3Moe, DeepSeek, etc.) model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-30B-A3B") mtq.quantize(model, mtq.INT8_DEFAULT_CFG, forward_loop) # During export, an .moe.html report with per-expert token counts is saved automatically ## Testing unittest, also test exporting qwen MOE ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added expert token count visualization for Mixture of Experts models, exported as HTML reports during model export. * Enhanced sparse MoE quantization with improved calibration-aware routing and automatic model block detection. * **Tests** * Added comprehensive test suite for sparse MoE quantization validation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent c4b662f commit ac7c985

File tree

7 files changed

+531
-77
lines changed

7 files changed

+531
-77
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ NVIDIA Model Optimizer Changelog (Linux)
66

77
**New Features**
88

9+
- User does not need to manually register MOE modules to cover experts calibration coverage in PTQ workflow.
10+
- ``hf_ptq.py`` now saves the quantization summary and moe expert token count table to the export directory.
911
- Add sparse attention optimization for transformer models (``modelopt.torch.sparsity.attention_sparsity``). This reduces computational cost by skipping attention computation. Supports calibration for threshold selection on HuggingFace models. See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
1012

1113
0.42 (2026-02-xx)

examples/llm_ptq/hf_ptq.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
export_hf_checkpoint,
5454
export_tensorrt_llm_checkpoint,
5555
get_model_type,
56+
save_expert_token_count_table,
5657
)
5758
from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model
5859
from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration
@@ -726,7 +727,12 @@ def post_quantize(
726727
"""
727728

728729
if args.verbose:
729-
mtq.print_quant_summary(full_model)
730+
try:
731+
mtq.print_quant_summary(full_model, args.export_path)
732+
save_expert_token_count_table(full_model, args.export_path)
733+
except Exception as e:
734+
print(f"Error saving quant summary: {e}")
735+
print("Continuing with generation...")
730736

731737
# Run some samples
732738
torch.cuda.empty_cache()

modelopt/torch/export/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .model_config import *
2020
from .model_config_export import *
2121
from .model_utils import *
22+
from .moe_utils import *
2223
from .plugins import *
2324
from .transformer_engine import *
2425
from .unified_export_hf import *

modelopt/torch/export/moe_utils.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Utilities for Mixture-of-Experts (MoE) model export."""
17+
18+
from pathlib import Path
19+
20+
import torch.nn as nn
21+
22+
23+
def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | None = None):
24+
"""Collect expert_token_count from all quantized MoE layers and save as an HTML table.
25+
26+
The table has rows for each MoE layer and columns for each expert, with cell values
27+
showing the number of tokens routed to that expert during calibration.
28+
29+
Args:
30+
model: The model containing quantized MoE layers with ``expert_token_count`` attributes.
31+
output_dir: Directory to save the HTML file. Defaults to current directory.
32+
"""
33+
rows = []
34+
for name, module in model.named_modules():
35+
if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0:
36+
rows.append((name, module.expert_token_count))
37+
38+
if not rows:
39+
return
40+
41+
num_experts = rows[0][1].shape[0]
42+
assert all(r[1].shape[0] == num_experts for r in rows), (
43+
"All MoE layers must have the same number of experts"
44+
)
45+
html_parts = [
46+
"<html><head><style>",
47+
"table { border-collapse: collapse; font-family: monospace; }",
48+
"th, td { border: 1px solid #ccc; padding: 4px 8px; text-align: right; }",
49+
"th { background: #f0f0f0; }",
50+
"</style></head><body>",
51+
"<h2>Expert Token Counts (per MoE layer)</h2>",
52+
"<table><tr><th>Layer/Expert</th>",
53+
]
54+
html_parts.extend(f"<th>{i}</th>" for i in range(num_experts))
55+
html_parts.append("</tr>")
56+
57+
for name, counts in rows:
58+
avg = counts.float().mean().item()
59+
html_parts.append(f"<tr><td>{name}</td>")
60+
for c in counts.tolist():
61+
if avg > 0 and c < avg * 0.05:
62+
style = ' style="background: #ff6666;"'
63+
elif avg > 0 and c < avg * 0.1:
64+
style = ' style="background: #ffcccc;"'
65+
else:
66+
style = ""
67+
html_parts.append(f"<td{style}>{c}</td>")
68+
html_parts.append("</tr>")
69+
70+
html_parts.append("</table></body></html>")
71+
html_content = "\n".join(html_parts)
72+
73+
if output_dir is None:
74+
output_dir = Path(".")
75+
output_path = Path(output_dir) / ".moe.html"
76+
output_path.write_text(html_content, encoding="utf-8")
77+
print(f"\033[1mExpert token count table saved to {output_path}\033[0m")

modelopt/torch/quantization/model_quant.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -508,14 +508,26 @@ def enable_quantizer(model: nn.Module, wildcard_or_filter_func: str | Callable):
508508

509509

510510
@atomic_print
511-
def print_quant_summary(model: nn.Module):
511+
def print_quant_summary(model: nn.Module, output_dir: str | None = None):
512512
"""Print summary of all quantizer modules in the model."""
513-
count = 0
514-
for name, mod in model.named_modules():
515-
if isinstance(mod, TensorQuantizer):
516-
print(f"{name:80} {mod}")
517-
count += 1
518-
print(f"{count} TensorQuantizers found in model")
513+
lines = [
514+
f"{name:80} {mod}"
515+
for name, mod in model.named_modules()
516+
if isinstance(mod, TensorQuantizer)
517+
]
518+
lines.append(f"{len(lines)} TensorQuantizers found in model")
519+
520+
if output_dir:
521+
path = (
522+
output_dir.joinpath(".quant_summary.txt")
523+
if hasattr(output_dir, "joinpath")
524+
else f"{output_dir}/.quant_summary.txt"
525+
)
526+
with open(path, "w", encoding="utf-8") as f:
527+
f.write("\n".join(lines) + "\n")
528+
print(f"\033[1mQuant summary saved to {path}\033[0m")
529+
else:
530+
print("\n".join(lines))
519531

520532

521533
def fold_weight(model: nn.Module):

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 99 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -450,20 +450,56 @@ class _QuantSparseMoe(QuantModule):
450450
"""
451451

452452
def _setup(self):
453-
pass
453+
num_experts = 0
454+
if hasattr(self, "gate") and hasattr(self.gate, "num_experts"):
455+
num_experts = self.gate.num_experts
456+
elif hasattr(self, "num_experts"):
457+
num_experts = self.num_experts
458+
elif hasattr(self, "experts") and hasattr(self.experts, "num_experts"):
459+
num_experts = self.experts.num_experts
460+
461+
self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cpu")
462+
self._count_expert_tokens = False
463+
464+
if num_experts == 0:
465+
warnings.warn(
466+
f"{self.__class__.__name__}: could not resolve num_experts; "
467+
"expert routing will not be tracked for this layer."
468+
)
469+
return
470+
471+
if hasattr(self, "gate"):
472+
self.gate.register_forward_hook(self._gate_forward_hook)
473+
474+
def _gate_forward_hook(self, module, input, output):
475+
if not self._count_expert_tokens:
476+
return
477+
with torch.no_grad():
478+
if isinstance(output, tuple) and len(output) >= 3:
479+
# v5.x TopKRouter: returns (logits, scores, indices)
480+
indices = output[2]
481+
else:
482+
# v4.x nn.Linear gate: returns logits tensor
483+
logits = output if not isinstance(output, tuple) else output[0]
484+
top_k = self.gate.top_k if hasattr(self.gate, "top_k") else self.top_k
485+
_, indices = torch.topk(logits.float(), top_k, dim=-1)
486+
counts = torch.bincount(
487+
indices.reshape(-1).cpu(), minlength=len(self.expert_token_count)
488+
)
489+
self.expert_token_count += counts
454490

455491
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
456-
if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
492+
is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules())
493+
if is_calib:
457494
# If any of the experts are in calibration mode, we will forward all tokens to all experts
458495
# This is used only for calibration, we need to re-calculate the actual outputs again using
459496
# the original top_k
460497
if TRANSFORMERS_VERSION_GE_5_0:
461-
assert hasattr(self, "gate")
462-
# Path for transformers >= 5.0
463-
original_top_k = self.gate.topk
464-
self.gate.topk = self.gate.num_experts
498+
assert hasattr(self, "gate") and hasattr(self.gate, "top_k")
499+
original_top_k = self.gate.top_k
500+
self.gate.top_k = self.gate.num_experts
465501
super().forward(hidden_states)
466-
self.gate.topk = original_top_k
502+
self.gate.top_k = original_top_k
467503
else:
468504
# Path for transformers < 5.0
469505
original_top_k = self.top_k
@@ -475,7 +511,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
475511
raise ValueError(f"Could not find num_experts in module {self}")
476512
super().forward(hidden_states)
477513
self.top_k = original_top_k
478-
return super().forward(hidden_states)
514+
# Enable counting only for the real-routing forward during calibration
515+
self._count_expert_tokens = is_calib
516+
output = super().forward(hidden_states)
517+
self._count_expert_tokens = False
518+
return output
479519

480520

481521
class _QuantLlama4TextExperts(QuantModule):
@@ -765,10 +805,7 @@ def unpack_weight(self):
765805

766806

767807
try:
768-
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe
769-
770-
if Llama4TextMoe not in QuantModuleRegistry:
771-
QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantSparseMoe)
808+
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts
772809

773810
if Llama4TextExperts not in QuantModuleRegistry:
774811
QuantModuleRegistry.register({Llama4TextExperts: "hf.Llama4TextExperts"})(
@@ -791,16 +828,6 @@ def unpack_weight(self):
791828
except ImportError:
792829
pass
793830

794-
try:
795-
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
796-
797-
if MixtralSparseMoeBlock not in QuantModuleRegistry:
798-
QuantModuleRegistry.register({MixtralSparseMoeBlock: "hf.MixtralSparseMoeBlock"})(
799-
_QuantSparseMoe
800-
)
801-
except ImportError:
802-
pass
803-
804831
try:
805832
from transformers.models.falcon.modeling_falcon import FalconLinear
806833

@@ -809,36 +836,6 @@ def unpack_weight(self):
809836
except ImportError:
810837
pass
811838

812-
try:
813-
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
814-
815-
if Qwen3MoeSparseMoeBlock not in QuantModuleRegistry:
816-
QuantModuleRegistry.register({Qwen3MoeSparseMoeBlock: "hf.Qwen3MoeSparseMoeBlock"})(
817-
_QuantSparseMoe
818-
)
819-
except ImportError:
820-
pass
821-
822-
try:
823-
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
824-
825-
if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry:
826-
QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})(
827-
_QuantSparseMoe
828-
)
829-
except ImportError:
830-
pass
831-
832-
try:
833-
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
834-
835-
if Qwen3NextSparseMoeBlock not in QuantModuleRegistry:
836-
QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})(
837-
_QuantSparseMoe
838-
)
839-
except ImportError:
840-
pass
841-
842839
try:
843840
from compressed_tensors.linear.compressed_linear import CompressedLinear
844841

@@ -850,15 +847,7 @@ def unpack_weight(self):
850847
pass
851848

852849
try:
853-
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
854-
Qwen3VLMoeTextExperts,
855-
Qwen3VLMoeTextSparseMoeBlock,
856-
)
857-
858-
if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry:
859-
QuantModuleRegistry.register(
860-
{Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"}
861-
)(_QuantSparseMoe)
850+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts
862851

863852
if Qwen3VLMoeTextExperts not in QuantModuleRegistry:
864853
QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})(
@@ -989,15 +978,56 @@ def register_falcon_linears_on_the_fly(model):
989978
QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear)
990979

991980

992-
def register_minimax_m2_moe_on_the_fly(model):
993-
"""Register MiniMax M2 MoE modules as a QUANT_MODULE.
981+
def _is_sparse_moe_block(module):
982+
"""Check if a module is structurally a sparse MoE block compatible with _QuantSparseMoe.
983+
984+
All HuggingFace MoE blocks (Mixtral, Qwen3Moe, Qwen2Moe, Qwen3Next, Llama4, MiniMax, etc.)
985+
share a common structural pattern: a ``gate`` (TopKRouter) sub-module with routing attributes
986+
(``top_k`` and ``num_experts``), and an ``experts`` sub-module.
994987
995-
MiniMax M2 MoE modules are defined in the model card, so we need to register them on the fly.
988+
This function detects that pattern instead of relying on class names, making it forward-compatible
989+
with new MoE architectures. Some MoE models (e.g. Glm4MoeMoE) have ``gate`` and ``experts`` but
990+
use a different routing interface (``n_routed_experts`` instead of ``num_experts``, custom
991+
``route_tokens_to_experts``), so we require ``num_experts`` to be present to avoid false positives.
996992
"""
997-
if type(model).__name__ in ["MiniMaxM2ForCausalLM"]:
998-
moe_type = type(model.model.layers[0].block_sparse_moe)
999-
if QuantModuleRegistry.get(moe_type) is None:
1000-
QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantSparseMoe)
993+
if not hasattr(module, "experts"):
994+
return False
995+
996+
# Primary: gate sub-module has topk/top_k + num_experts (standard TopKRouter pattern)
997+
if hasattr(module, "gate"):
998+
gate = module.gate
999+
has_topk = hasattr(gate, "top_k")
1000+
has_num_experts = hasattr(gate, "num_experts")
1001+
if has_topk and has_num_experts:
1002+
return True
1003+
1004+
# Fallback: top_k + num_experts on the block itself (older transformers, e.g. v4.x Qwen3Next)
1005+
return hasattr(module, "top_k") and hasattr(module, "num_experts")
1006+
1007+
1008+
def register_sparse_moe_on_the_fly(model):
1009+
"""Auto-detect and register MOE modules as _QuantSparseMoe.
1010+
1011+
Walks the model tree, identifies MoE blocks by their structural attributes
1012+
(``gate`` + ``experts``), and registers unregistered ones with ``_QuantSparseMoe``.
1013+
"""
1014+
visited_types = set()
1015+
for name, module in model.named_modules():
1016+
mod_type = type(module)
1017+
1018+
# Avoid duplicate registration: skip if we already processed this type
1019+
# in this walk, or if it was previously registered in the QuantModuleRegistry.
1020+
if mod_type in visited_types or QuantModuleRegistry.get(mod_type) is not None:
1021+
continue
1022+
1023+
visited_types.add(mod_type)
1024+
1025+
if _is_sparse_moe_block(module):
1026+
print(
1027+
f"\033[1mDetected MOE module '{name}' of type {mod_type.__name__}, "
1028+
f"registering with _QuantSparseMoe.\033[0m"
1029+
)
1030+
QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(_QuantSparseMoe)
10011031

10021032

10031033
def _is_supported_hf_model(model):
@@ -1065,7 +1095,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model):
10651095
[
10661096
register_falcon_linears_on_the_fly,
10671097
register_dbrx_moe_on_the_fly,
1068-
register_minimax_m2_moe_on_the_fly,
1098+
register_sparse_moe_on_the_fly,
10691099
register_hf_attentions_on_the_fly,
10701100
convert_hf_parallel_linears_on_the_fly,
10711101
]

0 commit comments

Comments
 (0)