@@ -236,6 +236,31 @@ def get_scaling_factor(quantizer: TensorQuantizer) -> torch.Tensor:
236236 return scaling_factor
237237
238238
239+ def _ensure_weight_quantizer_calibrated (
240+ weight_quantizer : TensorQuantizer , weight : torch .Tensor , module_name : str = ""
241+ ) -> None :
242+ """Calibrate weight quantizer if amax is not set.
243+
244+ This is a lazy calibration pattern used during export when weight quantizers
245+ may not have been calibrated during the main calibration phase.
246+
247+ Args:
248+ weight_quantizer: The weight quantizer to calibrate
249+ weight: The weight tensor to use for calibration
250+ module_name: Optional module name for better warning messages
251+ """
252+ if not hasattr (weight_quantizer , "_amax" ) or weight_quantizer ._amax is None :
253+ warn (
254+ f"Weight quantizer{ f' for { module_name } ' if module_name else '' } was not calibrated. "
255+ f"Computing amax from weights. This may occur if: "
256+ f"some experts were not activated during calibration (expected for MoE models), try increasing --calib_size"
257+ )
258+ weight_quantizer .reset_amax ()
259+ enable_stats_collection (weight_quantizer )
260+ weight_quantizer (weight )
261+ finish_stats_collection (weight_quantizer )
262+
263+
239264def get_activation_scaling_factor (
240265 module : nn .Module , input_quantizer_name : str = "input_quantizer"
241266) -> torch .Tensor :
@@ -279,6 +304,10 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
279304 QUANTIZATION_NVFP4_SVDQUANT ,
280305 QUANTIZATION_W4A8_NVFP4_FP8 ,
281306 ]:
307+ # Calibrate weight quantizer if amax is not set
308+ module_name = f"{ type (module ).__name__ } .{ weight_name } "
309+ _ensure_weight_quantizer_calibrated (weight_quantizer , weight , module_name )
310+
282311 if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8 :
283312 # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
284313 # This is because the kernel dequantizes weight to fp8, which is in range 448.
@@ -307,13 +336,26 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
307336 if weight_quantizer is None :
308337 return None
309338
310- if get_quantization_format (module ) in [
339+ quantization_format = get_quantization_format (module )
340+
341+ # Calibrate weight quantizer if amax is not set for all NVFP4 variants
342+ if quantization_format in [
343+ QUANTIZATION_NVFP4 ,
344+ QUANTIZATION_NVFP4_AWQ ,
345+ QUANTIZATION_NVFP4_SVDQUANT ,
346+ QUANTIZATION_W4A8_NVFP4_FP8 ,
347+ ]:
348+ weight = getattr (module , weight_name )
349+ module_name = f"{ type (module ).__name__ } .{ weight_name } "
350+ _ensure_weight_quantizer_calibrated (weight_quantizer , weight , module_name )
351+
352+ if quantization_format in [
311353 QUANTIZATION_NVFP4 ,
312354 QUANTIZATION_NVFP4_AWQ ,
313355 QUANTIZATION_NVFP4_SVDQUANT ,
314356 ]:
315357 return NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (weight_quantizer )
316- elif get_quantization_format ( module ) == QUANTIZATION_W4A8_NVFP4_FP8 :
358+ elif quantization_format == QUANTIZATION_W4A8_NVFP4_FP8 :
317359 # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
318360 # This is because the kernel dequantizes weight to fp8, which is in range 448.
319361 return weight_quantizer ._amax .float () / 448.0
0 commit comments