Skip to content

Latest commit

 

History

History

README.md

MLX Delegate for ExecuTorch

Note: The MLX delegate is experimental and under active development.

The MLX delegate compiles PyTorch models to run on Apple Silicon GPUs via the MLX framework. It consists of:

  • A Python compilation pipeline that converts ExportedPrograms (Edge IR) into a custom FlatBuffer bytecode format.
  • A C++ runtime that loads the bytecode and executes it using MLX GPU primitives.

Adding a new op? Jump to How to Add a New Op.

Getting Started

The MLX delegate requires Apple Silicon (M1 or later) and the Metal compiler, which ships with Xcode (not the standalone Command Line Tools).

Check if Metal is available:

xcrun -sdk macosx --find metal

If this prints a path (e.g. /Applications/Xcode.app/.../metal), you're set. If it errors, you either need to install Xcode from the App Store or https://developer.apple.com/xcode/, or — if Xcode is already installed but the command line developer directory points at Command Line Tools — switch it:

sudo xcode-select -s /Applications/Xcode.app/Contents/Developer

Python (pybindings)

The simplest way to get started is to install ExecuTorch with Python bindings. From the repo root:

python install_executorch.py

This builds and installs the executorch pip package with pybindings. On Apple Silicon, when the Metal compiler is available, the MLX backend is automatically included. You can then export models in Python using the MLX partitioner and run them via the ExecuTorch Python API.

C++ (CMake preset)

To build the C++ runtime with the MLX delegate, use the mlx-release CMake workflow preset from the repo root:

cmake --workflow --preset mlx-release

This configures and builds a Release build of the ExecuTorch runtime with the MLX delegate and installs artifacts into cmake-out/. The preset enables the MLX delegate along with commonly needed extensions (module, data loader, flat tensor, LLM runner, etc.).

Downstream C++ apps can then find_package(executorch) and link against mlxdelegate and mlx. See examples/models/llama/CMakeLists.txt for a working example.

There is also an mlx-debug preset that enables debug symbols and compiles in per-op logging support, which is useful during development:

cmake --workflow --preset mlx-debug

The debug build compiles in the logging code, but to actually see per-op output you must also set the environment variable when running the binary:

ET_MLX_ENABLE_OP_LOGGING=1 ./cmake-out/my_app

Debugging

Set ET_MLX_DEBUG=1 during AOT (export/compilation) to see detailed debug logging from the partitioner and preprocessor — including ops-to-not-decompose lists, graph dumps, per-node support decisions, and serialization details:

ET_MLX_DEBUG=1 python -m executorch.backends.mlx.examples.llm.export_llm_hf ...

Directory Layout

backends/mlx/
├── serialization/              # Schema + code generation
│   ├── schema.fbs              # ← Source of truth (FlatBuffer schema)
│   ├── generate.py             # Code generator (schema.fbs → everything else)
│   ├── mlx_graph_schema.py     # [GENERATED] Python dataclasses for IR nodes
│   ├── mlx_graph_serialize.py  # Serialization to FlatBuffer binary
│   ├── _generated_serializers.py # [GENERATED] Per-op FlatBuffer builders
│   └── _generated/             # [GENERATED] FlatBuffer Python bindings (flatc)
├── runtime/                    # C++ runtime (loaded at inference time)
│   ├── MLXBackend.cpp          # BackendInterface (init / execute / destroy)
│   ├── MLXLoader.h/.cpp        # [GENERATED] FlatBuffer → C++ structs
│   ├── MLXExecutor.h           # ExecutionState, constant loading, helpers
│   ├── MLXInterpreter.h        # Op dispatch loop + per-op exec_* functions
│   └── schema_generated.h      # [GENERATED] FlatBuffer C++ bindings (flatc)
├── llm/                        # LLM infrastructure (KV cache, attention, etc.)
│   ├── cache.py                # KV cache implementations (ET + HF static cache)
│   ├── et_attention.py         # ExecuTorch custom SDPA attention
│   ├── hf_attention.py         # HuggingFace custom SDPA attention
│   ├── quantization.py         # TorchAO quantization helpers
│   └── source_transformation.py # Source transforms for MLX export
├── _generated_inspector.py      # [GENERATED] Inspector utilities for .pte debugging
├── _logging.py                 # Debug logging utilities (ET_MLX_DEBUG)
├── builder/                    # Core build infrastructure
│   ├── op_registry.py          # REGISTRY (op handler registration)
│   ├── op_helpers.py           # Helper utilities for op handlers
│   ├── pattern_matcher.py      # Pattern matching for multi-node fusions
│   ├── program_builder.py      # MLXProgramBuilder
│   └── slot_manager.py         # Tensor/value slot allocation
├── ops.py                      # Op handlers  (ATen target → MLX IR node)
├── patterns.py                 # Pattern handlers (multi-node fusions)
├── passes.py                   # Graph passes (RMSNorm fusion, CSE, etc.)
├── pattern_utils.py            # Pattern matching utilities for passes
├── partitioner.py              # Decides which ops to delegate to MLX
├── preprocess.py               # BackendDetails.preprocess() entry point
├── custom_ops.py               # Custom torch ops (kv_cache_update, custom_sdpa, rope)
├── pte_inspector.py            # .pte file inspection/debugging tool
├── test/
│   ├── test_ops.py             # Op test definitions (models + configs)
│   ├── test_utils.py           # OpTestCase base class + helpers
│   ├── op_test_runner.cpp      # C++ test runner (loads .pte, runs, compares)
│   └── run_all_tests.py        # End-to-end: export → C++ run → compare
└── examples/
    ├── llm/                    # LLM export + run via HuggingFace
    └── whisper/                # Whisper export + run

Files marked [GENERATED] are NOT CHECKED IN CODE and are produced by running:

python backends/mlx/serialization/generate.py

Compilation Pipeline

The compilation pipeline converts a PyTorch model into a .pte file containing the MLX delegate payload. The high-level flow:

torch.export()           →  ExportedProgram (ATen IR)
to_edge_transform_and_lower()  →  Edge IR + partitioning + lowering

Within that flow, the MLX-specific steps are:

  1. Partitioning (partitioner.py) — MLXPartitioner walks the Edge IR graph and tags nodes that MLX can handle. It uses MLXProgramBuilder in a dry-run mode to determine support — so partitioning and compilation use the exact same logic. Unsupported ops fall back to ExecuTorch's portable runtime.

  2. Preprocessing (preprocess.py) — For each partitioned subgraph, MLXBackend.preprocess() is called. It builds an MLXGraph via MLXProgramBuilder, serializes it to FlatBuffer, and returns a PreprocessResult with the binary payload and constant data.

  3. Op handling (ops.py, patterns.py) — During the build, MLXProgramBuilder walks the FX graph node-by-node and dispatches to registered handlers. Single-op handlers live in ops.py; multi-node fused patterns (e.g., quantized linear, SDPA, KV cache update) live in patterns.py.

  4. Serialization (serialization/) — The MLXGraph dataclass tree is serialized to a FlatBuffer binary. See Serialization below.

The complete preprocessing flow:

ExportedProgram (subgraph)
  → MLXProgramBuilder.build()      # walks FX graph, calls op handlers
  → MLXGraph                       # Python IR (dataclasses from mlx_graph_schema.py)
  → MLXGraphSerializer.serialize() # FlatBuffer binary
  → PreprocessResult               # returned to ExecuTorch

How to Add a New Op

This section walks through adding a new op end-to-end, using aten.addmm as an example.

Step 1: Add the Node to schema.fbs

Add a new table in the "Op nodes" section and add it to the OpNode union:

table AddmmNode {
    mat1: Tid (required);
    mat2: Tid (required);
    out: Tid (required);
    bias: Tid;  // optional
}

Then add AddmmNode to the union OpNode { ... } list.

Step 2: Run the Code Generator

python backends/mlx/serialization/generate.py

This regenerates:

  • mlx_graph_schema.py — adds AddmmNode Python dataclass
  • _generated_serializers.py — adds _build_AddmmNode serializer
  • runtime/MLXLoader.h — adds AddmmNode C++ struct, OpCode::ADDMM, loader
  • runtime/MLXLoader.cpp — adds FlatBuffer → AddmmNode deserialization
  • runtime/schema_generated.h — FlatBuffer C++ bindings

Step 3: Add the Python Op Handler (ops.py)

Register a handler that converts the ATen op to your new node. Make sure to import AddmmNode from mlx_graph_schema:

from executorch.backends.mlx.serialization.mlx_graph_schema import AddmmNode

@REGISTRY.register(target=[torch.ops.aten.addmm.default])
def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot:
    args = P.args(n)
    kwargs = P.kwargs(n)
    require_args(args, 3, 3, "aten.addmm")
    require_kwargs(kwargs, {"beta", "alpha"}, "aten.addmm")
    bias, mat1, mat2 = args[0], args[1], args[2]

    beta = kwargs.get("beta", 1)
    alpha = kwargs.get("alpha", 1)

    out = P.make_or_get_slot(n)
    P.emit(
        AddmmNode(
            mat1=P.slot_to_tid(mat1),
            mat2=P.slot_to_tid(mat2),
            out=P.slot_to_tid(out),
            bias=P.slot_to_tid(bias),
            alpha=float(alpha),
            beta=float(beta),
        )
    )
    return out

Key APIs:

  • P.args(n) — resolves FX node args to Slot objects (tensor/value references)
  • P.make_or_get_slot(n) — allocates the output tensor slot
  • P.slot_to_tid(slot) — converts a Slot to a Tid for the IR node
  • P.emit(node) — appends the instruction to the graph

Step 4: Add the C++ Op Handler (MLXInterpreter.h)

Add an exec_* function in the ops namespace:

inline void exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) {
    const auto& mat1 = st.const_tensor_ref(n.mat1);
    const auto& mat2 = st.const_tensor_ref(n.mat2);

    array Y = n.bias ? addmm(
                           st.const_tensor_ref(*n.bias),
                           mat1,
                           mat2,
                           /*alpha=*/n.alpha,
                           /*beta=*/n.beta,
                           s)
                     : matmul(mat1, mat2, s);

    st.set_tensor(n.out, std::move(Y));
}

Then add the dispatch case in Interpreter::dispatch():

case OpCode::ADDMM:
    ops::exec_addmm(std::get<AddmmNode>(instr.node), st, s);
    break;

Step 5: Write a Test (test/test_ops.py)

Each test follows a standard pattern:

  1. Define a nn.Module that uses the op.
  2. Define an OpTestCase subclass that specifies test configurations.
  3. Decorate with @register_test to register it with the test runner.
class AddmmModel(nn.Module):
    """Model that performs addmm: bias + (mat1 @ mat2)."""

    def __init__(self, in_features, out_features, bias=True, alpha=1.0, beta=1.0):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.randn(out_features))
        else:
            self.bias = None
        self.alpha = alpha
        self.beta = beta

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.bias is not None:
            return torch.addmm(
                self.bias, x, self.weight.t(), beta=self.beta, alpha=self.alpha
            )
        else:
            return torch.mm(x, self.weight.t())

@register_test
class AddmmTest(OpTestCase):
    name = "addmm"
    rtol = 1e-4
    atol = 1e-4

    def __init__(self, batch_size=2, in_features=64, out_features=32,
                 bias=True, alpha=1.0, beta=1.0):
        self.batch_size = batch_size
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.alpha = alpha
        self.beta = beta
        self.name = f"addmm_{in_features}x{out_features}"

    @classmethod
    def get_test_configs(cls):
        return [
            cls(batch_size=2, in_features=64, out_features=32),
            cls(batch_size=2, in_features=64, out_features=32, bias=False),
            cls(batch_size=4, in_features=128, out_features=64),
            cls(batch_size=2, in_features=64, out_features=32, alpha=2.0, beta=0.5),
        ]

    def create_model(self):
        return AddmmModel(
            self.in_features, self.out_features,
            bias=self.bias, alpha=self.alpha, beta=self.beta,
        )

    def create_inputs(self):
        return (torch.randn(self.batch_size, self.in_features),)

Step 6: Run Tests

Tests are end-to-end: export .pte → run via C++ op_test_runner → compare outputs against PyTorch reference. Since adding a new op always involves C++ changes, use --rebuild to recompile the runtime:

python -m executorch.backends.mlx.test.run_all_tests --rebuild addmm

Run all tests in parallel:

python -m executorch.backends.mlx.test.run_all_tests --rebuild -j4 --clean-after

Other useful flags:

Flag Purpose
--rebuild Rebuild the C++ op_test_runner before running
-j N / --parallel N Run N tests in parallel
--clean-after Remove generated test artifacts after running
--list List all available test names and exit
-v / --verbose Verbose output

Test artifacts are saved to test/op_tests/<test_name>/ (.pte, input/output .bin files). See test/README.md for full details on test architecture, prerequisites, and the OpTestCase API.

Checklist

  • Add *Node table to schema.fbs + add to OpNode union
  • Run python backends/mlx/serialization/generate.py
  • Add @REGISTRY.register handler in ops.py (and import the new node class)
  • Add exec_* function in runtime/MLXInterpreter.h
  • Add case OpCode::* in Interpreter::dispatch()
  • Add test model + OpTestCase in test/test_ops.py
  • Run python -m executorch.backends.mlx.test.run_all_tests --rebuild <test_name>

Serialization

Overview

The serialization system converts a Python MLXGraph dataclass tree into a FlatBuffer binary that the C++ runtime can load. The source of truth is schema.fbs — a single FlatBuffer schema file from which all code on both sides is generated.

Schema (schema.fbs)

The schema defines:

Concept FlatBuffer type Purpose
Tid struct Tensor slot index (indexes into the runtime tensor array)
Vid struct Value slot index (for scalar int32/float/bool values)
IntOrVid table A field that is either a literal int64 or a runtime Vid reference (for dynamic shapes)
FloatOrVid table Same idea for floats
TidOrVid table Either a tensor or a scalar value
Op node tables table One per op (e.g. AddNode, SiluNode, ReshapeNode). Each declares its inputs/outputs as Tid/Vid references and any scalar parameters.
OpNode union Union of all op node tables
Instruction table Wraps an OpNode union
MLXGraph table (root) The complete program: slot counts, instruction list, I/O maps, named slots, tensor metadata

Key design points:

  • No embedded weights. Constants are stored in ExecuTorch's named_data_map and loaded by name at runtime. This enables zero-copy on unified memory.
  • Tensor IDs (Tid) are globally ordered: Constants → Inputs → Outputs → Mutable Buffers → Temps. The runtime uses this ordering for O(1) type lookup.
  • Dynamic shapes are supported via IntOrVid — a shape dimension can be either a literal integer or a reference to a runtime value produced by sym_size / item() ops.

Code Generation (generate.py)

generate.py parses schema.fbs and generates all boilerplate on both the Python and C++ sides:

Generated file What it contains
mlx_graph_schema.py Python @dataclass for every op node, Tid, Vid, IntOrVid, etc.
_generated_serializers.py GeneratedOpBuilders mixin class with _build_*Node methods for every op
_generated_inspector.py Inspector utilities for debugging .pte files
runtime/MLXLoader.h C++ structs for every op node, OpCode enum, NodeVariant, Instruction, MLXProgram
runtime/MLXLoader.cpp load_instruction() and load_program() — FlatBuffer → C++ struct conversion
runtime/schema_generated.h Standard FlatBuffer C++ bindings (via flatc)
_generated/ directory Standard FlatBuffer Python bindings (via flatc)

Running the generator:

python backends/mlx/serialization/generate.py

Use --skip-flatc if you only changed op node definitions (not core types) and want to skip the flatc invocation.

Serialization Format

The binary payload embedded in the .pte file has this layout:

[Header: 24 bytes]
    4 bytes   padding (zeros)
    4 bytes   magic ("MLX0")
    8 bytes   data_segment_offset (uint64 LE)
    8 bytes   data_segment_size   (uint64 LE)
[FlatBuffer payload]
[Padding to 16-byte alignment]
[Data segment (currently unused — constants go via named_data_map)]

The MLXGraphSerializer class (in mlx_graph_serialize.py) drives serialization. It inherits GeneratedOpBuilders for the per-op builders and adds the root-table construction, I/O maps, tensor metadata, and header.


Runtime

Initialization (init)

When ExecuTorch loads a .pte with an MLX delegate blob, MLXBackend::init() is called:

  1. Parse FlatBufferloader::load_program() deserializes the binary into an MLXProgram struct (C++ mirrors of the schema).
  2. Load constants — Iterates named_slots, calls named_data_map->get_data(name) for each constant tensor, wraps the buffer as an mlx::core::array (zero-copy when possible on unified memory).
  3. Initialize mutable buffers — Creates zero-filled MLX arrays for persistent state (e.g., KV cache). These live across execute() calls.
  4. Bind execution stateExecutionState::bind() pre-computes tensor ID ranges for O(1) routing.

Execution (execute)

Each execute() call:

  1. Reset per-execution state (inputs/outputs/temps cleared; mutable buffers and constants are retained).
  2. Bind inputs — Walk input_map, convert each ExecuTorch tensor to an mlx::core::array (zero-copy pointer wrap).
  3. Run instructionsInterpreter::run() dispatches each Instruction through a switch on OpCode, calling the corresponding exec_* function.
  4. Evaluate — Call mlx::core::eval() on output tensors to trigger lazy GPU computation.
  5. Copy outputs — Convert MLX arrays back to ExecuTorch tensors via memcpy.

Tensor ID Layout

Tensor slot IDs are assigned in a fixed order during compilation:

 ┌──────────┬──────────┬──────────┬────────────────┬──────────┐
 │ Constants│  Inputs  │ Outputs  │ Mutable Buffers│  Temps   │
 │  0..C-1  │  C..I-1  │  I..O-1  │   O..M-1       │  M..T-1  │
 └──────────┴──────────┴──────────┴────────────────┴──────────┘

The runtime stores constants and mutable buffers in separate containers (ConstantData, MutableBufferData). Inputs, outputs, and temps share a flat vector<optional<Tensor>> in ExecutionState.

Key Runtime Files

File Role
MLXBackend.cpp init() / execute() / destroy() — the ExecuTorch BackendInterface
MLXLoader.h/.cpp [GENERATED] Deserializes FlatBuffer into MLXProgram (C++ structs)
MLXExecutor.h ExecutionState, ConstantData, MutableBufferData, constant loading, dtype conversion
MLXInterpreter.h The op dispatch switch + all exec_* implementations