Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion .github/workflows/mlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,12 @@ jobs:
name: "gemma3-1b"
use-custom: [false, true]
qconfig: ["4w", "nvfp4"]
include:
- model:
id: "google/gemma-4-E2B-it"
name: "gemma4-e2b"
use-custom: true
qconfig: "4w"
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
secrets: inherit
with:
Expand All @@ -506,12 +512,21 @@ jobs:
MODEL_NAME="${{ matrix.model.name }}"
USE_CUSTOM="${{ matrix.use-custom }}"
QCONFIG="${{ matrix.qconfig }}"
MODEL_REVISION=""
if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then
MODEL_REVISION="b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf"
fi

CUSTOM_ARGS=""
if [ "${USE_CUSTOM}" = "true" ]; then
CUSTOM_ARGS="--use-custom-sdpa --use-custom-kv-cache"
fi

QEMBEDDING_ARGS="--qembedding ${QCONFIG}"
if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then
QEMBEDDING_ARGS=""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why no embeeding?

fi

echo "::group::Install ExecuTorch and configure MLX build"
${CONDA_RUN} python install_executorch.py > /dev/null
${CONDA_RUN} cmake --preset mlx-release
Expand All @@ -522,23 +537,32 @@ jobs:
${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN
OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt)
${CONDA_RUN} pip install transformers "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}"
if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then
# Gemma 4 requires a newer Transformers build than the CI-wide
# optimum-executorch pin currently brings in. Keep this pinned to the
# locally validated commit instead of floating on Transformers HEAD.
GEMMA4_TRANSFORMERS_COMMIT=61461a7bcb458db7cf6eeea49678b9ab776a7821
${CONDA_RUN} pip install -U "transformers @ git+https://github.com/huggingface/transformers.git@${GEMMA4_TRANSFORMERS_COMMIT}"
fi
echo "::endgroup::"

${CONDA_RUN} pip list

echo "::group::Export ${MODEL_NAME}"
${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.export_llm_hf \
--model-id "${MODEL_ID}" \
${MODEL_REVISION:+--revision "${MODEL_REVISION}"} \
--output /tmp/${MODEL_NAME}.pte \
--qlinear ${QCONFIG} \
--qembedding ${QCONFIG} \
${QEMBEDDING_ARGS} \
${CUSTOM_ARGS}
echo "::endgroup::"

echo "::group::Run ${MODEL_NAME} inference"
OUTPUT=$(${CONDA_RUN} python -m executorch.backends.mlx.examples.llm.run_llm_hf \
--pte /tmp/${MODEL_NAME}.pte \
--model-id "${MODEL_ID}" \
${MODEL_REVISION:+--revision "${MODEL_REVISION}"} \
--prompt "What is the capital of France?" \
--max-new-tokens 50 2>&1)
echo "$OUTPUT"
Expand Down
56 changes: 40 additions & 16 deletions backends/mlx/builder/program_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,26 +444,50 @@ def _make_io_slots(self): # noqa: C901
else:
raise NotImplementedError(f"Support for input {arg} is not implemented")

placeholder_nodes = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow this change.

Why is gemma4 sensistive to this?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got here by diffing a previously working Gemma 4 .pte against a fresh export.

What changed there was the slot assignment for the two rotary constants used by sliding-window vs full attention. This change was just to make that assignment deterministic instead of depending on raw placeholder traversal order.

Gemma 4 is where I noticed it because that model exercises both constants in the same path.

If you’d prefer, I can drop this

node.name: node for node in self.ep.graph.nodes if node.op == "placeholder"
}

# Allocate placeholder-backed slots in graph-signature order instead of
# raw FX node traversal order. This keeps lifted constant tids stable
# across equivalent exports, which matters for models like Gemma 4 that
# carry multiple rotary constant placeholders with similar structure.
for name in constant_tensors:
node = placeholder_nodes.get(name)
if node is None or node.users == {}:
continue
self.make_or_get_slot(node, id_space=IdSpace.Constant)

for name in user_inputs:
node = placeholder_nodes.get(name)
if node is None or node.users == {}:
continue
val = node.meta.get("val", None)
if isinstance(val, torch.Tensor) and not val.is_contiguous():
raise ValueError(
f"MLX backend requires contiguous input tensors, "
f"but input '{node.name}' has non-contiguous strides. "
f"shape={list(val.shape)}, stride={list(val.stride())}. "
f"Ensure example inputs passed to torch.export.export() "
f"are contiguous (call .contiguous() on them)."
)
self.make_or_get_slot(node, id_space=IdSpace.Input)

for name in mutable_buffers:
node = placeholder_nodes.get(name)
if node is None or node.users == {}:
continue
self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer)

classified_placeholders = (
set(constant_tensors) | set(user_inputs) | set(mutable_buffers)
)

for node in self.ep.graph.nodes:
if node.op == "placeholder":
if node.users == {}:
continue
if node.name in constant_tensors:
self.make_or_get_slot(node, id_space=IdSpace.Constant)
elif node.name in user_inputs:
val = node.meta.get("val", None)
if isinstance(val, torch.Tensor) and not val.is_contiguous():
raise ValueError(
f"MLX backend requires contiguous input tensors, "
f"but input '{node.name}' has non-contiguous strides. "
f"shape={list(val.shape)}, stride={list(val.stride())}. "
f"Ensure example inputs passed to torch.export.export() "
f"are contiguous (call .contiguous() on them)."
)
self.make_or_get_slot(node, id_space=IdSpace.Input)
elif node.name in mutable_buffers:
self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer)
else:
if node.name not in classified_placeholders:
raise NotImplementedError(
f"Support for placeholder {node.name} is not implemented"
)
Expand Down
35 changes: 34 additions & 1 deletion backends/mlx/examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This example demonstrates how to export and run LLMs using the MLX delegate for
- **KV Cache**: Efficient KV cache implementation for autoregressive generation
- **Custom Ops**: Uses `mlx::custom_sdpa` and `mlx::kv_cache_update` for optimal execution on MLX
- **Pybindings**: Run inference using ExecuTorch Python bindings
- **Gemma 4**: Text-only export and run flow supports processor-backed checkpoints such as `google/gemma-4-E2B-it`

## Requirements

Expand Down Expand Up @@ -52,6 +53,25 @@ python -m executorch.backends.mlx.examples.llm.export_llm_hf \
--use-custom-kv-cache \
--qlinear 4w \
--qembedding 4w

# Gemma 4 text-only export
python -m executorch.backends.mlx.examples.llm.export_llm_hf \
--model-id "google/gemma-4-E2B-it" \
--revision "b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf" \
--output gemma4_hf_int4.pte \
--use-custom-sdpa \
--use-custom-kv-cache \
--qlinear 4w
```

Gemma 4 support is currently validated for the text-only path using
`--use-custom-sdpa --use-custom-kv-cache --qlinear 4w`.

Validated with `transformers` commit
`61461a7bcb458db7cf6eeea49678b9ab776a7821`:

```bash
pip install -U "transformers @ git+https://github.com/huggingface/transformers.git@61461a7bcb458db7cf6eeea49678b9ab776a7821"
```

### Options
Expand Down Expand Up @@ -81,12 +101,25 @@ python -m executorch.backends.mlx.examples.llm.run_llm_hf \
--prompt "Explain quantum computing in simple terms"
```

Gemma 4 checkpoints may use `AutoProcessor` instead of `AutoTokenizer`; `run_llm_hf` now supports both paths automatically for text-only prompts.

Validated Gemma 4 run command:

```bash
python -m executorch.backends.mlx.examples.llm.run_llm_hf \
--pte gemma4_hf_int4.pte \
--model-id google/gemma-4-E2B-it \
--revision b4a601102c3d45e2b7b50e2057a6d5ec8ed4adcf \
--prompt "What is the capital of France?" \
--max-new-tokens 50
```

### Options

| Option | Default | Description |
|--------|---------|-------------|
| `--pte` | `llama_hf.pte` | Path to .pte file |
| `--model-id` | `unsloth/Llama-3.2-1B-Instruct` | HuggingFace model ID (for tokenizer) |
| `--model-id` | `unsloth/Llama-3.2-1B-Instruct` | HuggingFace model ID (for tokenizer or processor) |
| `--prompt` | `The quick brown fox` | Input prompt |
| `--max-new-tokens` | `50` | Maximum tokens to generate |

Expand Down
29 changes: 24 additions & 5 deletions backends/mlx/examples/llm/export_llm_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

def _export_with_optimum(
model_id: str,
revision: Optional[str],
output_path: str,
max_seq_len: int,
dtype: str,
Expand All @@ -73,6 +74,7 @@ def _export_with_optimum(
logger.info(f"Loading model using optimum-executorch: {model_id}")
exportable = load_causal_lm_model(
model_id,
revision=revision,
dtype=dtype_str,
max_seq_len=max_seq_len,
)
Expand Down Expand Up @@ -124,6 +126,7 @@ def _export_with_optimum(

def _export_with_custom_components(
model_id: str,
revision: Optional[str],
output_path: str,
max_seq_len: int,
dtype: str,
Expand Down Expand Up @@ -166,20 +169,21 @@ def _export_with_custom_components(

attn_implementation = "mlx" if use_custom_sdpa else None

# Detect sliding window models (e.g., gemma)
sliding_window = None

logger.info(f"Loading HuggingFace model: {model_id}")
load_kwargs = {
"torch_dtype": torch_dtype,
"low_cpu_mem_usage": True,
}
if revision is not None:
load_kwargs["revision"] = revision
if attn_implementation:
load_kwargs["attn_implementation"] = attn_implementation
model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs)

# Check if model uses sliding window attention
sliding_window = getattr(model.config, "sliding_window", None)
# Check if model uses sliding window attention. Multimodal configs like
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this regress gemma3?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t expect this to regress Gemma 3. The change is just switching the sliding-window lookup to model.config.get_text_config(), which also covers the plain text config case and is needed for Gemma 4 where those attrs live under text_config. I scoped the logic to the same attribute lookup, not a Gemma-4-specific branch. I can also rerun a Gemma 3 smoke test and report back.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it would be great to try on gemma3 as a smoke test, that would be great.

If you are unable to access the version from Google, try the unsloth version unsloth/gemma-3-1b-it (https://github.com/pytorch/executorch/blob/main/.github/workflows/mlx.yml#L469C18-L469C39)

# Gemma 4 keep transformer attributes under text_config.
text_config = model.config.get_text_config()
sliding_window = getattr(text_config, "sliding_window", None)
if sliding_window is not None:
logger.info(f"Model has sliding_window={sliding_window}")
# Cap max_seq_len to sliding window size for cache allocation
Expand All @@ -188,11 +192,16 @@ def _export_with_custom_components(
else:
effective_cache_len = max_seq_len

# The HF ExecuTorch cache wrappers validate both generation_config.use_cache
# and the text config's use_cache flag before constructing static caches.
model.generation_config.use_cache = True
model.generation_config.cache_implementation = "static"
model.generation_config.cache_config = {
"batch_size": 1,
"max_cache_len": effective_cache_len,
}
text_config = model.config.get_text_config()
text_config.use_cache = True
model.eval()

# Use HybridCache wrapper for sliding window models (stores cache as .cache),
Expand Down Expand Up @@ -341,6 +350,7 @@ def _save_program(executorch_program, output_path: str) -> None:

def export_llama_hf(
model_id: str,
revision: Optional[str],
output_path: str,
max_seq_len: int = 1024,
dtype: str = "bf16",
Expand Down Expand Up @@ -372,6 +382,7 @@ def export_llama_hf(
)
_export_with_custom_components(
model_id=model_id,
revision=revision,
output_path=output_path,
max_seq_len=max_seq_len,
dtype=dtype,
Expand All @@ -387,6 +398,7 @@ def export_llama_hf(
logger.info("Using optimum-executorch pipeline (no custom components)")
_export_with_optimum(
model_id=model_id,
revision=revision,
output_path=output_path,
max_seq_len=max_seq_len,
dtype=dtype,
Expand All @@ -408,6 +420,12 @@ def main():
default="unsloth/Llama-3.2-1B-Instruct",
help="HuggingFace model ID",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="Optional HuggingFace model revision/commit to pin",
)
parser.add_argument(
"--output",
type=str,
Expand Down Expand Up @@ -447,6 +465,7 @@ def main():

export_llama_hf(
model_id=args.model_id,
revision=args.revision,
output_path=args.output,
max_seq_len=args.max_seq_len,
dtype=args.dtype,
Expand Down
Loading
Loading