Skip to content

Commit a9d5674

Browse files
authored
[MLX][Gemma4] Introduce Q6K kernels (pytorch#20004)
### Summary Adds fused **GGUF Q6_K** custom Metal kernels to the MLX backend and wires them into the Gemma 4 31B GGUF export path, so Q6_K-quantized linear and embedding weights run directly from llama.cpp's packed block layout instead of taking the slow non-fused dequantize path. Also shrinks the exported `.pte` (and its in-memory footprint) by de-duplicating repeated kernel source blobs. **New custom kernel ops** (`backends/mlx/custom_kernel_ops/gguf/`) The `gguf/` package is organized as format routers over per-format implementations, so new GGUF formats (e.g. Q4_K) can be added without touching the op definitions: - `gguf/linear.py` / `gguf/embedding.py`: thin **format routers** — each owns the op identity (`mlx::gguf_linear` / `mlx::gguf_embedding`: custom op, fake, and lowering registration) and dispatches on the `format` arg. Only `"q6k"` is supported today; other formats raise `NotImplementedError`. - `gguf/q6k/common.py`: shared Q6_K primitives — constants, the pure-torch `dequantize_q6_k` reference, and the Metal header (`block_q6_K` struct + dequant helpers). Lightweight (no builder import), re-exported from `gguf/q6k/__init__.py`. - `gguf/q6k/linear.py`: `out = x @ dequant(weight)^T (+bias)` against a raw GGUF `block_q6_K` blob (no repacking). Emits two Metal kernels — a fused mat-vec for decode (`M==1`, ported from llama.cpp `kernel_mul_mv_q6_K_f32_impl`) and a tiled simdgroup mat-mat for prefill (`M>1`). For dynamic/symbolic `M`, both chains are emitted and selected at runtime via a new `IfNode`. - `gguf/q6k/embedding.py`: gather counterpart that dequantizes Q6_K rows directly. **Runtime / schema** New `IfNode` in `schema.fbs` (runtime conditional selecting one of two instruction chains on an integer condition) plus `exec_if` dispatch in `MLXInterpreter.h`. **Serialization: smaller `.pte` + lower load-time RAM** - Serializer de-duplicates identical strings into a single FlatBuffer offset (shared-string emission in the generated serializers / `generate.py` / `mlx_graph_serialize.py`). The big repeated `MetalKernelNode` source/header blobs are now written once. On Gemma 4 31B this cut the MLX graph metadata from ~1.23 MiB to ~0.47 MiB (~62%). - Loader interns those shared blobs into one `std::shared_ptr<const std::string>` keyed by the FlatBuffer string pointer (`StringPool` in `MLXLoader.{h,cpp}.tmpl`; `MLXInterpreter.h` derefs the handle), so a newly-produced `.pte` also uses less RAM at runtime. - Fully backward-compatible: no schema/format change. Old `.pte` files load unchanged (just without the dedup). **Gemma 4 31B GGUF loader** (`examples/models/gemma4_31b/`) - `iter_gguf_tensors` now yields the tensor's quant type and can emit Q6_K tensors as the raw `(N, n_blocks*210)` uint8 blob (`q6k_raw`); added `_raw_q6_k` helper and made `_unpack_q6_k` accept an already-materialized tensor. - New `mlx_gguf_linear.py` carrier modules (`GGUFLinear`/`GGUFEmbedding`) and `_handle_mlx_q6k` routing: Linear weights → `gguf_linear`, token embedding → `gguf_embedding`, tied lm_head reuses the embedding blob via `gguf_linear`, with a quantized-tensor fallback for any other Q6_K module. - Removed the `ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS` env-var workaround in `export.py` since the fused path no longer needs it. **Refactor** - Renamed `backends/mlx/model_ops/` → `backends/mlx/custom_kernel_ops/` (with a `test/` subpackage) and updated all imports (`turboquant_cache.py`, `qwen3_5_moe/mlx_source_transformations.py`). ### Test plan - New/updated unit tests: `custom_kernel_ops/gguf/test/test_linear.py`, `test_embedding.py`; `backends/mlx/test/test_serialization_dedup.py` (asserts identical source/header are written once); `examples/models/gemma4_31b/quant/tests/test_gguf.py` and `examples/models/gemma4_31b/tests/test_mlx_pipeline.py`. - CI (`.github/workflows/mlx.yml`) discovers op tests recursively (`custom_kernel_ops/**/test/test_*.py`) so per-format subpackage tests run with no per-op CI edit. Run locally: ```bash # Build the op runner once (per CI): cmake --preset mlx-release -DEXECUTORCH_BUILD_TESTS=ON -DEXECUTORCH_MLX_ENABLE_SANITIZERS=OFF cmake --build cmake-out --target op_test_runner -j # GPU op tests (export + run on device): python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear run -v python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_embedding run -v # Pure-Python checks: python -m pytest backends/mlx/test/test_serialization_dedup.py \ examples/models/gemma4_31b/quant/tests/test_gguf.py \ examples/models/gemma4_31b/tests/test_mlx_pipeline.py -v ```
1 parent 2d42918 commit a9d5674

53 files changed

Lines changed: 3819 additions & 829 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/mlx.yml

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ on:
1313
- backends/mlx/**
1414
- extension/llm/export/**
1515
- extension/audio/**
16+
- examples/models/gemma4_31b/**
1617
- examples/models/parakeet/**
1718
- examples/models/voxtral_realtime/**
1819
- examples/models/qwen3_5_moe/**
@@ -77,6 +78,8 @@ jobs:
7778
backends/mlx/test/test_passes.py \
7879
backends/mlx/test/test_pattern_utils.py \
7980
backends/mlx/test/test_partitioner.py \
81+
backends/mlx/test/test_serialization_dedup.py \
82+
examples/models/gemma4_31b/quant/tests/test_pack_mlx.py \
8083
examples/models/gemma4_31b/tests/test_mlx_pipeline.py \
8184
-v
8285
echo "::endgroup::"
@@ -89,20 +92,16 @@ jobs:
8992
./cmake-out/backends/mlx/test/multi_thread_test_runner
9093
echo "::endgroup::"
9194
92-
echo "::group::Run gated_delta_rule op tests"
93-
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run -v
94-
echo "::endgroup::"
95-
96-
echo "::group::Run tq_norm op tests"
97-
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_norm run -v
98-
echo "::endgroup::"
99-
100-
echo "::group::Run tq4_compress op tests"
101-
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq4_compress run -v
102-
echo "::endgroup::"
103-
104-
echo "::group::Run tq_dequant op tests"
105-
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_dequant run -v
95+
echo "::group::Run custom_kernel_ops op tests"
96+
# Run every custom_kernel_ops/**/test/test_*.py via its OpTestCase `run`
97+
# CLI. Recurses into per-format subpackages (e.g. gguf/test), so adding a
98+
# new op test file requires no change here.
99+
set -e
100+
for t in $(find backends/mlx/custom_kernel_ops -path '*/test/test_*.py' | sort); do
101+
mod="executorch.$(echo "${t%.py}" | tr '/' '.')"
102+
echo "--- ${mod} ---"
103+
${CONDA_RUN} python -m "${mod}" run -v
104+
done
106105
echo "::endgroup::"
107106
108107
test-mlx-qwen35-moe:

backends/mlx/builder/op_helpers.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,79 @@ def emit_quantized_biases(
329329
return biases
330330

331331

332+
def emit_quantized_gather(
333+
P: MLXProgramBuilder,
334+
out: Slot,
335+
indices_slot: Slot,
336+
qdata_slot: Slot,
337+
scales_slot: Slot,
338+
biases_slot: Optional[Slot],
339+
*,
340+
group_size: int,
341+
bits: int,
342+
mode: str,
343+
out_dtype: torch.dtype,
344+
) -> None:
345+
"""Gather quantized rows by index and dequantize them into ``out``.
346+
347+
Emits ``TakeNode`` for qdata and scales (and biases when present), then a
348+
``DequantizeNode``.
349+
"""
350+
from executorch.backends.mlx.serialization.mlx_graph_schema import (
351+
DequantizeNode,
352+
IntOrVidOrTid,
353+
TakeNode,
354+
)
355+
356+
ids_index = IntOrVidOrTid.from_tid(P.slot_to_tid(indices_slot))
357+
358+
_, wq_sel = P.make_tmp_slot()
359+
P.emit(
360+
TakeNode(
361+
x=P.slot_to_tid(qdata_slot),
362+
index=ids_index,
363+
out=P.slot_to_tid(wq_sel),
364+
axis=0,
365+
)
366+
)
367+
368+
_, sc_sel = P.make_tmp_slot()
369+
P.emit(
370+
TakeNode(
371+
x=P.slot_to_tid(scales_slot),
372+
index=ids_index,
373+
out=P.slot_to_tid(sc_sel),
374+
axis=0,
375+
)
376+
)
377+
378+
biases_tid = None
379+
if biases_slot is not None:
380+
_, b_sel = P.make_tmp_slot()
381+
P.emit(
382+
TakeNode(
383+
x=P.slot_to_tid(biases_slot),
384+
index=ids_index,
385+
out=P.slot_to_tid(b_sel),
386+
axis=0,
387+
)
388+
)
389+
biases_tid = P.slot_to_tid(b_sel)
390+
391+
P.emit(
392+
DequantizeNode(
393+
w=P.slot_to_tid(wq_sel),
394+
scales=P.slot_to_tid(sc_sel),
395+
out=P.slot_to_tid(out),
396+
biases=biases_tid,
397+
group_size=group_size,
398+
bits=bits,
399+
mode=mode,
400+
dtype=torch_dtype_to_scalar_type(out_dtype),
401+
)
402+
)
403+
404+
332405
def to_mlx_qparams(
333406
qdata: torch.Tensor,
334407
scale: torch.Tensor,
@@ -421,6 +494,34 @@ def parse_dequant_nvfp4_node(
421494
return qdata, scale, per_tensor_scale, output_dtype
422495

423496

497+
def parse_dequant_int4_node(
498+
node: Node,
499+
) -> Optional[Tuple[Node, Node, Node, int, Optional[torch.dtype]]]:
500+
"""Parse a torchao.dequantize_int4_tensor node.
501+
502+
Returns (qdata, scale, zero_point, group_size, output_dtype) or None if not a
503+
dequantize_int4_tensor node or the custom op is not registered.
504+
"""
505+
target = get_aten_target(node.target)
506+
try:
507+
import executorch.extension.llm.export.int4 # noqa: F401
508+
except ImportError:
509+
return None
510+
511+
if target is not torch.ops.torchao.dequantize_int4_tensor.default:
512+
return None
513+
514+
qdata, scale, zero_point, group_size = node.args[0:4]
515+
516+
output_dtype = None
517+
if len(node.args) > 4:
518+
output_dtype = node.args[4]
519+
elif "output_dtype" in node.kwargs:
520+
output_dtype = node.kwargs["output_dtype"]
521+
522+
return qdata, scale, zero_point, group_size, output_dtype
523+
524+
424525
def parse_dequant_node(
425526
node: Node,
426527
) -> Optional[Tuple[Node, Node, Node, int, int, Optional[torch.dtype], int]]:
File renamed without changes.
File renamed without changes.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
#
8+
9+
"""GGUF-quantized weight lowering for the MLX backend.
10+
11+
Import :mod:`.patterns` for its side effect to enable lowering of
12+
``torchao::dequantize_gguf -> linear/embedding`` to the Q6_K / Q4_K kernels::
13+
14+
import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401
15+
16+
This ``__init__`` is side-effect free, so importing ``.q6k`` for the pure-torch
17+
dequant does not pull in the MLX builder/registry.
18+
"""
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
#
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
#
8+
9+
"""MLX pattern handlers for GGUF-quantized weights.
10+
11+
``ExportableGGUFTensor`` (extension/llm/export/gguf.py) lowers a quantized
12+
linear/embedding to::
13+
14+
linear(x, torchao::dequantize_gguf(weight, ggml_type, out_dtype), bias)
15+
embedding(torchao::dequantize_gguf(weight, ggml_type, out_dtype), indices)
16+
17+
These handlers match that ``dequantize_gguf -> linear/embedding`` subgraph and
18+
lower it without materializing the dequantized weight:
19+
20+
* **Q6_K** -> fused custom Metal kernels in :mod:`.q6k`.
21+
* **Q4_K** -> MLX's native 4-bit affine ops via :mod:`.q4k` (GGUF blocks
22+
repacked into MLX qparams at export time).
23+
24+
Both cover linear and embedding.
25+
26+
Other quant types are left unmatched (the caller is expected to convert them to a
27+
torchao ``Int4Tensor`` / ``IntxUnpackedToInt8Tensor`` first).
28+
29+
Importing this module registers the patterns as a side effect.
30+
"""
31+
32+
from __future__ import annotations
33+
34+
from typing import Optional, Tuple
35+
36+
import torch
37+
from executorch.backends.mlx.builder.op_helpers import get_aten_target
38+
from executorch.backends.mlx.builder.op_registry import PatternHandler, REGISTRY
39+
from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder
40+
from executorch.backends.mlx.builder.slot_manager import Slot
41+
from executorch.backends.mlx.pattern_utils import has_single_user, match_target
42+
from torch.export.exported_program import ExportedProgram
43+
from torch.fx.node import Node
44+
45+
# Quant types each pattern can lower (Q6_K via custom Metal kernels, Q4_K via
46+
# MLX-native affine ops).
47+
_LINEAR_TYPES = {"q4_k", "q6_k"}
48+
_EMBEDDING_TYPES = {"q4_k", "q6_k"}
49+
50+
51+
def parse_dequantize_gguf_node(
52+
node: Node,
53+
) -> Optional[Tuple[Node, str, torch.dtype]]:
54+
"""Parse a ``torchao::dequantize_gguf`` node.
55+
56+
Returns ``(weight_node, ggml_type, output_dtype)`` or ``None`` if ``node`` is
57+
not a ``dequantize_gguf`` node (or the op isn't registered).
58+
"""
59+
try:
60+
import executorch.extension.llm.export.gguf # noqa: F401 registers the op
61+
except ImportError:
62+
return None
63+
64+
if get_aten_target(node.target) is not torch.ops.torchao.dequantize_gguf.default:
65+
return None
66+
67+
weight = node.args[0]
68+
ggml_type = node.args[1]
69+
output_dtype = torch.bfloat16
70+
if len(node.args) > 2:
71+
output_dtype = node.args[2]
72+
elif "output_dtype" in node.kwargs:
73+
output_dtype = node.kwargs["output_dtype"]
74+
return weight, ggml_type, output_dtype
75+
76+
77+
@REGISTRY.register_pattern(name="GGUF_QUANTIZED_LINEAR")
78+
class GGUFQuantizedLinearHandler(PatternHandler):
79+
"""Lower ``dequantize_gguf + linear`` to a fused quantized matmul.
80+
81+
Matches ``linear(x, dequantize_gguf(weight, ggml_type, out_dtype), bias)``
82+
and dispatches on ``ggml_type``: Q6_K -> custom Metal kernels, Q4_K -> MLX
83+
4-bit ``quantized_matmul``.
84+
"""
85+
86+
def __init__(self, head, body, weight, ggml_type, output_dtype):
87+
super().__init__(head, body)
88+
self.weight = weight
89+
self.ggml_type = ggml_type
90+
self.output_dtype = output_dtype
91+
92+
@classmethod
93+
def maybe_create(cls, ep: ExportedProgram, head: Node):
94+
if not match_target(head, torch.ops.aten.linear.default):
95+
return None
96+
if len(head.args) < 2 or not isinstance(head.args[1], Node):
97+
return None
98+
dequant = head.args[1]
99+
if not has_single_user(dequant):
100+
return None
101+
parsed = parse_dequantize_gguf_node(dequant)
102+
if parsed is None:
103+
return None
104+
weight, ggml_type, output_dtype = parsed
105+
if ggml_type not in _LINEAR_TYPES:
106+
return None
107+
return cls(head, [dequant], weight, ggml_type, output_dtype)
108+
109+
def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot:
110+
assert n == self.head
111+
x_node = n.args[0]
112+
bias_node = n.args[2] if len(n.args) > 2 else None
113+
if self.ggml_type == "q6_k":
114+
from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.linear import (
115+
emit_linear,
116+
)
117+
else: # q4_k
118+
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear import (
119+
emit_linear,
120+
)
121+
return emit_linear(P, n, x_node, self.weight, bias_node)
122+
123+
124+
@REGISTRY.register_pattern(name="GGUF_QUANTIZED_EMBEDDING")
125+
class GGUFQuantizedEmbeddingHandler(PatternHandler):
126+
"""Lower ``dequantize_gguf + embedding`` to a quantized gather.
127+
128+
Matches ``embedding(dequantize_gguf(weight, ggml_type, out_dtype), indices)``
129+
and dispatches on ``ggml_type``: Q6_K -> custom Metal gather, Q4_K -> MLX
130+
quantized gather.
131+
"""
132+
133+
def __init__(self, head, body, weight, ggml_type, output_dtype):
134+
super().__init__(head, body)
135+
self.weight = weight
136+
self.ggml_type = ggml_type
137+
self.output_dtype = output_dtype
138+
139+
@classmethod
140+
def maybe_create(cls, ep: ExportedProgram, head: Node):
141+
if not match_target(head, torch.ops.aten.embedding.default):
142+
return None
143+
if len(head.args) < 2 or not isinstance(head.args[0], Node):
144+
return None
145+
dequant = head.args[0]
146+
if not has_single_user(dequant):
147+
return None
148+
parsed = parse_dequantize_gguf_node(dequant)
149+
if parsed is None:
150+
return None
151+
weight, ggml_type, output_dtype = parsed
152+
if ggml_type not in _EMBEDDING_TYPES:
153+
return None
154+
return cls(head, [dequant], weight, ggml_type, output_dtype)
155+
156+
def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot:
157+
assert n == self.head
158+
indices_node = n.args[1]
159+
if self.ggml_type == "q6_k":
160+
from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.embedding import (
161+
emit_embedding,
162+
)
163+
else: # q4_k
164+
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding import (
165+
emit_embedding,
166+
)
167+
return emit_embedding(P, n, self.weight, indices_node, self.output_dtype)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
#
8+
9+
"""GGUF Q4_K format lowering for the MLX backend (native affine 4-bit).
10+
11+
See :mod:`.linear` / :mod:`.embedding` for the ``emit_*`` lowerings (called by
12+
``custom_kernel_ops.gguf.patterns``); they are not imported here to keep the
13+
package import light.
14+
"""

0 commit comments

Comments
 (0)