Skip to content

Commit 993cff5

Browse files
committed
[gemma4_31b][cuda] Export Gemma4-31B @128k under 32 GB
Three CUDA-export memory optimizations: - tq4_sdpa: add BLOCK_N=16 (and a BLOCK_M=32) autotune config. The superset is kept for big-shared-memory GPUs (A100/H100); the Triton autotuner auto-prunes configs that exceed a GPU's shared memory (OutOfResources -> inf), so the same config list also works on the 5090 (Blackwell, ~101 KB SMEM) where the previous smallest config did not fit. - int4_dispatch: chunk the inline _dequant_matmul along N for vocab-sized weights (N>65536, i.e. only the lm_head). Avoids transiently materializing the full ~10 GiB bf16 lm_head when AOTI executes the int4_plain_mm custom op during autotune / cpp_wrapper. The runtime decode path uses the C++ dp4a shim and the M>4 prefill inline path is below the threshold, so this never enters the runtime graph -> zero runtime / accuracy impact. Applied unconditionally (no flag). - cuda_backend / aoti_backend: skip occupying the GPU with the KV-cache buffers during AOTI compile (gated behind low_memory_mode). A new move_program_to_device hook places KV constants on the target device but immediately frees their storage (resize_(0)), so the fake-tensor device check passes while no real KV bytes sit on the GPU during autotune. The emptied buffers are re-synthesized as zeros at the _unlift_graph clone and at serialization, and excluded from constant dedup (resize_(0) gives every KV data_ptr 0, which would otherwise collapse same-shape caches across layers). Result on 2xA100: Gemma4-31B @128k no-TQ export peak 36.3 -> 27.0 GiB; the exported model runs correctly (output "...Paris.").
1 parent 1b726b2 commit 993cff5

4 files changed

Lines changed: 229 additions & 19 deletions

File tree

backends/aoti/aoti_backend.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,21 @@ def codesign_so(cls, so_path: str, compile_specs: List[CompileSpec]) -> None:
112112
"""
113113
return
114114

115+
@classmethod
116+
def move_program_to_device(
117+
cls,
118+
edge_program: ExportedProgram,
119+
device: str,
120+
compile_specs: List[CompileSpec],
121+
) -> ExportedProgram:
122+
"""Move the exported program to the target device for compilation.
123+
124+
Default implementation moves everything (params, buffers, constants) via
125+
``move_to_device_pass``. Concrete backends may override to keep large
126+
non-parameter tensors off the device during a low-memory export.
127+
"""
128+
return move_to_device_pass(edge_program, device)
129+
115130
@classmethod
116131
def release_moved_tensors(
117132
cls,
@@ -196,9 +211,13 @@ def preprocess(
196211
decomposition_table = cls.get_decomposition_table()
197212
options = cls.get_aoti_compile_options(compile_specs)
198213

199-
# Move the edge_program to the target device
200-
device_edge_program = move_to_device_pass(
201-
edge_program, device_name if device_name != "metal" else "mps"
214+
# Move the edge_program to the target device. Routed through a hook so
215+
# backends can keep large non-parameter tensors (e.g. KV-cache buffers)
216+
# off the device during a low-memory export.
217+
device_edge_program = cls.move_program_to_device(
218+
edge_program,
219+
device_name if device_name != "metal" else "mps",
220+
compile_specs,
202221
)
203222

204223
# Replace view_copy with view

backends/cuda/cuda_backend.py

Lines changed: 166 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,29 +61,81 @@ def _is_cpu_clone_active() -> bool:
6161
return getattr(_CPU_CLONE_GUARD, "active", False)
6262

6363

64+
def _full_zeros_preserving_strides(x: torch.Tensor, device) -> torch.Tensor:
65+
"""Allocate a zero-filled tensor matching ``x``'s size/stride/dtype on ``device``.
66+
67+
Used to re-synthesize KV-cache buffers whose storage was freed (``resize_(0)``)
68+
during the low-memory device move. KV content is all zeros, so this exactly
69+
reproduces the buffer for both the lifted graph value and serialization.
70+
"""
71+
needed = 1
72+
for size, stride in zip(x.size(), x.stride()):
73+
needed += (size - 1) * stride
74+
buf = torch.zeros(int(needed), dtype=x.dtype, device=device)
75+
return torch.as_strided(buf, x.size(), x.stride())
76+
77+
78+
def _is_emptied(x) -> bool:
79+
return (
80+
isinstance(x, torch.Tensor)
81+
and x.numel() > 0
82+
and x.untyped_storage().nbytes() == 0
83+
)
84+
85+
6486
@contextlib.contextmanager
6587
def _compile_time_cpu_clones(target_device: torch.device):
6688
"""Force AOTI's mutated-buffer clones onto CPU while preserving the
6789
serialized constants' target device."""
68-
from torch._inductor import compile_fx as _cfx
90+
from torch._inductor import compile_fx as _cfx, graph as _graph
6991
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu as _Cpp
92+
from torch._inductor.graph import GraphLowering as _GL
7093

7194
orig_clone = _cfx.clone_preserve_strides
7295
orig_codegen_device = _Cpp.codegen_device
96+
orig_get_const = _GL.get_original_value_of_constant
97+
orig_is_same = _graph.is_same_tensor
98+
99+
def _is_same_skip_emptied(data, value):
100+
# KV buffers freed via resize_(0) all have data_ptr 0, so the stock
101+
# is_same_tensor would treat every same-shape KV constant as a duplicate
102+
# and collapse the 60 layers' caches into one — the runtime needs each
103+
# FQN's own buffer, so the collapsed ones load uninitialized garbage.
104+
# Never dedup an emptied tensor.
105+
if _is_emptied(data) or _is_emptied(value):
106+
return False
107+
return orig_is_same(data, value)
73108

74109
def _cpu_clone_preserve_strides(x: torch.Tensor) -> torch.Tensor:
75-
# `clone_preserve_strides` is shared by `_unlift_graph` (clones
76-
# lifted buffers — can be safely kept on CPU) and by autotuning code
77-
# in `triton_heuristics.py` (clones for benchmark — must stay on
78-
# GPU for Triton). Discriminate by caller frame so we only force
79-
# CPU clones for the buffer-lifting path.
110+
# `clone_preserve_strides` is shared by `_unlift_graph` (clones lifted
111+
# buffers — can be safely kept on CPU) and by autotuning code in
112+
# `triton_heuristics.py` (clones for benchmark — must stay on GPU for
113+
# Triton). Discriminate by caller frame so we only force CPU clones for
114+
# the buffer-lifting path.
80115
import sys
81116

82117
caller = sys._getframe(1).f_code.co_name
83118
if caller == "_unlift_graph":
119+
# KV-cache buffers are emptied (storage resize_(0)) by the low-memory
120+
# device move so they never occupy GPU memory during compile. Their
121+
# content is all zeros, so re-synthesize zeros (on CPU, strides
122+
# preserved) instead of cloning the now-empty storage.
123+
if _is_emptied(x):
124+
return _full_zeros_preserving_strides(x, "cpu")
84125
return orig_clone(x).cpu()
85126
return orig_clone(x)
86127

128+
def _get_const_synthesize_zeros(self, name):
129+
# AOTI serializes each constant via get_original_value_of_constant ->
130+
# _to_bytes. For KV buffers we freed with resize_(0) this would otherwise
131+
# fall back to the empty-storage constant and write 0 bytes, producing a
132+
# .ptd with an uninitialized cache. Re-synthesize the zeros so the blob
133+
# holds a correctly-zeroed KV cache.
134+
value = orig_get_const(self, name)
135+
if _is_emptied(value):
136+
return _full_zeros_preserving_strides(value, "cpu")
137+
return value
138+
87139
def _codegen_device_target_aware(self, device):
88140
# Translate accidental CPU device strings back to the model target
89141
# device only when a constant we forced to CPU is being serialized.
@@ -99,6 +151,8 @@ def _codegen_device_target_aware(self, device):
99151

100152
_cfx.clone_preserve_strides = _cpu_clone_preserve_strides
101153
_Cpp.codegen_device = _codegen_device_target_aware
154+
_GL.get_original_value_of_constant = _get_const_synthesize_zeros
155+
_graph.is_same_tensor = _is_same_skip_emptied
102156
prev_active = getattr(_CPU_CLONE_GUARD, "active", False)
103157
_CPU_CLONE_GUARD.active = True
104158
try:
@@ -107,6 +161,89 @@ def _codegen_device_target_aware(self, device):
107161
_CPU_CLONE_GUARD.active = prev_active
108162
_cfx.clone_preserve_strides = orig_clone
109163
_Cpp.codegen_device = orig_codegen_device
164+
_GL.get_original_value_of_constant = orig_get_const
165+
_graph.is_same_tensor = orig_is_same
166+
167+
168+
def _is_kv_buffer(name, v) -> bool:
169+
return (
170+
isinstance(v, torch.Tensor)
171+
and not isinstance(v, torch.nn.Parameter)
172+
and "kv_cache" in name
173+
)
174+
175+
176+
def _empty_strided_on_device(v, location):
177+
"""A device tensor with v's shape/stride/dtype but zero (freed) storage."""
178+
t = torch.empty_strided(v.shape, v.stride(), dtype=v.dtype, device=location)
179+
t.untyped_storage().resize_(0) # free bytes, keep device + shape/stride
180+
return t
181+
182+
183+
def _move_graph_nodes_to_device(graph_module, location):
184+
"""Point node device kwargs / aten.to.device targets / meta vals at location."""
185+
import torch.utils._pytree as pytree
186+
187+
def _to_loc(v):
188+
return v.to(location) if isinstance(v, torch.Tensor) else v
189+
190+
for m in graph_module.modules():
191+
if not isinstance(m, torch.fx.GraphModule):
192+
continue
193+
for node in m.graph.nodes:
194+
if "device" in node.kwargs:
195+
node.kwargs = {**node.kwargs, "device": location}
196+
if node.op == "call_function" and node.target is torch.ops.aten.to.device:
197+
args = list(node.args)
198+
args[1] = location
199+
node.args = tuple(args)
200+
node.meta["val"] = pytree.tree_map(_to_loc, node.meta.get("val"))
201+
202+
203+
def _move_to_device_resize_kv(ep, location):
204+
"""``move_to_device_pass`` variant that frees KV-cache storage on-device.
205+
206+
Mirrors ``torch.export.passes.move_to_device_pass`` exactly, except KV-cache
207+
buffers (FQN contains ``kv_cache``) are placed on ``location`` but with their
208+
storage immediately freed via ``resize_(0)``. This keeps ``device ==
209+
location`` — so the fake-tensor device check on the ``index_copy`` cache
210+
update passes (``self`` and ``values`` both on cuda) — while no real KV bytes
211+
occupy the device during the AOTI compile. KV content is all zeros, so the
212+
emptied tensors are re-synthesized as zeros at the ``_unlift_graph`` clone
213+
(see ``_compile_time_cpu_clones``), which is reused as both the lifted initial
214+
value and the serialized ``.ptd`` constant. The empty/free is interleaved per
215+
tensor so the transient device peak is a single KV buffer, not the whole cache.
216+
Only ``kv_cache`` tensors are emptied (they are the lone large zero-buffers);
217+
every other tensor is moved normally so non-zero content is never lost.
218+
"""
219+
import torch.utils._pytree as pytree
220+
221+
for k, v in ep.state_dict.items():
222+
if isinstance(v, torch.nn.Parameter):
223+
ep._state_dict[k] = torch.nn.Parameter(v.to(location), v.requires_grad)
224+
elif _is_kv_buffer(k, v):
225+
ep._state_dict[k] = _empty_strided_on_device(v, location)
226+
else:
227+
ep._state_dict[k] = v.to(location)
228+
229+
for k, v in ep.constants.items():
230+
if isinstance(v, torch.Tensor):
231+
ep._constants[k] = (
232+
_empty_strided_on_device(v, location)
233+
if _is_kv_buffer(k, v)
234+
else v.to(location)
235+
)
236+
237+
if ep.example_inputs is not None:
238+
args, kwargs = ep.example_inputs
239+
ep._example_inputs = (
240+
pytree.tree_map_only(torch.Tensor, lambda t: t.to(location), args),
241+
pytree.tree_map_only(torch.Tensor, lambda t: t.to(location), kwargs),
242+
)
243+
244+
_move_graph_nodes_to_device(ep.graph_module, location)
245+
ep.validate()
246+
return ep
110247

111248

112249
@final
@@ -424,6 +561,29 @@ def _is_low_memory_mode(compile_specs: List[CompileSpec]) -> bool:
424561
return spec.value.decode("utf-8").upper() == "ON"
425562
return False
426563

564+
@classmethod
565+
def move_program_to_device(
566+
cls,
567+
edge_program,
568+
device: str,
569+
compile_specs: List[CompileSpec],
570+
):
571+
"""Move the program to ``device`` for AOTI compile.
572+
573+
On a low-memory export (``low_memory_mode="ON"``) the KV-cache buffers —
574+
which can be 10+ GiB at long context — are placed on-device but with their
575+
storage freed (``resize_(0)``), so they never occupy device memory during
576+
the autotune / cpp_wrapper compile while still satisfying the device-match
577+
check on the cache update. They are re-synthesized as zeros for the lifted
578+
graph and the serialized blob. This activates automatically with low-memory
579+
mode. Other (non-low-memory) exports use the stock pass.
580+
"""
581+
from torch.export.passes import move_to_device_pass
582+
583+
if not cls._is_low_memory_mode(compile_specs):
584+
return move_to_device_pass(edge_program, device)
585+
return _move_to_device_resize_kv(edge_program, device)
586+
427587
@classmethod
428588
def release_moved_tensors(
429589
cls,

backends/cuda/quantize_op_dispatch/int4_dispatch.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,28 +60,54 @@ def _cuda(self, qdata, scale, zero, group_size):
6060
return _dequant_matmul(self, qdata, scale, zero, group_size)
6161

6262

63+
# Chunked dequant for the export GPU budget. The lm_head dequant (N = vocab_size,
64+
# e.g. 262144) runs through the int4_plain_mm custom op (M=1); AOTI executes that
65+
# op's CUDA impl during autotune / cpp_wrapper codegen, where it transiently holds
66+
# ~5 full-size bf16 temporaries (low/high/data/data-z/w_deq) — ~10 GiB for a
67+
# 262144-row weight even though the final w_deq is only ~2.6 GiB. Chunking along N
68+
# caps that at ~chunk rows. It is numerically identical (F.linear output rows are
69+
# independent), and because only the lm_head (custom-op) path crosses the N
70+
# threshold — never the M>4 prefill inline path — it never enters the runtime
71+
# graph: ZERO runtime / accuracy impact. Applied unconditionally to any weight
72+
# whose row count exceeds the threshold.
73+
_DEQUANT_N_THRESHOLD = 65536
74+
_DEQUANT_N_CHUNK = 32768
75+
76+
6377
def _dequant_matmul(x, qdata, scale, zero, group_size):
6478
"""Dequant INT4 weights to input dtype and call F.linear.
6579
6680
scale/zero are in the coalesced [N, n_groups] layout (baked into the
6781
weight constant at pack time), aligned row-for-row with qdata's [N, *].
82+
83+
Large weights (N > threshold, i.e. the lm_head) are chunked along N to bound
84+
the dequant intermediate (see note above); smaller weights take the original
85+
single-shot dequant.
6886
"""
6987
N, K_half = qdata.shape
7088
K = K_half * 2
7189
n_groups = K // group_size
7290
gs_half = group_size // 2
7391
dtype = x.dtype
7492

75-
p = qdata.to(torch.uint8).reshape(N, n_groups, gs_half)
76-
low = (p & 0x0F).to(dtype)
77-
high = ((p >> 4) & 0x0F).to(dtype)
78-
data = torch.stack([low, high], dim=-1).reshape(N, n_groups, group_size)
79-
80-
s = scale.to(dtype).unsqueeze(-1)
81-
z = zero.to(dtype).unsqueeze(-1)
82-
w_deq = ((data - z) * s).reshape(N, K)
83-
84-
return F.linear(x, w_deq)
93+
def _dq(qd, sc, ze, rows):
94+
p = qd.to(torch.uint8).reshape(rows, n_groups, gs_half)
95+
low = (p & 0x0F).to(dtype)
96+
high = ((p >> 4) & 0x0F).to(dtype)
97+
data = torch.stack([low, high], dim=-1).reshape(rows, n_groups, group_size)
98+
s = sc.to(dtype).unsqueeze(-1)
99+
z = ze.to(dtype).unsqueeze(-1)
100+
w_deq = ((data - z) * s).reshape(rows, K)
101+
return F.linear(x, w_deq)
102+
103+
if N <= _DEQUANT_N_THRESHOLD:
104+
return _dq(qdata, scale, zero, N)
105+
106+
outs = []
107+
for i in range(0, N, _DEQUANT_N_CHUNK):
108+
j = min(i + _DEQUANT_N_CHUNK, N)
109+
outs.append(_dq(qdata[i:j], scale[i:j], zero[i:j], j - i))
110+
return torch.cat(outs, dim=-1)
85111

86112

87113
# ---------------------------------------------------------------------------

backends/cuda/triton/kernels/tq4_sdpa.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,10 @@ def _tq4_sdpa_fwd_kernel_body(
294294
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2),
295295
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3),
296296
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2),
297+
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=2),
298+
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=3),
299+
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=2),
300+
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=3),
297301
],
298302
key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"],
299303
)
@@ -410,6 +414,7 @@ def _tq4_sdpa_fwd_kernel_m64(
410414
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2),
411415
triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2),
412416
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2),
417+
triton.Config({"BLOCK_M": 32, "BLOCK_N": 16}, num_warps=4, num_stages=3),
413418
],
414419
key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"],
415420
)

0 commit comments

Comments
 (0)