Skip to content

Commit 86c2912

Browse files
Fridah-nvjenchen13
authored andcommitted
fixes for fused moe (qwen3.6, GLM5.1 + MSE calibration (#1382)
### What does this PR do? Type of change: Bug fix <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> Fixes several issues with NVFP4 MSE calibration and export for fused MoE expert modules (_QuantFusedExperts — used by Qwen3.6, GLM-5.1, and other HF transformers 5.0+ models that store expert weights as 3-D nn.Parameters). - Bug 1 — MSE weight calibration runs 0 iterations for fused experts (model_calib.py) The weight-quantizer discovery loop in mse_calibrate used the singular attribute name gate_up_proj_weight_quantizer to look up quantizers, but _QuantFusedExperts stores them in a plural nn.ModuleList named gate_up_proj_weight_quantizers. All 20,480 expert quantizers were silently skipped, resulting in "MSE weight calibration: 0it" and no MSE-optimized scales. Fix: add a second pass that detects plural {param}_weight_quantizers ModuleLists and enqueues each per-expert quantizer with a (param_name, expert_idx) tuple; step 3 unpacks the tuple to extract the per-expert weight slice. - Bug 2 — Zero weight scales in exported checkpoint (nvfp4_tensor.py) Per-block weight scales can silently underflow to 0 when cast to FP8 E4M3FN. The existing scale == 0 guard only catches exact float32 zeros; values in (0, 2^-9) pass through and become 0 after the FP8 cast. This affects both the dynamic recompute path (get_weights_scaling_factor) and the static calibrated path (get_weights_scaling_factor_from_quantizer). Fix: clamp per-block scales to 2^-9 (smallest positive FP8 E4M3FN subnormal) before the FP8 cast in both paths. - Bug 3 — Zero/corrupt amax for uncalibrated experts at export (moe_utils.py) Experts that receive no tokens during calibration have _amax = 0 or uninitialized values. The existing scalar fallback used 1e-4 which itself underflows to 0 in FP8 E4M3FN (1e-4 < 2^-9 ≈ 0.00195). Additionally, the per-block fallback tensor had shape (H*W, 1) instead of (H, W), causing a shape mismatch that silently bypassed the fallback and fell through to the bad scalar. Finally, a stale zero global_amax from an uncalibrated expert was not recomputed, causing division-by-zero in the FP8 scale formula. Fix: reshape the per-block fallback correctly; raise the clamp floor to 2e-3; always recompute global_amax from the current (possibly patched) per-block _amax. Additional fixes: - moe_utils.py: safe CPU extraction of _amax before deepcopy to avoid async CUDA errors from corrupt bfloat16 amax storage on under-calibrated experts. - model_quant.py: print_quant_summary now calls os.makedirs(output_dir, exist_ok=True) before writing .quant_summary.txt, preventing a FileNotFoundError when the export directory doesn't exist yet. - tensor_quantizer.py: change default format in _short_amax / _short_tensor from ".4f" to ".2e" so small amax values (e.g. 3.5e-7) display as 3.50e-07 instead of 0.0000. - hf_ptq.py: strip leading pad tokens from the preview input and add skip_special_tokens=True to input_decode, fixing degenerate pre/post-PTQ output on models that use EOS as the pad token (e.g. Qwen3). ### Usage ```python # Quantize Qwen3.6-35B-A3B (or any compatible fused-expert MoE) with the new recipe: python examples/llm_ptq/hf_ptq.py \ --pyt_ckpt_path /path/to/Qwen3.6-35B-A3B \ --recipe modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml \ --export_path /path/to/output \ --calib_size 512 --calib_seq 2048 ``` ### Testing validated on Qwen3.6-35B-A3B (8× B200): - 21,740 quantizers inserted; 20,480/20,480 MSE weight calibrations completed (~11 min) - 0 / 2,013,265,920 zero weight_scale entries in the exported checkpoint (3 shards) - Pre- and post-PTQ generation produce coherent, semantically consistent output ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ / ❌ / N/A <!--- Mandatory --> - Did you write any new necessary tests?: ✅ / ❌ / N/A <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a new NVFP4 quantization recipe for expert layers with MSE-based calibration. * **Bug Fixes** * Fixed FP8 scale underflow handling to prevent zero scaling factors. * Fixed output directory creation for quantization summaries. * **Improvements** * Enhanced preview input handling for language models by removing padding tokens. * Improved quantizer display precision for better readability. [![Review Change Stack](https://storage.googleapis.com/coderabbit_public_assets/review-stack-in-coderabbit-ui.svg)](https://app.coderabbit.ai/change-stack/NVIDIA/Model-Optimizer/pull/1382) <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent 0be7138 commit 86c2912

6 files changed

Lines changed: 174 additions & 32 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def finish_stats_collection(model: nn.Module, method: str | None = None, **kwarg
864864

865865
cal = getattr(module, "_calibrator", None)
866866
if cal and not getattr(module, "_dynamic", False):
867-
if method in {"entropy"}:
867+
if method == "entropy":
868868
if cal.compute_amax(method) is not None:
869869
module.load_calib_amax("entropy", **kwargs)
870870
elif cal.compute_amax(**kwargs) is not None:

modelopt/torch/quantization/model_quant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,7 @@ def print_quant_summary(model: nn.Module, output_dir: str | None = None):
595595
lines.append(f"{len(lines)} TensorQuantizers found in model")
596596

597597
if output_dir:
598+
os.makedirs(output_dir, exist_ok=True)
598599
path = os.path.join(output_dir, ".quant_summary.txt")
599600
with open(path, "w", encoding="utf-8") as f:
600601
f.write("\n".join(lines) + "\n")

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,7 +1122,7 @@ def forward(self, inputs):
11221122

11231123
return outputs
11241124

1125-
def _short_amax(self, fmt=".4f"):
1125+
def _short_amax(self, fmt=".2e"):
11261126
"""Short description of amax.
11271127
11281128
Returns:
@@ -1140,7 +1140,7 @@ def _short_amax(self, fmt=".4f"):
11401140
return "meta"
11411141
return self._short_tensor(self._amax, fmt)
11421142

1143-
def _short_tensor(self, tensor: torch.Tensor, fmt=".4f"):
1143+
def _short_tensor(self, tensor: torch.Tensor, fmt=".2e"):
11441144
"""Short description of tensor."""
11451145
if tensor.numel() == 1:
11461146
return f"{tensor.item():{fmt}}"

modelopt/torch/quantization/qtensor/nvfp4_tensor.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@
2828
__all__ = ["NVFP4QTensor"]
2929

3030

31+
def _cast_per_block_scale_to_fp8(
32+
per_block_scale: torch.Tensor,
33+
per_block_scale_max: torch.Tensor | None = None,
34+
) -> torch.Tensor:
35+
"""Clamp to FP8 E4M3FN range [2**-9, 448] and cast — avoids underflow→0 / overflow→NaN.
36+
37+
When ``per_block_scale_max`` is provided, first rescales as
38+
``per_block_scale.float() * 448 / per_block_scale_max`` — the static-export
39+
path needs this because the ``[==0]=1.0`` safety net combined with a small
40+
``global_amax`` can drive the rescaled value above 448 (see PR #1397).
41+
"""
42+
if per_block_scale_max is not None:
43+
per_block_scale = per_block_scale.float() * 448.0 / per_block_scale_max
44+
return per_block_scale.clamp(min=2**-9, max=448.0).to(torch.float8_e4m3fn)
45+
46+
3147
class NVFP4QTensor(BaseQuantizedTensor):
3248
"""Implements the INT4 quantization on tensors for more efficient storage or computation.
3349
@@ -122,17 +138,8 @@ def get_weights_scaling_factor_from_quantizer(
122138
expected_shape = (*weight.shape[:-1], num_blocks_per_row)
123139
per_block_scale = per_block_scale.view(expected_shape)
124140

125-
# Quantize scales to FP8. Saturate to the fp8_e4m3fn max (448) before the
126-
# cast: when the [==0]=1.0 safety net above fires (per_block_amax was zero
127-
# for an all-zero weight block) and global_amax is small, the pre-cast value
128-
# explodes to ``1.0 * 448 / (global_amax/6)``. fp8_e4m3fn has no Inf, so any
129-
# value >= 480 casts to NaN — clamp first to keep the stored byte finite.
130141
if not keep_high_precision:
131-
per_block_scale = (
132-
(per_block_scale * 448.0 / per_block_scale_max)
133-
.clamp_(max=448.0)
134-
.to(torch.float8_e4m3fn)
135-
)
142+
per_block_scale = _cast_per_block_scale_to_fp8(per_block_scale, per_block_scale_max)
136143
return per_block_scale, weights_scaling_factor_2
137144
else:
138145
# Dynamic path: compute from weight tensor
@@ -171,9 +178,8 @@ def get_weights_scaling_factor(
171178
)
172179
# Set all zero values in scale to 1.0
173180
per_block_scale[per_block_scale == 0] = 1.0
174-
# Convert to torch.float8_e4m3fn
175181
if not keep_high_precision:
176-
per_block_scale = per_block_scale.to(torch.float8_e4m3fn)
182+
per_block_scale = _cast_per_block_scale_to_fp8(per_block_scale)
177183
return per_block_scale, weights_scaling_factor_2
178184

179185
@classmethod

tests/unit/torch/quantization/plugins/test_fused_experts.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -256,27 +256,51 @@ def test_expert_index_recovery(self):
256256
# Tests for export
257257
# ---------------------------------------------------------------------------
258258
class TestExportFusedExperts:
259+
@staticmethod
260+
def _cleanup_registry(mod_type):
261+
if QuantModuleRegistry.get(mod_type) is not None:
262+
QuantModuleRegistry.unregister(mod_type)
263+
259264
def test_export_creates_per_expert_submodules(self):
260265
"""_export_fused_experts should create per-expert submodules with standard naming."""
266+
import modelopt.torch.quantization as mtq
261267
from modelopt.torch.export.moe_utils import _export_fused_experts
262268

263-
experts = _SyntheticFusedExperts()
264-
expert_type = type(experts)
269+
model = _TinyMoEModel()
270+
expert_type = type(model.moe.experts)
271+
self._cleanup_registry(expert_type)
265272

266-
# Manually register and convert
267-
if QuantModuleRegistry.get(expert_type) is None:
268-
QuantModuleRegistry.register({expert_type: "test.SyntheticFusedExperts"})(
269-
_QuantFusedExperts
270-
)
271-
converted = QuantModuleRegistry.convert(experts)
273+
quant_cfg = {
274+
"quant_cfg": [
275+
{"quantizer_name": "*", "enable": False},
276+
{
277+
"quantizer_name": "*gate_up_proj_input_quantizer",
278+
"cfg": {"num_bits": 8, "axis": None},
279+
},
280+
{
281+
"quantizer_name": "*down_proj_input_quantizer",
282+
"cfg": {"num_bits": 8, "axis": None},
283+
},
284+
{
285+
"quantizer_name": "*gate_up_proj_weight_quantizer",
286+
"cfg": {"num_bits": 8, "axis": 0},
287+
},
288+
{
289+
"quantizer_name": "*down_proj_weight_quantizer",
290+
"cfg": {"num_bits": 8, "axis": 0},
291+
},
292+
],
293+
"algorithm": "max",
294+
}
272295

273-
# Run a forward pass to calibrate (set amaxes)
274-
seq_len = 16
275-
hidden_states = torch.randn(seq_len, HIDDEN_DIM)
276-
top_k_index = torch.randint(0, NUM_EXPERTS, (seq_len, TOP_K))
277-
top_k_weights = torch.softmax(torch.randn(seq_len, TOP_K), dim=-1)
278-
with torch.no_grad():
279-
converted(hidden_states, top_k_index, top_k_weights)
296+
def forward_loop(m):
297+
torch.manual_seed(0)
298+
for _ in range(2):
299+
x = torch.randn(1, 4, HIDDEN_DIM)
300+
m(x)
301+
302+
mtq.quantize(model, quant_cfg, forward_loop=forward_loop)
303+
converted = model.moe.experts
280304

281305
_export_fused_experts(converted, torch.float16)
282306

@@ -297,8 +321,7 @@ def test_export_creates_per_expert_submodules(self):
297321
assert not hasattr(converted, "down_proj")
298322
assert not hasattr(converted, "gate_up_proj_weight_quantizers")
299323

300-
if QuantModuleRegistry.get(expert_type) is not None:
301-
QuantModuleRegistry.unregister(expert_type)
324+
self._cleanup_registry(expert_type)
302325

303326
def test_uncalibrated_expert_gate_up_share_amax(self, monkeypatch):
304327
"""gate_proj and up_proj must share weight_scale_2 even when an expert
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for NVFP4QTensor per-block FP8 scale clamping (underflow + overflow)."""
17+
18+
from types import SimpleNamespace
19+
20+
import torch
21+
22+
from modelopt.torch.quantization.qtensor.nvfp4_tensor import (
23+
NVFP4QTensor,
24+
_cast_per_block_scale_to_fp8,
25+
)
26+
27+
_FP8_E4M3FN_MIN = 2**-9 # 0.001953125 — smallest positive FP8 E4M3FN subnormal
28+
_FP8_E4M3FN_MAX = 448.0
29+
30+
31+
class TestNVFP4ScaleClamping:
32+
"""Per-block weight scales outside the FP8 E4M3FN range must be clamped, not turned into 0/NaN."""
33+
34+
def test_no_zero_scales_for_tiny_weights(self):
35+
"""Tiny per-block amax (<<FP8 min) must not underflow to zero after FP8 cast."""
36+
block_size = 16
37+
tiny_weight = torch.full((4, block_size), 1e-10)
38+
# wsf2=1.0 → per_block_scale = amax/(6*wsf2) ≈ 1.7e-11 << 2^-9, exercises FP8-min clamp
39+
wsf2 = torch.tensor(1.0)
40+
41+
per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(tiny_weight, block_size, wsf2)
42+
per_block_scale_f32 = per_block_scale.float()
43+
44+
assert (per_block_scale_f32 > 0).all(), (
45+
f"Zero per-block scales found after FP8 cast: {per_block_scale_f32.tolist()}. "
46+
"FP8 scale underflow clamping likely regressed."
47+
)
48+
assert (per_block_scale_f32 >= _FP8_E4M3FN_MIN).all(), (
49+
"Per-block scales with zero values found after FP8 cast "
50+
"(below the FP8 E4M3FN subnormal minimum — clamp would have prevented this)."
51+
)
52+
53+
def test_normal_weights_unaffected_by_clamp(self):
54+
"""Weights with typical magnitudes must not be affected by the underflow clamp."""
55+
block_size = 16
56+
torch.manual_seed(42)
57+
normal_weight = torch.randn(8, block_size)
58+
59+
per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(normal_weight, block_size)
60+
assert (per_block_scale.float() > 0).all(), "Normal weights produced zero scales."
61+
62+
def test_mixed_weight_no_zeros(self):
63+
"""Mixed-magnitude tensor (normal + tiny blocks) must have no zero scales."""
64+
block_size = 16
65+
weight = torch.cat(
66+
[
67+
torch.randn(4, block_size),
68+
torch.full((4, block_size), 1e-12),
69+
],
70+
dim=0,
71+
)
72+
73+
per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor(weight, block_size)
74+
assert (per_block_scale.float() > 0).all(), (
75+
"Zero scales in mixed-magnitude tensor after FP8 cast."
76+
)
77+
78+
def test_helper_clamps_overflow_to_max(self):
79+
"""Values above 448 must saturate to 448, not cast to NaN (fp8_e4m3fn has no Inf)."""
80+
oversized = torch.tensor([100.0, 448.0, 1e3, 1e6])
81+
out = _cast_per_block_scale_to_fp8(oversized).float()
82+
assert torch.isfinite(out).all(), f"FP8 cast produced non-finite values: {out.tolist()}"
83+
assert (out <= _FP8_E4M3FN_MAX).all(), f"FP8 cast values exceed 448: {out.tolist()}"
84+
85+
def test_helper_clamps_underflow_to_min(self):
86+
"""Values below the FP8 subnormal must clamp up, not collapse to 0."""
87+
tiny = torch.tensor([0.0, 1e-12, 1e-6, _FP8_E4M3FN_MIN / 2])
88+
out = _cast_per_block_scale_to_fp8(tiny).float()
89+
assert (out > 0).all(), f"FP8 cast produced zero scales: {out.tolist()}"
90+
91+
def test_static_path_no_nan_when_block_amax_zero(self):
92+
"""Static path: zero-amax block + small global_amax must clamp to 448, not cast to NaN."""
93+
block_size = 16
94+
# global_amax small enough that 1.0 * 448 / (global_amax/6) >> 448.
95+
global_amax = torch.tensor(0.01)
96+
# One block with amax=0 (triggers safety net), three normal blocks.
97+
per_block_amax = torch.tensor([[0.0, 0.005, 0.008, 0.01]])
98+
weight = torch.randn(1, 4 * block_size)
99+
q = SimpleNamespace(
100+
global_amax=global_amax,
101+
_amax=per_block_amax,
102+
block_sizes={-1: block_size},
103+
)
104+
105+
per_block_scale, _ = NVFP4QTensor.get_weights_scaling_factor_from_quantizer(q, weight)
106+
per_block_scale_f32 = per_block_scale.float()
107+
assert torch.isfinite(per_block_scale_f32).all(), (
108+
f"NaN/Inf in exported static per-block scale: {per_block_scale_f32.tolist()}"
109+
)
110+
assert (per_block_scale_f32 <= _FP8_E4M3FN_MAX).all(), (
111+
f"Static per-block scale exceeds FP8 max 448: {per_block_scale_f32.tolist()}"
112+
)

0 commit comments

Comments
 (0)