Skip to content

Commit d686ac9

Browse files
committed
Memory optimization
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent 35a6aea commit d686ac9

1 file changed

Lines changed: 71 additions & 14 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 71 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -359,16 +359,42 @@ def mse_calibrate(
359359
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
360360
seen_modules.add(parent_module)
361361

362-
# Step 3: Calibrate weight quantizers once with MSE calibration
363-
# This ensures weights are only calibrated once, not during every forward pass
364-
for parent_module, weight_name, weight_quantizer in weight_quantizers:
362+
# Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
363+
# This prevents massive memory accumulation seen in large models
364+
for idx, (parent_module, weight_name, weight_quantizer) in enumerate(weight_quantizers):
365365
# Enable calibration mode for the weight quantizer
366-
enable_stats_collection(parent_module)
366+
weight_quantizer.disable_quant()
367+
weight_quantizer.enable_calib()
367368
with enable_weight_access_and_writeback(parent_module, model):
368369
weight = getattr(parent_module, weight_name)
369370
weight_quantizer(weight)
370-
finish_stats_collection(parent_module, method="mse")
371-
weight_quantizer._calibrator.reset()
371+
372+
# IMMEDIATELY compute amax and reset calibrator to free memory
373+
cal = getattr(weight_quantizer, "_calibrator", None)
374+
if cal is not None and cal.compute_amax() is not None:
375+
weight_quantizer.load_calib_amax()
376+
377+
weight_quantizer.enable_quant()
378+
weight_quantizer.disable_calib()
379+
380+
# Synchronize ALL CUDA devices before resetting to ensure all async operations complete
381+
# This is critical for multi-GPU setups where tensors may be on different devices
382+
if torch.cuda.is_available():
383+
for dev_id in range(torch.cuda.device_count()):
384+
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
385+
386+
if cal is not None and hasattr(cal, "reset"):
387+
cal.reset()
388+
389+
if (idx + 1) % 10 == 0 and torch.cuda.is_available():
390+
for dev_id in range(torch.cuda.device_count()):
391+
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
392+
torch.cuda.empty_cache()
393+
394+
if torch.cuda.is_available():
395+
for dev_id in range(torch.cuda.device_count()):
396+
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
397+
torch.cuda.empty_cache()
372398

373399
# TODO: Sync amax across distributed processes
374400

@@ -604,19 +630,50 @@ def quant_func(x, amax, quantizer=weight_quantizer):
604630
error_func=error_func,
605631
)
606632

607-
# Calibrate weights with local Hessian MSE
608-
for name, module in weight_quantizers_info:
633+
# Free cached memory before heavy calibration
634+
if torch.cuda.is_available():
635+
torch.cuda.empty_cache()
636+
637+
# Process weights ONE AT A TIME with immediate amax computation and cleanup
638+
weight_list = [
639+
(name, module)
640+
for name, module in weight_quantizers_info
641+
if module.weight_quantizer._calibrator is not None
642+
]
643+
644+
for idx, (name, module) in enumerate(weight_list):
609645
weight_quantizer = module.weight_quantizer
610-
if weight_quantizer._calibrator is None:
611-
continue
646+
cal = weight_quantizer._calibrator
612647

613-
# Enable calibration mode for the weight quantizer
614-
enable_stats_collection(module)
648+
# Step 1: Calibrate this weight
649+
weight_quantizer.disable_quant()
650+
weight_quantizer.enable_calib()
615651
with enable_weight_access_and_writeback(module, model, name_to_module):
616652
weight = module.weight
617653
weight_quantizer(weight)
618-
finish_stats_collection(module, method="mse")
619-
weight_quantizer._calibrator.reset()
654+
655+
# Step 2: IMMEDIATELY compute amax (before calibration data grows)
656+
if cal.compute_amax() is not None:
657+
weight_quantizer.load_calib_amax()
658+
659+
weight_quantizer.enable_quant()
660+
weight_quantizer.disable_calib()
661+
662+
# Step 3: Sync all devices and reset calibrator for next weight
663+
if torch.cuda.is_available():
664+
for dev_id in range(torch.cuda.device_count()):
665+
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
666+
667+
if hasattr(cal, "reset"):
668+
cal.reset()
669+
670+
if (idx + 1) % 10 == 0 and torch.cuda.is_available():
671+
torch.cuda.empty_cache()
672+
673+
if torch.cuda.is_available():
674+
for dev_id in range(torch.cuda.device_count()):
675+
torch.cuda.synchronize(torch.device(f"cuda:{dev_id}"))
676+
torch.cuda.empty_cache()
620677

621678
# Cleanup and free memory
622679
LocalHessianHelper.cache_mode = False

0 commit comments

Comments
 (0)