Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .github/workflows/mlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,12 @@ jobs:
name: "gemma3-1b"
use-custom: [false, true]
qconfig: ["4w", "nvfp4"]
include:
- model:
id: "google/gemma-4-E2B-it"
name: "gemma4-e2b"
use-custom: true
qconfig: "4w"
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
secrets: inherit
with:
Expand All @@ -493,6 +499,11 @@ jobs:
CUSTOM_ARGS="--use-custom-sdpa --use-custom-kv-cache"
fi

QEMBEDDING_ARGS="--qembedding ${QCONFIG}"
if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then
QEMBEDDING_ARGS=""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why no embeeding?

fi

echo "::group::Install ExecuTorch and configure MLX build"
${CONDA_RUN} python install_executorch.py > /dev/null
${CONDA_RUN} cmake --preset mlx-release
Expand All @@ -512,7 +523,7 @@ jobs:
--model-id "${MODEL_ID}" \
--output /tmp/${MODEL_NAME}.pte \
--qlinear ${QCONFIG} \
--qembedding ${QCONFIG} \
${QEMBEDDING_ARGS} \
${CUSTOM_ARGS}
echo "::endgroup::"

Expand Down
26 changes: 25 additions & 1 deletion backends/mlx/examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This example demonstrates how to export and run LLMs using the MLX delegate for
- **KV Cache**: Efficient KV cache implementation for autoregressive generation
- **Custom Ops**: Uses `mlx::custom_sdpa` and `mlx::kv_cache_update` for optimal execution on MLX
- **Pybindings**: Run inference using ExecuTorch Python bindings
- **Gemma 4**: Text-only export and run flow supports processor-backed checkpoints such as `google/gemma-4-E2B-it`

## Requirements

Expand Down Expand Up @@ -52,8 +53,19 @@ python -m executorch.backends.mlx.examples.llm.export_llm_hf \
--use-custom-kv-cache \
--qlinear 4w \
--qembedding 4w

# Gemma 4 text-only export
python -m executorch.backends.mlx.examples.llm.export_llm_hf \
--model-id "google/gemma-4-E2B-it" \
--output gemma4_hf_int4.pte \
--use-custom-sdpa \
--use-custom-kv-cache \
--qlinear 4w
```

Gemma 4 support is currently validated for the text-only path using
`--use-custom-sdpa --use-custom-kv-cache --qlinear 4w`.

### Options

| Option | Default | Description |
Expand Down Expand Up @@ -81,12 +93,24 @@ python -m executorch.backends.mlx.examples.llm.run_llm_hf \
--prompt "Explain quantum computing in simple terms"
```

Gemma 4 checkpoints may use `AutoProcessor` instead of `AutoTokenizer`; `run_llm_hf` now supports both paths automatically for text-only prompts.

Validated Gemma 4 run command:

```bash
python -m executorch.backends.mlx.examples.llm.run_llm_hf \
--pte gemma4_hf_int4.pte \
--model-id google/gemma-4-E2B-it \
--prompt "What is the capital of France?" \
--max-new-tokens 50
```

### Options

| Option | Default | Description |
|--------|---------|-------------|
| `--pte` | `llama_hf.pte` | Path to .pte file |
| `--model-id` | `unsloth/Llama-3.2-1B-Instruct` | HuggingFace model ID (for tokenizer) |
| `--model-id` | `unsloth/Llama-3.2-1B-Instruct` | HuggingFace model ID (for tokenizer or processor) |
| `--prompt` | `The quick brown fox` | Input prompt |
| `--max-new-tokens` | `50` | Maximum tokens to generate |

Expand Down
14 changes: 9 additions & 5 deletions backends/mlx/examples/llm/export_llm_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,6 @@ def _export_with_custom_components(

attn_implementation = "mlx" if use_custom_sdpa else None

# Detect sliding window models (e.g., gemma)
sliding_window = None

logger.info(f"Loading HuggingFace model: {model_id}")
load_kwargs = {
"torch_dtype": torch_dtype,
Expand All @@ -178,8 +175,10 @@ def _export_with_custom_components(
load_kwargs["attn_implementation"] = attn_implementation
model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs)

# Check if model uses sliding window attention
sliding_window = getattr(model.config, "sliding_window", None)
# Check if model uses sliding window attention. Multimodal configs like
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this regress gemma3?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t expect this to regress Gemma 3. The change is just switching the sliding-window lookup to model.config.get_text_config(), which also covers the plain text config case and is needed for Gemma 4 where those attrs live under text_config. I scoped the logic to the same attribute lookup, not a Gemma-4-specific branch. I can also rerun a Gemma 3 smoke test and report back.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it would be great to try on gemma3 as a smoke test, that would be great.

If you are unable to access the version from Google, try the unsloth version unsloth/gemma-3-1b-it (https://github.com/pytorch/executorch/blob/main/.github/workflows/mlx.yml#L469C18-L469C39)

# Gemma 4 keep transformer attributes under text_config.
text_config = model.config.get_text_config()
sliding_window = getattr(text_config, "sliding_window", None)
if sliding_window is not None:
logger.info(f"Model has sliding_window={sliding_window}")
# Cap max_seq_len to sliding window size for cache allocation
Expand All @@ -188,11 +187,16 @@ def _export_with_custom_components(
else:
effective_cache_len = max_seq_len

# The HF ExecuTorch cache wrappers validate both generation_config.use_cache
# and the text config's use_cache flag before constructing static caches.
model.generation_config.use_cache = True
model.generation_config.cache_implementation = "static"
model.generation_config.cache_config = {
"batch_size": 1,
"max_cache_len": effective_cache_len,
}
text_config = model.config.get_text_config()
text_config.use_cache = True
model.eval()

# Use HybridCache wrapper for sliding window models (stores cache as .cache),
Expand Down
147 changes: 133 additions & 14 deletions backends/mlx/examples/llm/run_llm_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
# LICENSE file in the root directory of this source tree.

"""
Run exported Llama model (from HuggingFace) using ExecuTorch pybindings.
Run exported HuggingFace LLM using ExecuTorch pybindings.

This script runs models exported using export_llm_hf.py. It loads the tokenizer
directly from HuggingFace using the same model ID used during export.
or processor directly from HuggingFace using the same model ID used during
export.

Usage:
python -m executorch.backends.mlx.examples.llm.run_llm_hf \
Expand All @@ -20,18 +21,89 @@
"""

import argparse
import ctypes
import logging
import os
import shutil
import time
from pathlib import Path

import torch
from executorch.runtime import Runtime, Verification
from transformers import AutoTokenizer
from transformers import AutoProcessor, AutoTokenizer

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)


def _iter_mlx_backend_candidates():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code should not be needed. Did you do:

python install_executorch.py --editable

on a mac machine with xcode installed? If so, in the install logs, did you see a comment about MLX installation being skipped for some reason?

env_path = os.environ.get("ET_MLX_BACKEND_DYLIB")
if env_path:
yield Path(env_path)

for parent in Path(__file__).resolve().parents:
pip_out = parent / "pip-out"
if pip_out.exists():
yield from sorted(
pip_out.glob(
"temp.*/cmake-out/backends/mlx/libmlxdelegate_runtime.dylib"
)
)
yield from sorted(
pip_out.glob("temp.*/cmake-out/backends/mlx/libmlxdelegate.dylib")
)
break


def _ensure_mlx_metallib(dylib_path: Path) -> None:
metallib_path = dylib_path.with_name("mlx.metallib")
if metallib_path.exists():
return

for parent in Path(__file__).resolve().parents:
pip_out = parent / "pip-out"
if not pip_out.exists():
continue
matches = sorted(
pip_out.glob(
"temp.*/cmake-out/backends/mlx/mlx/mlx/backend/metal/kernels/mlx.metallib"
)
)
if not matches:
continue
shutil.copyfile(matches[0], metallib_path)
logger.info(f"Copied MLX metallib next to runtime library: {metallib_path}")
return


def _ensure_mlx_backend_registered() -> Runtime:
runtime = Runtime.get()
if runtime.backend_registry.is_available("MLXBackend"):
return runtime

for candidate in _iter_mlx_backend_candidates():
if not candidate.is_file():
continue
try:
_ensure_mlx_metallib(candidate)
ctypes.CDLL(str(candidate), mode=ctypes.RTLD_GLOBAL)
except OSError as exc:
logger.info(f"Failed to load MLX backend library {candidate}: {exc}")
continue

runtime = Runtime.get()
if runtime.backend_registry.is_available("MLXBackend"):
logger.info(f"Loaded MLX backend runtime library: {candidate}")
return runtime

logger.warning(
"MLXBackend is not registered. If you built mlxdelegate locally, "
"set ET_MLX_BACKEND_DYLIB to the path of libmlxdelegate_runtime.dylib."
)
return runtime


def _get_max_input_seq_len(program) -> int:
"""Inspect the .pte program metadata to determine the max input_ids seq len.

Expand All @@ -46,18 +118,61 @@ def _get_max_input_seq_len(program) -> int:
return sizes[1] if len(sizes) >= 2 else 1


def _load_text_processor(model_id: str):
"""
Load a text processor for the model.

Prefer AutoProcessor for multimodal/text-hybrid models like Gemma 4, and
fall back to AutoTokenizer for text-only checkpoints.
"""
try:
processor = AutoProcessor.from_pretrained(model_id)
if hasattr(processor, "apply_chat_template") and hasattr(processor, "decode"):
logger.info(f"Loaded processor from HuggingFace: {model_id}")
return processor, True
except Exception as exc:
logger.info(f"AutoProcessor unavailable for {model_id}: {exc}")

logger.info(f"Loading tokenizer from HuggingFace: {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
return tokenizer, False


def _apply_chat_template(text_processor, messages) -> str:
try:
return text_processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
except TypeError:
return text_processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)


def _get_eos_token_id(text_processor):
eos_token_id = getattr(text_processor, "eos_token_id", None)
if eos_token_id is not None:
return eos_token_id
tokenizer = getattr(text_processor, "tokenizer", None)
return getattr(tokenizer, "eos_token_id", None)


def run_inference(
pte_path: str,
model_id: str,
prompt: str,
max_new_tokens: int = 50,
) -> str:
"""Run inference on the exported HuggingFace model."""
logger.info(f"Loading tokenizer from HuggingFace: {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
text_processor, uses_processor = _load_text_processor(model_id)

logger.info(f"Loading model from {pte_path}...")
et_runtime = Runtime.get()
et_runtime = _ensure_mlx_backend_registered()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be needed, see comment on the install process.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s fair. I added this while debugging the installed-package path locally because MLXBackend was not being registered from the installed package, and I wanted a way to keep validating the runtime path. Since the install-path issue is now fixed, I’ll remove it and rely on the normal install flow.

program = et_runtime.load_program(pte_path, verification=Verification.Minimal)

max_seq_len = _get_max_input_seq_len(program)
Expand All @@ -67,14 +182,18 @@ def run_inference(

logger.info(f"Encoding prompt: {prompt!r}")
messages = [{"role": "user", "content": prompt}]
formatted_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
input_ids = tokenizer.encode(formatted_prompt, return_tensors="pt")
formatted_prompt = _apply_chat_template(text_processor, messages)
if uses_processor:
input_ids = text_processor(text=formatted_prompt, return_tensors="pt")[
"input_ids"
]
else:
input_ids = text_processor.encode(formatted_prompt, return_tensors="pt")
logger.info(f"Input shape: {input_ids.shape}")

generated_tokens = input_ids[0].tolist()
seq_len = input_ids.shape[1]
eos_token_id = _get_eos_token_id(text_processor)

start_time = time.time()

Expand Down Expand Up @@ -120,7 +239,7 @@ def run_inference(
next_token = torch.argmax(next_token_logits).item()
generated_tokens.append(next_token)

if next_token == tokenizer.eos_token_id:
if eos_token_id is not None and next_token == eos_token_id:
logger.info(f"EOS token reached at position {i + 1}")
break

Expand All @@ -135,12 +254,12 @@ def run_inference(

# Decode only the newly generated tokens (not the input prompt)
new_tokens = generated_tokens[seq_len:]
generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
generated_text = text_processor.decode(new_tokens, skip_special_tokens=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this break the path where uses_processor=False?

Can we unify these two paths somehow?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up unifying this path. text_processor is now either an AutoProcessor or an AutoTokenizer, and both decode through text_processor.decode(...), so the uses_processor=False case should still work. The remaining split is only at encode time, where AutoProcessor needs processor(text=..., return_tensors="pt") and AutoTokenizer still uses encode(...).

return generated_text


def main():
parser = argparse.ArgumentParser(description="Run exported HuggingFace Llama model")
parser = argparse.ArgumentParser(description="Run exported HuggingFace LLM")
parser.add_argument(
"--pte",
type=str,
Expand All @@ -151,7 +270,7 @@ def main():
"--model-id",
type=str,
default="unsloth/Llama-3.2-1B-Instruct",
help="HuggingFace model ID (used to load tokenizer)",
help="HuggingFace model ID (used to load tokenizer or processor)",
)
parser.add_argument(
"--prompt",
Expand Down
Loading
Loading