@@ -300,18 +300,15 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
300300
301301 quantization_format = get_quantization_format (module )
302302
303- # Handle NVFP4 variants (static or dynamic)
304- is_nvfp4_static = isinstance (weight_quantizer , NVFP4StaticQuantizer )
305- if is_nvfp4_static or quantization_format in [
303+ if quantization_format in [
306304 QUANTIZATION_NVFP4 ,
307305 QUANTIZATION_NVFP4_AWQ ,
308306 QUANTIZATION_NVFP4_SVDQUANT ,
309307 QUANTIZATION_W4A8_NVFP4_FP8 ,
310308 ]:
311- # Calibrate weight quantizer if amax is not set (only needed for dynamic quantizers)
312- if not is_nvfp4_static :
313- module_name = f"{ type (module ).__name__ } .{ weight_name } "
314- _ensure_weight_quantizer_calibrated (weight_quantizer , weight , module_name )
309+ # Calibrate weight quantizer if amax is not set
310+ module_name = f"{ type (module ).__name__ } .{ weight_name } "
311+ _ensure_weight_quantizer_calibrated (weight_quantizer , weight , module_name )
315312
316313 if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8 :
317314 # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
@@ -347,18 +344,16 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
347344
348345 quantization_format = get_quantization_format (module )
349346
350- is_nvfp4_static = isinstance (weight_quantizer , NVFP4StaticQuantizer )
351- if is_nvfp4_static or quantization_format in [
347+ if quantization_format in [
352348 QUANTIZATION_NVFP4 ,
353349 QUANTIZATION_NVFP4_AWQ ,
354350 QUANTIZATION_NVFP4_SVDQUANT ,
355351 QUANTIZATION_W4A8_NVFP4_FP8 ,
356352 ]:
357- # Calibrate weight quantizer if amax is not set (only needed for dynamic quantizers)
358- if not is_nvfp4_static :
359- weight = getattr (module , weight_name )
360- module_name = f"{ type (module ).__name__ } .{ weight_name } "
361- _ensure_weight_quantizer_calibrated (weight_quantizer , weight , module_name )
353+ # Calibrate weight quantizer if amax is not set
354+ weight = getattr (module , weight_name )
355+ module_name = f"{ type (module ).__name__ } .{ weight_name } "
356+ _ensure_weight_quantizer_calibrated (weight_quantizer , weight , module_name )
362357
363358 if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8 :
364359 # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
0 commit comments