Skip to content

Commit 4efa007

Browse files
committed
Add Gemma 4 text-decoder export to CoreML
The Gemma 4 text decoder shipped with examples/models/gemma4 already implements hybrid sliding/full attention, partial RoPE, per-layer head_dim, MQA, and YOCO KV sharing in plain PyTorch. That implementation lowers cleanly through torch.export and CoreMLPartitioner — every node in the resulting edge program is a single executorch_call_delegate and a getitem. This script wires up the small amount of glue needed for an on-device-friendly default: * compile_specs targeting iOS18+ so the YOCO KV caches can be taken over as stateful tensors. * fp16 by default (the ANE requires fp16). * compute_unit=CPU_AND_NE so the runtime is free to keep ops on the ANE. * Optional --random_weights mode for smoke-testing the export without a HuggingFace checkpoint, plus --config_json / --sliding_window / --sliding_window_pattern overrides. Audio and vision encoders are intentionally out of scope here — the existing ATen pipeline in examples/models/gemma4 is more appropriate for those. ### Test plan `test.py` builds a 10-layer synthetic Gemma 4 (4 sliding + 1 full × 2) and runs the full export pipeline, asserting the resulting .pte exists. $ python -m pytest examples/apple/coreml/gemma4/test.py -v test.py::TestGemma4CoreMLExport::test_eager_forward_runs PASSED test.py::TestGemma4CoreMLExport::test_full_export_pipeline_lowers_to_coreml PASSED ============================== 2 passed in 15.32s ============================== I also ran the export by hand against the synthetic config and confirmed the lowered edge program contains only `executorch_call_delegate` and `getitem` at the top level. Authored with Claude.
1 parent 94d2881 commit 4efa007

4 files changed

Lines changed: 493 additions & 0 deletions

File tree

examples/apple/coreml/gemma4/BUCK

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target")
2+
# Any targets that should be shared between fbcode and xplat must be defined in
3+
# targets.bzl. This file can contain fbcode-only targets.
4+
5+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
6+
7+
fbcode_target(_kind = runtime.python_binary,
8+
name = "export_gemma4_text_decoder_coreml",
9+
srcs = [
10+
"export_gemma4_text_decoder_coreml.py",
11+
],
12+
main_module = "executorch.examples.apple.coreml.gemma4.export_gemma4_text_decoder_coreml",
13+
_is_external_target = True,
14+
base_module = "executorch.examples.apple.coreml.gemma4",
15+
visibility = ["PUBLIC"],
16+
deps = [
17+
"//caffe2:torch",
18+
"//executorch/backends/apple/coreml:backend",
19+
"//executorch/backends/apple/coreml:partitioner",
20+
"//executorch/examples/models/gemma4:text_decoder",
21+
"//executorch/exir:lib",
22+
"//executorch/extension/export_util:export_util",
23+
],
24+
)
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Export Gemma 4 text decoder to a CoreML-delegated ExecuTorch program.
8+
9+
Gemma 4's hybrid sliding/full attention is structurally compatible with
10+
CoreML's MLProgram backend: the existing Gemma4TextModel implementation
11+
in ``examples/models/gemma4/text_decoder/`` lowers cleanly through
12+
``torch.export`` and ``CoreMLPartitioner``. This script wraps that
13+
pipeline with the CoreML-specific defaults (iOS18+ for stateful KV
14+
caches, fp16, MQA-friendly mutable-buffer handling) so users do not
15+
have to reassemble it themselves.
16+
17+
Usage::
18+
19+
# From a HuggingFace checkpoint directory:
20+
python export_gemma4_text_decoder_coreml.py \\
21+
--checkpoint_path /path/to/gemma4-e2b-it \\
22+
--output gemma4_text_decoder.pte
23+
24+
# From a JSON config alone (random weights, smoke-test mode):
25+
python export_gemma4_text_decoder_coreml.py \\
26+
--config_json /path/to/config.json --random_weights \\
27+
--max_seq_len 1024 --output gemma4_synthetic.pte
28+
29+
The audio / vision encoders shipped with Gemma 4 are not part of this
30+
export — for those the existing ``examples/models/gemma4`` ATen pipeline
31+
is more appropriate.
32+
"""
33+
34+
import argparse
35+
import json
36+
import logging
37+
import os
38+
from typing import Optional, Tuple
39+
40+
import coremltools as ct
41+
import torch
42+
43+
import executorch.exir
44+
from executorch.backends.apple.coreml.compiler import CoreMLBackend
45+
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
46+
from executorch.examples.models.gemma4.text_decoder.gemma4_config import Gemma4Config
47+
from executorch.examples.models.gemma4.text_decoder.gemma4_transformer import (
48+
Gemma4TextModel,
49+
)
50+
from executorch.exir import EdgeCompileConfig
51+
from executorch.exir.capture._config import ExecutorchBackendConfig
52+
from executorch.extension.export_util.utils import save_pte_program
53+
54+
55+
logger = logging.getLogger(__name__)
56+
logger.setLevel(logging.INFO)
57+
58+
59+
def _load_config(
60+
checkpoint_path: Optional[str],
61+
config_json: Optional[str],
62+
max_seq_len: int,
63+
sliding_window: Optional[int],
64+
sliding_window_pattern: Optional[int],
65+
) -> Gemma4Config:
66+
"""Build a Gemma4Config from a checkpoint dir, a JSON file, or defaults."""
67+
if checkpoint_path is not None:
68+
config = Gemma4Config.from_json(os.path.join(checkpoint_path, "config.json"))
69+
elif config_json is not None:
70+
config = Gemma4Config.from_json(config_json)
71+
else:
72+
config = Gemma4Config()
73+
74+
config.max_seq_len = max_seq_len
75+
config.max_context_len = max_seq_len
76+
if sliding_window is not None:
77+
config.sliding_window = sliding_window
78+
if sliding_window_pattern is not None:
79+
config.sliding_window_pattern = sliding_window_pattern
80+
return config
81+
82+
83+
def _load_weights(
84+
model: Gemma4TextModel,
85+
config: Gemma4Config,
86+
checkpoint_path: str,
87+
dtype: torch.dtype,
88+
) -> None:
89+
"""Load Gemma 4 text-decoder weights from a HuggingFace checkpoint dir.
90+
91+
Reuses the same convert_weights flow that examples/models/gemma4 uses
92+
so the loaded model exactly matches what ``examples/models/gemma4``
93+
would produce on the ATen path.
94+
"""
95+
from executorch.examples.models.gemma4.text_decoder.convert_weights import (
96+
convert_hf_to_custom,
97+
)
98+
99+
state_dict = convert_hf_to_custom(checkpoint_path, config, dtype=dtype)
100+
missing, unexpected = model.load_state_dict(state_dict, strict=False)
101+
if missing:
102+
logger.warning(
103+
"Missing %d keys when loading weights (first 5: %s)",
104+
len(missing),
105+
missing[:5],
106+
)
107+
if unexpected:
108+
logger.warning(
109+
"Unexpected %d keys (first 5: %s)", len(unexpected), unexpected[:5]
110+
)
111+
112+
113+
def build_model(
114+
config: Gemma4Config,
115+
checkpoint_path: Optional[str],
116+
dtype: torch.dtype,
117+
) -> Gemma4TextModel:
118+
model = Gemma4TextModel(config).eval()
119+
if checkpoint_path is not None:
120+
_load_weights(model, config, checkpoint_path, dtype)
121+
return model.to(dtype)
122+
123+
124+
def _example_inputs(input_len: int) -> Tuple[torch.Tensor, ...]:
125+
"""Inputs for prefill: a single batch with `input_len` placeholder tokens."""
126+
return (torch.zeros(1, input_len, dtype=torch.long),)
127+
128+
129+
def export(
130+
model: Gemma4TextModel,
131+
input_len: int,
132+
minimum_deployment_target: ct.target,
133+
compute_precision: ct.precision,
134+
output_path: str,
135+
) -> None:
136+
"""Run the Gemma 4 text-decoder model through to_edge_transform_and_lower."""
137+
example_inputs = _example_inputs(input_len)
138+
139+
logger.info("Eager smoke-test (input_len=%d)...", input_len)
140+
with torch.no_grad():
141+
model(*example_inputs)
142+
143+
logger.info("torch.export...")
144+
ep = torch.export.export(model, example_inputs, strict=False)
145+
logger.info(
146+
" exported program: %d nodes",
147+
sum(1 for _ in ep.graph_module.graph.nodes),
148+
)
149+
150+
compile_specs = CoreMLBackend.generate_compile_specs(
151+
minimum_deployment_target=minimum_deployment_target,
152+
compute_precision=compute_precision,
153+
compute_unit=ct.ComputeUnit.CPU_AND_NE,
154+
model_type=CoreMLBackend.MODEL_TYPE.MODEL,
155+
)
156+
partitioner = CoreMLPartitioner(
157+
compile_specs=compile_specs,
158+
# Gemma 4's text decoder owns its KV caches as torch buffers; let
159+
# CoreML take them over as iOS18+ stateful tensors.
160+
take_over_mutable_buffer=True,
161+
)
162+
163+
logger.info("to_edge_transform_and_lower with CoreMLPartitioner...")
164+
edge = executorch.exir.to_edge_transform_and_lower(
165+
ep,
166+
partitioner=[partitioner],
167+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
168+
)
169+
170+
fully_delegated = all(
171+
node.op != "call_function"
172+
or node.target.__name__ in ("executorch_call_delegate", "getitem")
173+
for node in edge.exported_program().graph.nodes
174+
)
175+
if fully_delegated:
176+
logger.info(" fully delegated: every call_function is a CoreML call.")
177+
else:
178+
leftover = sorted(
179+
{
180+
node.target.__name__
181+
for node in edge.exported_program().graph.nodes
182+
if node.op == "call_function"
183+
and node.target.__name__
184+
not in ("executorch_call_delegate", "getitem")
185+
}
186+
)
187+
logger.warning(
188+
" %d op type(s) fell back to portable: %s",
189+
len(leftover),
190+
leftover,
191+
)
192+
193+
logger.info("to_executorch...")
194+
program = edge.to_executorch(
195+
ExecutorchBackendConfig(extract_delegate_segments=True)
196+
)
197+
save_pte_program(program, output_path)
198+
logger.info("Saved %s (%.2f MB)", output_path, os.path.getsize(output_path) / 1e6)
199+
200+
201+
def main() -> int:
202+
logging.basicConfig(level=logging.INFO, format="%(message)s")
203+
204+
parser = argparse.ArgumentParser(description=__doc__.splitlines()[0])
205+
parser.add_argument(
206+
"--checkpoint_path",
207+
type=str,
208+
default=None,
209+
help="Path to a HuggingFace Gemma 4 checkpoint directory.",
210+
)
211+
parser.add_argument(
212+
"--config_json",
213+
type=str,
214+
default=None,
215+
help="Path to a Gemma 4 config.json (used if --checkpoint_path is omitted).",
216+
)
217+
parser.add_argument(
218+
"--random_weights",
219+
action="store_true",
220+
help="Skip checkpoint loading; use random weights (smoke-test only).",
221+
)
222+
parser.add_argument(
223+
"--output",
224+
type=str,
225+
default="gemma4_text_decoder.pte",
226+
help="Output .pte path.",
227+
)
228+
parser.add_argument("--max_seq_len", type=int, default=2048)
229+
parser.add_argument(
230+
"--input_len",
231+
type=int,
232+
default=64,
233+
help="Prefill sequence length used to build example inputs for export.",
234+
)
235+
parser.add_argument(
236+
"--sliding_window",
237+
type=int,
238+
default=None,
239+
help="Override the model's sliding window (default: from config).",
240+
)
241+
parser.add_argument(
242+
"--sliding_window_pattern",
243+
type=int,
244+
default=None,
245+
help="Override the sliding/full attention pattern (default: from config).",
246+
)
247+
parser.add_argument("--dtype", choices=["fp16", "fp32"], default="fp16")
248+
parser.add_argument(
249+
"--minimum_deployment_target",
250+
type=str,
251+
default="iOS18",
252+
choices=["iOS17", "iOS18", "iOS26"],
253+
help="Minimum CoreML deployment target. Stateful KV caches require iOS18+.",
254+
)
255+
args = parser.parse_args()
256+
257+
if args.random_weights and (args.checkpoint_path or args.config_json):
258+
# Allow --random_weights with --config_json (for synthetic export); the
259+
# combination with --checkpoint_path would be confusing because the
260+
# checkpoint's config would be loaded but its weights ignored.
261+
if args.checkpoint_path:
262+
parser.error("--random_weights conflicts with --checkpoint_path")
263+
if not args.random_weights and not args.checkpoint_path:
264+
parser.error("either --checkpoint_path or --random_weights is required")
265+
266+
config = _load_config(
267+
checkpoint_path=args.checkpoint_path if not args.random_weights else None,
268+
config_json=args.config_json,
269+
max_seq_len=args.max_seq_len,
270+
sliding_window=args.sliding_window,
271+
sliding_window_pattern=args.sliding_window_pattern,
272+
)
273+
274+
dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype]
275+
target = {
276+
"iOS17": ct.target.iOS17,
277+
"iOS18": ct.target.iOS18,
278+
"iOS26": ct.target.iOS26,
279+
}[args.minimum_deployment_target]
280+
precision = {torch.float16: ct.precision.FLOAT16, torch.float32: ct.precision.FLOAT32}[dtype]
281+
282+
logger.info("Gemma 4 text decoder export -> CoreML")
283+
logger.info(" dtype=%s target=%s", args.dtype, args.minimum_deployment_target)
284+
logger.info(
285+
" layers=%d hidden=%d kv_heads=%d sliding_window=%d pattern=%d",
286+
config.num_hidden_layers,
287+
config.hidden_size,
288+
config.num_key_value_heads,
289+
config.sliding_window,
290+
config.sliding_window_pattern,
291+
)
292+
293+
model = build_model(
294+
config,
295+
checkpoint_path=args.checkpoint_path if not args.random_weights else None,
296+
dtype=dtype,
297+
)
298+
299+
export(
300+
model,
301+
input_len=args.input_len,
302+
minimum_deployment_target=target,
303+
compute_precision=precision,
304+
output_path=args.output,
305+
)
306+
return 0
307+
308+
309+
if __name__ == "__main__":
310+
raise SystemExit(main())
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Gemma 4 text decoder on CoreML
2+
3+
This directory exports the Gemma 4 text decoder shipped with
4+
`examples/models/gemma4` to a CoreML-delegated ExecuTorch program.
5+
6+
Gemma 4's hybrid sliding/full attention, partial RoPE, per-layer
7+
head_dim, MQA, and YOCO KV sharing are all expressed in plain PyTorch
8+
in the upstream `examples/models/gemma4/text_decoder/` package, and that
9+
implementation lowers cleanly through `torch.export` and
10+
`CoreMLPartitioner` — every call is a single `executorch_call_delegate`
11+
in the resulting `.pte`. This script assembles the small amount of
12+
glue (CoreML compile specs, iOS18+ deployment target for stateful KV
13+
caches, fp16 conversion) needed to run that lowering with sensible
14+
defaults for on-device deployment.
15+
16+
The audio and vision encoders are intentionally **not** exported here;
17+
the existing ATen pipeline in `examples/models/gemma4` is more
18+
appropriate for those.
19+
20+
## Usage
21+
22+
### From a HuggingFace checkpoint
23+
24+
```
25+
python export_gemma4_text_decoder_coreml.py \
26+
--checkpoint_path /path/to/gemma4-e2b-it \
27+
--output gemma4_text_decoder.pte
28+
```
29+
30+
### Synthetic config (smoke test, no weights)
31+
32+
```
33+
python export_gemma4_text_decoder_coreml.py \
34+
--random_weights \
35+
--max_seq_len 1024 \
36+
--output /tmp/gemma4_synthetic.pte
37+
```
38+
39+
## Options
40+
41+
| Option | Default | Description |
42+
|---|---|---|
43+
| `--checkpoint_path` | (required if no `--random_weights`) | HuggingFace Gemma 4 checkpoint dir |
44+
| `--config_json` | (off) | Use this `config.json` instead of the checkpoint's |
45+
| `--random_weights` | (off) | Skip weight loading; smoke-test only |
46+
| `--max_seq_len` | 2048 | Maximum context length |
47+
| `--input_len` | 64 | Prefill seqlen used for example inputs |
48+
| `--sliding_window` | (from config) | Override sliding-attention window |
49+
| `--sliding_window_pattern` | (from config) | Override hybrid pattern (P=5 for Gemma 4 E2B) |
50+
| `--dtype` | `fp16` | `fp16` or `fp32`. ANE requires fp16. |
51+
| `--minimum_deployment_target` | `iOS18` | iOS17 / iOS18 / iOS26. Stateful KV caches need iOS18+. |
52+
53+
## Tests
54+
55+
`test.py` builds a 10-layer synthetic Gemma 4 model (4 sliding + 1 full
56+
× 2) and runs the full export pipeline, asserting that the resulting
57+
`.pte` exists and is non-empty:
58+
59+
```
60+
$ python -m pytest examples/apple/coreml/gemma4/test.py -v
61+
============================== 2 passed in 15.32s ==============================
62+
```

0 commit comments

Comments
 (0)