@@ -359,16 +359,42 @@ def mse_calibrate(
359359 weight_quantizers .append ((parent_module , weight_name , weight_quantizer ))
360360 seen_modules .add (parent_module )
361361
362- # Step 3: Calibrate weight quantizers once with MSE calibration
363- # This ensures weights are only calibrated once, not during every forward pass
364- for parent_module , weight_name , weight_quantizer in weight_quantizers :
362+ # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
363+ # This prevents massive memory accumulation seen in large models
364+ for idx , ( parent_module , weight_name , weight_quantizer ) in enumerate ( weight_quantizers ) :
365365 # Enable calibration mode for the weight quantizer
366- enable_stats_collection (parent_module )
366+ weight_quantizer .disable_quant ()
367+ weight_quantizer .enable_calib ()
367368 with enable_weight_access_and_writeback (parent_module , model ):
368369 weight = getattr (parent_module , weight_name )
369370 weight_quantizer (weight )
370- finish_stats_collection (parent_module , method = "mse" )
371- weight_quantizer ._calibrator .reset ()
371+
372+ # IMMEDIATELY compute amax and reset calibrator to free memory
373+ cal = getattr (weight_quantizer , "_calibrator" , None )
374+ if cal is not None and cal .compute_amax () is not None :
375+ weight_quantizer .load_calib_amax ()
376+
377+ weight_quantizer .enable_quant ()
378+ weight_quantizer .disable_calib ()
379+
380+ # Synchronize ALL CUDA devices before resetting to ensure all async operations complete
381+ # This is critical for multi-GPU setups where tensors may be on different devices
382+ if torch .cuda .is_available ():
383+ for dev_id in range (torch .cuda .device_count ()):
384+ torch .cuda .synchronize (torch .device (f"cuda:{ dev_id } " ))
385+
386+ if cal is not None and hasattr (cal , "reset" ):
387+ cal .reset ()
388+
389+ if (idx + 1 ) % 10 == 0 and torch .cuda .is_available ():
390+ for dev_id in range (torch .cuda .device_count ()):
391+ torch .cuda .synchronize (torch .device (f"cuda:{ dev_id } " ))
392+ torch .cuda .empty_cache ()
393+
394+ if torch .cuda .is_available ():
395+ for dev_id in range (torch .cuda .device_count ()):
396+ torch .cuda .synchronize (torch .device (f"cuda:{ dev_id } " ))
397+ torch .cuda .empty_cache ()
372398
373399 # TODO: Sync amax across distributed processes
374400
@@ -604,19 +630,50 @@ def quant_func(x, amax, quantizer=weight_quantizer):
604630 error_func = error_func ,
605631 )
606632
607- # Calibrate weights with local Hessian MSE
608- for name , module in weight_quantizers_info :
633+ # Free cached memory before heavy calibration
634+ if torch .cuda .is_available ():
635+ torch .cuda .empty_cache ()
636+
637+ # Process weights ONE AT A TIME with immediate amax computation and cleanup
638+ weight_list = [
639+ (name , module )
640+ for name , module in weight_quantizers_info
641+ if module .weight_quantizer ._calibrator is not None
642+ ]
643+
644+ for idx , (name , module ) in enumerate (weight_list ):
609645 weight_quantizer = module .weight_quantizer
610- if weight_quantizer ._calibrator is None :
611- continue
646+ cal = weight_quantizer ._calibrator
612647
613- # Enable calibration mode for the weight quantizer
614- enable_stats_collection (module )
648+ # Step 1: Calibrate this weight
649+ weight_quantizer .disable_quant ()
650+ weight_quantizer .enable_calib ()
615651 with enable_weight_access_and_writeback (module , model , name_to_module ):
616652 weight = module .weight
617653 weight_quantizer (weight )
618- finish_stats_collection (module , method = "mse" )
619- weight_quantizer ._calibrator .reset ()
654+
655+ # Step 2: IMMEDIATELY compute amax (before calibration data grows)
656+ if cal .compute_amax () is not None :
657+ weight_quantizer .load_calib_amax ()
658+
659+ weight_quantizer .enable_quant ()
660+ weight_quantizer .disable_calib ()
661+
662+ # Step 3: Sync all devices and reset calibrator for next weight
663+ if torch .cuda .is_available ():
664+ for dev_id in range (torch .cuda .device_count ()):
665+ torch .cuda .synchronize (torch .device (f"cuda:{ dev_id } " ))
666+
667+ if hasattr (cal , "reset" ):
668+ cal .reset ()
669+
670+ if (idx + 1 ) % 10 == 0 and torch .cuda .is_available ():
671+ torch .cuda .empty_cache ()
672+
673+ if torch .cuda .is_available ():
674+ for dev_id in range (torch .cuda .device_count ()):
675+ torch .cuda .synchronize (torch .device (f"cuda:{ dev_id } " ))
676+ torch .cuda .empty_cache ()
620677
621678 # Cleanup and free memory
622679 LocalHessianHelper .cache_mode = False
0 commit comments