Skip to content

Commit 3fcc5a7

Browse files
committed
Fix non-scalar input amax in preprocess_linear_fusion for MoE export
preprocess_linear_fusion unconditionally asserts `modules[0].input_quantizer.amax.numel() == 1`, which breaks for NVFP4 quantization when the model has per-expert-decomposed MoE linears (gate_proj/up_proj pairs per expert). NVFP4's per-channel input quantizer produces a vector amax, not a scalar, so the assertion trips immediately on the first expert during `export_hf_checkpoint()`. Root cause: the function was written assuming fused linears have per-tensor scalar input amax. That's true for dense FP8/INT8 paths but false for NVFP4's per-channel activation statistics, which modelopt's own NVFP4_AWQ_FULL_CFG produces. This change: - Keeps the existing scalar-amax path (dense + FP8/INT8 unchanged) - Adds a non-scalar path using elementwise max (`.amax(dim=0)`) across the stacked per-channel amax tensors of the modules being fused Numerical correctness for the MoE case: the modules being fused here (e.g. gate_proj and up_proj of one expert) consume the *same* input tensor by construction, so their per-channel input amax tensors are identical. Elementwise max is therefore a no-op, and is the correct unification rule if they ever differ due to floating-point accumulation. Validated end-to-end on SuperGemma4 26B (128-expert MoE) with NVFP4_AWQ_FULL_CFG; export now completes and the serialized checkpoint loads + generates correctly. Before: export failed with `AssertionError: Only support scalar input quant amax` after 2h 24min of successful calibration. Signed-off-by: AEON-7 <m2vgz48wpp@privaterelay.appleid.com>
1 parent c9b1155 commit 3fcc5a7

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,11 +1375,21 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False
13751375
return
13761376

13771377
if modules[0].input_quantizer.is_enabled and modules[0].input_quantizer.amax is not None:
1378-
assert modules[0].input_quantizer.amax.numel() == 1, (
1379-
"Only support scalar input quant amax"
1380-
)
1381-
1382-
input_amax = torch.max(torch.stack([module.input_quantizer.amax for module in modules]))
1378+
if modules[0].input_quantizer.amax.numel() == 1:
1379+
# Scalar amax (e.g. dense layers with per-tensor activation quant):
1380+
# unify via scalar max across the modules being fused.
1381+
input_amax = torch.max(
1382+
torch.stack([module.input_quantizer.amax for module in modules])
1383+
)
1384+
else:
1385+
# Non-scalar amax (e.g. NVFP4 per-channel input quantizer on
1386+
# per-expert-decomposed MoE). Modules being fused here share the
1387+
# same input tensor, so their per-channel amax vectors are
1388+
# identical by construction. Elementwise max is a no-op in that
1389+
# case and is the correct unification rule if they ever differ.
1390+
input_amax = torch.stack(
1391+
[module.input_quantizer.amax for module in modules]
1392+
).amax(dim=0)
13831393
for module in modules:
13841394
module.input_quantizer.amax = input_amax
13851395

0 commit comments

Comments
 (0)