Skip to content

Commit 5395f20

Browse files
authored
[MLX][Gemma4] Add turbo quant support (pytorch#19866)
Add TurboQuant TQ4 KV cache to the MLX backend, exposed on gemma4_31b via --turboquant. Compresses full-attention KV cache from bf16 to a 4-bit codebook + per-vector norms, letting Gemma 4 31B-IT scale to very long contexts. Sliding-window layers are unchanged. What's in the PR New cache subclass: - backends/mlx/llm/turboquant_cache.py: MLXTurboQuantKVCache, a drop-in subclass of TurboQuantKVCache. Three custom ops + Metal kernels: - mlx::tq4_compress (model_ops/tq4_compress.py): bucketize + cast(uint8) + nibble-pack in one kernel. - mlx::tq_norm (model_ops/tq_norm.py): L2 norm with simd_sum cross-lane reduction in fp32 registers; bf16 in / bf16 out. - mlx::tq_dequant (model_ops/tq_dequant.py): unpack + centroid gather + multiply-by-norm in one kernel. Per-op tests: - test_tq4_compress.py, test_tq_norm.py, test_tq_dequant.py Wiring: - examples/models/gemma4_31b/mlx_source_transformations.py: - examples/models/gemma4_31b/export.py: --turboquant CLI flag - examples/models/gemma4_31b/README.md: TurboQuant subsection. Perf on M4 Max 64GB Ram: ``` 2K prompt: bf16 cache: prefill 189.7 tok/s, decode 17.4 tok/s TurboQuant cache: prefill 187.7 tok/s, decode 16.9 tok/s 8K prompt: bf16 cache: prefill 170.0 tok/s, decode 17.1 tok/s TurboQuant cache: prefill 166.0 tok/s, decode 11.9 tok/s ``` For TQ, max context length is set to 64K. On bf16 cache, max context length is 10K. TODO: why does decode slow more for TQ than bf16?
1 parent 0e6b67e commit 5395f20

15 files changed

Lines changed: 1961 additions & 28 deletions

File tree

.github/workflows/mlx.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,18 @@ jobs:
8080
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run -v
8181
echo "::endgroup::"
8282
83+
echo "::group::Run tq_norm op tests"
84+
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_norm run -v
85+
echo "::endgroup::"
86+
87+
echo "::group::Run tq4_compress op tests"
88+
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq4_compress run -v
89+
echo "::endgroup::"
90+
91+
echo "::group::Run tq_dequant op tests"
92+
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_dequant run -v
93+
echo "::endgroup::"
94+
8395
test-mlx-qwen35-moe:
8496
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
8597
with:

backends/mlx/builder/op_helpers.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
if TYPE_CHECKING:
1919
from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder
20+
from executorch.backends.mlx.serialization.mlx_graph_schema import IntOrVid
2021

2122
# When True, always serialize the biases tensor for quantized ops.
2223
# When False, use init-time computation when zero_point is all zeros,
@@ -173,6 +174,117 @@ def emit_lifted_constant(P: "MLXProgramBuilder", value, dtype: torch.dtype) -> S
173174
return slot
174175

175176

177+
def emit_shape(
178+
P: "MLXProgramBuilder",
179+
node: Node,
180+
slot: Slot,
181+
*,
182+
end_dim: "Optional[int]" = None,
183+
) -> "list[IntOrVid]":
184+
"""Return the shape of ``node`` as a list of ``IntOrVid``.
185+
186+
Each static dim becomes a literal ``IntOrVid``; each dynamic dim
187+
emits a ``SymSizeNode`` against ``slot`` and is wrapped via
188+
``P.to_int_or_vid``.
189+
190+
Args:
191+
P: program builder.
192+
node: FX node whose shape to walk (must have ``meta['val']``).
193+
slot: slot corresponding to ``node`` (used as the
194+
``SymSize`` source for any dynamic dim).
195+
end_dim: stop index (exclusive). ``None`` means the full ndim.
196+
Negative values index from the end (e.g. ``-1`` is "all
197+
leading dims, drop the last").
198+
199+
Returns:
200+
``list[IntOrVid]`` of length ``end_dim`` (after normalization).
201+
"""
202+
from executorch.backends.mlx.serialization.mlx_graph_schema import (
203+
IntOrVid,
204+
SymSizeNode,
205+
)
206+
207+
shape = node.meta["val"].shape
208+
ndim = len(shape)
209+
if end_dim is None:
210+
end_dim = ndim
211+
elif end_dim < 0:
212+
end_dim += ndim
213+
214+
out: "list[IntOrVid]" = []
215+
for dim_idx in range(end_dim):
216+
s = shape[dim_idx]
217+
if isinstance(s, int):
218+
out.append(IntOrVid.from_literal(int(s)))
219+
else:
220+
_, d_val = P.make_tmp_value_slot()
221+
P.emit(
222+
SymSizeNode(
223+
a=P.slot_to_tid(slot),
224+
dim=dim_idx,
225+
out=P.slot_to_vid(d_val),
226+
)
227+
)
228+
out.append(P.to_int_or_vid(d_val))
229+
return out
230+
231+
232+
def emit_product(
233+
P: "MLXProgramBuilder",
234+
dims: "list[IntOrVid]",
235+
) -> "IntOrVid":
236+
"""Multiplicative reduction over a list of ``IntOrVid`` values.
237+
238+
Folds all literal entries AOT into a single static product, then
239+
emits ``MultiplyIntNode`` only for the dynamic entries (and one
240+
final node combining the static product with the dynamic accumulator
241+
when both contribute).
242+
243+
Args:
244+
P: program builder.
245+
dims: list of ``IntOrVid``. May be empty (returns
246+
``IntOrVid.from_literal(1)``), all literals, or a mix.
247+
248+
Returns:
249+
An ``IntOrVid`` representing the product. Always literal when
250+
every entry is literal (or ``dims`` is empty).
251+
"""
252+
from executorch.backends.mlx.serialization.mlx_graph_schema import (
253+
IntOrVid,
254+
MultiplyIntNode,
255+
)
256+
257+
static_product = 1
258+
dynamic_dims: "list[IntOrVid]" = []
259+
for d in dims:
260+
if d.is_vid:
261+
dynamic_dims.append(d)
262+
else:
263+
static_product *= d.literal
264+
265+
if not dynamic_dims:
266+
return IntOrVid.from_literal(static_product)
267+
268+
acc = dynamic_dims[0]
269+
for d in dynamic_dims[1:]:
270+
_, acc_val = P.make_tmp_value_slot()
271+
P.emit(MultiplyIntNode(a=acc, b=d, out=P.slot_to_vid(acc_val)))
272+
acc = P.to_int_or_vid(acc_val)
273+
274+
if static_product == 1:
275+
return acc
276+
277+
_, final_val = P.make_tmp_value_slot()
278+
P.emit(
279+
MultiplyIntNode(
280+
a=IntOrVid.from_literal(static_product),
281+
b=acc,
282+
out=P.slot_to_vid(final_val),
283+
)
284+
)
285+
return P.to_int_or_vid(final_val)
286+
287+
176288
def emit_quantized_biases(
177289
P: "MLXProgramBuilder",
178290
zero_point_key: str,
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
#!/usr/bin/env python3
2+
#
3+
# Copyright (c) Meta Platforms, Inc. and affiliates.
4+
# All rights reserved.
5+
#
6+
# This source code is licensed under the BSD-style license found in the
7+
# LICENSE file in the root directory of this source tree.
8+
9+
"""
10+
TurboQuant TQ4 KV cache for the MLX backend.
11+
12+
Subclass of the backend-agnostic
13+
``extension/llm/modules/turboquant/kv_cache.py::TurboQuantKVCache``.
14+
15+
The cache stores K and V in **rotated space** (post-multiplied by R^T)
16+
as nibble-packed uint8 codebook indices plus per-vector bf16 norms.
17+
SDPA runs in rotated space and undoes the rotation on the output side
18+
(both Q and output rotations are ``T_q × D²``, much smaller than
19+
applying the inverse rotation to K/V which would be ``T_kv × D²``).
20+
21+
Reference:
22+
TurboQuant: Online Vector Quantization with Near-optimal
23+
Distortion Rate. arXiv:2504.19874 (ICLR 2026).
24+
"""
25+
26+
from typing import Optional, Tuple
27+
28+
# Register the MLX custom ops used by this cache.
29+
import executorch.backends.mlx.custom_ops # noqa: F401 mlx::custom_sdpa, mlx::kv_cache_update
30+
import executorch.backends.mlx.model_ops.tq4_compress # noqa: F401 mlx::tq4_compress
31+
import executorch.backends.mlx.model_ops.tq_dequant # noqa: F401 mlx::tq_dequant
32+
import executorch.backends.mlx.model_ops.tq_norm # noqa: F401 mlx::tq_norm
33+
34+
import torch
35+
36+
from executorch.extension.llm.modules.turboquant.kv_cache import (
37+
TurboQuantKVCache as _SharedTurboQuantKVCache,
38+
)
39+
40+
41+
class TurboQuantKVCache(_SharedTurboQuantKVCache):
42+
"""
43+
TurboQuant TQ4 KV cache, MLX-backend variant.
44+
45+
Drop-in replacement for ``backends/mlx/llm/cache.py::KVCache``.
46+
47+
Args:
48+
max_batch_size: Must be 1 (TQ4 is batch=1 only).
49+
max_context_length: Maximum sequence length.
50+
n_heads: Number of KV heads.
51+
head_dim: Per-head dimension. Must be even and a multiple of 64.
52+
enable_dynamic_shape: Accepted for interface parity; ignored.
53+
dtype: Compute dtype (bf16). Used for pre-cast buffers.
54+
bits: Quantization bits (must be 4).
55+
seed: RNG seed for the orthogonal rotation matrix.
56+
"""
57+
58+
def __init__(
59+
self,
60+
max_batch_size: int,
61+
max_context_length: int,
62+
n_heads: int,
63+
head_dim: int,
64+
enable_dynamic_shape: bool,
65+
dtype: torch.dtype = torch.bfloat16,
66+
bits: int = 4,
67+
seed: int = 42,
68+
):
69+
if max_batch_size != 1:
70+
raise ValueError(
71+
f"TurboQuantKVCache only supports max_batch_size=1, "
72+
f"got {max_batch_size}"
73+
)
74+
if bits != 4:
75+
raise ValueError(
76+
f"TurboQuantKVCache only supports bits=4 "
77+
f"(16-entry codebook), got bits={bits}"
78+
)
79+
# MLX-backend Metal kernels need ``head_dim % 64 == 0``: ``tq_norm``
80+
# uses 32 SIMD lanes (so D must be a multiple of 32), and
81+
# ``tq_dequant`` packs 2 dims per byte across 32 lanes (so D must
82+
# be a multiple of 64). Take the stricter constraint here.
83+
if head_dim % 64 != 0:
84+
raise ValueError(
85+
f"TurboQuantKVCache requires head_dim to be "
86+
f"a multiple of 64 (Metal SIMD + 4-bit pack constraint), "
87+
f"got {head_dim}"
88+
)
89+
super().__init__(
90+
n_heads=n_heads,
91+
head_dim=head_dim,
92+
max_seq_len=max_context_length,
93+
bits=bits,
94+
seed=seed,
95+
)
96+
self.max_batch_size = max_batch_size
97+
self.max_context_length = max_context_length
98+
self.enable_dynamic_shape = enable_dynamic_shape
99+
100+
# Replace parent's fp32 ``rotation`` and ``centroids`` buffers
101+
# with compute-dtype versions in-place. Avoids a per-call
102+
# ``_to_copy`` cast in the lowered graph at every use site.
103+
# Parent's ``_decompress`` (testing-only) is the sole consumer
104+
# of these as fp32 and is not called at runtime.
105+
self.register_buffer(
106+
"rotation",
107+
self.rotation.to(dtype).contiguous(),
108+
persistent=False,
109+
)
110+
self.register_buffer(
111+
"centroids",
112+
self.centroids.to(dtype).contiguous(),
113+
persistent=False,
114+
)
115+
# Pre-cast eps for the divide-by-zero guard in _compress.
116+
self.register_buffer(
117+
"norm_eps",
118+
torch.tensor(1e-10, dtype=dtype),
119+
persistent=False,
120+
)
121+
122+
def _compress(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
123+
"""Compress ``(1, H, T, D)`` → packed ``(1, H, T, D//2)`` u8 +
124+
norms ``(1, H, T, 1)`` bf16.
125+
126+
The L2-norm reduction uses ``mlx::tq_norm`` (one Metal kernel
127+
with fp32 sum-of-squares in registers via ``simd_sum``); the
128+
bucketize + nibble-pack tail uses ``mlx::tq4_compress`` (one
129+
Metal kernel for both steps).
130+
"""
131+
orig_shape = x.shape
132+
flat = x.reshape(-1, self.head_dim)
133+
134+
norms = torch.ops.mlx.tq_norm(flat)
135+
normalized = flat / (norms + self.norm_eps)
136+
rotated = normalized @ self.rotation_T
137+
packed = torch.ops.mlx.tq4_compress(rotated, self.boundaries)
138+
139+
return (
140+
packed.reshape(*orig_shape[:-1], self.half_dim),
141+
norms.reshape(*orig_shape[:-1], 1),
142+
)
143+
144+
def update(
145+
self,
146+
input_pos,
147+
k_val: torch.Tensor,
148+
v_val: torch.Tensor,
149+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
150+
"""Compress + write K/V at ``input_pos``, return the full
151+
compressed cache buffers.
152+
153+
Accepts ``input_pos`` as either a ``(T,)`` LongTensor of
154+
positions or a Python int / SymInt ``start_pos``. Writes go
155+
through ``mlx::kv_cache_update`` (matching the non-TQ
156+
``MLXKVCache`` path) which lowers to a tighter in-place
157+
scatter than ``index_copy_`` would.
158+
"""
159+
if isinstance(input_pos, torch.Tensor):
160+
start_pos = input_pos[0].item()
161+
seq_len = k_val.size(2)
162+
torch._check(seq_len == v_val.size(2))
163+
torch._check(start_pos >= 0)
164+
torch._check(start_pos + seq_len <= self.max_context_length)
165+
else:
166+
start_pos = input_pos
167+
168+
k_packed, k_norms = self._compress(k_val)
169+
v_packed, v_norms = self._compress(v_val)
170+
171+
torch.ops.mlx.kv_cache_update(self.k_packed, k_packed, start_pos)
172+
torch.ops.mlx.kv_cache_update(self.k_norms, k_norms, start_pos)
173+
torch.ops.mlx.kv_cache_update(self.v_packed, v_packed, start_pos)
174+
torch.ops.mlx.kv_cache_update(self.v_norms, v_norms, start_pos)
175+
176+
# Slices on the return create new graph nodes so the same node
177+
# is not both BUFFER_MUTATION and USER_OUTPUT.
178+
return (
179+
self.k_packed[:, :, :, :],
180+
self.k_norms[:, :, :, :],
181+
self.v_packed[:, :, :, :],
182+
self.v_norms[:, :, :, :],
183+
)
184+
185+
# forward() is inherited from the parent (delegates to update).
186+
187+
def sdpa(
188+
self,
189+
query: torch.Tensor,
190+
start_pos,
191+
scale: Optional[float] = None,
192+
) -> torch.Tensor:
193+
"""SDPA over the compressed cache.
194+
195+
Runs attention in rotated space:
196+
1. Q_rot = Q @ R^T (T_q x D^2)
197+
2. K_rot, V_rot = tq_dequant(...) (rotated-space K/V)
198+
3. out_rot = custom_sdpa(Q_rot, K_rot, V_rot, ...)
199+
4. out = out_rot @ R (T_q x D^2)
200+
201+
Since R is orthogonal, score = (Q·R^T)·(K·R^T)^T = Q·K^T, so
202+
attention is invariant under matched rotation of Q and K. The
203+
``T_kv x D^2`` inverse-rotation matmul on K/V is replaced with
204+
two ``T_q x D^2`` matmuls (Q and output).
205+
206+
Args:
207+
query: ``(B, H_q, T_q, D)`` bf16.
208+
start_pos: int or SymInt — absolute position of the first
209+
query token.
210+
scale: 1/sqrt(D) if None.
211+
212+
Returns:
213+
``(B, H_q, T_q, D)`` bf16 attention output, in original
214+
(un-rotated) space.
215+
"""
216+
seq_len = query.size(2)
217+
end_pos = start_pos + seq_len
218+
torch._check(start_pos >= 0)
219+
torch._check(end_pos <= self.max_context_length)
220+
221+
q_rot = query @ self.rotation_T
222+
223+
k_packed_live = self.k_packed[:, :, :end_pos, :]
224+
k_norms_live = self.k_norms[:, :, :end_pos, :]
225+
v_packed_live = self.v_packed[:, :, :end_pos, :]
226+
v_norms_live = self.v_norms[:, :, :end_pos, :]
227+
228+
# TODO: optimize with a fused dequant + SDPA
229+
k_rot = torch.ops.mlx.tq_dequant(k_packed_live, k_norms_live, self.centroids)
230+
v_rot = torch.ops.mlx.tq_dequant(v_packed_live, v_norms_live, self.centroids)
231+
232+
out_rot = torch.ops.mlx.custom_sdpa(
233+
q_rot,
234+
k_rot,
235+
v_rot,
236+
start_pos,
237+
None, # attn_mask
238+
0.0, # dropout_p
239+
True, # is_causal
240+
scale,
241+
)
242+
243+
return out_rot @ self.rotation

0 commit comments

Comments
 (0)