Skip to content

Commit 7b34de6

Browse files
authored
Unify weight_scale_2 between gate_proj/up_proj (and w1/w3) in the HF export path for MOE models (#1033)
### What does this PR do? Unify `weight_scale_2` between `gate_proj/up_proj` (and `w1/w3`) in the HF export path for MOE models. Serving engines fuse these projections into a single `gate_up_proj` and require a shared scale; this takes the element-wise max of the two independent scales as a conservative choice that avoids overflow. Type of change: ? Bug fix ### Usage ```python # Add a code snippet demonstrating how to use this ``` ### Testing <!-- Mention how have you tested your change if applicable. --> ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ / N/A <!--- If ❌, explain why. --> - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ / ❌ / N/A <!--- Mandatory --> - Did you write any new necessary tests?: ✅ / ❌ / N/A <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Automatic synchronization of quantization scaling between Mixture-of-Experts gate and up projections during model export for non‑fused MoE setups (e.g., Qwen MoE, DeepSeek). * **Bug Fixes / Improvements** * Export now emits a brief notification when gate/up scaling values are adjusted to ensure consistent quantization. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
1 parent 1070d89 commit 7b34de6

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

modelopt/torch/export/layer_utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,55 @@ def set_expert_quantizer_amax(
11711171
return uncalibrated_modules
11721172

11731173

1174+
# Gate/up naming pairs for standard (unfused) MoE architectures.
1175+
# Fused variants (gate_up_proj, linear_fc1) already share a single quantizer and need no sync.
1176+
_GATE_UP_PAIRS = [("gate_proj", "up_proj"), ("w1", "w3")]
1177+
1178+
1179+
def sync_moe_gate_up_amax(model: nn.Module) -> int:
1180+
"""Take element-wise max of gate and up weight quantizer amaxes per expert.
1181+
1182+
Serving engines fuse gate_proj and up_proj into a single gate_up_proj and
1183+
require a single weight_scale_2. Since weight_scale_2 = amax / (6 * 448),
1184+
syncing amaxes before quantization ensures the per-block weight_scale values
1185+
are computed against a consistent global scale.
1186+
1187+
Only affects standard MoE models with separate gate/up linear layers
1188+
(e.g. Qwen MoE, DeepSeek). Models with already-fused gate_up_proj
1189+
(e.g. Llama4, GptOss) are unaffected.
1190+
1191+
Returns:
1192+
Number of expert gate/up pairs whose amaxes were synced.
1193+
"""
1194+
synced = 0
1195+
for _, sub_module in model.named_modules():
1196+
if not (is_moe(sub_module) and hasattr(sub_module, "experts")):
1197+
continue
1198+
if not hasattr(sub_module.experts, "__iter__"):
1199+
continue
1200+
for expert in sub_module.experts:
1201+
for gate_name, up_name in _GATE_UP_PAIRS:
1202+
gate_linear = getattr(expert, gate_name, None)
1203+
up_linear = getattr(expert, up_name, None)
1204+
if gate_linear is None or up_linear is None:
1205+
continue
1206+
gate_wq = getattr(gate_linear, "weight_quantizer", None)
1207+
up_wq = getattr(up_linear, "weight_quantizer", None)
1208+
if gate_wq is None or up_wq is None:
1209+
break
1210+
gate_amax = getattr(gate_wq, "amax", None)
1211+
up_amax = getattr(up_wq, "amax", None)
1212+
if gate_amax is None or up_amax is None:
1213+
break
1214+
if not torch.equal(gate_amax, up_amax):
1215+
shared_amax = torch.max(gate_amax, up_amax)
1216+
gate_wq.amax = shared_amax
1217+
up_wq.amax = shared_amax.clone()
1218+
synced += 1
1219+
break
1220+
return synced
1221+
1222+
11741223
def build_stacked_experts(
11751224
experts: nn.Module,
11761225
linear_names: list[str],

modelopt/torch/export/unified_export_hf.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
is_moe,
7474
is_quantlinear,
7575
set_expert_quantizer_amax,
76+
sync_moe_gate_up_amax,
7677
)
7778
from .model_config import (
7879
QUANTIZATION_FP8,
@@ -775,6 +776,18 @@ def _export_transformers_checkpoint(
775776
exclude_modules.append(pattern)
776777
print(f"Adding MTP layer to quantization_config ignore: {pattern}")
777778

779+
# Safety net: sync any gate/up weight quantizer amaxes that
780+
# requantize_resmooth_fused_llm_layers did not reach (e.g. experts not
781+
# activated during the dummy forward, or non-standard expert naming).
782+
synced = sync_moe_gate_up_amax(model)
783+
if synced:
784+
warnings.warn(
785+
f"Found {synced} MoE expert gate/up projection pair(s) with mismatched "
786+
f"weight_scale_2 after requantize_resmooth_fused_llm_layers. "
787+
f"This typically means the dummy forward did not activate these experts. "
788+
f"Taking element-wise max of amaxes for serving-engine fusion."
789+
)
790+
778791
# Process all quantized modules and export weights
779792
_process_quantized_modules(model, dtype, is_modelopt_qlora)
780793

0 commit comments

Comments
 (0)