Skip to content

Commit 6b1f7da

Browse files
Edwardf0t1claude
authored andcommitted
Generic Fused MoE Quantization + Export for transformers 5.0+ (#1187)
## What does this PR do? Add generic quantization and export support for **fused MoE expert modules** in HuggingFace transformers 5.0+. In transformers 5.0+, all major MoE models switched from sequential per-expert `nn.ModuleList` to **fused 3D tensor parameters** (`gate_up_proj`, `down_proj`). This breaks ModelOpt's existing per-expert quantization and export pipeline, which assumes iterable expert submodules. **Affected models (verified against transformers v5.5.0 source):** - `MixtralExperts` (Mixtral) - `Qwen2MoeExperts` (Qwen2-MoE) - `Qwen3MoeExperts` (Qwen3-MoE) - `Qwen3_5MoeExperts` (Qwen3.5-MoE) - `DeepseekV3NaiveMoe` (DeepSeek-V3) - `JambaExperts`, `OlmoeExperts`, and any future model following the same HF standard pattern **Key insight:** All these models share an identical fused expert structure and forward pattern. A single generic solution replaces N model-specific implementations. ### Context: relationship to PR #975 and PR #1170 - **PR #975** (`kmorabi/bump-transformers-5.0`): Adds experimental transformers 5.0 support but explicitly skips batched MoE experts. This PR fills that gap. - **PR #1170** (`chenjiel/refactor_qwen35`): Handles only Qwen3.5 using `_QuantFunctionalMixin`. This PR generalizes that approach to all fused MoE models. ### Changes **Quantization** (`modelopt/torch/quantization/plugins/huggingface.py`): - `_QuantFusedExperts(_QuantFunctionalMixin)` -- Generic wrapper that intercepts `F.linear` calls and applies per-expert quantization via storage-offset-based expert index recovery. Each expert gets its own weight and input quantizers (`nn.ModuleList`). - `_is_fused_experts_module()` -- Structural detector: `gate_up_proj` (3D) + `down_proj` (3D) + `num_experts` + `act_fn`. - `register_fused_experts_on_the_fly()` -- Auto-registration callback, added to `CUSTOM_MODEL_PLUGINS` before `register_sparse_moe_on_the_fly` so explicit registrations (Llama4, GptOss, etc.) take priority. - `_get_fused_expert_intermediate_dim()` -- Helper for cross-version attribute name resolution (`intermediate_dim` / `intermediate_size` / fallback to shape). **Export** (`modelopt/torch/export/moe_utils.py`, `unified_export_hf.py`, `layer_utils.py`): - `_export_fused_experts()` -- Splits fused 3D weights into per-expert 2D projections (`gate_proj`, `up_proj`, `down_proj`), handles amax fallback for uncalibrated experts, proportionally slices per-channel amax, and registers results under the standard `experts.{E}.gate_proj.weight` naming convention. - Integration in `_process_quantized_modules` and `_export_transformers_checkpoint` to dispatch to `_export_fused_experts` for fused expert modules. - Structural detection in `get_expert_linear_names` for fused experts. Added `MixtralSparseMoeBlock` to the `gate_proj`/`down_proj`/`up_proj` group (transformers 5.0 naming). **Tests** (`tests/unit/torch/quantization/plugins/test_fused_experts.py`): - Synthetic fused expert model matching the exact HF 5.0+ pattern. - Tests for structural detection, auto-registration, two-level registration (block + expert), quantizer creation, forward pass-through correctness, expert index recovery, and export output structure. ### Two-level registration design SparseMoeBlock --> _QuantSparseMoe (calibration control, token counting, top_k override) .experts --> _QuantFusedExperts (per-expert F.linear interception + quantization) `register_fused_experts_on_the_fly` runs first to register the inner expert module; `register_sparse_moe_on_the_fly` then registers the outer block. `_QuantSparseMoe.layer_sync_moe_local_experts_amax` skips fused experts (they are not iterable), as per-expert amax is managed internally by `_QuantFusedExperts`. ### Known limitations - **`@use_experts_implementation` backends**: The `F.linear` interception only works with `experts_implementation="eager"` (default). `batched_mm` / `grouped_mm` use `torch.bmm` / `torch._grouped_mm` instead and are not intercepted. - **Storage offset fragility**: Expert index recovery via `storage_offset()` breaks under `.contiguous()`, FSDP2 redistribution, or `torch.compile` materialization. Runtime assertions are included. - **Toggle state machine**: Assumes exactly 2 `F.linear` calls per expert. Documented in docstrings. - **Non-standard MoE models**: DBRX, GptOss, Llama4, Step3p5 have different layouts and are already explicitly handled. The generic solution does not attempt to cover these. ### Testing - [x] Unit tests with synthetic fused expert model: detection, registration, quantization, export - [x] Verify existing sequential MoE tests still pass (`test_sparse_moe.py`) - [ ] GPU test with a real MoE model on transformers 5.x ### Before your PR is "Ready for review" - Is this change backward compatible?: Yes -- existing explicit registrations take priority; sequential MoE models are unaffected. - Did you write any new necessary tests?: Yes - Did you update Changelog?: No (pending) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added quantization support for fused Mixture-of-Experts (MoE) modules with automatic detection, per-expert quantization handling, and export to per-expert submodules; unified checkpoint export now supports fused MoE experts. * **Tests** * Added end-to-end tests covering fused-experts detection, conversion, forward correctness, expert index recovery, and export. * **Changelog** * Updated release notes to announce fused MoE expert support for Hugging Face exports. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e861961 commit 6b1f7da

8 files changed

Lines changed: 625 additions & 102 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Changelog
2727

2828
- [Security] Changed the default of ``weights_only`` to ``True`` in ``torch.load`` for secure checkpoint loading. If you need to load a checkpoint that requires unpickling arbitrary objects, first register the class in ``torch.serialization.add_safe_globals([cls])`` before loading. Added :meth:`safe_save <modelopt.torch.utils.serialization.safe_save>` and :meth:`safe_load <modelopt.torch.utils.serialization.safe_load>` API to save and load checkpoints securely.
2929
- Bump minimum required PyTorch version to 2.8.
30-
- [Experimental] Add support for transformers>=5.0. Unified Hugging Face checkpoint export for quantized checkpoints may not work for MoE models with transformers>=5.0 yet.
30+
- [Experimental] Add support for transformers>=5.0, including generic PTQ and unified HF checkpoint export for fused MoE expert modules (Mixtral, Qwen2-MoE, Qwen3-MoE, Qwen3.5-MoE, DeepSeek-V3, Jamba, OLMoE, etc.).
3131
- Improve ``megatron_preprocess_data``: add ``--reasoning_content`` support for Nemotron v3 datasets, eliminate intermediate JSONL for HuggingFace datasets, return output file prefixes from the Python API, add gzip input support (``.jsonl.gz``), add ``--strip_newlines`` flag for plain-text pretraining data, add ``--hf_streaming`` for very large datasets (only consumed rows downloaded), and auto-shuffle when ``--hf_max_samples_per_split`` is set to avoid biased sampling.
3232

3333
0.43 (2026-04-09)

modelopt/torch/export/layer_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,12 @@ def module_match_name_list(module, name_list):
965965
"""
966966
return any(name.lower() in type(module).__name__.lower() for name in name_list)
967967

968+
# Structural detection: after _export_fused_experts, fused expert modules
969+
# have per-expert submodules with gate_proj/up_proj/down_proj.
970+
# Also handles models that originally used this naming (Qwen, DeepSeek, etc.).
971+
if hasattr(module, "experts") and hasattr(module.experts, "gate_up_proj_weight_quantizers"):
972+
return ["gate_up_proj", "down_proj"]
973+
968974
if module_match_name_list(
969975
module,
970976
[
@@ -976,12 +982,17 @@ def module_match_name_list(module, name_list):
976982
],
977983
):
978984
return ["gate_proj", "down_proj", "up_proj"]
985+
elif module_match_name_list(module, ["MixtralSparseMoeBlock"]):
986+
# Old-style Mixtral (iterable experts) uses w1/w2/w3.
987+
# Fused Mixtral (transformers 5.0+) is already handled by the
988+
# structural gate_up_proj_weight_quantizers check above.
989+
return ["w1", "w2", "w3"]
979990
elif module_match_name_list(module, ["MixtralMoeSparseMoeBlock"]):
991+
# Older transformers naming for Mixtral
980992
return ["linear_fc1", "linear_fc2"]
981993
elif module_match_name_list(module, ["DBRXMoeSparseMoeBlock"]):
982994
return ["w1_linear", "w2_linear", "v1_linear"]
983995
elif module_match_name_list(module, ["GptOssMoE"]):
984-
# GPT-OSS MoE modules use gate_up_proj and down_proj
985996
return ["gate_up_proj", "down_proj"]
986997
else:
987998
# assuming w1, w2, w3 by default

modelopt/torch/export/moe_utils.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,139 @@
1515

1616
"""Utilities for Mixture-of-Experts (MoE) model export."""
1717

18+
import copy
19+
import warnings
1820
from pathlib import Path
1921

22+
import torch
2023
import torch.nn as nn
2124

2225

26+
def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
27+
"""Split fused MoE expert weights and export per-expert quantization scales.
28+
29+
Works with any module wrapped by ``_QuantFusedExperts`` — i.e. any HF
30+
transformers 5.0+ fused expert container that stores ``gate_up_proj`` and
31+
``down_proj`` as 3-D ``nn.Parameter`` tensors with per-expert quantizer
32+
``nn.ModuleList`` s.
33+
34+
Steps:
35+
36+
1. Handle amax fallback for uncalibrated expert input quantizers.
37+
2. Split fused 3-D weights into per-expert 2-D projections
38+
(``gate_proj``, ``up_proj``, ``down_proj``).
39+
3. Call ``_export_quantized_weight`` on each projection.
40+
4. Register results under the standard naming convention::
41+
42+
{E}.gate_proj.weight, {E}.gate_proj.weight_scale, ...
43+
{E}.up_proj.weight, {E}.up_proj.weight_scale, ...
44+
{E}.down_proj.weight, {E}.down_proj.weight_scale, ...
45+
"""
46+
from modelopt.torch.export.unified_export_hf import _export_quantized_weight
47+
from modelopt.torch.quantization.plugins.huggingface import _get_fused_expert_intermediate_dim
48+
49+
n = module.num_experts
50+
expert_dim = _get_fused_expert_intermediate_dim(module)
51+
52+
# 1. Shared input quantizers — one per projection type, shared across all experts.
53+
gate_up_input_q = module.gate_up_proj_input_quantizer
54+
down_input_q = module.down_proj_input_quantizer
55+
56+
gate_up = module.gate_up_proj.data
57+
down = module.down_proj.data
58+
59+
# 2-3. Split + export each per-expert projection.
60+
fused_dim0 = gate_up.shape[1] # 2 * expert_dim
61+
62+
for idx in range(n):
63+
expert = nn.Module()
64+
65+
projections = [
66+
("gate_proj", gate_up[idx, :expert_dim, :], 0, fused_dim0, True),
67+
("up_proj", gate_up[idx, expert_dim:, :], expert_dim, fused_dim0, True),
68+
("down_proj", down[idx], 0, down.shape[1], False),
69+
]
70+
71+
for proj_name, weight_slice, fused_start, fused_total, is_gate_up in projections:
72+
w_quantizer_src = (
73+
module.gate_up_proj_weight_quantizers[idx]
74+
if is_gate_up
75+
else module.down_proj_weight_quantizers[idx]
76+
)
77+
i_quantizer = gate_up_input_q if is_gate_up else down_input_q
78+
79+
# gate/up share a weight quantizer — clone so each gets independent amax.
80+
w_quantizer = copy.deepcopy(w_quantizer_src) if is_gate_up else w_quantizer_src
81+
82+
# For per-channel amax (dim >= 1), proportionally slice dim-0
83+
# to match the split weight.
84+
if (
85+
hasattr(w_quantizer, "_amax")
86+
and w_quantizer._amax is not None
87+
and w_quantizer._amax.dim() >= 1
88+
):
89+
amax = w_quantizer._amax
90+
amax_dim0 = amax.shape[0]
91+
if fused_total % amax_dim0 == 0:
92+
slice_start = fused_start * amax_dim0 // fused_total
93+
slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total
94+
w_quantizer.amax = amax[slice_start:slice_end].contiguous()
95+
else:
96+
warnings.warn(
97+
f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not "
98+
f"evenly divide fused_total ({fused_total}). Skipping amax slicing, "
99+
f"which may produce incorrect quantization scales.",
100+
stacklevel=2,
101+
)
102+
103+
# If the weight quantizer was never calibrated, compute amax from weights.
104+
if (
105+
hasattr(w_quantizer, "is_enabled")
106+
and w_quantizer.is_enabled
107+
and (
108+
not hasattr(w_quantizer, "_amax")
109+
or w_quantizer._amax is None
110+
or torch.all(w_quantizer._amax == 0)
111+
)
112+
):
113+
w_quantizer.amax = weight_slice.abs().amax().to(torch.float32)
114+
warnings.warn(
115+
f"Expert {idx} {proj_name} weight quantizer was not calibrated "
116+
f"(amax missing or zero). Using weight-derived amax as fallback. "
117+
f"Consider using more calibration data to activate all experts.",
118+
stacklevel=2,
119+
)
120+
121+
wrapper = nn.Module()
122+
wrapper.weight = nn.Parameter(weight_slice.contiguous(), requires_grad=False)
123+
wrapper.weight_quantizer = w_quantizer
124+
wrapper.input_quantizer = i_quantizer
125+
126+
_export_quantized_weight(wrapper, dtype)
127+
128+
proj = nn.Module()
129+
proj.weight = wrapper.weight
130+
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
131+
if hasattr(wrapper, attr):
132+
proj.register_buffer(attr, getattr(wrapper, attr))
133+
134+
expert.add_module(proj_name, proj)
135+
136+
module.add_module(str(idx), expert)
137+
138+
# 4. Remove fused params and quantizer lists — replaced by per-expert submodules
139+
for attr in (
140+
"gate_up_proj",
141+
"down_proj",
142+
"gate_up_proj_weight_quantizers",
143+
"gate_up_proj_input_quantizer",
144+
"down_proj_weight_quantizers",
145+
"down_proj_input_quantizer",
146+
):
147+
if hasattr(module, attr):
148+
delattr(module, attr)
149+
150+
23151
def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | None = None):
24152
"""Collect expert_token_count from all quantized MoE layers and save as an HTML table.
25153

modelopt/torch/export/unified_export_hf.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,13 @@ def _process_quantized_modules(
677677
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
678678
for weight_name in ["gate_up_proj", "down_proj"]:
679679
_export_quantized_weight(sub_module, dtype, weight_name)
680+
elif hasattr(sub_module, "gate_up_proj_weight_quantizers"):
681+
# Generic fused MoE experts (_QuantFusedExperts) with per-expert
682+
# quantizer ModuleLists. Split into per-expert modules and export.
683+
from modelopt.torch.export.moe_utils import _export_fused_experts
684+
685+
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
686+
_export_fused_experts(sub_module, dtype)
680687

681688

682689
def _export_transformers_checkpoint(
@@ -721,6 +728,9 @@ def _export_transformers_checkpoint(
721728
modules=list(linear_modulelist),
722729
quantizer_attrs=["input_quantizer"],
723730
)
731+
elif hasattr(sub_module.experts, "gate_up_proj_weight_quantizers"):
732+
# _QuantFusedExperts: amax fallback is handled in _export_fused_experts
733+
break
724734
elif "QuantGptOssExperts" in type(sub_module.experts).__name__:
725735
# Handle GPT-OSS experts specifically
726736
# GPT-OSS experts use gate_up_proj and down_proj

0 commit comments

Comments
 (0)