diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 4ceb51cd2c..1202eeac39 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -1375,11 +1375,21 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False return if modules[0].input_quantizer.is_enabled and modules[0].input_quantizer.amax is not None: - assert modules[0].input_quantizer.amax.numel() == 1, ( - "Only support scalar input quant amax" - ) - - input_amax = torch.max(torch.stack([module.input_quantizer.amax for module in modules])) + if modules[0].input_quantizer.amax.numel() == 1: + # Scalar amax (e.g. dense layers with per-tensor activation quant): + # unify via scalar max across the modules being fused. + input_amax = torch.max( + torch.stack([module.input_quantizer.amax for module in modules]) + ) + else: + # Non-scalar amax (e.g. NVFP4 per-channel input quantizer on + # per-expert-decomposed MoE). Modules being fused here share the + # same input tensor, so their per-channel amax vectors are + # identical by construction. Elementwise max is a no-op in that + # case and is the correct unification rule if they ever differ. + input_amax = torch.stack( + [module.input_quantizer.amax for module in modules] + ).amax(dim=0) for module in modules: module.input_quantizer.amax = input_amax