Skip to content

Commit 42581f1

Browse files
authored
Add GGUF → MLX export support for Gemma 4 31B (pytorch#19829)
Enable loading GGUF files (e.g. Q4_K_M) and exporting to the MLX backend. Three areas of change: GGUF loader (gguf_loader.py): - Add MLX backend support alongside CUDA - Keep embedding quantized for MLX (QuantizedEmbeddingHandler supports quantized gather natively, unlike CUDA's Int4Tensor) - Fix stale docstring references to Int4TilePackedTo4dTensor/tinygemm MLX backend (op_helpers.py, patterns.py): - Accept group_size=16 in parse_dequant_node for GGUF Q6_K tensors - For group_size < 32, emit DequantizeNode + TransposeNode + AddmmNode instead of QuantizedMatmulNode, since MLX Metal kernels are only instantiated for group_size >= 32. Weights stay packed as int8 in the .pte file and are dequantized on-device at runtime — same strategy CUDA/Inductor uses (separate Triton dequant + cuBLAS mm). Packer (pack_mlx.py): - Add 16 to supported group sizes so Q6_K IntxUnpackedToInt8Tensor passes through to export unchanged Tests (test_ops.py): - Add group_size=16 configs for int8, int4, and no-bias variants Test Plan: Export and run this model https://huggingface.co/unsloth/gemma-4-31B-it-GGUF/blob/main/gemma-4-31B-it-Q4_K_M.gguf On M1 32GB machine (exported on Linux A100) ``` (executorch_dev) mnachin@mnachin-mbp executorch % ./cmake-out/examples/models/gemma4_31b/gemma4_31b_runner \ --model_path /Users/mnachin/repos/models/gemma-4-31B-it-GGUF/model.pte \ --tokenizer_path /Users/mnachin/repos/models/gemma-4-31B-it-HQQ-INT4/tokenizer.json \ --prompt "Tell me a joke about RAM usage" \ --max_new_tokens 128 \ --temperature 0.8 I tokenizers:regex.cpp:27] Registering override fallback regex WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1779926968.603672 54889180 re2.cc:237] Error parsing '((\<pad\>|ool\|\>1\x00\x00\ �\<t|respo|\<tool_call\|\>|\<bos\>|\<\|tool_response\>|\<\|think\|\>|\x0...': invalid UTF-8 I tokenizers:re2_regex.cpp:27] Re2 failed to compile regex: ((\<pad\>|ool\|\>1\x00\x00\ �\<t|respo|\<tool_call\|\>|\<bos\>|\<\|tool_response\>|\<\|think\|\>|\x00\x00\\\<|\<tool_response\|\>|\<mask\>|\<\|\"\|\>|all\|\>j\x00\x00\\|\<channel\|\>|\<\|turn\>|\<turn\|\>|\<\|image\>|\<\|$ I tokenizers:regex_lookahead.cpp:27] Creating PCRE2 regex I tokenizers:pcre2_regex.cpp:48] PCRE2 UTF-8 validation failed at offset 27: UTF-8 error: byte 2 top bits not 0x80. Retrying without UTF flags. Loading model... Prompt tokens: 23 Why did the computer go to therapy? Because it had too many **unresolved dependencies** and it just couldn't stop **dwelling on the past**... but it forgot everything the moment it took a nap.<turn|> PyTorchObserver {"prefill_token_per_sec":2.49539,"decode_token_per_sec":0.0880671,"prompt_tokens":23,"generated_tokens":44,"model_load_start_ms":1779926968052,"model_load_end_ms":1779926982494,"inference_start_ms":1779926982497,"inference_end_ms":1779927491333,"prompt_eval_end_ms":1779926991714,"first_token_ms":1779926991714,"aggregate_sampling_time_ms":0,"SCALING_FACTOR_UNITS_PER_SECOND":1000} ``` For reference, here's the this model: https://huggingface.co/SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4 ``` (executorch_dev) mnachin@mnachin-mbp executorch % ./cmake-out/examples/models/gemma4_31b/gemma4_31b_runner \ --model_path /Users/mnachin/repos/models/gemma-4-31B-it-HQQ-INT4/model.pte \ --tokenizer_path /Users/mnachin/repos/models/gemma-4-31B-it-HQQ-INT4/tokenizer.json \ --prompt "Tell me a joke about RAM usage" \ --max_new_tokens 128 \ --temperature 0.8 I tokenizers:regex.cpp:27] Registering override fallback regex WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1779927592.109382 54914733 re2.cc:237] Error parsing '((\<pad\>|ool\|\>1\x00\x00\ �\<t|respo|\<tool_call\|\>|\<bos\>|\<\|tool_response\>|\<\|think\|\>|\x0...': invalid UTF-8 I tokenizers:re2_regex.cpp:27] Re2 failed to compile regex: ((\<pad\>|ool\|\>1\x00\x00\ �\<t|respo|\<tool_call\|\>|\<bos\>|\<\|tool_response\>|\<\|think\|\>|\x00\x00\\\<|\<tool_response\|\>|\<mask\>|\<\|\"\|\>|all\|\>j\x00\x00\\|\<channel\|\>|\<\|turn\>|\<turn\|\>|\<\|image\>|\<\|$ I tokenizers:regex_lookahead.cpp:27] Creating PCRE2 regex I tokenizers:pcre2_regex.cpp:48] PCRE2 UTF-8 validation failed at offset 27: UTF-8 error: byte 2 top bits not 0x80. Retrying without UTF flags. Loading model... Prompt tokens: 23 Why did the computer go to therapy? Because it had too many **unresolved dependencies** and couldn't stop **dwelling on the past**, but it still couldn't remember why it was there. *** Alternatively, a shorter one: **Why was the RAM so stressed?** Because it had too much on its mind, but it knew that as soon as it slept, it would forget everything.<turn|> PyTorchObserver {"prefill_token_per_sec":9.11975,"decode_token_per_sec":5.24998,"prompt_tokens":23,"generated_tokens":86,"model_load_start_ms":1779927591719,"model_load_end_ms":1779927603575,"inference_start_ms":1779927603579,"inference_end_ms":1779927622482,"prompt_eval_end_ms":1779927606101,"first_token_ms":1779927606101,"aggregate_sampling_time_ms":0,"SCALING_FACTOR_UNITS_PER_SECOND":1000} ``` There's definitely performance degradation when running GGUF
1 parent 000d810 commit 42581f1

11 files changed

Lines changed: 233 additions & 26 deletions

File tree

.github/workflows/mlx.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ jobs:
4747
4848
${CONDA_RUN} pip list
4949
50+
echo "::group::Install Python test requirements"
51+
${CONDA_RUN} pip install gguf
52+
echo "::endgroup::"
53+
5054
echo "::group::Build test runners"
5155
${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 ))
5256
echo "::endgroup::"

backends/mlx/builder/op_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def parse_dequant_node(
334334
if len(non_one) != 1:
335335
return None
336336
quantized_dim, group_size = non_one[0]
337-
if group_size not in [32, 64, 128]:
337+
if group_size not in [16, 32, 64, 128]:
338338
return None
339339

340340
# TODO: MLX supports 3, 5, and 7, but we need to figure out the

backends/mlx/patterns.py

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from __future__ import annotations
1717

18+
import os
1819
from typing import Any, List, Optional, Tuple
1920

2021
import torch
@@ -37,6 +38,7 @@
3738
)
3839
from executorch.backends.mlx.serialization.mlx_graph_schema import (
3940
AddIntNode,
41+
AddmmNode,
4042
AddNode,
4143
AsTypeNode,
4244
DequantizeNode,
@@ -52,6 +54,7 @@
5254
SubtractIntNode,
5355
SymSizeNode,
5456
TakeNode,
57+
TransposeNode,
5558
)
5659
from torch.export.exported_program import ExportedProgram
5760
from torch.fx.node import Node
@@ -883,6 +886,18 @@ def maybe_create(
883886
out_dtype=out_dtype,
884887
)
885888

889+
# MLX's quantized_matmul Metal kernels are only instantiated for
890+
# group_size in {32, 64, 128}. For smaller group sizes (e.g. GGUF
891+
# Q6_K with group_size=16), emit DequantizeNode + matmul instead.
892+
# Weights stay packed in the .pte file; dequantized on-device.
893+
# This non-fused path is significantly slower and must be opted in
894+
# via ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS=1.
895+
_MIN_FUSED_GROUP_SIZE = 32
896+
897+
@staticmethod
898+
def _allow_non_fused() -> bool:
899+
return os.environ.get("ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS", "0") == "1"
900+
886901
def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot:
887902
assert n == self.head
888903

@@ -908,19 +923,59 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot:
908923
x_dtype = x_node.meta["val"].dtype
909924
needs_cast = self.out_dtype != x_dtype
910925

911-
P.emit(
912-
QuantizedMatmulNode(
913-
x=P.slot_to_tid(x_slot),
914-
w=P.slot_to_tid(w),
915-
scales=P.slot_to_tid(scale_slot),
916-
out=P.slot_to_tid(out),
917-
biases=P.slot_to_tid(biases),
918-
group_size=self.group_size,
919-
bits=self.bits,
920-
mode="affine",
921-
transpose=True,
926+
if self.group_size >= self._MIN_FUSED_GROUP_SIZE:
927+
P.emit(
928+
QuantizedMatmulNode(
929+
x=P.slot_to_tid(x_slot),
930+
w=P.slot_to_tid(w),
931+
scales=P.slot_to_tid(scale_slot),
932+
out=P.slot_to_tid(out),
933+
biases=P.slot_to_tid(biases),
934+
group_size=self.group_size,
935+
bits=self.bits,
936+
mode="affine",
937+
transpose=True,
938+
)
922939
)
923-
)
940+
else:
941+
if not self._allow_non_fused():
942+
raise ValueError(
943+
f"Quantized linear with group_size={self.group_size} requires "
944+
f"the non-fused dequantize+matmul path, which is significantly "
945+
f"slower than the fused QuantizedMatmulNode (group_size >= 32). "
946+
f"Set ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS=1 to allow this."
947+
)
948+
out_scalar_type = torch_dtype_to_scalar_type(self.out_dtype)
949+
_, w_deq = P.make_tmp_slot()
950+
P.emit(
951+
DequantizeNode(
952+
w=P.slot_to_tid(w),
953+
scales=P.slot_to_tid(scale_slot),
954+
out=P.slot_to_tid(w_deq),
955+
biases=P.slot_to_tid(biases),
956+
group_size=self.group_size,
957+
bits=self.bits,
958+
mode="affine",
959+
dtype=out_scalar_type,
960+
)
961+
)
962+
_, w_t = P.make_tmp_slot()
963+
P.emit(
964+
TransposeNode(
965+
x=P.slot_to_tid(w_deq),
966+
out=P.slot_to_tid(w_t),
967+
perm=[1, 0],
968+
)
969+
)
970+
P.emit(
971+
AddmmNode(
972+
mat1=P.slot_to_tid(x_slot),
973+
mat2=P.slot_to_tid(w_t),
974+
out=P.slot_to_tid(out),
975+
)
976+
)
977+
# DequantizeNode already produces the correct dtype.
978+
needs_cast = False
924979

925980
if has_bias:
926981
P.emit(

backends/mlx/test/test_ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
See README.md in this directory for full documentation.
2525
"""
2626

27+
import os
2728
from typing import Callable, Dict, List, Optional, Tuple
2829

2930
import torch
@@ -5621,8 +5622,21 @@ def get_test_configs(cls) -> List["QuantizedLinearTest"]:
56215622
cls(group_size=128),
56225623
cls(qdtype=torch.int2),
56235624
cls(qdtype=torch.int8),
5625+
# group_size=16: exercises the non-fused dequantize+matmul path
5626+
# (requires ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS=1).
5627+
cls(qdtype=torch.int8, group_size=16),
5628+
cls(qdtype=torch.int4, group_size=16),
5629+
cls(qdtype=torch.int8, group_size=16, bias=False),
56245630
]
56255631

5632+
def generate_test_files(self, verbose=False):
5633+
if self.group_size < 32:
5634+
os.environ["ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS"] = "1"
5635+
try:
5636+
return super().generate_test_files(verbose=verbose)
5637+
finally:
5638+
os.environ.pop("ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS", None)
5639+
56265640
def create_model(self) -> nn.Module:
56275641
model = LinearModel(self.in_features, self.out_features, bias=self.bias)
56285642
model = model.to(self.dtype)

examples/models/gemma4_31b/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ both export and eager inference:
1515
|---|---|---|
1616
| `quantize_and_save.py` | bf16 HF checkpoint → quantized checkpoint (one-time) | ~30 GB CPU |
1717
| `export.py --prequantized <dir>` | quantized checkpoint → `model.pte` + `model.ptd` | ~24 GB CPU + CUDA for packing |
18+
| `export.py --gguf <file> [--backend mlx]` | GGUF file (Q4_K_M, etc.) → `model.pte` + `model.ptd` | ~24 GB CPU |
1819
| `inference.py --prequantized <dir>` | quantized checkpoint → eager generation under `torch.compile` | ~24 GB GPU |
1920
| `inference.py --gguf <file>` | GGUF file (Q4_K_M, etc.) → eager generation | ~24 GB GPU |
2021
| `export.py --model-dir <hf>` | one-shot bf16 → quantize → export (no intermediate file) | ~30 GB CPU + CUDA for packing |

examples/models/gemma4_31b/export.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,12 @@ def main() -> None:
443443
backend=args.backend,
444444
)
445445

446-
export_and_lower(model, config, args.output_dir, backend=args.backend)
446+
if args.gguf and args.backend == "mlx":
447+
os.environ["ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS"] = "1"
448+
try:
449+
export_and_lower(model, config, args.output_dir, backend=args.backend)
450+
finally:
451+
os.environ.pop("ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS", None)
447452

448453

449454
if __name__ == "__main__":

examples/models/gemma4_31b/gguf_loader.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
1313
Usage:
1414
model, config = load_gguf_model("model.gguf", backend="cuda")
15+
model, config = load_gguf_model("model.gguf", backend="mlx")
1516
"""
1617

1718
from typing import Optional
@@ -104,10 +105,11 @@ def load_gguf_model(
104105
Streams tensors one at a time for low peak memory.
105106
106107
GGUF ties ``embed_tokens`` and ``lm_head`` into a single Q4_K tensor.
107-
We untie them: the embedding is dequantized to bf16 (``nn.Embedding``
108-
needs gather, which ``Int4TilePackedTo4dTensor`` does not support),
109-
while ``lm_head`` keeps the original Q4_K quantization (``nn.Linear``
110-
matmul via tinygemm).
108+
We untie them so ``lm_head`` keeps the original Q4_K quantization.
109+
On CUDA, the embedding is dequantized to bf16 because ``Int4Tensor``
110+
does not support the gather op that ``nn.Embedding`` requires. On
111+
MLX, the embedding stays quantized — ``QuantizedEmbeddingHandler``
112+
handles quantized gather natively.
111113
112114
Returns ``(model, config)``.
113115
"""
@@ -120,8 +122,12 @@ def load_gguf_model(
120122
from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS
121123

122124
packers = DEFAULT_CUDA_PACKERS
125+
elif backend == "mlx":
126+
from executorch.examples.models.gemma4_31b.quant import DEFAULT_MLX_PACKERS
127+
128+
packers = DEFAULT_MLX_PACKERS
123129
else:
124-
raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.")
130+
raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda', 'mlx'.")
125131

126132
config = Gemma4_31BConfig(max_seq_len=max_seq_len)
127133

@@ -143,7 +149,8 @@ def load_gguf_model(
143149

144150
if model_key == "embed_tokens.weight" and isinstance(result, Int4Tensor):
145151
embed_quant = result
146-
result = dequantize_weight(result, torch.bfloat16)
152+
if backend == "cuda":
153+
result = dequantize_weight(result, torch.bfloat16)
147154

148155
pack_one(model, model_key, result, packers)
149156

examples/models/gemma4_31b/quant/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,3 @@ The format is compatible with torchao's `save_pretrained` / `load_pretrained`.
5050

5151
- `pack_metal.py` — Metal backend packer.
5252
- `gguf.py` — extend with Q5_K, Q8_0 GGUF quant types.
53-
- Upstream `Int4TilePackedTo4dTensor.from_int4_tensor()` to torchao
54-
to replace the manual conversion in `pack_int4_for_cuda`.

examples/models/gemma4_31b/quant/pack_mlx.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from .pack import ModulePackerFn, pack_model # noqa: F401
2424

25-
_MLX_SUPPORTED_GROUP_SIZES = (128, 64, 32)
25+
_MLX_SUPPORTED_GROUP_SIZES = (128, 64, 32, 16)
2626

2727

2828
# ---------------------------------------------------------------------------
@@ -126,7 +126,9 @@ def pack_for_mlx(module: nn.Module, weights: dict[str, torch.Tensor]) -> None:
126126
default dispatch produces the ``dequantize_affine → linear`` pattern
127127
MLX expects. Regroups to a compatible group_size when needed (e.g.
128128
per-axis group_size=5376 → group_size=128) since MLX's
129-
``parse_dequant_node`` only accepts group_size in {32, 64, 128}.
129+
``parse_dequant_node`` only accepts group_size in {16, 32, 64, 128}.
130+
Group sizes ≥ 32 use the fused ``QuantizedMatmulNode``; group_size=16
131+
(e.g. GGUF Q6_K) falls back to ``DequantizeNode`` + matmul at export.
130132
"""
131133
from torchao.quantization import IntxUnpackedToInt8Tensor
132134
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor

examples/models/gemma4_31b/quant/tests/test_pack_mlx.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_regroup_preserves_dequant(self):
146146

147147
class TestMlxGroupSize(unittest.TestCase):
148148
def test_passthrough(self):
149-
for gs in (32, 64, 128):
149+
for gs in (16, 32, 64, 128):
150150
self.assertEqual(_mlx_group_size(gs, 256), gs)
151151

152152
def test_regroup_5376(self):
@@ -157,7 +157,49 @@ def test_regroup_256(self):
157157

158158
def test_rejects_indivisible(self):
159159
with self.assertRaises(ValueError):
160-
_mlx_group_size(48, 48)
160+
_mlx_group_size(7, 7)
161+
162+
163+
class TestPackLinearGroupSize16(unittest.TestCase):
164+
"""Packing group_size=16 weights (GGUF Q6_K) preserves semantics."""
165+
166+
def _make_gs16_tensor(self, N=64, K=128):
167+
from torchao.quantization import IntxUnpackedToInt8Tensor
168+
169+
return IntxUnpackedToInt8Tensor(
170+
qdata=torch.randint(-32, 31, (N, K), dtype=torch.int8),
171+
scale=torch.randn(N, K // 16, dtype=torch.bfloat16),
172+
zero_point=torch.zeros(N, K // 16, dtype=torch.int8),
173+
target_dtype=torch.int8,
174+
block_size=(1, 16),
175+
dtype=torch.bfloat16,
176+
activation_quantization=None,
177+
)
178+
179+
def test_dequant_preserves_values(self):
180+
"""Packing preserves the dequantized weight values."""
181+
w = self._make_gs16_tensor(64, 128)
182+
before = dequantize_weight(w, torch.float32)
183+
184+
module = nn.Linear(128, 64, bias=False)
185+
pack_for_mlx(module, {"weight": w})
186+
after = dequantize_weight(module.weight.data, torch.float32)
187+
188+
self.assertTrue(
189+
torch.allclose(before, after, atol=1e-5),
190+
f"max diff: {(before - after).abs().max():.6g}",
191+
)
192+
193+
def test_forward_produces_valid_output(self):
194+
"""Packed gs=16 weight produces finite output in a linear forward."""
195+
w = self._make_gs16_tensor(64, 128)
196+
module = nn.Linear(128, 64, bias=False)
197+
pack_for_mlx(module, {"weight": w})
198+
199+
x = torch.randn(1, 128, dtype=torch.bfloat16)
200+
out = torch.nn.functional.linear(x, module.weight.data.dequantize())
201+
self.assertEqual(out.shape, torch.Size([1, 64]))
202+
self.assertFalse(torch.isnan(out).any())
161203

162204

163205
class TestPackEmbeddingForMlx(unittest.TestCase):

0 commit comments

Comments
 (0)