Skip to content

Commit ca37250

Browse files
committed
Pin Gemma 4 MLX flow to validated model revision
1 parent ee272c3 commit ca37250

4 files changed

Lines changed: 35 additions & 4 deletions

File tree

.github/workflows/mlx.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,10 @@ jobs:
512512
MODEL_NAME="${{ matrix.model.name }}"
513513
USE_CUSTOM="${{ matrix.use-custom }}"
514514
QCONFIG="${{ matrix.qconfig }}"
515+
MODEL_REVISION=""
516+
if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then
517+
MODEL_REVISION="b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf"
518+
fi
515519
516520
CUSTOM_ARGS=""
517521
if [ "${USE_CUSTOM}" = "true" ]; then
@@ -547,6 +551,7 @@ jobs:
547551
echo "::group::Export ${MODEL_NAME}"
548552
${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.export_llm_hf \
549553
--model-id "${MODEL_ID}" \
554+
${MODEL_REVISION:+--revision "${MODEL_REVISION}"} \
550555
--output /tmp/${MODEL_NAME}.pte \
551556
--qlinear ${QCONFIG} \
552557
${QEMBEDDING_ARGS} \
@@ -557,6 +562,7 @@ jobs:
557562
OUTPUT=$(${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.run_llm_hf \
558563
--pte /tmp/${MODEL_NAME}.pte \
559564
--model-id "${MODEL_ID}" \
565+
${MODEL_REVISION:+--revision "${MODEL_REVISION}"} \
560566
--prompt "What is the capital of France?" \
561567
--max-new-tokens 50 2>&1)
562568
echo "$OUTPUT"

backends/mlx/examples/llm/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ python -m executorch.backends.mlx.examples.llm.export_llm_hf \
5757
# Gemma 4 text-only export
5858
python -m executorch.backends.mlx.examples.llm.export_llm_hf \
5959
--model-id "google/gemma-4-E2B-it" \
60+
--revision "b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf" \
6061
--output gemma4_hf_int4.pte \
6162
--use-custom-sdpa \
6263
--use-custom-kv-cache \
@@ -108,6 +109,7 @@ Validated Gemma 4 run command:
108109
python -m executorch.backends.mlx.examples.llm.run_llm_hf \
109110
--pte gemma4_hf_int4.pte \
110111
--model-id google/gemma-4-E2B-it \
112+
--revision b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf \
111113
--prompt "What is the capital of France?" \
112114
--max-new-tokens 50
113115
```

backends/mlx/examples/llm/export_llm_hf.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
def _export_with_optimum(
5252
model_id: str,
53+
revision: Optional[str],
5354
output_path: str,
5455
max_seq_len: int,
5556
dtype: str,
@@ -73,6 +74,7 @@ def _export_with_optimum(
7374
logger.info(f"Loading model using optimum-executorch: {model_id}")
7475
exportable = load_causal_lm_model(
7576
model_id,
77+
revision=revision,
7678
dtype=dtype_str,
7779
max_seq_len=max_seq_len,
7880
)
@@ -124,6 +126,7 @@ def _export_with_optimum(
124126

125127
def _export_with_custom_components(
126128
model_id: str,
129+
revision: Optional[str],
127130
output_path: str,
128131
max_seq_len: int,
129132
dtype: str,
@@ -171,6 +174,8 @@ def _export_with_custom_components(
171174
"torch_dtype": torch_dtype,
172175
"low_cpu_mem_usage": True,
173176
}
177+
if revision is not None:
178+
load_kwargs["revision"] = revision
174179
if attn_implementation:
175180
load_kwargs["attn_implementation"] = attn_implementation
176181
model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs)
@@ -345,6 +350,7 @@ def _save_program(executorch_program, output_path: str) -> None:
345350

346351
def export_llama_hf(
347352
model_id: str,
353+
revision: Optional[str],
348354
output_path: str,
349355
max_seq_len: int = 1024,
350356
dtype: str = "bf16",
@@ -376,6 +382,7 @@ def export_llama_hf(
376382
)
377383
_export_with_custom_components(
378384
model_id=model_id,
385+
revision=revision,
379386
output_path=output_path,
380387
max_seq_len=max_seq_len,
381388
dtype=dtype,
@@ -391,6 +398,7 @@ def export_llama_hf(
391398
logger.info("Using optimum-executorch pipeline (no custom components)")
392399
_export_with_optimum(
393400
model_id=model_id,
401+
revision=revision,
394402
output_path=output_path,
395403
max_seq_len=max_seq_len,
396404
dtype=dtype,
@@ -412,6 +420,12 @@ def main():
412420
default="unsloth/Llama-3.2-1B-Instruct",
413421
help="HuggingFace model ID",
414422
)
423+
parser.add_argument(
424+
"--revision",
425+
type=str,
426+
default=None,
427+
help="Optional HuggingFace model revision/commit to pin",
428+
)
415429
parser.add_argument(
416430
"--output",
417431
type=str,
@@ -451,6 +465,7 @@ def main():
451465

452466
export_llama_hf(
453467
model_id=args.model_id,
468+
revision=args.revision,
454469
output_path=args.output,
455470
max_seq_len=args.max_seq_len,
456471
dtype=args.dtype,

backends/mlx/examples/llm/run_llm_hf.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _get_max_input_seq_len(program) -> int:
4747
return sizes[1] if len(sizes) >= 2 else 1
4848

4949

50-
def _load_text_processor(model_id: str):
50+
def _load_text_processor(model_id: str, revision: str | None):
5151
"""
5252
Load a text processor for the model.
5353
@@ -58,13 +58,13 @@ def _load_text_processor(model_id: str):
5858
"""
5959
logger.info(f"Loading tokenizer from HuggingFace: {model_id}...")
6060
try:
61-
tokenizer = AutoTokenizer.from_pretrained(model_id)
61+
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
6262
return tokenizer, False
6363
except Exception as exc:
6464
logger.info(f"AutoTokenizer unavailable for {model_id}: {exc}")
6565

6666
try:
67-
processor = AutoProcessor.from_pretrained(model_id)
67+
processor = AutoProcessor.from_pretrained(model_id, revision=revision)
6868
if hasattr(processor, "apply_chat_template") and hasattr(processor, "decode"):
6969
logger.info(f"Loaded processor from HuggingFace: {model_id}")
7070
return processor, True
@@ -101,11 +101,12 @@ def _get_eos_token_id(text_processor):
101101
def run_inference(
102102
pte_path: str,
103103
model_id: str,
104+
revision: str | None,
104105
prompt: str,
105106
max_new_tokens: int = 50,
106107
) -> str:
107108
"""Run inference on the exported HuggingFace model."""
108-
text_processor, uses_processor = _load_text_processor(model_id)
109+
text_processor, uses_processor = _load_text_processor(model_id, revision)
109110

110111
logger.info(f"Loading model from {pte_path}...")
111112
et_runtime = Runtime.get()
@@ -208,6 +209,12 @@ def main():
208209
default="unsloth/Llama-3.2-1B-Instruct",
209210
help="HuggingFace model ID (used to load tokenizer or processor)",
210211
)
212+
parser.add_argument(
213+
"--revision",
214+
type=str,
215+
default=None,
216+
help="Optional HuggingFace model revision/commit to pin",
217+
)
211218
parser.add_argument(
212219
"--prompt",
213220
type=str,
@@ -226,6 +233,7 @@ def main():
226233
generated_text = run_inference(
227234
pte_path=args.pte,
228235
model_id=args.model_id,
236+
revision=args.revision,
229237
prompt=args.prompt,
230238
max_new_tokens=args.max_new_tokens,
231239
)

0 commit comments

Comments
 (0)