Skip to content

Commit 2931f61

Browse files
committed
add reviewers feedback
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent 21aa442 commit 2931f61

2 files changed

Lines changed: 20 additions & 32 deletions

File tree

modelopt/torch/quantization/config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@
388388
"algorithm": "max",
389389
}
390390

391-
NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG = {
391+
NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG = {
392392
"quant_cfg": {
393393
"*weight_quantizer": {
394394
"num_bits": (2, 1),
@@ -397,7 +397,10 @@
397397
"enable": True,
398398
},
399399
"*input_quantizer": {
400-
"enable": False,
400+
"num_bits": (2, 1),
401+
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
402+
"axis": None,
403+
"enable": True,
401404
},
402405
**_default_disabled_quantizer_cfg,
403406
},

modelopt/torch/quantization/model_calib.py

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -387,10 +387,12 @@ def local_hessian_calibrate(
387387
):
388388
"""Calibrate the model using local Hessian-weighted MSE search.
389389
390-
This calibration method collects input activations during forward pass, computes
391-
per-block local Hessian matrices (H = X @ X.T), and uses them to weight the
392-
MSE loss for scale selection. This minimizes output reconstruction error rather
393-
than weight reconstruction error.
390+
Instead of minimizing weight error ||W - Wq||², this minimizes Hessian-weighted error:
391+
loss = (W - Wq)ᵀ H (W - Wq)
392+
where H = X @ X.T approximates output reconstruction error ||WX - WqX||².
393+
394+
Per-block Hessians of shape (cin // block_size, block_size, block_size) are accumulated
395+
during forward pass and used to weight the MSE loss during scale search.
394396
395397
Args:
396398
model: Model to be calibrated.
@@ -512,6 +514,11 @@ def forward(self, input, *args, **kwargs):
512514

513515
return self._forward_no_local_hessian(input, *args, **kwargs)
514516

517+
# First, run max_calibrate on the whole model to get initial amax for all quantizers
518+
# This calibrates both weight_quantizer and input_quantizer with max calibration
519+
print_rank_0("local_hessian: Running max calibration for all quantizers...")
520+
max_calibrate(model, forward_loop, distributed_sync)
521+
515522
# Setup helpers for all quantized linear modules
516523
name_to_module = dict(model.named_modules())
517524
weight_quantizers_info = []
@@ -531,12 +538,6 @@ def forward(self, input, *args, **kwargs):
531538

532539
# TODO(fridah-nv): Sync Hessian across distributed processes if needed
533540

534-
# Get initial amax using max calibration on weights
535-
print_rank_0("local_hessian: Computing initial amax with max calibration...")
536-
for name, module in weight_quantizers_info:
537-
with enable_weight_access_and_writeback(module, model, name_to_module):
538-
max_calibrate(module, lambda m: m.weight_quantizer(m.weight), distributed_sync)
539-
540541
# Replace calibrators with MseCalibrator using local Hessian error function
541542
print_rank_0("local_hessian: Running MSE calibration with local Hessian loss...")
542543
for name, module in weight_quantizers_info:
@@ -608,34 +609,18 @@ def quant_func(x, amax, quantizer=weight_quantizer):
608609
if weight_quantizer._calibrator is None:
609610
continue
610611

611-
weight_quantizer.disable_quant()
612-
weight_quantizer.enable_calib()
613-
612+
# Enable calibration mode for the weight quantizer
613+
enable_stats_collection(module)
614614
with enable_weight_access_and_writeback(module, model, name_to_module):
615615
weight = module.weight
616616
weight_quantizer(weight)
617-
618-
# Compute optimal amax and load it
619-
for name, module in weight_quantizers_info:
620-
weight_quantizer = module.weight_quantizer
621-
if weight_quantizer._calibrator is None:
622-
continue
623-
624-
cal = weight_quantizer._calibrator
625-
if cal.compute_amax() is not None:
626-
weight_quantizer.load_calib_amax()
627-
628-
weight_quantizer.enable_quant()
629-
weight_quantizer.disable_calib()
617+
finish_stats_collection(module, method="mse")
618+
weight_quantizer._calibrator.reset()
630619

631620
# Cleanup and free memory
632621
LocalHessianHelper.cache_mode = False
633622
for name, module in weight_quantizers_info:
634623
module.local_hessian.cleanup()
635-
if hasattr(module.weight_quantizer, "_calibrator"):
636-
cal = module.weight_quantizer._calibrator
637-
if hasattr(cal, "clear"):
638-
cal.clear()
639624

640625
print_rank_0("local_hessian: Calibration complete.")
641626

0 commit comments

Comments
 (0)