Skip to content

Commit 8b5da94

Browse files
committed
se NVFP4StaticQuantizer and NVFP4MSECalibrator
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent fa20e3b commit 8b5da94

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -563,26 +563,37 @@ def quant_func(x, amax, quantizer=weight_quantizer):
563563

564564
return xq
565565

566-
is_nvfp4_per_block = (
567-
fp8_scale_sweep
568-
and weight_quantizer.is_static_block_quant
566+
is_nvfp4_static = (
567+
weight_quantizer.is_static_block_quant
569568
and weight_quantizer._num_bits == (2, 1)
570569
and weight_quantizer._block_sizes is not None
571570
and weight_quantizer._block_sizes.get("scale_bits") == (4, 3)
572571
)
573572

573+
if is_nvfp4_static:
574+
global_amax = reduce_amax(initial_amax, axis=None)
575+
NVFP4StaticQuantizer.from_tensor_quantizer(weight_quantizer, global_amax=global_amax)
576+
574577
error_func = helper.get_error_func()
575578

576-
weight_quantizer._calibrator = MseCalibrator(
577-
amax=initial_amax,
578-
axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None,
579-
step_size=step_size,
580-
start_multiplier=start_multiplier,
581-
stop_multiplier=stop_multiplier,
582-
quant_func=quant_func,
583-
error_func=error_func,
584-
fp8_scale_sweep=is_nvfp4_per_block,
585-
)
579+
if fp8_scale_sweep and is_nvfp4_static:
580+
weight_quantizer._calibrator = NVFP4MSECalibrator(
581+
amax=initial_amax,
582+
axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None,
583+
global_amax=weight_quantizer.global_amax,
584+
quant_func=quant_func,
585+
error_func=error_func,
586+
)
587+
else:
588+
weight_quantizer._calibrator = MseCalibrator(
589+
amax=initial_amax,
590+
axis=weight_quantizer._calibrator._axis if weight_quantizer._calibrator else None,
591+
step_size=step_size,
592+
start_multiplier=start_multiplier,
593+
stop_multiplier=stop_multiplier,
594+
quant_func=quant_func,
595+
error_func=error_func,
596+
)
586597

587598
# Calibrate weights with local Hessian MSE
588599
for name, module in weight_quantizers_info:

0 commit comments

Comments
 (0)