Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions .github/workflows/mlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -489,11 +489,25 @@ jobs:
name: "gemma3-1b"
use-custom: [false, true]
qconfig: ["4w", "nvfp4"]
runner: ["macos-14-xlarge"]
include:
- model:
id: "google/gemma-4-E2B-it"
name: "gemma4-e2b"
use-custom: true
qconfig: "4w"
runner: "macos-15-xlarge"
- model:
id: "google/gemma-4-E2B-it"
name: "gemma4-e2b"
use-custom: false
qconfig: "4w"
runner: "macos-15-xlarge"
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
secrets: inherit
with:
job-name: test-mlx-llm-${{ matrix.model.name }}${{ matrix.use-custom && '-custom' || '' }}-${{ matrix.qconfig }}
runner: macos-14-xlarge
runner: ${{ matrix.runner }}
python-version: "3.12"
submodules: recursive
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
Expand All @@ -506,12 +520,21 @@ jobs:
MODEL_NAME="${{ matrix.model.name }}"
USE_CUSTOM="${{ matrix.use-custom }}"
QCONFIG="${{ matrix.qconfig }}"
MODEL_REVISION=""
if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then
MODEL_REVISION="b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf"
fi

CUSTOM_ARGS=""
if [ "${USE_CUSTOM}" = "true" ]; then
CUSTOM_ARGS="--use-custom-sdpa --use-custom-kv-cache"
fi

QEMBEDDING_ARGS="--qembedding ${QCONFIG}"
if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then
QEMBEDDING_ARGS=""
fi

echo "::group::Install ExecuTorch and configure MLX build"
${CONDA_RUN} python install_executorch.py > /dev/null
${CONDA_RUN} cmake --preset mlx-release
Expand All @@ -522,23 +545,32 @@ jobs:
${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN
OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt)
${CONDA_RUN} pip install transformers "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}"
if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then
# Gemma 4 requires a newer Transformers build than the CI-wide
# optimum-executorch pin currently brings in. Keep this pinned to the
# locally validated commit instead of floating on Transformers HEAD.
GEMMA4_TRANSFORMERS_COMMIT=61461a7bcb458db7cf6eeea49678b9ab776a7821
${CONDA_RUN} pip install -U "transformers @ git+https://github.com/huggingface/transformers.git@${GEMMA4_TRANSFORMERS_COMMIT}"
fi
echo "::endgroup::"

${CONDA_RUN} pip list

echo "::group::Export ${MODEL_NAME}"
${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.export_llm_hf \
--model-id "${MODEL_ID}" \
${MODEL_REVISION:+--revision "${MODEL_REVISION}"} \
--output /tmp/${MODEL_NAME}.pte \
--qlinear ${QCONFIG} \
--qembedding ${QCONFIG} \
${QEMBEDDING_ARGS} \
${CUSTOM_ARGS}
echo "::endgroup::"

echo "::group::Run ${MODEL_NAME} inference"
OUTPUT=$(${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.run_llm_hf \
--pte /tmp/${MODEL_NAME}.pte \
--model-id "${MODEL_ID}" \
${MODEL_REVISION:+--revision "${MODEL_REVISION}"} \
--prompt "What is the capital of France?" \
--max-new-tokens 50 2>&1)
echo "$OUTPUT"
Expand Down
35 changes: 34 additions & 1 deletion backends/mlx/examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This example demonstrates how to export and run LLMs using the MLX delegate for
- **KV Cache**: Efficient KV cache implementation for autoregressive generation
- **Custom Ops**: Uses `mlx::custom_sdpa` and `mlx::kv_cache_update` for optimal execution on MLX
- **Pybindings**: Run inference using ExecuTorch Python bindings
- **Gemma 4**: Text-only export and run flow supports processor-backed checkpoints such as `google/gemma-4-E2B-it`

## Requirements

Expand Down Expand Up @@ -52,6 +53,25 @@ python -m executorch.backends.mlx.examples.llm.export_llm_hf \
--use-custom-kv-cache \
--qlinear 4w \
--qembedding 4w

# Gemma 4 text-only export
python -m executorch.backends.mlx.examples.llm.export_llm_hf \
--model-id "google/gemma-4-E2B-it" \
--revision "b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf" \
--output gemma4_hf_int4.pte \
--use-custom-sdpa \
--use-custom-kv-cache \
--qlinear 4w
```

Gemma 4 support is currently validated for the text-only path using
`--use-custom-sdpa --use-custom-kv-cache --qlinear 4w`.

Validated with `transformers` commit
`61461a7bcb458db7cf6eeea49678b9ab776a7821`:

```bash
pip install -U "transformers @ git+https://github.com/huggingface/transformers.git@61461a7bcb458db7cf6eeea49678b9ab776a7821"
```

### Options
Expand Down Expand Up @@ -81,12 +101,25 @@ python -m executorch.backends.mlx.examples.llm.run_llm_hf \
--prompt "Explain quantum computing in simple terms"
```

Gemma 4 checkpoints may use `AutoProcessor` instead of `AutoTokenizer`; `run_llm_hf` now supports both paths automatically for text-only prompts.

Validated Gemma 4 run command:

```bash
python -m executorch.backends.mlx.examples.llm.run_llm_hf \
--pte gemma4_hf_int4.pte \
--model-id google/gemma-4-E2B-it \
--revision b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf \
--prompt "What is the capital of France?" \
--max-new-tokens 50
```

### Options

| Option | Default | Description |
|--------|---------|-------------|
| `--pte` | `llama_hf.pte` | Path to .pte file |
| `--model-id` | `unsloth/Llama-3.2-1B-Instruct` | HuggingFace model ID (for tokenizer) |
| `--model-id` | `unsloth/Llama-3.2-1B-Instruct` | HuggingFace model ID (for tokenizer or processor) |
| `--prompt` | `The quick brown fox` | Input prompt |
| `--max-new-tokens` | `50` | Maximum tokens to generate |

Expand Down
85 changes: 39 additions & 46 deletions backends/mlx/examples/llm/export_llm_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

def _export_with_optimum(
model_id: str,
revision: Optional[str],
output_path: str,
max_seq_len: int,
dtype: str,
Expand All @@ -73,6 +74,7 @@ def _export_with_optimum(
logger.info(f"Loading model using optimum-executorch: {model_id}")
exportable = load_causal_lm_model(
model_id,
revision=revision,
dtype=dtype_str,
max_seq_len=max_seq_len,
)
Expand Down Expand Up @@ -124,6 +126,7 @@ def _export_with_optimum(

def _export_with_custom_components(
model_id: str,
revision: Optional[str],
output_path: str,
max_seq_len: int,
dtype: str,
Expand Down Expand Up @@ -166,20 +169,21 @@ def _export_with_custom_components(

attn_implementation = "mlx" if use_custom_sdpa else None

# Detect sliding window models (e.g., gemma)
sliding_window = None

logger.info(f"Loading HuggingFace model: {model_id}")
load_kwargs = {
"torch_dtype": torch_dtype,
"low_cpu_mem_usage": True,
}
if revision is not None:
load_kwargs["revision"] = revision
if attn_implementation:
load_kwargs["attn_implementation"] = attn_implementation
model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs)

# Check if model uses sliding window attention
sliding_window = getattr(model.config, "sliding_window", None)
# Check if model uses sliding window attention. Multimodal configs like
# Gemma 4 keep transformer attributes under text_config.
text_config = model.config.get_text_config()
sliding_window = getattr(text_config, "sliding_window", None)
if sliding_window is not None:
logger.info(f"Model has sliding_window={sliding_window}")
# Cap max_seq_len to sliding window size for cache allocation
Expand All @@ -188,11 +192,16 @@ def _export_with_custom_components(
else:
effective_cache_len = max_seq_len

# The HF ExecuTorch cache wrappers validate both generation_config.use_cache
# and the text config's use_cache flag before constructing static caches.
model.generation_config.use_cache = True
model.generation_config.cache_implementation = "static"
model.generation_config.cache_config = {
"batch_size": 1,
"max_cache_len": effective_cache_len,
}
text_config = model.config.get_text_config()
text_config.use_cache = True
model.eval()

# Use HybridCache wrapper for sliding window models (stores cache as .cache),
Expand All @@ -219,52 +228,26 @@ def _export_with_custom_components(
)

if use_custom_kv_cache:
if sliding_window is not None:
# Use ring buffer cache for sliding window models
from executorch.backends.mlx.llm.source_transformation import (
replace_hf_cache_with_mlx_ring_buffer,
)
from executorch.backends.mlx.llm.source_transformation import (
replace_hf_cache_with_mlx,
)

if sliding_window is not None:
logger.info(
f"Replacing StaticCache with RingBuffer KV cache "
f"(window_size={effective_cache_len})..."
"Replacing HuggingFace StaticCache with HFStaticCache "
f"(capped to sliding window: {effective_cache_len})..."
)
replace_hf_cache_with_mlx_ring_buffer(
exportable,
model.config,
max_batch_size=1,
window_size=effective_cache_len,
dtype=torch_dtype,
)

if use_custom_sdpa:
# Re-register attention with sliding window closure
from executorch.backends.mlx.llm.hf_attention import (
register_mlx_sliding_window_attention,
)

register_mlx_sliding_window_attention(exportable)
model.config._attn_implementation = "mlx_sliding_window"
logger.info(
" Registered sliding window attention (mlx_sliding_window)"
)

logger.info(" RingBuffer KV cache installed successfully")
else:
# Use standard linear cache for non-sliding-window models
from executorch.backends.mlx.llm.source_transformation import (
replace_hf_cache_with_mlx,
)

logger.info("Replacing HuggingFace StaticCache with HFStaticCache...")
replace_hf_cache_with_mlx(
exportable,
model.config,
max_batch_size=1,
max_cache_len=effective_cache_len,
dtype=torch_dtype,
)
logger.info(" HFStaticCache installed successfully")

replace_hf_cache_with_mlx(
exportable,
model.config,
max_batch_size=1,
max_cache_len=effective_cache_len,
dtype=torch_dtype,
)
logger.info(" HFStaticCache installed successfully")

from executorch.backends.mlx.llm.quantization import quantize_model_

Expand Down Expand Up @@ -341,6 +324,7 @@ def _save_program(executorch_program, output_path: str) -> None:

def export_llama_hf(
model_id: str,
revision: Optional[str],
output_path: str,
max_seq_len: int = 1024,
dtype: str = "bf16",
Expand Down Expand Up @@ -372,6 +356,7 @@ def export_llama_hf(
)
_export_with_custom_components(
model_id=model_id,
revision=revision,
output_path=output_path,
max_seq_len=max_seq_len,
dtype=dtype,
Expand All @@ -387,6 +372,7 @@ def export_llama_hf(
logger.info("Using optimum-executorch pipeline (no custom components)")
_export_with_optimum(
model_id=model_id,
revision=revision,
output_path=output_path,
max_seq_len=max_seq_len,
dtype=dtype,
Expand All @@ -408,6 +394,12 @@ def main():
default="unsloth/Llama-3.2-1B-Instruct",
help="HuggingFace model ID",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="Optional HuggingFace model revision/commit to pin",
)
parser.add_argument(
"--output",
type=str,
Expand Down Expand Up @@ -447,6 +439,7 @@ def main():

export_llama_hf(
model_id=args.model_id,
revision=args.revision,
output_path=args.output,
max_seq_len=args.max_seq_len,
dtype=args.dtype,
Expand Down
Loading
Loading