99import typing
1010from abc import ABC , abstractmethod
1111from enum import Enum
12- from typing import Any , Dict , List , Set
12+ from typing import Any , Dict , List , Optional , Set
1313
1414import torch
1515from executorch .backends .aoti .passes .replace_view_copy_with_view import (
@@ -88,8 +88,14 @@ def save_data_externally(cls) -> bool:
8888 return False
8989
9090 @classmethod
91- def get_extra_aoti_compile_context_manager (cls ):
92- """Return extra context manager to apply during aoti_compile stage. By default returns an empty context manager."""
91+ def get_extra_aoti_compile_context_manager (
92+ cls , compile_specs : Optional [List [CompileSpec ]] = None
93+ ):
94+ """Return extra context manager to apply during aoti_compile stage. By default returns an empty context manager.
95+
96+ Subclasses may inspect ``compile_specs`` to opt into behaviors that
97+ only apply to specific methods/models (e.g. low-memory export).
98+ """
9399 return contextlib .nullcontext ()
94100
95101 @classmethod
@@ -105,6 +111,24 @@ def codesign_so(cls, so_path: str, compile_specs: List[CompileSpec]) -> None:
105111 """
106112 return
107113
114+ @classmethod
115+ def release_moved_tensors (
116+ cls ,
117+ device_edge_program : ExportedProgram ,
118+ compile_specs : List [CompileSpec ],
119+ ) -> None :
120+ """Release device memory held by tensors that ``move_to_device_pass``
121+ placed on the target device.
122+
123+ Called at the end of ``preprocess`` so that the next ``preprocess``
124+ call (e.g. for the next method in a multi-method export) can reuse
125+ the freed memory. Override in concrete backends (e.g. ``CudaBackend``)
126+ to actually free device memory.
127+
128+ Default: no-op.
129+ """
130+ return
131+
108132 @classmethod
109133 @contextlib .contextmanager
110134 def collect_unsupported_fallback_kernels (cls , missing_fallback_kernels : Set [str ]):
@@ -208,7 +232,7 @@ def preprocess(
208232 # Compile with fallback kernel collection
209233 with cls .collect_unsupported_fallback_kernels (
210234 missing_fallback_kernels
211- ), torch .no_grad (), cls .get_extra_aoti_compile_context_manager ():
235+ ), torch .no_grad (), cls .get_extra_aoti_compile_context_manager (compile_specs ):
212236 paths = torch ._inductor .aot_compile (
213237 edge_program_module , tuple (user_input_placeholders ), options = options
214238 )
@@ -269,6 +293,12 @@ def preprocess(
269293 os .remove (so_path )
270294 os .remove (blob_path )
271295
296+ # Release device memory held by tensors that ``move_to_device_pass``
297+ # placed on the target device. Default impl is a no-op; concrete
298+ # backends (e.g. CudaBackend) override this to free GPU memory before
299+ # the next preprocess call (e.g. for the next method).
300+ cls .release_moved_tensors (device_edge_program , compile_specs )
301+
272302 return PreprocessResult (
273303 processed_bytes = b"" ,
274304 debug_handle_map = {},
0 commit comments