Skip to content

Commit 1867cfc

Browse files
committed
MLX delegate
1 parent 27b778d commit 1867cfc

96 files changed

Lines changed: 25484 additions & 2 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@
4848
# TODO Lower to 3.24 when XNNPACK dependency is updated to include
4949
# https://github.com/google/XNNPACK/commit/c690daa67f883e1b627aadf7684c06797e9a0684
5050
cmake_minimum_required(VERSION 3.29)
51+
52+
# Set minimum macOS deployment target for Apple platforms.
53+
# This must be set before the project() call and before any subdirectory processing.
54+
# MLX requires macOS >= 14.0, so we default to 14.0 if not set.
55+
if(APPLE AND (NOT CMAKE_OSX_DEPLOYMENT_TARGET OR CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL ""))
56+
set(CMAKE_OSX_DEPLOYMENT_TARGET "14.0" CACHE STRING "Minimum macOS version" FORCE)
57+
endif()
5158
project(executorch)
5259

5360
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
@@ -563,6 +570,11 @@ if(EXECUTORCH_BUILD_MPS)
563570
list(APPEND _executorch_backends mpsdelegate)
564571
endif()
565572

573+
if(EXECUTORCH_BUILD_MLX)
574+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/apple/mlx)
575+
list(APPEND _executorch_backends mlxdelegate)
576+
endif()
577+
566578
if(EXECUTORCH_BUILD_NEURON)
567579
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/mediatek)
568580
list(APPEND _executorch_backends neuron_backend)
@@ -842,6 +854,10 @@ if(EXECUTORCH_BUILD_PYBIND)
842854
list(APPEND _dep_libs mpsdelegate)
843855
endif()
844856

857+
if(EXECUTORCH_BUILD_MLX)
858+
list(APPEND _dep_libs mlxdelegate)
859+
endif()
860+
845861
if(EXECUTORCH_BUILD_OPENVINO)
846862
list(APPEND _dep_libs openvino_backend)
847863
endif()

CMakePresets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@
110110
"inherits": ["common"],
111111
"cacheVariables": {
112112
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/pybind.cmake",
113-
"CMAKE_OSX_DEPLOYMENT_TARGET": "12.0"
113+
"CMAKE_OSX_DEPLOYMENT_TARGET": "14.0"
114114
},
115115
"condition": {
116116
"type": "inList",

backends/apple/mlx/CMakeLists.txt

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
#
8+
9+
cmake_minimum_required(VERSION 3.19)
10+
11+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
12+
13+
if(NOT CMAKE_CXX_STANDARD)
14+
set(CMAKE_CXX_STANDARD 17)
15+
endif()
16+
17+
# Source root directory for executorch.
18+
if(NOT EXECUTORCH_ROOT)
19+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
20+
endif()
21+
22+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
23+
24+
set(_common_compile_options -Wno-deprecated-declarations)
25+
26+
# -----------------------------------------------------------------------------
27+
# FlatBuffer schema generation
28+
# -----------------------------------------------------------------------------
29+
30+
set(_mlx_schema__include_dir "${CMAKE_BINARY_DIR}/schema/include")
31+
set(_mlx_schema__srcs
32+
${CMAKE_CURRENT_SOURCE_DIR}/serialization/schema.fbs
33+
)
34+
35+
# Paths to headers generated from the .fbs files.
36+
set(_mlx_schema__outputs
37+
"${_mlx_schema__include_dir}/executorch/backends/apple/mlx/serialization/schema_generated.h"
38+
)
39+
40+
# Generate the headers from the .fbs files.
41+
add_custom_command(
42+
OUTPUT ${_mlx_schema__outputs}
43+
COMMAND
44+
flatc --cpp --cpp-std c++11 --scoped-enums -o
45+
"${_mlx_schema__include_dir}/executorch/backends/apple/mlx/serialization"
46+
${_mlx_schema__srcs}
47+
WORKING_DIRECTORY ${EXECUTORCH_ROOT}
48+
DEPENDS flatc ${_mlx_schema__srcs}
49+
COMMENT "Generating mlx_schema headers"
50+
VERBATIM
51+
)
52+
53+
add_library(mlx_schema INTERFACE ${_mlx_schema__outputs})
54+
set_target_properties(mlx_schema PROPERTIES LINKER_LANGUAGE CXX)
55+
target_include_directories(
56+
mlx_schema
57+
INTERFACE
58+
$<BUILD_INTERFACE:${_mlx_schema__include_dir}>
59+
$<BUILD_INTERFACE:${EXECUTORCH_ROOT}/third-party/flatbuffers/include>
60+
)
61+
62+
# -----------------------------------------------------------------------------
63+
# MLX dependency (fetched via FetchContent)
64+
# -----------------------------------------------------------------------------
65+
66+
include(FetchContent)
67+
68+
# MLX build options - we only need the C++ library
69+
set(MLX_BUILD_PYTHON_BINDINGS OFF CACHE BOOL "" FORCE)
70+
set(MLX_BUILD_TESTS OFF CACHE BOOL "" FORCE)
71+
set(MLX_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
72+
set(MLX_BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE)
73+
set(MLX_BUILD_CPU OFF CACHE BOOL "" FORCE)
74+
set(MLX_BUILD_METAL ON CACHE BOOL "" FORCE)
75+
set(MLX_BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE)
76+
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
77+
set(MLX_BUILD_SAFETENSORS OFF CACHE BOOL "" FORCE)
78+
set(MLX_METAL_JIT OFF CACHE BOOL "" FORCE)
79+
80+
# MLX uses FetchContent for json. When FetchContent_MakeAvailable(json) is called,
81+
# it will run add_subdirectory on the json source. ExecuTorch already adds json via
82+
# add_subdirectory(third-party/json) BEFORE this backend is processed.
83+
#
84+
# To prevent the conflict, we patch MLX's CMakeLists.txt to wrap the json fetch
85+
# in a target check. The patch file is in backends/apple/mlx/patches/
86+
87+
# Ensure CMAKE_OSX_DEPLOYMENT_TARGET is set for MLX's version check.
88+
# MLX requires macOS >= 14.0. If not set, default to 14.0.
89+
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET OR CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
90+
set(CMAKE_OSX_DEPLOYMENT_TARGET "14.0" CACHE STRING "Minimum macOS version" FORCE)
91+
endif()
92+
93+
FetchContent_Declare(
94+
mlx
95+
GIT_REPOSITORY https://github.com/ml-explore/mlx.git
96+
GIT_TAG v0.30.3
97+
PATCH_COMMAND git apply ${CMAKE_CURRENT_SOURCE_DIR}/patches/mlx_json.patch || true
98+
)
99+
100+
message(STATUS "Fetching MLX...")
101+
FetchContent_MakeAvailable(mlx)
102+
103+
# -----------------------------------------------------------------------------
104+
# MLX Backend library
105+
# -----------------------------------------------------------------------------
106+
107+
set(_mlx_backend__srcs
108+
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp
109+
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp
110+
)
111+
112+
add_library(mlxdelegate ${_mlx_backend__srcs})
113+
114+
# Ensure schema is generated before compiling
115+
add_dependencies(mlxdelegate mlx_schema flatc)
116+
117+
target_include_directories(
118+
mlxdelegate
119+
PRIVATE
120+
${CMAKE_CURRENT_SOURCE_DIR}/runtime
121+
${_mlx_schema__include_dir}
122+
${mlx_SOURCE_DIR}
123+
)
124+
125+
# Link against MLX and executorch
126+
target_link_libraries(
127+
mlxdelegate
128+
PRIVATE
129+
mlx_schema
130+
executorch_core
131+
mlx
132+
)
133+
134+
executorch_target_link_options_shared_lib(mlxdelegate)
135+
target_compile_options(mlxdelegate PUBLIC ${_common_compile_options})
136+
137+
install(
138+
TARGETS mlxdelegate mlx_schema
139+
EXPORT ExecuTorchTargets
140+
DESTINATION ${CMAKE_INSTALL_LIBDIR}
141+
)
142+
143+
# -----------------------------------------------------------------------------
144+
# Tests
145+
# -----------------------------------------------------------------------------
146+
147+
add_subdirectory(test)

backends/apple/mlx/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
#
8+
9+
"""MLX backend for ExecuTorch - executes models on Apple Silicon using MLX."""
10+
11+
# Import custom_ops module to register custom ATen ops before anything else
12+
from executorch.backends.apple.mlx import custom_ops as _custom_ops # noqa: F401
13+
14+
from executorch.backends.apple.mlx.preprocess import MLXBackend
15+
from executorch.backends.apple.mlx.partitioner import MLXPartitioner
16+
17+
__all__ = ["MLXBackend", "MLXPartitioner"]

backends/apple/mlx/custom_ops.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
#
8+
9+
"""
10+
Custom MLX operator definitions.
11+
12+
This module defines custom operators that are supported by the MLX backend.
13+
These ops are used during model export to represent operations that MLX
14+
can execute efficiently but may not have direct PyTorch equivalents.
15+
16+
The ops are registered using torch.library and include:
17+
- rms_norm: RMSNorm normalization
18+
- apply_rope: Rotary Position Embedding application
19+
"""
20+
21+
from typing import Optional, Tuple
22+
23+
import torch
24+
from torch import Tensor
25+
26+
27+
# =============================================================================
28+
# rms_norm: RMSNorm normalization
29+
# =============================================================================
30+
31+
32+
@torch.library.custom_op("mlx::rms_norm", mutates_args=())
33+
def rms_norm(x: Tensor, weight: Tensor, eps: float = 1e-5) -> Tensor:
34+
"""
35+
RMSNorm normalization.
36+
37+
Args:
38+
x: Input tensor of shape (..., hidden_dim)
39+
weight: Weight tensor of shape (hidden_dim,)
40+
eps: Small constant for numerical stability
41+
42+
Returns:
43+
Normalized tensor of the same shape as x
44+
"""
45+
x_f = x.to(torch.float32)
46+
var = x_f.pow(2).mean(dim=-1, keepdim=True)
47+
y = x_f * torch.rsqrt(var + eps)
48+
y = y.to(x.dtype)
49+
return y * weight.to(x.dtype)
50+
51+
52+
@torch.library.register_fake("mlx::rms_norm")
53+
def rms_norm_fake(x: Tensor, weight: Tensor, eps: float = 1e-5) -> Tensor:
54+
"""Fake implementation for tracing."""
55+
return x.new_empty(x.shape)
56+
57+
58+
# =============================================================================
59+
# apply_rope: Rotary Position Embedding
60+
# =============================================================================
61+
62+
63+
@torch.library.custom_op("mlx::apply_rope", mutates_args=())
64+
def apply_rope(
65+
q_in: Tensor, # (B, Hq, T, D)
66+
k_in: Tensor, # (B, Hk, T, D)
67+
head_dim: int,
68+
pos: int, # int, not tensor
69+
traditional: bool = False,
70+
base: float = 500000.0,
71+
scale: float = 1.0,
72+
freqs: Optional[Tensor] = None,
73+
) -> Tuple[Tensor, Tensor]:
74+
"""
75+
Apply Rotary Position Embedding to query and key tensors.
76+
77+
Args:
78+
q_in: Query tensor of shape (B, Hq, T, D)
79+
k_in: Key tensor of shape (B, Hk, T, D)
80+
head_dim: Dimension of each attention head
81+
pos: Starting position index (int, not tensor)
82+
traditional: Whether to use traditional RoPE formulation
83+
base: Base for frequency computation
84+
scale: Scale factor for frequencies
85+
freqs: Optional precomputed frequencies
86+
87+
Returns:
88+
Tuple of (rotated_q, rotated_k)
89+
"""
90+
Dh = int(head_dim)
91+
assert q_in.size(-1) == Dh and k_in.size(-1) == Dh, "head_dim mismatch"
92+
93+
# unpack as (B, H, T, D)
94+
B, Hq, T, _ = q_in.shape
95+
B2, Hk, T2, _ = k_in.shape
96+
assert B == B2 and T == T2, "RoPE expects q and k to have same B,T"
97+
half = Dh // 2
98+
99+
if freqs is None:
100+
# [1, 1, 1, half] to broadcast over B,H,T
101+
i = torch.arange(half, device=q_in.device, dtype=torch.float32)
102+
inv_freq = (base ** (-2.0 * i / Dh)).view(1, 1, 1, half)
103+
104+
# positions: [1, 1, T, 1]
105+
pos_range = torch.arange(
106+
pos, pos + T, device=q_in.device, dtype=torch.float32
107+
).view(1, 1, T, 1)
108+
109+
# final angles: [1, 1, T, half]
110+
angles = (pos_range * inv_freq) * float(scale)
111+
else:
112+
# assume freqs is already per-position, just reshape to [1,1,T,half]
113+
angles = freqs.to(torch.float32).view(1, 1, T, half)
114+
115+
cos = angles.cos().to(q_in.dtype) # [1,1,T,half]
116+
sin = angles.sin().to(q_in.dtype) # [1,1,T,half]
117+
118+
def rot(x: Tensor) -> Tensor:
119+
# x: [B, H, T, D]
120+
x1, x2 = x[..., :half], x[..., half : 2 * half]
121+
xr = x1 * cos - x2 * sin
122+
xi = x1 * sin + x2 * cos
123+
if 2 * half != Dh:
124+
return torch.cat([xr, xi, x[..., 2 * half :]], dim=-1)
125+
return torch.cat([xr, xi], dim=-1)
126+
127+
q_out = rot(q_in)
128+
k_out = rot(k_in)
129+
return q_out, k_out
130+
131+
132+
@torch.library.register_fake("mlx::apply_rope")
133+
def apply_rope_fake(
134+
q_in: Tensor,
135+
k_in: Tensor,
136+
head_dim: int,
137+
pos: int,
138+
traditional: bool = False,
139+
base: float = 500000.0,
140+
scale: float = 1.0,
141+
freqs: Optional[Tensor] = None,
142+
) -> Tuple[Tensor, Tensor]:
143+
"""Fake implementation for tracing."""
144+
return (
145+
q_in.new_empty(q_in.shape),
146+
k_in.new_empty(k_in.shape),
147+
)

0 commit comments

Comments
 (0)