Skip to content

Commit 47af14f

Browse files
authored
[gemma4_31b][cuda] Export Gemma4-31B @128k on 5090 (#20480)
Current gemma4-31b can not be successfully exported on consumer gpu like 5090 with three reasons: 1. During int4_dispatch we need to dequant whole matmul weight to bf16 for prefill in one step for lm_head, leading to weight duplcation; 2. When lowering to AOTI-CUDA, we moved the whole model, including kv cache, onto gpu. With context length increased, the gpu memory consumption will also be increased dramatically. 3. No autotune config for kernels like sdpa work for consumer gpu like 5090. Three CUDA-export memory optimizations, all gated behind the existing low_memory_mode compile spec (no impact on other models or on runtime): - int4_dispatch: chunk the inline _dequant_matmul along N for vocab-sized weights, gated behind a low-memory flag with an N>65536 threshold so only the lm_head crosses it. Avoids transiently materializing the full ~10 GiB bf16 lm_head during AOTI autotune / cpp_wrapper. The prefill MLP path is untouched -> zero runtime impact. - cuda_backend / aoti_backend: skip occupying the GPU with the KV-cache buffers during AOTI compile. A new move_program_to_device hook places KV constants on the target device but immediately frees their storage (resize_(0)), so the fake-tensor device check passes while no real KV bytes sit on the GPU during autotune. The emptied buffers are re-synthesized as zeros at the _unlift_graph clone and at serialization, and excluded from constant dedup (resize_(0) gives every KV data_ptr 0, which would otherwise collapse same-shape caches across layers). All gated behind low_memory_mode. - tq4_sdpa: add BLOCK_N=16 (and a BLOCK_M=32) autotune config. The superset is kept for big-shared-memory GPUs (A100/H100); the Triton autotuner auto-prunes configs that exceed a GPU's shared memory (OutOfResources -> inf), so the same config list also works on the 5090 (Blackwell, ~101 KB SMEM) where the previous smallest config did not fit. Full Gemma4-31B on 128k TQ export: peak 28.0 GiB, runtime output correct ("...Paris.").
1 parent c0643f5 commit 47af14f

4 files changed

Lines changed: 247 additions & 19 deletions

File tree

backends/aoti/aoti_backend.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,21 @@ def codesign_so(cls, so_path: str, compile_specs: List[CompileSpec]) -> None:
112112
"""
113113
return
114114

115+
@classmethod
116+
def move_program_to_device(
117+
cls,
118+
edge_program: ExportedProgram,
119+
device: str,
120+
compile_specs: List[CompileSpec],
121+
) -> ExportedProgram:
122+
"""Move the exported program to the target device for compilation.
123+
124+
Default implementation moves everything (params, buffers, constants) via
125+
``move_to_device_pass``. Concrete backends may override to keep large
126+
non-parameter tensors off the device during a low-memory export.
127+
"""
128+
return move_to_device_pass(edge_program, device)
129+
115130
@classmethod
116131
def release_moved_tensors(
117132
cls,
@@ -196,9 +211,13 @@ def preprocess(
196211
decomposition_table = cls.get_decomposition_table()
197212
options = cls.get_aoti_compile_options(compile_specs)
198213

199-
# Move the edge_program to the target device
200-
device_edge_program = move_to_device_pass(
201-
edge_program, device_name if device_name != "metal" else "mps"
214+
# Move the edge_program to the target device. Routed through a hook so
215+
# backends can keep large non-parameter tensors (e.g. KV-cache buffers)
216+
# off the device during a low-memory export.
217+
device_edge_program = cls.move_program_to_device(
218+
edge_program,
219+
device_name if device_name != "metal" else "mps",
220+
compile_specs,
202221
)
203222

204223
# Replace view_copy with view

backends/cuda/cuda_backend.py

Lines changed: 184 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,29 +61,81 @@ def _is_cpu_clone_active() -> bool:
6161
return getattr(_CPU_CLONE_GUARD, "active", False)
6262

6363

64+
def _full_zeros_preserving_strides(x: torch.Tensor, device) -> torch.Tensor:
65+
"""Allocate a zero-filled tensor matching ``x``'s size/stride/dtype on ``device``.
66+
67+
Used to re-synthesize KV-cache buffers whose storage was freed (``resize_(0)``)
68+
during the low-memory device move. KV content is all zeros, so this exactly
69+
reproduces the buffer for both the lifted graph value and serialization.
70+
"""
71+
needed = 1
72+
for size, stride in zip(x.size(), x.stride()):
73+
needed += (size - 1) * stride
74+
buf = torch.zeros(int(needed), dtype=x.dtype, device=device)
75+
return torch.as_strided(buf, x.size(), x.stride())
76+
77+
78+
def _is_emptied(x) -> bool:
79+
return (
80+
isinstance(x, torch.Tensor)
81+
and x.numel() > 0
82+
and x.untyped_storage().nbytes() == 0
83+
)
84+
85+
6486
@contextlib.contextmanager
6587
def _compile_time_cpu_clones(target_device: torch.device):
6688
"""Force AOTI's mutated-buffer clones onto CPU while preserving the
6789
serialized constants' target device."""
68-
from torch._inductor import compile_fx as _cfx
90+
from torch._inductor import compile_fx as _cfx, graph as _graph
6991
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu as _Cpp
92+
from torch._inductor.graph import GraphLowering as _GL
7093

7194
orig_clone = _cfx.clone_preserve_strides
7295
orig_codegen_device = _Cpp.codegen_device
96+
orig_get_const = _GL.get_original_value_of_constant
97+
orig_is_same = _graph.is_same_tensor
98+
99+
def _is_same_skip_emptied(data, value):
100+
# KV buffers freed via resize_(0) all have data_ptr 0, so the stock
101+
# is_same_tensor would treat every same-shape KV constant as a duplicate
102+
# and collapse the 60 layers' caches into one — the runtime needs each
103+
# FQN's own buffer, so the collapsed ones load uninitialized garbage.
104+
# Never dedup an emptied tensor.
105+
if _is_emptied(data) or _is_emptied(value):
106+
return False
107+
return orig_is_same(data, value)
73108

74109
def _cpu_clone_preserve_strides(x: torch.Tensor) -> torch.Tensor:
75-
# `clone_preserve_strides` is shared by `_unlift_graph` (clones
76-
# lifted buffers — can be safely kept on CPU) and by autotuning code
77-
# in `triton_heuristics.py` (clones for benchmark — must stay on
78-
# GPU for Triton). Discriminate by caller frame so we only force
79-
# CPU clones for the buffer-lifting path.
110+
# `clone_preserve_strides` is shared by `_unlift_graph` (clones lifted
111+
# buffers — can be safely kept on CPU) and by autotuning code in
112+
# `triton_heuristics.py` (clones for benchmark — must stay on GPU for
113+
# Triton). Discriminate by caller frame so we only force CPU clones for
114+
# the buffer-lifting path.
80115
import sys
81116

82117
caller = sys._getframe(1).f_code.co_name
83118
if caller == "_unlift_graph":
119+
# KV-cache buffers are emptied (storage resize_(0)) by the low-memory
120+
# device move so they never occupy GPU memory during compile. Their
121+
# content is all zeros, so re-synthesize zeros (on CPU, strides
122+
# preserved) instead of cloning the now-empty storage.
123+
if _is_emptied(x):
124+
return _full_zeros_preserving_strides(x, "cpu")
84125
return orig_clone(x).cpu()
85126
return orig_clone(x)
86127

128+
def _get_const_synthesize_zeros(self, name):
129+
# AOTI serializes each constant via get_original_value_of_constant ->
130+
# _to_bytes. For KV buffers we freed with resize_(0) this would otherwise
131+
# fall back to the empty-storage constant and write 0 bytes, producing a
132+
# .ptd with an uninitialized cache. Re-synthesize the zeros so the blob
133+
# holds a correctly-zeroed KV cache.
134+
value = orig_get_const(self, name)
135+
if _is_emptied(value):
136+
return _full_zeros_preserving_strides(value, "cpu")
137+
return value
138+
87139
def _codegen_device_target_aware(self, device):
88140
# Translate accidental CPU device strings back to the model target
89141
# device only when a constant we forced to CPU is being serialized.
@@ -99,6 +151,8 @@ def _codegen_device_target_aware(self, device):
99151

100152
_cfx.clone_preserve_strides = _cpu_clone_preserve_strides
101153
_Cpp.codegen_device = _codegen_device_target_aware
154+
_GL.get_original_value_of_constant = _get_const_synthesize_zeros
155+
_graph.is_same_tensor = _is_same_skip_emptied
102156
prev_active = getattr(_CPU_CLONE_GUARD, "active", False)
103157
_CPU_CLONE_GUARD.active = True
104158
try:
@@ -107,6 +161,107 @@ def _codegen_device_target_aware(self, device):
107161
_CPU_CLONE_GUARD.active = prev_active
108162
_cfx.clone_preserve_strides = orig_clone
109163
_Cpp.codegen_device = orig_codegen_device
164+
_GL.get_original_value_of_constant = orig_get_const
165+
_graph.is_same_tensor = orig_is_same
166+
167+
168+
def _is_kv_buffer(name, v) -> bool:
169+
"""True only for an actual KV-cache *content* buffer that is safe to free.
170+
171+
The low-memory path (``_move_to_device_resize_kv``) frees every buffer this
172+
matches and re-synthesizes it as ZEROS in both the lifted graph and the
173+
serialized ``.ptd`` (see ``_full_zeros_preserving_strides`` /
174+
``_get_const_synthesize_zeros``). That is only valid for genuine KV *content*,
175+
which is all-zeros at export time (caches start empty).
176+
177+
It must NOT match the non-zero constants that some KV-cache modules register
178+
alongside the cache — e.g. TurboQuant registers its codebook/rotation
179+
(``centroids``/``boundaries``/``rotation``/``rotation_T``) as buffers on the
180+
``kv_cache`` module, so their FQNs also contain ``kv_cache``. Freeing+zeroing
181+
those silently corrupts the serialized model (TQ4 dequant -> 0 -> garbage).
182+
Gate on the buffer actually being all-zeros so only empty KV content is freed;
183+
this is robust to any future constant name (a non-zero buffer is never freed).
184+
"""
185+
if not isinstance(v, torch.Tensor) or isinstance(v, torch.nn.Parameter):
186+
return False
187+
if "kv_cache" not in name or v.numel() == 0 or v.is_meta:
188+
return False
189+
# Only the genuinely all-zero KV content may be freed + re-zeroed; non-zero
190+
# constants (TurboQuant centroids/rotation/...) must be preserved as-is.
191+
return bool(torch.count_nonzero(v) == 0)
192+
193+
194+
def _empty_strided_on_device(v, location):
195+
"""A device tensor with v's shape/stride/dtype but zero (freed) storage."""
196+
t = torch.empty_strided(v.shape, v.stride(), dtype=v.dtype, device=location)
197+
t.untyped_storage().resize_(0) # free bytes, keep device + shape/stride
198+
return t
199+
200+
201+
def _move_graph_nodes_to_device(graph_module, location):
202+
"""Point node device kwargs / aten.to.device targets / meta vals at location."""
203+
import torch.utils._pytree as pytree
204+
205+
def _to_loc(v):
206+
return v.to(location) if isinstance(v, torch.Tensor) else v
207+
208+
for m in graph_module.modules():
209+
if not isinstance(m, torch.fx.GraphModule):
210+
continue
211+
for node in m.graph.nodes:
212+
if "device" in node.kwargs:
213+
node.kwargs = {**node.kwargs, "device": location}
214+
if node.op == "call_function" and node.target is torch.ops.aten.to.device:
215+
args = list(node.args)
216+
args[1] = location
217+
node.args = tuple(args)
218+
node.meta["val"] = pytree.tree_map(_to_loc, node.meta.get("val"))
219+
220+
221+
def _move_to_device_resize_kv(ep, location):
222+
"""``move_to_device_pass`` variant that frees KV-cache storage on-device.
223+
224+
Mirrors ``torch.export.passes.move_to_device_pass`` exactly, except KV-cache
225+
buffers (FQN contains ``kv_cache``) are placed on ``location`` but with their
226+
storage immediately freed via ``resize_(0)``. This keeps ``device ==
227+
location`` — so the fake-tensor device check on the ``index_copy`` cache
228+
update passes (``self`` and ``values`` both on cuda) — while no real KV bytes
229+
occupy the device during the AOTI compile. KV content is all zeros, so the
230+
emptied tensors are re-synthesized as zeros at the ``_unlift_graph`` clone
231+
(see ``_compile_time_cpu_clones``), which is reused as both the lifted initial
232+
value and the serialized ``.ptd`` constant. The empty/free is interleaved per
233+
tensor so the transient device peak is a single KV buffer, not the whole cache.
234+
Only ``kv_cache`` tensors are emptied (they are the lone large zero-buffers);
235+
every other tensor is moved normally so non-zero content is never lost.
236+
"""
237+
import torch.utils._pytree as pytree
238+
239+
for k, v in ep.state_dict.items():
240+
if isinstance(v, torch.nn.Parameter):
241+
ep._state_dict[k] = torch.nn.Parameter(v.to(location), v.requires_grad)
242+
elif _is_kv_buffer(k, v):
243+
ep._state_dict[k] = _empty_strided_on_device(v, location)
244+
else:
245+
ep._state_dict[k] = v.to(location)
246+
247+
for k, v in ep.constants.items():
248+
if isinstance(v, torch.Tensor):
249+
ep._constants[k] = (
250+
_empty_strided_on_device(v, location)
251+
if _is_kv_buffer(k, v)
252+
else v.to(location)
253+
)
254+
255+
if ep.example_inputs is not None:
256+
args, kwargs = ep.example_inputs
257+
ep._example_inputs = (
258+
pytree.tree_map_only(torch.Tensor, lambda t: t.to(location), args),
259+
pytree.tree_map_only(torch.Tensor, lambda t: t.to(location), kwargs),
260+
)
261+
262+
_move_graph_nodes_to_device(ep.graph_module, location)
263+
ep.validate()
264+
return ep
110265

111266

112267
@final
@@ -424,6 +579,29 @@ def _is_low_memory_mode(compile_specs: List[CompileSpec]) -> bool:
424579
return spec.value.decode("utf-8").upper() == "ON"
425580
return False
426581

582+
@classmethod
583+
def move_program_to_device(
584+
cls,
585+
edge_program,
586+
device: str,
587+
compile_specs: List[CompileSpec],
588+
):
589+
"""Move the program to ``device`` for AOTI compile.
590+
591+
On a low-memory export (``low_memory_mode="ON"``) the KV-cache buffers —
592+
which can be 10+ GiB at long context — are placed on-device but with their
593+
storage freed (``resize_(0)``), so they never occupy device memory during
594+
the autotune / cpp_wrapper compile while still satisfying the device-match
595+
check on the cache update. They are re-synthesized as zeros for the lifted
596+
graph and the serialized blob. This activates automatically with low-memory
597+
mode. Other (non-low-memory) exports use the stock pass.
598+
"""
599+
from torch.export.passes import move_to_device_pass
600+
601+
if not cls._is_low_memory_mode(compile_specs):
602+
return move_to_device_pass(edge_program, device)
603+
return _move_to_device_resize_kv(edge_program, device)
604+
427605
@classmethod
428606
def release_moved_tensors(
429607
cls,

backends/cuda/quantize_op_dispatch/int4_dispatch.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,28 +60,54 @@ def _cuda(self, qdata, scale, zero, group_size):
6060
return _dequant_matmul(self, qdata, scale, zero, group_size)
6161

6262

63+
# Chunked dequant for the export GPU budget. The lm_head dequant (N = vocab_size,
64+
# e.g. 262144) runs through the int4_plain_mm custom op (M=1); AOTI executes that
65+
# op's CUDA impl during autotune / cpp_wrapper codegen, where it transiently holds
66+
# ~5 full-size bf16 temporaries (low/high/data/data-z/w_deq) — ~10 GiB for a
67+
# 262144-row weight even though the final w_deq is only ~2.6 GiB. Chunking along N
68+
# caps that at ~chunk rows. It is numerically identical (F.linear output rows are
69+
# independent), and because only the lm_head (custom-op) path crosses the N
70+
# threshold — never the M>4 prefill inline path — it never enters the runtime
71+
# graph: ZERO runtime / accuracy impact. Applied unconditionally to any weight
72+
# whose row count exceeds the threshold.
73+
_DEQUANT_N_THRESHOLD = 65536
74+
_DEQUANT_N_CHUNK = 32768
75+
76+
6377
def _dequant_matmul(x, qdata, scale, zero, group_size):
6478
"""Dequant INT4 weights to input dtype and call F.linear.
6579
6680
scale/zero are in the coalesced [N, n_groups] layout (baked into the
6781
weight constant at pack time), aligned row-for-row with qdata's [N, *].
82+
83+
Large weights (N > threshold, i.e. the lm_head) are chunked along N to bound
84+
the dequant intermediate (see note above); smaller weights take the original
85+
single-shot dequant.
6886
"""
6987
N, K_half = qdata.shape
7088
K = K_half * 2
7189
n_groups = K // group_size
7290
gs_half = group_size // 2
7391
dtype = x.dtype
7492

75-
p = qdata.to(torch.uint8).reshape(N, n_groups, gs_half)
76-
low = (p & 0x0F).to(dtype)
77-
high = ((p >> 4) & 0x0F).to(dtype)
78-
data = torch.stack([low, high], dim=-1).reshape(N, n_groups, group_size)
79-
80-
s = scale.to(dtype).unsqueeze(-1)
81-
z = zero.to(dtype).unsqueeze(-1)
82-
w_deq = ((data - z) * s).reshape(N, K)
83-
84-
return F.linear(x, w_deq)
93+
def _dq(qd, sc, ze, rows):
94+
p = qd.to(torch.uint8).reshape(rows, n_groups, gs_half)
95+
low = (p & 0x0F).to(dtype)
96+
high = ((p >> 4) & 0x0F).to(dtype)
97+
data = torch.stack([low, high], dim=-1).reshape(rows, n_groups, group_size)
98+
s = sc.to(dtype).unsqueeze(-1)
99+
z = ze.to(dtype).unsqueeze(-1)
100+
w_deq = ((data - z) * s).reshape(rows, K)
101+
return F.linear(x, w_deq)
102+
103+
if N <= _DEQUANT_N_THRESHOLD:
104+
return _dq(qdata, scale, zero, N)
105+
106+
outs = []
107+
for i in range(0, N, _DEQUANT_N_CHUNK):
108+
j = min(i + _DEQUANT_N_CHUNK, N)
109+
outs.append(_dq(qdata[i:j], scale[i:j], zero[i:j], j - i))
110+
return torch.cat(outs, dim=-1)
85111

86112

87113
# ---------------------------------------------------------------------------

backends/cuda/triton/kernels/tq4_sdpa.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,10 @@ def _tq4_sdpa_fwd_kernel_body(
294294
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2),
295295
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3),
296296
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2),
297+
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=2),
298+
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=4, num_stages=3),
299+
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=2),
300+
triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_warps=8, num_stages=3),
297301
],
298302
key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"],
299303
)
@@ -410,6 +414,7 @@ def _tq4_sdpa_fwd_kernel_m64(
410414
triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2),
411415
triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2),
412416
triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2),
417+
triton.Config({"BLOCK_M": 32, "BLOCK_N": 16}, num_warps=4, num_stages=3),
413418
],
414419
key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"],
415420
)

0 commit comments

Comments
 (0)