@@ -563,26 +563,37 @@ def quant_func(x, amax, quantizer=weight_quantizer):
563563
564564 return xq
565565
566- is_nvfp4_per_block = (
567- fp8_scale_sweep
568- and weight_quantizer .is_static_block_quant
566+ is_nvfp4_static = (
567+ weight_quantizer .is_static_block_quant
569568 and weight_quantizer ._num_bits == (2 , 1 )
570569 and weight_quantizer ._block_sizes is not None
571570 and weight_quantizer ._block_sizes .get ("scale_bits" ) == (4 , 3 )
572571 )
573572
573+ if is_nvfp4_static :
574+ global_amax = reduce_amax (initial_amax , axis = None )
575+ NVFP4StaticQuantizer .from_tensor_quantizer (weight_quantizer , global_amax = global_amax )
576+
574577 error_func = helper .get_error_func ()
575578
576- weight_quantizer ._calibrator = MseCalibrator (
577- amax = initial_amax ,
578- axis = weight_quantizer ._calibrator ._axis if weight_quantizer ._calibrator else None ,
579- step_size = step_size ,
580- start_multiplier = start_multiplier ,
581- stop_multiplier = stop_multiplier ,
582- quant_func = quant_func ,
583- error_func = error_func ,
584- fp8_scale_sweep = is_nvfp4_per_block ,
585- )
579+ if fp8_scale_sweep and is_nvfp4_static :
580+ weight_quantizer ._calibrator = NVFP4MSECalibrator (
581+ amax = initial_amax ,
582+ axis = weight_quantizer ._calibrator ._axis if weight_quantizer ._calibrator else None ,
583+ global_amax = weight_quantizer .global_amax ,
584+ quant_func = quant_func ,
585+ error_func = error_func ,
586+ )
587+ else :
588+ weight_quantizer ._calibrator = MseCalibrator (
589+ amax = initial_amax ,
590+ axis = weight_quantizer ._calibrator ._axis if weight_quantizer ._calibrator else None ,
591+ step_size = step_size ,
592+ start_multiplier = start_multiplier ,
593+ stop_multiplier = stop_multiplier ,
594+ quant_func = quant_func ,
595+ error_func = error_func ,
596+ )
586597
587598 # Calibrate weights with local Hessian MSE
588599 for name , module in weight_quantizers_info :
0 commit comments