Skip to content

Commit 1d71ba3

Browse files
committed
release moved tensors
1 parent a15733b commit 1d71ba3

2 files changed

Lines changed: 37 additions & 43 deletions

File tree

backends/aoti/aoti_backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,12 @@ def preprocess(
275275
os.remove(so_path)
276276
os.remove(blob_path)
277277

278+
# Release device memory held by tensors that ``move_to_device_pass``
279+
# placed on the target device. Default impl is a no-op; concrete
280+
# backends (e.g. CudaBackend) override this to free GPU memory before
281+
# the next preprocess call (e.g. for the next method).
282+
cls.release_moved_tensors(device_edge_program, compile_specs)
283+
278284
return PreprocessResult(
279285
processed_bytes=b"",
280286
debug_handle_map={},

backends/cuda/cuda_backend.py

Lines changed: 31 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -392,49 +392,37 @@ def _is_low_memory_mode(compile_specs: List[CompileSpec]) -> bool:
392392
return False
393393

394394
@classmethod
395-
def preprocess_multimethod(
395+
def release_moved_tensors(
396396
cls,
397-
edge_programs,
398-
compile_specs,
399-
):
397+
device_edge_program,
398+
compile_specs: List[CompileSpec],
399+
) -> None:
400400
"""
401-
Override of base preprocess_multimethod to run aggressive GPU cleanup
402-
between methods (e.g. decode then prefill). Inductor caches hold CUDA
403-
tensors from the first compilation, causing the second to OOM under
404-
tight VRAM caps (e.g. 24GB simulating an RTX 4090).
405-
406-
The aggressive cleanup (resizing every CUDA tensor's storage to 0)
407-
is only enabled for methods that opt into ``low_memory_mode="ON"``
408-
— it can otherwise break models that expect their CUDA tensors to
409-
stay live across method preprocessing.
401+
Free GPU memory held by tensors that ``move_to_device_pass`` placed
402+
on CUDA (params, buffers, and constants of ``device_edge_program``).
403+
404+
Resizing the underlying storage to 0 returns those bytes to PyTorch's
405+
caching allocator, so the next ``preprocess`` call (e.g. for the
406+
next method in a multi-method export) can reuse them when its own
407+
``move_to_device_pass`` runs.
410408
"""
411-
import gc
412-
413-
preprocess_results = {}
414-
for method_name, programs in edge_programs.items():
415-
assert method_name in compile_specs
416-
compile_specs_for_method = compile_specs[method_name]
417-
assert len(compile_specs_for_method) == len(programs)
418-
results_for_method = []
419-
for program, compile_spec_for_program in zip(
420-
programs, compile_specs_for_method
421-
):
422-
preprocess_result = cls.preprocess(program, compile_spec_for_program)
423-
results_for_method.append(preprocess_result)
424-
425-
# GPU cleanup between methods. Aggressive storage resize is
426-
# only run for methods that opt into low-memory mode.
427-
if torch.cuda.is_available():
428-
if cls._is_low_memory_mode(compile_spec_for_program):
429-
gc.collect()
430-
for obj in gc.get_objects():
431-
if isinstance(obj, torch.Tensor) and obj.is_cuda:
432-
try:
433-
obj.untyped_storage().resize_(0)
434-
except Exception:
435-
pass
436-
gc.collect()
437-
torch.cuda.empty_cache()
438-
439-
preprocess_results[method_name] = results_for_method
440-
return preprocess_results
409+
if not torch.cuda.is_available():
410+
return
411+
412+
pools = []
413+
state_dict = getattr(device_edge_program, "state_dict", None)
414+
if state_dict:
415+
pools.append(state_dict.values())
416+
constants = getattr(device_edge_program, "constants", None)
417+
if constants:
418+
pools.append(constants.values())
419+
420+
for pool in pools:
421+
for tensor in pool:
422+
if isinstance(tensor, torch.Tensor) and tensor.is_cuda:
423+
try:
424+
tensor.untyped_storage().resize_(0)
425+
except Exception:
426+
# Some storages may be shared / non-resizable; skip
427+
# them rather than failing the export.
428+
pass

0 commit comments

Comments
 (0)