Note: The MLX delegate is experimental and under active development.
The MLX delegate compiles PyTorch models to run on Apple Silicon GPUs via the MLX framework. It consists of:
- A Python compilation pipeline that converts ExportedPrograms (Edge IR) into a custom FlatBuffer bytecode format.
- A C++ runtime that loads the bytecode and executes it using MLX GPU primitives.
Adding a new op? Jump to How to Add a New Op.
The MLX delegate requires Apple Silicon (M1 or later) and the Metal compiler, which ships with Xcode (not the standalone Command Line Tools).
Check if Metal is available:
xcrun -sdk macosx --find metalIf this prints a path (e.g. /Applications/Xcode.app/.../metal), you're set.
If it errors, you either need to install Xcode from the
App Store or
https://developer.apple.com/xcode/, or — if Xcode is already installed but the
command line developer directory points at Command Line Tools — switch it:
sudo xcode-select -s /Applications/Xcode.app/Contents/DeveloperThe simplest way to get started is to install ExecuTorch with Python bindings. From the repo root:
python install_executorch.pyThis builds and installs the executorch pip package with pybindings. On Apple
Silicon, when the Metal compiler is available, the MLX backend is automatically
included. You can then export models in Python using the MLX partitioner and run
them via the ExecuTorch Python API.
To build the C++ runtime with the MLX delegate, use the mlx-release CMake
workflow preset from the repo root:
cmake --workflow --preset mlx-releaseThis configures and builds a Release build of the ExecuTorch runtime with the
MLX delegate and installs artifacts into cmake-out/. The preset enables the
MLX delegate along with commonly needed extensions (module, data loader, flat
tensor, LLM runner, etc.).
Downstream C++ apps can then find_package(executorch) and link against
mlxdelegate and mlx. See
examples/models/llama/CMakeLists.txt
for a working example.
There is also an mlx-debug preset that enables debug symbols and compiles in
per-op logging support, which is useful during development:
cmake --workflow --preset mlx-debugThe debug build compiles in the logging code, but to actually see per-op output you must also set the environment variable when running the binary:
ET_MLX_ENABLE_OP_LOGGING=1 ./cmake-out/my_appSet ET_MLX_DEBUG=1 during AOT (export/compilation) to see detailed debug
logging from the partitioner and preprocessor — including ops-to-not-decompose
lists, graph dumps, per-node support decisions, and serialization details:
ET_MLX_DEBUG=1 python -m executorch.backends.mlx.examples.llm.export_llm_hf ...backends/mlx/
├── serialization/ # Schema + code generation
│ ├── schema.fbs # ← Source of truth (FlatBuffer schema)
│ ├── generate.py # Code generator (schema.fbs → everything else)
│ ├── mlx_graph_schema.py # [GENERATED] Python dataclasses for IR nodes
│ ├── mlx_graph_serialize.py # Serialization to FlatBuffer binary
│ ├── _generated_serializers.py # [GENERATED] Per-op FlatBuffer builders
│ └── _generated/ # [GENERATED] FlatBuffer Python bindings (flatc)
├── runtime/ # C++ runtime (loaded at inference time)
│ ├── MLXBackend.cpp # BackendInterface (init / execute / destroy)
│ ├── MLXLoader.h/.cpp # [GENERATED] FlatBuffer → C++ structs
│ ├── MLXExecutor.h # ExecutionState, constant loading, helpers
│ ├── MLXInterpreter.h # Op dispatch loop + per-op exec_* functions
│ └── schema_generated.h # [GENERATED] FlatBuffer C++ bindings (flatc)
├── llm/ # LLM infrastructure (KV cache, attention, etc.)
│ ├── cache.py # KV cache implementations (ET + HF static cache)
│ ├── et_attention.py # ExecuTorch custom SDPA attention
│ ├── hf_attention.py # HuggingFace custom SDPA attention
│ ├── quantization.py # TorchAO quantization helpers
│ └── source_transformation.py # Source transforms for MLX export
├── _generated_inspector.py # [GENERATED] Inspector utilities for .pte debugging
├── _logging.py # Debug logging utilities (ET_MLX_DEBUG)
├── builder/ # Core build infrastructure
│ ├── op_registry.py # REGISTRY (op handler registration)
│ ├── op_helpers.py # Helper utilities for op handlers
│ ├── pattern_matcher.py # Pattern matching for multi-node fusions
│ ├── program_builder.py # MLXProgramBuilder
│ └── slot_manager.py # Tensor/value slot allocation
├── ops.py # Op handlers (ATen target → MLX IR node)
├── patterns.py # Pattern handlers (multi-node fusions)
├── passes.py # Graph passes (RMSNorm fusion, CSE, etc.)
├── pattern_utils.py # Pattern matching utilities for passes
├── partitioner.py # Decides which ops to delegate to MLX
├── preprocess.py # BackendDetails.preprocess() entry point
├── custom_ops.py # Custom torch ops (kv_cache_update, custom_sdpa, rope)
├── pte_inspector.py # .pte file inspection/debugging tool
├── test/
│ ├── test_ops.py # Op test definitions (models + configs)
│ ├── test_utils.py # OpTestCase base class + helpers
│ ├── op_test_runner.cpp # C++ test runner (loads .pte, runs, compares)
│ └── run_all_tests.py # End-to-end: export → C++ run → compare
└── examples/
├── llm/ # LLM export + run via HuggingFace
└── whisper/ # Whisper export + run
Files marked [GENERATED] are NOT CHECKED IN CODE and are produced by running:
python backends/mlx/serialization/generate.pyThe compilation pipeline converts a PyTorch model into a .pte file containing
the MLX delegate payload. The high-level flow:
torch.export() → ExportedProgram (ATen IR)
to_edge_transform_and_lower() → Edge IR + partitioning + lowering
Within that flow, the MLX-specific steps are:
-
Partitioning (
partitioner.py) —MLXPartitionerwalks the Edge IR graph and tags nodes that MLX can handle. It usesMLXProgramBuilderin a dry-run mode to determine support — so partitioning and compilation use the exact same logic. Unsupported ops fall back to ExecuTorch's portable runtime. -
Preprocessing (
preprocess.py) — For each partitioned subgraph,MLXBackend.preprocess()is called. It builds anMLXGraphviaMLXProgramBuilder, serializes it to FlatBuffer, and returns aPreprocessResultwith the binary payload and constant data. -
Op handling (
ops.py,patterns.py) — During the build,MLXProgramBuilderwalks the FX graph node-by-node and dispatches to registered handlers. Single-op handlers live inops.py; multi-node fused patterns (e.g., quantized linear, SDPA, KV cache update) live inpatterns.py. -
Serialization (
serialization/) — TheMLXGraphdataclass tree is serialized to a FlatBuffer binary. See Serialization below.
The complete preprocessing flow:
ExportedProgram (subgraph)
→ MLXProgramBuilder.build() # walks FX graph, calls op handlers
→ MLXGraph # Python IR (dataclasses from mlx_graph_schema.py)
→ MLXGraphSerializer.serialize() # FlatBuffer binary
→ PreprocessResult # returned to ExecuTorch
This section walks through adding a new op end-to-end, using aten.addmm
as an example.
Add a new table in the "Op nodes" section and add it to the OpNode union:
table AddmmNode {
mat1: Tid (required);
mat2: Tid (required);
out: Tid (required);
bias: Tid; // optional
}Then add AddmmNode to the union OpNode { ... } list.
python backends/mlx/serialization/generate.pyThis regenerates:
mlx_graph_schema.py— addsAddmmNodePython dataclass_generated_serializers.py— adds_build_AddmmNodeserializerruntime/MLXLoader.h— addsAddmmNodeC++ struct,OpCode::ADDMM, loaderruntime/MLXLoader.cpp— adds FlatBuffer →AddmmNodedeserializationruntime/schema_generated.h— FlatBuffer C++ bindings
Register a handler that converts the ATen op to your new node. Make sure to
import AddmmNode from mlx_graph_schema:
from executorch.backends.mlx.serialization.mlx_graph_schema import AddmmNode
@REGISTRY.register(target=[torch.ops.aten.addmm.default])
def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
kwargs = P.kwargs(n)
require_args(args, 3, 3, "aten.addmm")
require_kwargs(kwargs, {"beta", "alpha"}, "aten.addmm")
bias, mat1, mat2 = args[0], args[1], args[2]
beta = kwargs.get("beta", 1)
alpha = kwargs.get("alpha", 1)
out = P.make_or_get_slot(n)
P.emit(
AddmmNode(
mat1=P.slot_to_tid(mat1),
mat2=P.slot_to_tid(mat2),
out=P.slot_to_tid(out),
bias=P.slot_to_tid(bias),
alpha=float(alpha),
beta=float(beta),
)
)
return outKey APIs:
P.args(n)— resolves FX node args toSlotobjects (tensor/value references)P.make_or_get_slot(n)— allocates the output tensor slotP.slot_to_tid(slot)— converts aSlotto aTidfor the IR nodeP.emit(node)— appends the instruction to the graph
Add an exec_* function in the ops namespace:
inline void exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) {
const auto& mat1 = st.const_tensor_ref(n.mat1);
const auto& mat2 = st.const_tensor_ref(n.mat2);
array Y = n.bias ? addmm(
st.const_tensor_ref(*n.bias),
mat1,
mat2,
/*alpha=*/n.alpha,
/*beta=*/n.beta,
s)
: matmul(mat1, mat2, s);
st.set_tensor(n.out, std::move(Y));
}Then add the dispatch case in Interpreter::dispatch():
case OpCode::ADDMM:
ops::exec_addmm(std::get<AddmmNode>(instr.node), st, s);
break;Each test follows a standard pattern:
- Define a
nn.Modulethat uses the op. - Define an
OpTestCasesubclass that specifies test configurations. - Decorate with
@register_testto register it with the test runner.
class AddmmModel(nn.Module):
"""Model that performs addmm: bias + (mat1 @ mat2)."""
def __init__(self, in_features, out_features, bias=True, alpha=1.0, beta=1.0):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.randn(out_features))
else:
self.bias = None
self.alpha = alpha
self.beta = beta
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.bias is not None:
return torch.addmm(
self.bias, x, self.weight.t(), beta=self.beta, alpha=self.alpha
)
else:
return torch.mm(x, self.weight.t())
@register_test
class AddmmTest(OpTestCase):
name = "addmm"
rtol = 1e-4
atol = 1e-4
def __init__(self, batch_size=2, in_features=64, out_features=32,
bias=True, alpha=1.0, beta=1.0):
self.batch_size = batch_size
self.in_features = in_features
self.out_features = out_features
self.bias = bias
self.alpha = alpha
self.beta = beta
self.name = f"addmm_{in_features}x{out_features}"
@classmethod
def get_test_configs(cls):
return [
cls(batch_size=2, in_features=64, out_features=32),
cls(batch_size=2, in_features=64, out_features=32, bias=False),
cls(batch_size=4, in_features=128, out_features=64),
cls(batch_size=2, in_features=64, out_features=32, alpha=2.0, beta=0.5),
]
def create_model(self):
return AddmmModel(
self.in_features, self.out_features,
bias=self.bias, alpha=self.alpha, beta=self.beta,
)
def create_inputs(self):
return (torch.randn(self.batch_size, self.in_features),)Tests are end-to-end: export .pte → run via C++ op_test_runner → compare
outputs against PyTorch reference. Since adding a new op always involves C++
changes, use --rebuild to recompile the runtime:
python -m executorch.backends.mlx.test.run_all_tests --rebuild addmmRun all tests in parallel:
python -m executorch.backends.mlx.test.run_all_tests --rebuild -j4 --clean-afterOther useful flags:
| Flag | Purpose |
|---|---|
--rebuild |
Rebuild the C++ op_test_runner before running |
-j N / --parallel N |
Run N tests in parallel |
--clean-after |
Remove generated test artifacts after running |
--list |
List all available test names and exit |
-v / --verbose |
Verbose output |
Test artifacts are saved to test/op_tests/<test_name>/ (.pte, input/output
.bin files). See test/README.md for full details on test
architecture, prerequisites, and the OpTestCase API.
- Add
*Nodetable toschema.fbs+ add toOpNodeunion - Run
python backends/mlx/serialization/generate.py - Add
@REGISTRY.registerhandler inops.py(and import the new node class) - Add
exec_*function inruntime/MLXInterpreter.h - Add
case OpCode::*inInterpreter::dispatch() - Add test model +
OpTestCaseintest/test_ops.py - Run
python -m executorch.backends.mlx.test.run_all_tests --rebuild <test_name>
The serialization system converts a Python MLXGraph dataclass tree into a
FlatBuffer binary that the C++ runtime can load. The source of truth is
schema.fbs — a single FlatBuffer schema file from which all code on both
sides is generated.
The schema defines:
| Concept | FlatBuffer type | Purpose |
|---|---|---|
Tid |
struct | Tensor slot index (indexes into the runtime tensor array) |
Vid |
struct | Value slot index (for scalar int32/float/bool values) |
IntOrVid |
table | A field that is either a literal int64 or a runtime Vid reference (for dynamic shapes) |
FloatOrVid |
table | Same idea for floats |
TidOrVid |
table | Either a tensor or a scalar value |
| Op node tables | table | One per op (e.g. AddNode, SiluNode, ReshapeNode). Each declares its inputs/outputs as Tid/Vid references and any scalar parameters. |
OpNode |
union | Union of all op node tables |
Instruction |
table | Wraps an OpNode union |
MLXGraph |
table (root) | The complete program: slot counts, instruction list, I/O maps, named slots, tensor metadata |
Key design points:
- No embedded weights. Constants are stored in ExecuTorch's
named_data_mapand loaded by name at runtime. This enables zero-copy on unified memory. - Tensor IDs (
Tid) are globally ordered: Constants → Inputs → Outputs → Mutable Buffers → Temps. The runtime uses this ordering for O(1) type lookup. - Dynamic shapes are supported via
IntOrVid— a shape dimension can be either a literal integer or a reference to a runtime value produced bysym_size/item()ops.
generate.py parses schema.fbs and generates all boilerplate on both the
Python and C++ sides:
| Generated file | What it contains |
|---|---|
mlx_graph_schema.py |
Python @dataclass for every op node, Tid, Vid, IntOrVid, etc. |
_generated_serializers.py |
GeneratedOpBuilders mixin class with _build_*Node methods for every op |
_generated_inspector.py |
Inspector utilities for debugging .pte files |
runtime/MLXLoader.h |
C++ structs for every op node, OpCode enum, NodeVariant, Instruction, MLXProgram |
runtime/MLXLoader.cpp |
load_instruction() and load_program() — FlatBuffer → C++ struct conversion |
runtime/schema_generated.h |
Standard FlatBuffer C++ bindings (via flatc) |
_generated/ directory |
Standard FlatBuffer Python bindings (via flatc) |
Running the generator:
python backends/mlx/serialization/generate.pyUse --skip-flatc if you only changed op node definitions (not core types) and
want to skip the flatc invocation.
The binary payload embedded in the .pte file has this layout:
[Header: 24 bytes]
4 bytes padding (zeros)
4 bytes magic ("MLX0")
8 bytes data_segment_offset (uint64 LE)
8 bytes data_segment_size (uint64 LE)
[FlatBuffer payload]
[Padding to 16-byte alignment]
[Data segment (currently unused — constants go via named_data_map)]
The MLXGraphSerializer class (in mlx_graph_serialize.py) drives
serialization. It inherits GeneratedOpBuilders for the per-op builders and
adds the root-table construction, I/O maps, tensor metadata, and header.
When ExecuTorch loads a .pte with an MLX delegate blob, MLXBackend::init()
is called:
- Parse FlatBuffer —
loader::load_program()deserializes the binary into anMLXProgramstruct (C++ mirrors of the schema). - Load constants — Iterates
named_slots, callsnamed_data_map->get_data(name)for each constant tensor, wraps the buffer as anmlx::core::array(zero-copy when possible on unified memory). - Initialize mutable buffers — Creates zero-filled MLX arrays for
persistent state (e.g., KV cache). These live across
execute()calls. - Bind execution state —
ExecutionState::bind()pre-computes tensor ID ranges for O(1) routing.
Each execute() call:
- Reset per-execution state (inputs/outputs/temps cleared; mutable buffers and constants are retained).
- Bind inputs — Walk
input_map, convert each ExecuTorch tensor to anmlx::core::array(zero-copy pointer wrap). - Run instructions —
Interpreter::run()dispatches eachInstructionthrough aswitchonOpCode, calling the correspondingexec_*function. - Evaluate — Call
mlx::core::eval()on output tensors to trigger lazy GPU computation. - Copy outputs — Convert MLX arrays back to ExecuTorch tensors via
memcpy.
Tensor slot IDs are assigned in a fixed order during compilation:
┌──────────┬──────────┬──────────┬────────────────┬──────────┐
│ Constants│ Inputs │ Outputs │ Mutable Buffers│ Temps │
│ 0..C-1 │ C..I-1 │ I..O-1 │ O..M-1 │ M..T-1 │
└──────────┴──────────┴──────────┴────────────────┴──────────┘
The runtime stores constants and mutable buffers in separate containers
(ConstantData, MutableBufferData). Inputs, outputs, and temps share a flat
vector<optional<Tensor>> in ExecutionState.
| File | Role |
|---|---|
MLXBackend.cpp |
init() / execute() / destroy() — the ExecuTorch BackendInterface |
MLXLoader.h/.cpp |
[GENERATED] Deserializes FlatBuffer into MLXProgram (C++ structs) |
MLXExecutor.h |
ExecutionState, ConstantData, MutableBufferData, constant loading, dtype conversion |
MLXInterpreter.h |
The op dispatch switch + all exec_* implementations |