@@ -392,49 +392,37 @@ def _is_low_memory_mode(compile_specs: List[CompileSpec]) -> bool:
392392 return False
393393
394394 @classmethod
395- def preprocess_multimethod (
395+ def release_moved_tensors (
396396 cls ,
397- edge_programs ,
398- compile_specs ,
399- ):
397+ device_edge_program ,
398+ compile_specs : List [ CompileSpec ] ,
399+ ) -> None :
400400 """
401- Override of base preprocess_multimethod to run aggressive GPU cleanup
402- between methods (e.g. decode then prefill). Inductor caches hold CUDA
403- tensors from the first compilation, causing the second to OOM under
404- tight VRAM caps (e.g. 24GB simulating an RTX 4090).
405-
406- The aggressive cleanup (resizing every CUDA tensor's storage to 0)
407- is only enabled for methods that opt into ``low_memory_mode="ON"``
408- — it can otherwise break models that expect their CUDA tensors to
409- stay live across method preprocessing.
401+ Free GPU memory held by tensors that ``move_to_device_pass`` placed
402+ on CUDA (params, buffers, and constants of ``device_edge_program``).
403+
404+ Resizing the underlying storage to 0 returns those bytes to PyTorch's
405+ caching allocator, so the next ``preprocess`` call (e.g. for the
406+ next method in a multi-method export) can reuse them when its own
407+ ``move_to_device_pass`` runs.
410408 """
411- import gc
412-
413- preprocess_results = {}
414- for method_name , programs in edge_programs .items ():
415- assert method_name in compile_specs
416- compile_specs_for_method = compile_specs [method_name ]
417- assert len (compile_specs_for_method ) == len (programs )
418- results_for_method = []
419- for program , compile_spec_for_program in zip (
420- programs , compile_specs_for_method
421- ):
422- preprocess_result = cls .preprocess (program , compile_spec_for_program )
423- results_for_method .append (preprocess_result )
424-
425- # GPU cleanup between methods. Aggressive storage resize is
426- # only run for methods that opt into low-memory mode.
427- if torch .cuda .is_available ():
428- if cls ._is_low_memory_mode (compile_spec_for_program ):
429- gc .collect ()
430- for obj in gc .get_objects ():
431- if isinstance (obj , torch .Tensor ) and obj .is_cuda :
432- try :
433- obj .untyped_storage ().resize_ (0 )
434- except Exception :
435- pass
436- gc .collect ()
437- torch .cuda .empty_cache ()
438-
439- preprocess_results [method_name ] = results_for_method
440- return preprocess_results
409+ if not torch .cuda .is_available ():
410+ return
411+
412+ pools = []
413+ state_dict = getattr (device_edge_program , "state_dict" , None )
414+ if state_dict :
415+ pools .append (state_dict .values ())
416+ constants = getattr (device_edge_program , "constants" , None )
417+ if constants :
418+ pools .append (constants .values ())
419+
420+ for pool in pools :
421+ for tensor in pool :
422+ if isinstance (tensor , torch .Tensor ) and tensor .is_cuda :
423+ try :
424+ tensor .untyped_storage ().resize_ (0 )
425+ except Exception :
426+ # Some storages may be shared / non-resizable; skip
427+ # them rather than failing the export.
428+ pass
0 commit comments