Skip to content

Commit 065b50e

Browse files
committed
Clean up Gemma 4 MLX path for macOS 15 CI
1 parent 719d2e8 commit 065b50e

4 files changed

Lines changed: 33 additions & 144 deletions

File tree

.github/workflows/mlx.yml

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -489,17 +489,25 @@ jobs:
489489
name: "gemma3-1b"
490490
use-custom: [false, true]
491491
qconfig: ["4w", "nvfp4"]
492+
runner: ["macos-14-xlarge"]
492493
include:
493494
- model:
494495
id: "google/gemma-4-E2B-it"
495496
name: "gemma4-e2b"
496497
use-custom: true
497498
qconfig: "4w"
499+
runner: "macos-15-xlarge"
500+
- model:
501+
id: "google/gemma-4-E2B-it"
502+
name: "gemma4-e2b"
503+
use-custom: false
504+
qconfig: "4w"
505+
runner: "macos-15-xlarge"
498506
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
499507
secrets: inherit
500508
with:
501509
job-name: test-mlx-llm-${{ matrix.model.name }}${{ matrix.use-custom && '-custom' || '' }}-${{ matrix.qconfig }}
502-
runner: macos-14-xlarge
510+
runner: ${{ matrix.runner }}
503511
python-version: "3.12"
504512
submodules: recursive
505513
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
@@ -512,11 +520,6 @@ jobs:
512520
MODEL_NAME="${{ matrix.model.name }}"
513521
USE_CUSTOM="${{ matrix.use-custom }}"
514522
QCONFIG="${{ matrix.qconfig }}"
515-
MODEL_REVISION=""
516-
if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then
517-
MODEL_REVISION="b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf"
518-
fi
519-
520523
CUSTOM_ARGS=""
521524
if [ "${USE_CUSTOM}" = "true" ]; then
522525
CUSTOM_ARGS="--use-custom-sdpa --use-custom-kv-cache"
@@ -551,7 +554,6 @@ jobs:
551554
echo "::group::Export ${MODEL_NAME}"
552555
${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.export_llm_hf \
553556
--model-id "${MODEL_ID}" \
554-
${MODEL_REVISION:+--revision "${MODEL_REVISION}"} \
555557
--output /tmp/${MODEL_NAME}.pte \
556558
--qlinear ${QCONFIG} \
557559
${QEMBEDDING_ARGS} \
@@ -562,7 +564,6 @@ jobs:
562564
OUTPUT=$(${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.run_llm_hf \
563565
--pte /tmp/${MODEL_NAME}.pte \
564566
--model-id "${MODEL_ID}" \
565-
${MODEL_REVISION:+--revision "${MODEL_REVISION}"} \
566567
--prompt "What is the capital of France?" \
567568
--max-new-tokens 50 2>&1)
568569
echo "$OUTPUT"

backends/mlx/builder/program_builder.py

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -444,50 +444,26 @@ def _make_io_slots(self): # noqa: C901
444444
else:
445445
raise NotImplementedError(f"Support for input {arg} is not implemented")
446446

447-
placeholder_nodes = {
448-
node.name: node for node in self.ep.graph.nodes if node.op == "placeholder"
449-
}
450-
451-
# Allocate placeholder-backed slots in graph-signature order instead of
452-
# raw FX node traversal order. This keeps lifted constant tids stable
453-
# across equivalent exports, which matters for models like Gemma 4 that
454-
# carry multiple rotary constant placeholders with similar structure.
455-
for name in constant_tensors:
456-
node = placeholder_nodes.get(name)
457-
if node is None or node.users == {}:
458-
continue
459-
self.make_or_get_slot(node, id_space=IdSpace.Constant)
460-
461-
for name in user_inputs:
462-
node = placeholder_nodes.get(name)
463-
if node is None or node.users == {}:
464-
continue
465-
val = node.meta.get("val", None)
466-
if isinstance(val, torch.Tensor) and not val.is_contiguous():
467-
raise ValueError(
468-
f"MLX backend requires contiguous input tensors, "
469-
f"but input '{node.name}' has non-contiguous strides. "
470-
f"shape={list(val.shape)}, stride={list(val.stride())}. "
471-
f"Ensure example inputs passed to torch.export.export() "
472-
f"are contiguous (call .contiguous() on them)."
473-
)
474-
self.make_or_get_slot(node, id_space=IdSpace.Input)
475-
476-
for name in mutable_buffers:
477-
node = placeholder_nodes.get(name)
478-
if node is None or node.users == {}:
479-
continue
480-
self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer)
481-
482-
classified_placeholders = (
483-
set(constant_tensors) | set(user_inputs) | set(mutable_buffers)
484-
)
485-
486447
for node in self.ep.graph.nodes:
487448
if node.op == "placeholder":
488449
if node.users == {}:
489450
continue
490-
if node.name not in classified_placeholders:
451+
if node.name in constant_tensors:
452+
self.make_or_get_slot(node, id_space=IdSpace.Constant)
453+
elif node.name in user_inputs:
454+
val = node.meta.get("val", None)
455+
if isinstance(val, torch.Tensor) and not val.is_contiguous():
456+
raise ValueError(
457+
f"MLX backend requires contiguous input tensors, "
458+
f"but input '{node.name}' has non-contiguous strides. "
459+
f"shape={list(val.shape)}, stride={list(val.stride())}. "
460+
f"Ensure example inputs passed to torch.export.export() "
461+
f"are contiguous (call .contiguous() on them)."
462+
)
463+
self.make_or_get_slot(node, id_space=IdSpace.Input)
464+
elif node.name in mutable_buffers:
465+
self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer)
466+
else:
491467
raise NotImplementedError(
492468
f"Support for placeholder {node.name} is not implemented"
493469
)

backends/mlx/examples/llm/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ python -m executorch.backends.mlx.examples.llm.export_llm_hf \
5757
# Gemma 4 text-only export
5858
python -m executorch.backends.mlx.examples.llm.export_llm_hf \
5959
--model-id "google/gemma-4-E2B-it" \
60-
--revision "b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf" \
6160
--output gemma4_hf_int4.pte \
6261
--use-custom-sdpa \
6362
--use-custom-kv-cache \
@@ -109,7 +108,6 @@ Validated Gemma 4 run command:
109108
python -m executorch.backends.mlx.examples.llm.run_llm_hf \
110109
--pte gemma4_hf_int4.pte \
111110
--model-id google/gemma-4-E2B-it \
112-
--revision b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf \
113111
--prompt "What is the capital of France?" \
114112
--max-new-tokens 50
115113
```

backends/mlx/examples/llm/export_llm_hf.py

Lines changed: 8 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -47,53 +47,6 @@
4747
logging.basicConfig(level=logging.INFO, format=FORMAT)
4848
logger = logging.getLogger(__name__)
4949

50-
_GEMMA4_MODEL_ID = "google/gemma-4-E2B-it"
51-
_GEMMA4_PROBLEM_LAYER_FQN = "model.language_model.layers.22.mlp.down_proj"
52-
53-
54-
def _get_submodule_by_fqn(root: torch.nn.Module, fqn: str) -> torch.nn.Module:
55-
cur = root
56-
for part in fqn.split("."):
57-
if part.isdigit():
58-
cur = cur[int(part)] # type: ignore[index]
59-
else:
60-
cur = getattr(cur, part)
61-
return cur
62-
63-
64-
def _capture_gemma4_float_fallback_weight(
65-
model_id: str,
66-
qlinear: Optional[str],
67-
model: torch.nn.Module,
68-
) -> Optional[torch.Tensor]:
69-
if model_id != _GEMMA4_MODEL_ID or qlinear != "4w":
70-
return None
71-
72-
layer = _get_submodule_by_fqn(model, _GEMMA4_PROBLEM_LAYER_FQN)
73-
weight = layer.weight.detach().clone()
74-
logger.info(
75-
"Saving %s in floating point to avoid the current Gemma 4 4w mismatch",
76-
_GEMMA4_PROBLEM_LAYER_FQN,
77-
)
78-
return weight
79-
80-
81-
def _restore_gemma4_float_fallback_weight(
82-
model_id: str,
83-
qlinear: Optional[str],
84-
model: torch.nn.Module,
85-
weight: Optional[torch.Tensor],
86-
) -> None:
87-
if weight is None or model_id != _GEMMA4_MODEL_ID or qlinear != "4w":
88-
return
89-
90-
layer = _get_submodule_by_fqn(model, _GEMMA4_PROBLEM_LAYER_FQN)
91-
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
92-
logger.info(
93-
"Restored %s in floating point after quantization",
94-
_GEMMA4_PROBLEM_LAYER_FQN,
95-
)
96-
9750

9851
def _export_with_optimum(
9952
model_id: str,
@@ -128,10 +81,6 @@ def _export_with_optimum(
12881

12982
from executorch.backends.mlx.llm.quantization import quantize_model_
13083

131-
gemma4_float_weight = _capture_gemma4_float_fallback_weight(
132-
model_id, qlinear, exportable.model
133-
)
134-
13584
quantize_model_(
13685
exportable.model,
13786
qlinear_config=qlinear,
@@ -143,9 +92,6 @@ def _export_with_optimum(
14392
)
14493
and not no_tie_word_embeddings,
14594
)
146-
_restore_gemma4_float_fallback_weight(
147-
model_id, qlinear, exportable.model, gemma4_float_weight
148-
)
14995

15096
logger.info("Exporting model with torch.export...")
15197
exported_progs = exportable.export()
@@ -215,24 +161,13 @@ def _export_with_custom_components(
215161
}
216162
torch_dtype = torch_dtype_map.get(dtype, torch.bfloat16)
217163

218-
effective_use_custom_sdpa = use_custom_sdpa
219-
effective_use_custom_kv_cache = use_custom_kv_cache
220-
if model_id == _GEMMA4_MODEL_ID and use_custom_sdpa:
221-
logger.info(
222-
"Disabling custom SDPA for Gemma 4 while keeping the custom cache path"
223-
)
224-
effective_use_custom_sdpa = False
225-
if model_id == _GEMMA4_MODEL_ID and use_custom_kv_cache:
226-
logger.info("Disabling custom KV cache for Gemma 4")
227-
effective_use_custom_kv_cache = False
228-
229-
if effective_use_custom_sdpa:
164+
if use_custom_sdpa:
230165
from executorch.backends.mlx.llm.hf_attention import register_mlx_attention
231166

232167
register_mlx_attention()
233168
logger.info("Registered MLX custom SDPA attention")
234169

235-
attn_implementation = "mlx" if effective_use_custom_sdpa else None
170+
attn_implementation = "mlx" if use_custom_sdpa else None
236171

237172
logger.info(f"Loading HuggingFace model: {model_id}")
238173
load_kwargs = {
@@ -292,7 +227,7 @@ def _export_with_custom_components(
292227
max_cache_len=effective_cache_len,
293228
)
294229

295-
if effective_use_custom_kv_cache:
230+
if use_custom_kv_cache:
296231
from executorch.backends.mlx.llm.source_transformation import (
297232
replace_hf_cache_with_mlx,
298233
)
@@ -316,10 +251,6 @@ def _export_with_custom_components(
316251

317252
from executorch.backends.mlx.llm.quantization import quantize_model_
318253

319-
gemma4_float_weight = _capture_gemma4_float_fallback_weight(
320-
model_id, qlinear, exportable.model
321-
)
322-
323254
quantize_model_(
324255
exportable.model,
325256
qlinear_config=qlinear,
@@ -329,9 +260,6 @@ def _export_with_custom_components(
329260
tie_word_embeddings=getattr(model.config, "tie_word_embeddings", False)
330261
and not no_tie_word_embeddings,
331262
)
332-
_restore_gemma4_float_fallback_weight(
333-
model_id, qlinear, exportable.model, gemma4_float_weight
334-
)
335263

336264
logger.info("Exporting model with torch.export...")
337265
seq_length = 3
@@ -421,24 +349,10 @@ def export_llama_hf(
421349
use_custom_sdpa: Use MLX custom SDPA (mlx::custom_sdpa)
422350
use_custom_kv_cache: Use MLX custom KV cache (mlx::kv_cache_update)
423351
"""
424-
effective_use_custom_sdpa = use_custom_sdpa
425-
effective_use_custom_kv_cache = use_custom_kv_cache
426-
if model_id == _GEMMA4_MODEL_ID:
427-
if effective_use_custom_sdpa:
428-
logger.info(
429-
"Disabling custom SDPA for Gemma 4 and falling back to the baseline export path"
430-
)
431-
effective_use_custom_sdpa = False
432-
if effective_use_custom_kv_cache:
433-
logger.info(
434-
"Disabling custom KV cache for Gemma 4 and falling back to the baseline export path"
435-
)
436-
effective_use_custom_kv_cache = False
437-
438-
if effective_use_custom_sdpa or effective_use_custom_kv_cache:
352+
if use_custom_sdpa or use_custom_kv_cache:
439353
logger.info(
440-
f"Using custom components: sdpa={effective_use_custom_sdpa}, "
441-
f"kv_cache={effective_use_custom_kv_cache}"
354+
f"Using custom components: sdpa={use_custom_sdpa}, "
355+
f"kv_cache={use_custom_kv_cache}"
442356
)
443357
_export_with_custom_components(
444358
model_id=model_id,
@@ -448,8 +362,8 @@ def export_llama_hf(
448362
dtype=dtype,
449363
qlinear=qlinear,
450364
qembedding=qembedding,
451-
use_custom_sdpa=effective_use_custom_sdpa,
452-
use_custom_kv_cache=effective_use_custom_kv_cache,
365+
use_custom_sdpa=use_custom_sdpa,
366+
use_custom_kv_cache=use_custom_kv_cache,
453367
no_tie_word_embeddings=no_tie_word_embeddings,
454368
qlinear_group_size=qlinear_group_size,
455369
qembedding_group_size=qembedding_group_size,

0 commit comments

Comments
 (0)