Skip to content

Commit a789bcd

Browse files
committed
Add GGUF import, dequantize_weight, pack_one, and test reorganization
- quant/gguf.py: unpack Q4_K/Q6_K GGUF blocks to CanonicalQuantizedWeight, with iter_gguf_tensors for streaming (low peak memory). Validated against original bf16 weights (Q4_K: 7.9%, Q6_K: 1.9% error). - gguf_loader.py: Gemma 4 31B GGUF key mapping + load_gguf_model. Handles tied embed/lm_head: embedding dequantized to bf16 (gather), lm_head keeps Q4_K (tinygemm matmul). - export.py and inference.py: --gguf flag for direct GGUF file loading. - quant/quantize.py: dequantize_weight (inverse of quantize_weight). - quant/pack.py: pack_one for single-weight streaming; pack_model delegates to pack_one for unquantized, groups quantized by parent for multi-weight modules (MoE-compatible). - quant/serialize.py: CanonicalQuantizedWeight.__post_init__ validation (dtype, shape, symmetric/zero consistency). - Tests moved to tests/ folders (quant/tests/ and tests/).
1 parent 9108a5b commit a789bcd

19 files changed

Lines changed: 917 additions & 85 deletions

.github/workflows/cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ jobs:
149149
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts="
150150
151151
# Run Gemma 4 31B tests (quant unit tests + pipeline integration tests)
152-
python -m pytest examples/models/gemma4_31b/quant/ examples/models/gemma4_31b/test_pipeline.py examples/models/gemma4_31b/test_cuda_pipeline.py -v -o "addopts="
152+
python -m pytest examples/models/gemma4_31b/quant/tests/ examples/models/gemma4_31b/tests/ -v -o "addopts="
153153
154154
export-model-cuda-artifact:
155155
name: export-model-cuda-artifact

examples/models/gemma4_31b/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ both export and eager inference:
1616
| `quantize_and_save.py` | bf16 HF checkpoint → quantized checkpoint (one-time) | ~30 GB CPU |
1717
| `export.py --prequantized <dir>` | quantized checkpoint → `model.pte` + `model.ptd` | ~24 GB CPU + CUDA for packing |
1818
| `inference.py --prequantized <dir>` | quantized checkpoint → eager generation under `torch.compile` | ~24 GB GPU |
19+
| `inference.py --gguf <file>` | GGUF file (Q4_K_M, etc.) → eager generation | ~24 GB GPU |
1920
| `export.py --model-dir <hf>` | one-shot bf16 → quantize → export (no intermediate file) | ~30 GB CPU + CUDA for packing |
2021

2122
The quantized checkpoint is a safetensors file with int values + per-group
@@ -85,6 +86,15 @@ python examples/models/gemma4_31b/inference.py \
8586
--temperature 0.8
8687
```
8788

89+
GGUF files from the community (e.g., Q4_K_M) can also be used directly:
90+
91+
```bash
92+
python examples/models/gemma4_31b/inference.py \
93+
--gguf ./gemma-4-31B-it-Q4_K_M.gguf \
94+
--tokenizer-path /path/to/tokenizer.json \
95+
--prompt "Hello"
96+
```
97+
8898
Useful before spending the export+lowering time to confirm the quantized
8999
model produces sensible text.
90100

examples/models/gemma4_31b/export.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
- "decode": T=1, static shape, returns the next sampled token.
1111
- "prefill": T>=2, dynamic shape, returns the next sampled token.
1212
13-
Two input paths:
13+
Three input paths:
1414
--prequantized <dir> Load a quantized checkpoint (from quantize_and_save.py)
1515
and pack for the target backend. No re-quantization.
16+
--gguf <file> Load a GGUF file (e.g., Q4_K_M from the community).
1617
--model-dir <hf> Load bf16 checkpoint, quantize, pack, and export
1718
in one shot.
1819
@@ -251,6 +252,11 @@ def main() -> None:
251252
default=None,
252253
help="Path to a quantized checkpoint directory. Skips quantization.",
253254
)
255+
src.add_argument(
256+
"--gguf",
257+
default=None,
258+
help="Path to a GGUF file (e.g., gemma-4-31B-it-Q4_K_M.gguf).",
259+
)
254260
parser.add_argument(
255261
"--output-dir",
256262
default="./gemma4_31b_exports",
@@ -285,6 +291,12 @@ def main() -> None:
285291
max_seq_len=args.max_seq_len,
286292
backend=args.backend,
287293
)
294+
elif args.gguf:
295+
from executorch.examples.models.gemma4_31b.gguf_loader import load_gguf_model
296+
297+
model, config = load_gguf_model(
298+
args.gguf, max_seq_len=args.max_seq_len, backend=args.backend
299+
)
288300
else:
289301
model, config = load_and_quantize(
290302
args.model_dir,
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Load a GGUF file into a Gemma 4 31B model.
8+
9+
Streams tensors one at a time via ``iter_gguf_tensors`` for low peak
10+
memory, remaps GGUF names to model FQNs, handles tied embed/lm_head,
11+
and packs for the target backend.
12+
13+
Usage:
14+
model, config = load_gguf_model("model.gguf", backend="cuda")
15+
"""
16+
17+
from typing import Optional
18+
19+
import torch
20+
21+
# GGUF pattern → model FQN pattern. ``{}`` is the layer index.
22+
_KEY_MAP = {
23+
"token_embd.weight": "embed_tokens.weight",
24+
"output_norm.weight": "norm.weight",
25+
# Per-layer attention
26+
"blk.{}.attn_q.weight": "layers.{}.self_attn.q_proj.weight",
27+
"blk.{}.attn_k.weight": "layers.{}.self_attn.k_proj.weight",
28+
"blk.{}.attn_v.weight": "layers.{}.self_attn.v_proj.weight",
29+
"blk.{}.attn_output.weight": "layers.{}.self_attn.o_proj.weight",
30+
"blk.{}.attn_q_norm.weight": "layers.{}.self_attn.q_norm.weight",
31+
"blk.{}.attn_k_norm.weight": "layers.{}.self_attn.k_norm.weight",
32+
# Per-layer norms
33+
"blk.{}.attn_norm.weight": "layers.{}.input_layernorm.weight",
34+
"blk.{}.post_attention_norm.weight": "layers.{}.post_attention_layernorm.weight",
35+
"blk.{}.ffn_norm.weight": "layers.{}.pre_feedforward_layernorm.weight",
36+
"blk.{}.post_ffw_norm.weight": "layers.{}.post_feedforward_layernorm.weight",
37+
# Per-layer MLP
38+
"blk.{}.ffn_gate.weight": "layers.{}.mlp.gate_proj.weight",
39+
"blk.{}.ffn_up.weight": "layers.{}.mlp.up_proj.weight",
40+
"blk.{}.ffn_down.weight": "layers.{}.mlp.down_proj.weight",
41+
# Per-layer scalar
42+
"blk.{}.layer_output_scale.weight": "layers.{}.layer_scalar",
43+
}
44+
45+
_IGNORED_KEYS = {"rope_freqs.weight"}
46+
47+
48+
def gguf_to_model_key(gguf_key: str) -> Optional[str]:
49+
"""Map a GGUF tensor name to a model FQN, or ``None`` to skip."""
50+
if gguf_key in _IGNORED_KEYS:
51+
return None
52+
53+
for gguf_pat, model_pat in _KEY_MAP.items():
54+
if "{}" not in gguf_pat:
55+
if gguf_key == gguf_pat:
56+
return model_pat
57+
continue
58+
prefix, suffix = gguf_pat.split("{}")
59+
if gguf_key.startswith(prefix) and gguf_key.endswith(suffix):
60+
layer_str = gguf_key[len(prefix) : len(gguf_key) - len(suffix)]
61+
if layer_str.isdigit():
62+
return model_pat.replace("{}", layer_str)
63+
64+
return None
65+
66+
67+
def _resolve_tied_lm_head(model, embed_cw, packers):
68+
"""Handle tied embed/lm_head after streaming all tensors."""
69+
from executorch.examples.models.gemma4_31b.quant import pack_one
70+
71+
lm_head = getattr(model.lm_head, "weight", None)
72+
if lm_head is None or lm_head.device.type != "meta":
73+
return
74+
if embed_cw is not None:
75+
pack_one(model, "lm_head.weight", embed_cw, packers)
76+
else:
77+
pack_one(
78+
model,
79+
"lm_head.weight",
80+
model.embed_tokens.weight.data.clone(),
81+
packers,
82+
)
83+
84+
85+
def _validate_no_meta(model):
86+
"""Ensure all parameters have been loaded."""
87+
for fqn, p in model.named_parameters():
88+
if p.device.type == "meta":
89+
raise RuntimeError(
90+
f"Weight '{fqn}' not found in GGUF file "
91+
f"(model/checkpoint version mismatch?)"
92+
)
93+
for p in model.parameters():
94+
p.requires_grad_(False)
95+
96+
97+
def load_gguf_model(
98+
gguf_path: str,
99+
max_seq_len: int = 4096,
100+
backend: str = "cuda",
101+
) -> tuple:
102+
"""Load a GGUF file, remap keys, and pack for the target backend.
103+
104+
Streams tensors one at a time for low peak memory.
105+
106+
GGUF ties ``embed_tokens`` and ``lm_head`` into a single Q4_K tensor.
107+
We untie them: the embedding is dequantized to bf16 (``nn.Embedding``
108+
needs gather, which ``Int4TilePackedTo4dTensor`` does not support),
109+
while ``lm_head`` keeps the original Q4_K quantization (``nn.Linear``
110+
matmul via tinygemm).
111+
112+
Returns ``(model, config)``.
113+
"""
114+
from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig
115+
from executorch.examples.models.gemma4_31b.quant import dequantize_weight, pack_one
116+
from executorch.examples.models.gemma4_31b.quant.gguf import iter_gguf_tensors
117+
from executorch.examples.models.gemma4_31b.quant.serialize import (
118+
CanonicalQuantizedWeight,
119+
)
120+
121+
if backend == "cuda":
122+
from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS
123+
124+
packers = DEFAULT_CUDA_PACKERS
125+
else:
126+
raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda'.")
127+
128+
config = Gemma4_31BConfig(max_seq_len=max_seq_len)
129+
130+
print("Building model on meta device...")
131+
with torch.device("meta"):
132+
model = Gemma4_31B(config)
133+
134+
embed_cw = None
135+
n_processed = 0
136+
137+
print(f"Streaming GGUF from {gguf_path}...")
138+
for gguf_name, result in iter_gguf_tensors(gguf_path):
139+
model_key = gguf_to_model_key(gguf_name)
140+
if model_key is None:
141+
continue
142+
143+
if isinstance(result, torch.Tensor) and result.dtype == torch.float32:
144+
result = result.to(torch.bfloat16)
145+
146+
if model_key == "embed_tokens.weight" and isinstance(
147+
result, CanonicalQuantizedWeight
148+
):
149+
embed_cw = result
150+
result = dequantize_weight(result, torch.bfloat16)
151+
152+
pack_one(model, model_key, result, packers)
153+
154+
n_processed += 1
155+
if n_processed % 100 == 0:
156+
print(f" Processed {n_processed} tensors...")
157+
158+
_resolve_tied_lm_head(model, embed_cw, packers)
159+
del embed_cw
160+
161+
_validate_no_meta(model)
162+
model.eval()
163+
164+
print(f"Model: {config.num_hidden_layers} layers, hidden={config.hidden_size}")
165+
return model, config

examples/models/gemma4_31b/inference.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,26 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
"""Eager inference on a prequantized Gemma 4 31B-IT model (CUDA + torch.compile).
7+
"""Eager inference on Gemma 4 31B-IT (CUDA + torch.compile).
88
9-
Loads a quantized checkpoint (from ``quantize_and_save.py``), packs for CUDA,
10-
materializes runtime buffers, optionally compiles with ``torch.compile``, and
11-
generates text autoregressively. The model performs Gumbel-max sampling
12-
on-device, so each forward returns the next token ID as a float tensor of
13-
shape ``[B, 1]``.
9+
Two input paths:
10+
--prequantized <dir> Load a quantized checkpoint (from quantize_and_save.py).
11+
--gguf <file> Load a GGUF file (e.g., Q4_K_M from the community).
12+
13+
Packs for the target backend (--backend cuda), materializes runtime buffers,
14+
optionally compiles with ``torch.compile``, and generates text autoregressively.
1415
1516
Usage:
1617
python inference.py \\
1718
--prequantized ./gemma4_31b_int4 \\
1819
--prompt "Write a short joke about saving RAM." \\
1920
--max-new-tokens 128 \\
2021
--temperature 0.8
22+
23+
python inference.py \\
24+
--gguf ./gemma-4-31B-it-Q4_K_M.gguf \\
25+
--tokenizer-path ./tokenizer.json \\
26+
--prompt "Hello"
2127
"""
2228

2329
import argparse
@@ -113,14 +119,23 @@ def generate(
113119

114120

115121
def main() -> None:
116-
parser = argparse.ArgumentParser(
117-
description="Eager inference on prequantized Gemma 4 31B-IT (CUDA)."
118-
)
119-
parser.add_argument(
122+
parser = argparse.ArgumentParser(description="Eager inference on Gemma 4 31B-IT.")
123+
src = parser.add_mutually_exclusive_group(required=True)
124+
src.add_argument(
120125
"--prequantized",
121-
required=True,
126+
default=None,
122127
help="Path to a quantized checkpoint directory.",
123128
)
129+
src.add_argument(
130+
"--gguf",
131+
default=None,
132+
help="Path to a GGUF file (e.g., gemma-4-31B-it-Q4_K_M.gguf).",
133+
)
134+
parser.add_argument(
135+
"--tokenizer-path",
136+
default=None,
137+
help="Path to tokenizer.json (required with --gguf, optional with --prequantized).",
138+
)
124139
parser.add_argument("--prompt", default="Hello", help="Input prompt.")
125140
parser.add_argument(
126141
"--max-new-tokens",
@@ -145,23 +160,41 @@ def main() -> None:
145160
action="store_true",
146161
help="Skip torch.compile (slower, but easier to debug).",
147162
)
163+
parser.add_argument(
164+
"--backend",
165+
default="cuda",
166+
choices=["cuda"],
167+
help="Target backend.",
168+
)
148169
args = parser.parse_args()
149170

150-
if not torch.cuda.is_available():
151-
parser.error("CUDA is required for inference.")
171+
if args.backend == "cuda" and not torch.cuda.is_available():
172+
parser.error("CUDA is required for the cuda backend.")
152173

153-
print(f"Loading prequantized model from {args.prequantized}...")
154-
model, config = load_prequantized_model(
155-
args.prequantized, max_seq_len=args.max_seq_len
156-
)
174+
if args.gguf:
175+
from executorch.examples.models.gemma4_31b.gguf_loader import load_gguf_model
176+
177+
model, config = load_gguf_model(
178+
args.gguf, args.max_seq_len, backend=args.backend
179+
)
180+
else:
181+
print(f"Loading prequantized model from {args.prequantized}...")
182+
model, config = load_prequantized_model(
183+
args.prequantized, max_seq_len=args.max_seq_len, backend=args.backend
184+
)
157185
_move_to_cuda(model, config)
158186
model.eval()
159187

160188
if not args.no_compile:
161189
print("Compiling model with torch.compile...")
162190
model = torch.compile(model, mode="default")
163191

164-
tokenizer_path = os.path.join(args.prequantized, "tokenizer.json")
192+
if args.tokenizer_path:
193+
tokenizer_path = args.tokenizer_path
194+
elif args.prequantized:
195+
tokenizer_path = os.path.join(args.prequantized, "tokenizer.json")
196+
else:
197+
parser.error("--tokenizer-path is required with --gguf.")
165198
from tokenizers import Tokenizer
166199

167200
tokenizer = Tokenizer.from_file(tokenizer_path)

examples/models/gemma4_31b/model.md

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -121,20 +121,19 @@ identical logits to sequential one-token-at-a-time prefill.
121121

122122
## Quantization
123123

124-
Three modules in `quant/`:
124+
Modules in `quant/`:
125125

126-
- **Recipe** (`recipe.py`): `QuantConfig` (bits, group_size, symmetric,
127-
method) + `QuantRule` (regex pattern, config, optional layer filter) +
128-
`QuantRecipe` (ordered rules, first match wins). Declares what to
129-
quantize and how — says nothing about packing or backends.
126+
- **Recipe** (`recipe.py`): `QuantConfig` + `QuantRule` + `QuantRecipe`.
127+
Declares what to quantize — says nothing about packing or backends.
128+
- **Quantize** (`quantize.py`): `quantize_weight` / `dequantize_weight` /
129+
`quantize_model`. Produces `CanonicalQuantizedWeight` from fp weights.
130130
- **Serialize** (`serialize.py`): `CanonicalQuantizedWeight` (int8 qdata +
131-
bf16 scale + optional zero). `save()` / `load()` persist to safetensors
132-
with a JSON header per weight. Packing-agnostic — any backend can read
133-
the file.
134-
- **Packer** (`pack_cuda.py`): converts `CanonicalQuantizedWeight` to
135-
backend runtime format at load time via `pack_model()`. Dispatches per
136-
parent module type (`nn.Linear``Int4TilePackedTo4dTensor` for
137-
tinygemm). Extensible via a packers dict.
131+
bf16 scale + optional zero). `save()` / `load()` persist to safetensors.
132+
- **Pack** (`pack.py` + `pack_cuda.py`): `pack_model` groups weights by
133+
parent module, `pack_one` handles single weights. Per-module packers
134+
dispatch by module type (`nn.Linear`, `nn.Embedding`, extensible for MoE).
135+
- **GGUF** (`gguf.py`): `unpack_gguf_tensor` / `iter_gguf_tensors` for
136+
loading community-quantized GGUF files (Q4_K, Q6_K).
138137

139138
The quantize-once flow:
140139

0 commit comments

Comments
 (0)