@@ -61,29 +61,81 @@ def _is_cpu_clone_active() -> bool:
6161 return getattr (_CPU_CLONE_GUARD , "active" , False )
6262
6363
64+ def _full_zeros_preserving_strides (x : torch .Tensor , device ) -> torch .Tensor :
65+ """Allocate a zero-filled tensor matching ``x``'s size/stride/dtype on ``device``.
66+
67+ Used to re-synthesize KV-cache buffers whose storage was freed (``resize_(0)``)
68+ during the low-memory device move. KV content is all zeros, so this exactly
69+ reproduces the buffer for both the lifted graph value and serialization.
70+ """
71+ needed = 1
72+ for size , stride in zip (x .size (), x .stride ()):
73+ needed += (size - 1 ) * stride
74+ buf = torch .zeros (int (needed ), dtype = x .dtype , device = device )
75+ return torch .as_strided (buf , x .size (), x .stride ())
76+
77+
78+ def _is_emptied (x ) -> bool :
79+ return (
80+ isinstance (x , torch .Tensor )
81+ and x .numel () > 0
82+ and x .untyped_storage ().nbytes () == 0
83+ )
84+
85+
6486@contextlib .contextmanager
6587def _compile_time_cpu_clones (target_device : torch .device ):
6688 """Force AOTI's mutated-buffer clones onto CPU while preserving the
6789 serialized constants' target device."""
68- from torch ._inductor import compile_fx as _cfx
90+ from torch ._inductor import compile_fx as _cfx , graph as _graph
6991 from torch ._inductor .codegen .cpp_wrapper_cpu import CppWrapperCpu as _Cpp
92+ from torch ._inductor .graph import GraphLowering as _GL
7093
7194 orig_clone = _cfx .clone_preserve_strides
7295 orig_codegen_device = _Cpp .codegen_device
96+ orig_get_const = _GL .get_original_value_of_constant
97+ orig_is_same = _graph .is_same_tensor
98+
99+ def _is_same_skip_emptied (data , value ):
100+ # KV buffers freed via resize_(0) all have data_ptr 0, so the stock
101+ # is_same_tensor would treat every same-shape KV constant as a duplicate
102+ # and collapse the 60 layers' caches into one — the runtime needs each
103+ # FQN's own buffer, so the collapsed ones load uninitialized garbage.
104+ # Never dedup an emptied tensor.
105+ if _is_emptied (data ) or _is_emptied (value ):
106+ return False
107+ return orig_is_same (data , value )
73108
74109 def _cpu_clone_preserve_strides (x : torch .Tensor ) -> torch .Tensor :
75- # `clone_preserve_strides` is shared by `_unlift_graph` (clones
76- # lifted buffers — can be safely kept on CPU) and by autotuning code
77- # in `triton_heuristics.py` (clones for benchmark — must stay on
78- # GPU for Triton). Discriminate by caller frame so we only force
79- # CPU clones for the buffer-lifting path.
110+ # `clone_preserve_strides` is shared by `_unlift_graph` (clones lifted
111+ # buffers — can be safely kept on CPU) and by autotuning code in
112+ # `triton_heuristics.py` (clones for benchmark — must stay on GPU for
113+ # Triton). Discriminate by caller frame so we only force CPU clones for
114+ # the buffer-lifting path.
80115 import sys
81116
82117 caller = sys ._getframe (1 ).f_code .co_name
83118 if caller == "_unlift_graph" :
119+ # KV-cache buffers are emptied (storage resize_(0)) by the low-memory
120+ # device move so they never occupy GPU memory during compile. Their
121+ # content is all zeros, so re-synthesize zeros (on CPU, strides
122+ # preserved) instead of cloning the now-empty storage.
123+ if _is_emptied (x ):
124+ return _full_zeros_preserving_strides (x , "cpu" )
84125 return orig_clone (x ).cpu ()
85126 return orig_clone (x )
86127
128+ def _get_const_synthesize_zeros (self , name ):
129+ # AOTI serializes each constant via get_original_value_of_constant ->
130+ # _to_bytes. For KV buffers we freed with resize_(0) this would otherwise
131+ # fall back to the empty-storage constant and write 0 bytes, producing a
132+ # .ptd with an uninitialized cache. Re-synthesize the zeros so the blob
133+ # holds a correctly-zeroed KV cache.
134+ value = orig_get_const (self , name )
135+ if _is_emptied (value ):
136+ return _full_zeros_preserving_strides (value , "cpu" )
137+ return value
138+
87139 def _codegen_device_target_aware (self , device ):
88140 # Translate accidental CPU device strings back to the model target
89141 # device only when a constant we forced to CPU is being serialized.
@@ -99,6 +151,8 @@ def _codegen_device_target_aware(self, device):
99151
100152 _cfx .clone_preserve_strides = _cpu_clone_preserve_strides
101153 _Cpp .codegen_device = _codegen_device_target_aware
154+ _GL .get_original_value_of_constant = _get_const_synthesize_zeros
155+ _graph .is_same_tensor = _is_same_skip_emptied
102156 prev_active = getattr (_CPU_CLONE_GUARD , "active" , False )
103157 _CPU_CLONE_GUARD .active = True
104158 try :
@@ -107,6 +161,107 @@ def _codegen_device_target_aware(self, device):
107161 _CPU_CLONE_GUARD .active = prev_active
108162 _cfx .clone_preserve_strides = orig_clone
109163 _Cpp .codegen_device = orig_codegen_device
164+ _GL .get_original_value_of_constant = orig_get_const
165+ _graph .is_same_tensor = orig_is_same
166+
167+
168+ def _is_kv_buffer (name , v ) -> bool :
169+ """True only for an actual KV-cache *content* buffer that is safe to free.
170+
171+ The low-memory path (``_move_to_device_resize_kv``) frees every buffer this
172+ matches and re-synthesizes it as ZEROS in both the lifted graph and the
173+ serialized ``.ptd`` (see ``_full_zeros_preserving_strides`` /
174+ ``_get_const_synthesize_zeros``). That is only valid for genuine KV *content*,
175+ which is all-zeros at export time (caches start empty).
176+
177+ It must NOT match the non-zero constants that some KV-cache modules register
178+ alongside the cache — e.g. TurboQuant registers its codebook/rotation
179+ (``centroids``/``boundaries``/``rotation``/``rotation_T``) as buffers on the
180+ ``kv_cache`` module, so their FQNs also contain ``kv_cache``. Freeing+zeroing
181+ those silently corrupts the serialized model (TQ4 dequant -> 0 -> garbage).
182+ Gate on the buffer actually being all-zeros so only empty KV content is freed;
183+ this is robust to any future constant name (a non-zero buffer is never freed).
184+ """
185+ if not isinstance (v , torch .Tensor ) or isinstance (v , torch .nn .Parameter ):
186+ return False
187+ if "kv_cache" not in name or v .numel () == 0 or v .is_meta :
188+ return False
189+ # Only the genuinely all-zero KV content may be freed + re-zeroed; non-zero
190+ # constants (TurboQuant centroids/rotation/...) must be preserved as-is.
191+ return bool (torch .count_nonzero (v ) == 0 )
192+
193+
194+ def _empty_strided_on_device (v , location ):
195+ """A device tensor with v's shape/stride/dtype but zero (freed) storage."""
196+ t = torch .empty_strided (v .shape , v .stride (), dtype = v .dtype , device = location )
197+ t .untyped_storage ().resize_ (0 ) # free bytes, keep device + shape/stride
198+ return t
199+
200+
201+ def _move_graph_nodes_to_device (graph_module , location ):
202+ """Point node device kwargs / aten.to.device targets / meta vals at location."""
203+ import torch .utils ._pytree as pytree
204+
205+ def _to_loc (v ):
206+ return v .to (location ) if isinstance (v , torch .Tensor ) else v
207+
208+ for m in graph_module .modules ():
209+ if not isinstance (m , torch .fx .GraphModule ):
210+ continue
211+ for node in m .graph .nodes :
212+ if "device" in node .kwargs :
213+ node .kwargs = {** node .kwargs , "device" : location }
214+ if node .op == "call_function" and node .target is torch .ops .aten .to .device :
215+ args = list (node .args )
216+ args [1 ] = location
217+ node .args = tuple (args )
218+ node .meta ["val" ] = pytree .tree_map (_to_loc , node .meta .get ("val" ))
219+
220+
221+ def _move_to_device_resize_kv (ep , location ):
222+ """``move_to_device_pass`` variant that frees KV-cache storage on-device.
223+
224+ Mirrors ``torch.export.passes.move_to_device_pass`` exactly, except KV-cache
225+ buffers (FQN contains ``kv_cache``) are placed on ``location`` but with their
226+ storage immediately freed via ``resize_(0)``. This keeps ``device ==
227+ location`` — so the fake-tensor device check on the ``index_copy`` cache
228+ update passes (``self`` and ``values`` both on cuda) — while no real KV bytes
229+ occupy the device during the AOTI compile. KV content is all zeros, so the
230+ emptied tensors are re-synthesized as zeros at the ``_unlift_graph`` clone
231+ (see ``_compile_time_cpu_clones``), which is reused as both the lifted initial
232+ value and the serialized ``.ptd`` constant. The empty/free is interleaved per
233+ tensor so the transient device peak is a single KV buffer, not the whole cache.
234+ Only ``kv_cache`` tensors are emptied (they are the lone large zero-buffers);
235+ every other tensor is moved normally so non-zero content is never lost.
236+ """
237+ import torch .utils ._pytree as pytree
238+
239+ for k , v in ep .state_dict .items ():
240+ if isinstance (v , torch .nn .Parameter ):
241+ ep ._state_dict [k ] = torch .nn .Parameter (v .to (location ), v .requires_grad )
242+ elif _is_kv_buffer (k , v ):
243+ ep ._state_dict [k ] = _empty_strided_on_device (v , location )
244+ else :
245+ ep ._state_dict [k ] = v .to (location )
246+
247+ for k , v in ep .constants .items ():
248+ if isinstance (v , torch .Tensor ):
249+ ep ._constants [k ] = (
250+ _empty_strided_on_device (v , location )
251+ if _is_kv_buffer (k , v )
252+ else v .to (location )
253+ )
254+
255+ if ep .example_inputs is not None :
256+ args , kwargs = ep .example_inputs
257+ ep ._example_inputs = (
258+ pytree .tree_map_only (torch .Tensor , lambda t : t .to (location ), args ),
259+ pytree .tree_map_only (torch .Tensor , lambda t : t .to (location ), kwargs ),
260+ )
261+
262+ _move_graph_nodes_to_device (ep .graph_module , location )
263+ ep .validate ()
264+ return ep
110265
111266
112267@final
@@ -424,6 +579,29 @@ def _is_low_memory_mode(compile_specs: List[CompileSpec]) -> bool:
424579 return spec .value .decode ("utf-8" ).upper () == "ON"
425580 return False
426581
582+ @classmethod
583+ def move_program_to_device (
584+ cls ,
585+ edge_program ,
586+ device : str ,
587+ compile_specs : List [CompileSpec ],
588+ ):
589+ """Move the program to ``device`` for AOTI compile.
590+
591+ On a low-memory export (``low_memory_mode="ON"``) the KV-cache buffers —
592+ which can be 10+ GiB at long context — are placed on-device but with their
593+ storage freed (``resize_(0)``), so they never occupy device memory during
594+ the autotune / cpp_wrapper compile while still satisfying the device-match
595+ check on the cache update. They are re-synthesized as zeros for the lifted
596+ graph and the serialized blob. This activates automatically with low-memory
597+ mode. Other (non-low-memory) exports use the stock pass.
598+ """
599+ from torch .export .passes import move_to_device_pass
600+
601+ if not cls ._is_low_memory_mode (compile_specs ):
602+ return move_to_device_pass (edge_program , device )
603+ return _move_to_device_resize_kv (edge_program , device )
604+
427605 @classmethod
428606 def release_moved_tensors (
429607 cls ,
0 commit comments