@@ -387,10 +387,12 @@ def local_hessian_calibrate(
387387):
388388 """Calibrate the model using local Hessian-weighted MSE search.
389389
390- This calibration method collects input activations during forward pass, computes
391- per-block local Hessian matrices (H = X @ X.T), and uses them to weight the
392- MSE loss for scale selection. This minimizes output reconstruction error rather
393- than weight reconstruction error.
390+ Instead of minimizing weight error ||W - Wq||², this minimizes Hessian-weighted error:
391+ loss = (W - Wq)ᵀ H (W - Wq)
392+ where H = X @ X.T approximates output reconstruction error ||WX - WqX||².
393+
394+ Per-block Hessians of shape (cin // block_size, block_size, block_size) are accumulated
395+ during forward pass and used to weight the MSE loss during scale search.
394396
395397 Args:
396398 model: Model to be calibrated.
@@ -512,6 +514,11 @@ def forward(self, input, *args, **kwargs):
512514
513515 return self ._forward_no_local_hessian (input , * args , ** kwargs )
514516
517+ # First, run max_calibrate on the whole model to get initial amax for all quantizers
518+ # This calibrates both weight_quantizer and input_quantizer with max calibration
519+ print_rank_0 ("local_hessian: Running max calibration for all quantizers..." )
520+ max_calibrate (model , forward_loop , distributed_sync )
521+
515522 # Setup helpers for all quantized linear modules
516523 name_to_module = dict (model .named_modules ())
517524 weight_quantizers_info = []
@@ -531,12 +538,6 @@ def forward(self, input, *args, **kwargs):
531538
532539 # TODO(fridah-nv): Sync Hessian across distributed processes if needed
533540
534- # Get initial amax using max calibration on weights
535- print_rank_0 ("local_hessian: Computing initial amax with max calibration..." )
536- for name , module in weight_quantizers_info :
537- with enable_weight_access_and_writeback (module , model , name_to_module ):
538- max_calibrate (module , lambda m : m .weight_quantizer (m .weight ), distributed_sync )
539-
540541 # Replace calibrators with MseCalibrator using local Hessian error function
541542 print_rank_0 ("local_hessian: Running MSE calibration with local Hessian loss..." )
542543 for name , module in weight_quantizers_info :
@@ -608,34 +609,18 @@ def quant_func(x, amax, quantizer=weight_quantizer):
608609 if weight_quantizer ._calibrator is None :
609610 continue
610611
611- weight_quantizer .disable_quant ()
612- weight_quantizer .enable_calib ()
613-
612+ # Enable calibration mode for the weight quantizer
613+ enable_stats_collection (module )
614614 with enable_weight_access_and_writeback (module , model , name_to_module ):
615615 weight = module .weight
616616 weight_quantizer (weight )
617-
618- # Compute optimal amax and load it
619- for name , module in weight_quantizers_info :
620- weight_quantizer = module .weight_quantizer
621- if weight_quantizer ._calibrator is None :
622- continue
623-
624- cal = weight_quantizer ._calibrator
625- if cal .compute_amax () is not None :
626- weight_quantizer .load_calib_amax ()
627-
628- weight_quantizer .enable_quant ()
629- weight_quantizer .disable_calib ()
617+ finish_stats_collection (module , method = "mse" )
618+ weight_quantizer ._calibrator .reset ()
630619
631620 # Cleanup and free memory
632621 LocalHessianHelper .cache_mode = False
633622 for name , module in weight_quantizers_info :
634623 module .local_hessian .cleanup ()
635- if hasattr (module .weight_quantizer , "_calibrator" ):
636- cal = module .weight_quantizer ._calibrator
637- if hasattr (cal , "clear" ):
638- cal .clear ()
639624
640625 print_rank_0 ("local_hessian: Calibration complete." )
641626
0 commit comments