Skip to content

Commit 7eaec3d

Browse files
committed
memory refinement
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent b2f158a commit 7eaec3d

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

modelopt/torch/quantization/model_calib.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,15 @@ def capture(weight_quantizer, weight, input_tensor):
804804
error_func_for=lambda q: error_funcs.get(id(q)),
805805
)
806806

807+
# Free the per-block Hessians (pinned by error_func closures) and the sweep's cached
808+
# allocations so export starts from a defragmented allocator.
809+
error_funcs.clear()
810+
for module in name_to_module.values():
811+
if isinstance(module, TensorQuantizer) and isinstance(module._calibrator, MseCalibrator):
812+
module._calibrator._error_func = None
813+
if torch.cuda.is_available():
814+
torch.cuda.empty_cache()
815+
807816
if debug:
808817
model._local_hessian_accumulators = accumulators
809818

0 commit comments

Comments
 (0)