Skip to content

Commit c19f1fe

Browse files
authored
feat: add native Metal support (#1668)
1 parent 18f852d commit c19f1fe

96 files changed

Lines changed: 20399 additions & 13 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ if (NOT APHRODITE_TARGET_DEVICE STREQUAL "cuda" AND
105105
NOT APHRODITE_TARGET_DEVICE STREQUAL "rocm")
106106
if (APHRODITE_TARGET_DEVICE STREQUAL "cpu")
107107
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
108+
elseif (APHRODITE_TARGET_DEVICE STREQUAL "metal")
109+
include(${CMAKE_CURRENT_LIST_DIR}/cmake/metal_extension.cmake)
108110
else()
109111
return()
110112
endif()

MANIFEST.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ include requirements/cuda.txt
44
include requirements/rocm.txt
55
include requirements/neuron.txt
66
include requirements/cpu.txt
7+
include requirements/metal.txt
78
include CMakeLists.txt
89
include aphrodite/endpoints/kobold/klite.embd
910

1011
recursive-include kernels *
1112
recursive-include cmake *
13+
recursive-include csrc/metal *
1214
prune csrc/xqa/cubin

aphrodite/benchmarks/perf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,9 @@ def main(args: argparse.Namespace) -> None:
286286

287287
if args.max_num_batched_tokens is None:
288288
args.max_num_batched_tokens = args.chunk_size
289+
max_required_len = args.max_length + args.gen_tokens + 1
290+
if args.max_model_len is None:
291+
args.max_model_len = max_required_len
289292

290293
# Keep request-level timing metrics enabled, but avoid interleaving the
291294
# normal "Request completed" log line with the perf.py-style table.
@@ -300,7 +303,6 @@ def main(args: argparse.Namespace) -> None:
300303
engine_args = EngineArgs.from_cli_args(args)
301304
llm = LLM.from_engine_args(engine_args)
302305

303-
max_required_len = args.max_length + args.gen_tokens + 1
304306
assert llm.llm_engine.model_config.max_model_len >= max_required_len, (
305307
f"Please ensure max_model_len is at least {max_required_len} tokens for this benchmark."
306308
)

aphrodite/config/device.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from aphrodite.config.utils import config
1111
from aphrodite.utils.hashing import safe_hash
1212

13-
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
13+
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu", "metal"]
1414

1515

1616
@config(config=ConfigDict(arbitrary_types_allowed=True))
@@ -65,9 +65,13 @@ def __post_init__(self):
6565
elif isinstance(self.device, torch.device):
6666
self.device_type = self.device.type
6767

68-
# Some device types require processing inputs on CPU
68+
# Some device types require processing inputs on CPU. Metal is an
69+
# MLX-backed platform, but Aphrodite tensors still flow through CPU/MPS
70+
# compatible paths rather than a torch "metal" device.
6971
if self.device_type in ["tpu"]:
7072
self.device = None
73+
elif self.device_type == "metal":
74+
self.device = torch.device("cpu")
7175
else:
7276
# Set device with device type
7377
self.device = torch.device(self.device_type)

aphrodite/envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def _get_or_set_default() -> str:
499499
environment_variables: dict[str, Callable[[], Any]] = {
500500
# ================== Installation Time Env Vars ==================
501501
# Target device of Aphrodite, supporting [cuda (by default),
502-
# rocm, cpu]
502+
# rocm, cpu, metal]
503503
"APHRODITE_TARGET_DEVICE": lambda: os.getenv("APHRODITE_TARGET_DEVICE", "cuda").lower(),
504504
# Main CUDA version of Aphrodite. This follows PyTorch but can be overridden.
505505
"APHRODITE_MAIN_CUDA_VERSION": lambda: os.getenv("APHRODITE_MAIN_CUDA_VERSION", "").lower() or "13.0",

aphrodite/metal/__init__.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Aphrodite Metal runtime - high-performance LLM inference on Apple Silicon.
3+
4+
This runtime enables Aphrodite to run on Apple Silicon Macs using MLX as the
5+
primary compute backend, with PyTorch for model loading and interoperability.
6+
"""
7+
8+
import logging
9+
import os
10+
import sys
11+
12+
__version__ = "0.2.0"
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
def _configure_logging() -> None:
18+
"""Configure aphrodite.metal logging to mirror Aphrodite settings."""
19+
from aphrodite.envs import APHRODITE_LOGGING_LEVEL
20+
21+
aphrodite_logger = logging.getLogger("aphrodite")
22+
metal_logger = logging.getLogger("aphrodite.metal")
23+
metal_logger.setLevel(logging.getLevelName(APHRODITE_LOGGING_LEVEL))
24+
25+
if aphrodite_logger.handlers and not metal_logger.handlers:
26+
for handler in aphrodite_logger.handlers:
27+
metal_logger.addHandler(handler)
28+
metal_logger.propagate = False
29+
30+
31+
def _apply_macos_defaults() -> None:
32+
"""Apply safe defaults for macOS when using the Metal plugin.
33+
34+
Aphrodite's v1 engine launches a worker process. When the start method is
35+
`fork`, macOS can crash the child process if the parent has imported libraries that
36+
touched the Objective-C runtime (commonly surfaced as
37+
`objc_initializeAfterForkError`).
38+
39+
Defaulting to `spawn` avoids forking a partially-initialized runtime.
40+
"""
41+
if sys.platform != "darwin":
42+
return
43+
if os.environ.get("APHRODITE_WORKER_MULTIPROC_METHOD") is not None:
44+
return
45+
46+
# macOS fork-safety:
47+
# `fork()` with an initialized Objective-C runtime is unsafe and can crash in
48+
# the child process (commonly observed via `objc_initializeAfterForkError`).
49+
# Using `spawn` starts a fresh interpreter and avoids inheriting this state.
50+
# See: https://www.sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html
51+
os.environ["APHRODITE_WORKER_MULTIPROC_METHOD"] = "spawn"
52+
logger.debug(
53+
"macOS detected + Metal plugin active: defaulting APHRODITE_WORKER_MULTIPROC_METHOD "
54+
"to 'spawn' to avoid Objective-C runtime fork-safety crashes. "
55+
"Set APHRODITE_WORKER_MULTIPROC_METHOD explicitly to override."
56+
)
57+
58+
59+
# Lazy imports to avoid loading heavy runtime dependencies on plain import.
60+
def __getattr__(name):
61+
"""Lazy import module components."""
62+
if name == "MetalConfig":
63+
from aphrodite.metal.config import MetalConfig
64+
65+
return MetalConfig
66+
elif name == "get_config":
67+
from aphrodite.metal.config import get_config
68+
69+
return get_config
70+
elif name == "reset_config":
71+
from aphrodite.metal.config import reset_config
72+
73+
return reset_config
74+
elif name == "MetalPlatform":
75+
from aphrodite.metal.platform import MetalPlatform
76+
77+
return MetalPlatform
78+
elif name == "register":
79+
return _register
80+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
81+
82+
83+
__all__ = [
84+
"MetalConfig",
85+
"MetalPlatform",
86+
"get_config",
87+
"reset_config",
88+
"register",
89+
]
90+
91+
92+
def _register() -> str | None:
93+
"""Register the Metal platform with Aphrodite.
94+
95+
Kept for compatibility with Aphrodite's plugin-style platform loader even
96+
though Metal is wired as a built-in platform.
97+
98+
Returns:
99+
Fully qualified class name if platform is available, None otherwise
100+
"""
101+
_configure_logging()
102+
_apply_macos_defaults()
103+
104+
# Register our env vars with Aphrodite's registry so validate_environ()
105+
# does not warn about unknown APHRODITE_METAL_* / APHRODITE_MLX_* variables.
106+
import aphrodite.envs
107+
108+
from aphrodite.metal.envs import environment_variables as metal_env_vars
109+
110+
aphrodite.envs.environment_variables.update(metal_env_vars)
111+
112+
from aphrodite.metal.compat import apply_compat_patches
113+
114+
apply_compat_patches()
115+
116+
from aphrodite.metal.platform import MetalPlatform
117+
118+
if MetalPlatform.is_available():
119+
return "aphrodite.platforms.metal.MetalPlatform"
120+
return None

0 commit comments

Comments
 (0)