Skip to content

Commit ea44272

Browse files
committed
address reviews
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
1 parent fb81b00 commit ea44272

4 files changed

Lines changed: 31 additions & 29 deletions

File tree

modelopt/torch/export/layer_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -979,10 +979,14 @@ def module_match_name_list(module, name_list):
979979
"Qwen3NextSparseMoeBlock",
980980
"Qwen3_5MoeSparseMoeBlock",
981981
"DeepseekMoE",
982-
"MixtralSparseMoeBlock",
983982
],
984983
):
985984
return ["gate_proj", "down_proj", "up_proj"]
985+
elif module_match_name_list(module, ["MixtralSparseMoeBlock"]):
986+
# Old-style Mixtral (iterable experts) uses w1/w2/w3.
987+
# Fused Mixtral (transformers 5.0+) is already handled by the
988+
# structural gate_up_proj_weight_quantizers check above.
989+
return ["w1", "w2", "w3"]
986990
elif module_match_name_list(module, ["MixtralMoeSparseMoeBlock"]):
987991
# Older transformers naming for Mixtral
988992
return ["linear_fc1", "linear_fc2"]

modelopt/torch/export/moe_utils.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Utilities for Mixture-of-Experts (MoE) model export."""
1717

1818
import copy
19+
import warnings
1920
from pathlib import Path
2021

2122
import torch
@@ -49,17 +50,9 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
4950
n = module.num_experts
5051
expert_dim = _get_fused_expert_intermediate_dim(module)
5152

52-
# 1. Input amax fallback — borrow from calibrated peers.
53-
for quantizer_list in [
54-
module.gate_up_proj_input_quantizers,
55-
module.down_proj_input_quantizers,
56-
]:
57-
wrappers = []
58-
for q in quantizer_list:
59-
w = nn.Module()
60-
w.input_quantizer = q
61-
wrappers.append(w)
62-
set_expert_quantizer_amax(modules=wrappers, quantizer_attrs=["input_quantizer"])
53+
# 1. Shared input quantizers — one per projection type, shared across all experts.
54+
gate_up_input_q = module.gate_up_proj_input_quantizer
55+
down_input_q = module.down_proj_input_quantizer
6356

6457
gate_up = module.gate_up_proj.data
6558
down = module.down_proj.data
@@ -82,11 +75,7 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
8275
if is_gate_up
8376
else module.down_proj_weight_quantizers[idx]
8477
)
85-
i_quantizer = (
86-
module.gate_up_proj_input_quantizers[idx]
87-
if is_gate_up
88-
else module.down_proj_input_quantizers[idx]
89-
)
78+
i_quantizer = gate_up_input_q if is_gate_up else down_input_q
9079

9180
# gate/up share a weight quantizer — clone so each gets independent amax.
9281
w_quantizer = copy.deepcopy(w_quantizer_src) if is_gate_up else w_quantizer_src
@@ -116,6 +105,12 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
116105
)
117106
):
118107
w_quantizer.amax = weight_slice.abs().amax().to(torch.float32)
108+
warnings.warn(
109+
f"Expert {idx} {proj_name} weight quantizer was not calibrated "
110+
f"(amax missing or zero). Using weight-derived amax as fallback. "
111+
f"Consider using more calibration data to activate all experts.",
112+
stacklevel=2,
113+
)
119114

120115
wrapper = nn.Module()
121116
wrapper.weight = nn.Parameter(weight_slice.contiguous(), requires_grad=False)
@@ -139,9 +134,9 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
139134
"gate_up_proj",
140135
"down_proj",
141136
"gate_up_proj_weight_quantizers",
142-
"gate_up_proj_input_quantizers",
137+
"gate_up_proj_input_quantizer",
143138
"down_proj_weight_quantizers",
144-
"down_proj_input_quantizers",
139+
"down_proj_input_quantizer",
145140
):
146141
if hasattr(module, attr):
147142
delattr(module, attr)

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -892,8 +892,9 @@ class _QuantFusedExperts(_QuantFunctionalMixin):
892892
893893
Per-expert quantization is achieved by intercepting ``F.linear`` and recovering
894894
the expert index from the weight tensor's storage offset into the 3-D parameter.
895-
Each expert gets its own weight and input quantizers (``nn.ModuleList``), so
896-
calibration granularity matches the per-expert decomposition approach.
895+
Each expert gets its own weight quantizers (``nn.ModuleList``), while input
896+
quantizers are shared across all experts (single ``TensorQuantizer``) to match
897+
the shared input quantization scale used by downstream inference frameworks.
897898
898899
Verified compatible models: Mixtral, Qwen2-MoE, Qwen3-MoE, Qwen3.5-MoE,
899900
DeepSeek-V3, Jamba, OLMoE.
@@ -926,9 +927,9 @@ def _get_expert_idx_from_gate_up(self, weight: torch.Tensor) -> int:
926927

927928
def _setup(self):
928929
n = self.num_experts
929-
self.gate_up_proj_input_quantizers = nn.ModuleList([TensorQuantizer() for _ in range(n)])
930+
self.gate_up_proj_input_quantizer = TensorQuantizer()
930931
self.gate_up_proj_weight_quantizers = nn.ModuleList([TensorQuantizer() for _ in range(n)])
931-
self.down_proj_input_quantizers = nn.ModuleList([TensorQuantizer() for _ in range(n)])
932+
self.down_proj_input_quantizer = TensorQuantizer()
932933
self.down_proj_weight_quantizers = nn.ModuleList([TensorQuantizer() for _ in range(n)])
933934

934935
self._register_temp_attribute("_down_proj_linear", False)
@@ -944,12 +945,12 @@ def functionals_to_replace(self):
944945
def _quantized_linear(input, weight, bias=None):
945946
if self._down_proj_linear:
946947
idx = self._current_expert_idx
947-
input = self.down_proj_input_quantizers[idx](input)
948+
input = self.down_proj_input_quantizer(input)
948949
weight = self.down_proj_weight_quantizers[idx](weight)
949950
else:
950951
idx = self._get_expert_idx_from_gate_up(weight)
951952
self._current_expert_idx = idx
952-
input = self.gate_up_proj_input_quantizers[idx](input)
953+
input = self.gate_up_proj_input_quantizer(input)
953954
weight = self.gate_up_proj_weight_quantizers[idx](weight)
954955
self._down_proj_linear = not self._down_proj_linear
955956
return _orig_linear(input, weight, bias)

tests/unit/torch/quantization/plugins/test_fused_experts.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,19 +189,21 @@ def test_two_level_registration(self):
189189
self._cleanup_registry(block_type)
190190

191191
def test_convert_creates_quantizers(self):
192-
"""After conversion, fused experts should have per-expert quantizer ModuleLists."""
192+
"""After conversion, fused experts should have shared input and per-expert weight quantizers."""
193193
model = _TinyMoEModel()
194194
expert_type = type(model.moe.experts)
195195
self._cleanup_registry(expert_type)
196196

197197
register_fused_experts_on_the_fly(model)
198198
converted = QuantModuleRegistry.convert(model.moe.experts)
199199

200-
assert hasattr(converted, "gate_up_proj_input_quantizers")
200+
# Shared input quantizers (single TensorQuantizer, not ModuleList)
201+
assert hasattr(converted, "gate_up_proj_input_quantizer")
202+
assert hasattr(converted, "down_proj_input_quantizer")
203+
# Per-expert weight quantizers (ModuleList)
201204
assert hasattr(converted, "gate_up_proj_weight_quantizers")
202-
assert hasattr(converted, "down_proj_input_quantizers")
203205
assert hasattr(converted, "down_proj_weight_quantizers")
204-
assert len(converted.gate_up_proj_input_quantizers) == NUM_EXPERTS
206+
assert len(converted.gate_up_proj_weight_quantizers) == NUM_EXPERTS
205207
assert len(converted.down_proj_weight_quantizers) == NUM_EXPERTS
206208
self._cleanup_registry(expert_type)
207209

0 commit comments

Comments
 (0)