Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,5 @@ images/
**/generated/**
aphrodite/_version.py
shellcheck-stable/
ep_kernels_workspace/
ep_kernels_workspace/
*.metallib
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ if (NOT APHRODITE_TARGET_DEVICE STREQUAL "cuda" AND
NOT APHRODITE_TARGET_DEVICE STREQUAL "rocm")
if (APHRODITE_TARGET_DEVICE STREQUAL "cpu")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
elseif(APHRODITE_TARGET_DEVICE STREQUAL "mps")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/mps_extension.cmake)
else()
return()
endif()
Expand Down
8 changes: 8 additions & 0 deletions aphrodite/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from aphrodite import _custom_ops as ops
from aphrodite.platforms import current_platform
from aphrodite.triton_utils import HAS_TRITON

if HAS_TRITON:
Expand Down Expand Up @@ -128,6 +129,13 @@ def forward_decode(

if use_v1:
# Run PagedAttention V1.
if current_platform.is_mps():
# Ensure helper tensors are on the same device as query (e.g., MPS)
device = query.device
if block_tables.device != device:
block_tables = block_tables.to(device)
if seq_lens.device != device:
seq_lens = seq_lens.to(device)
ops.paged_attention_v1(
output,
query,
Expand Down
6 changes: 6 additions & 0 deletions aphrodite/modeling/_custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def forward_xpu(self, *args, **kwargs):
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)

def forward_mps(self, *args, **kwargs):
# On MPS, prefer the PyTorch-native implementation unless explicitly overridden.
return self.forward_native(*args, **kwargs)

def forward_cpu(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)
Expand Down Expand Up @@ -101,6 +105,8 @@ def dispatch_forward(self):
return self.forward_tpu
elif current_platform.is_xpu():
return self.forward_xpu
elif current_platform.is_mps():
return self.forward_mps
elif current_platform.is_neuron():
return self.forward_neuron
elif current_platform.is_out_of_tree():
Expand Down
14 changes: 7 additions & 7 deletions aphrodite/modeling/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class FatreluAndMul(CustomOp):
def __init__(self, threshold: float = 0.):
super().__init__()
self.threshold = threshold
if current_platform.is_cuda_alike():
if current_platform.is_cuda_alike() or current_platform.is_mps():
self.op = torch.ops._C.fatrelu_and_mul
elif current_platform.is_cpu():
self._forward_method = self.forward_native
Expand Down Expand Up @@ -63,7 +63,7 @@ class SiluAndMul(CustomOp):

def __init__(self):
super().__init__()
if current_platform.is_cuda_alike():
if current_platform.is_cuda_alike() or current_platform.is_mps():
self.op = torch.ops._C.silu_and_mul
elif current_platform.is_xpu():
from aphrodite._ipex_ops import ipex_ops
Expand Down Expand Up @@ -111,7 +111,7 @@ class MulAndSilu(CustomOp):

def __init__(self):
super().__init__()
if current_platform.is_cuda_alike():
if current_platform.is_cuda_alike() or current_platform.is_mps():
self.op = torch.ops._C.mul_and_silu
elif current_platform.is_xpu():
from aphrodite._ipex_ops import ipex_ops
Expand Down Expand Up @@ -202,7 +202,7 @@ def __init__(self, approximate: str = "none"):
self.approximate = approximate
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
if current_platform.is_cuda_alike() or current_platform.is_cpu():
if current_platform.is_cuda_alike() or current_platform.is_cpu() or current_platform.is_mps():
if approximate == "none":
self.op = torch.ops._C.gelu_and_mul
elif approximate == "tanh":
Expand Down Expand Up @@ -242,7 +242,7 @@ class NewGELU(CustomOp):

def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
if current_platform.is_cuda_alike() or current_platform.is_cpu() or current_platform.is_mps():
self.op = torch.ops._C.gelu_new
elif current_platform.is_xpu():
from aphrodite._ipex_ops import ipex_ops
Expand All @@ -268,7 +268,7 @@ class FastGELU(CustomOp):

def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
if current_platform.is_cuda_alike() or current_platform.is_cpu() or current_platform.is_mps():
self.op = torch.ops._C.gelu_fast
elif current_platform.is_xpu():
from aphrodite._ipex_ops import ipex_ops
Expand All @@ -293,7 +293,7 @@ class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
if current_platform.is_cuda_alike() or current_platform.is_cpu() or current_platform.is_mps():
self.op = torch.ops._C.gelu_quick
elif current_platform.is_xpu():
from aphrodite._ipex_ops import ipex_ops
Expand Down
22 changes: 21 additions & 1 deletion aphrodite/modeling/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,12 @@ def forward_native(
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype)
if self.has_weight:
x = x * self.weight
weight = self.weight
if weight.device != x.device:
weight = weight.to(device=x.device)
if weight.dtype != orig_dtype:
weight = weight.to(dtype=orig_dtype)
x = x * weight
if residual is None:
return x
else:
Expand All @@ -169,6 +174,14 @@ def forward_cuda(
else:
return norm_func(x, self.weight.data, self.variance_epsilon)

def forward_mps(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
# Use native computation on MPS for correctness
return self.forward_native(x, residual)

def forward_xpu(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -264,3 +277,10 @@ def forward_cuda(
self.forward_static)
self._is_compiled = True
return self.forward_native(x, residual)

def forward_mps(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
return self.forward_native(x, residual)
40 changes: 34 additions & 6 deletions aphrodite/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,22 +147,32 @@ def xpu_platform_plugin() -> Optional[str]:


def cpu_platform_plugin() -> Optional[str]:
is_cpu = False
logger.debug("Checking if CPU platform is available.")
try:
import sys
# On macOS, if MPS is built and available, do NOT activate CPU
if sys.platform.startswith("darwin"):
try:
import torch # type: ignore
if torch.backends.mps.is_built() and torch.backends.mps.is_available():
logger.debug("MPS detected on macOS; CPU platform will not be activated.")
return None
except Exception:
# If torch import fails, fall through to CPU detection
pass

is_cpu = aphrodite_version_matches_substr("cpu")
if is_cpu:
logger.debug("Confirmed CPU platform is available because"
" Aphrodite is built with CPU.")
logger.debug("Confirmed CPU platform is available because Aphrodite is built with CPU.")
if not is_cpu:
import sys
# As a fallback, allow CPU on macOS only if MPS is not available
is_cpu = sys.platform.startswith("darwin")
if is_cpu:
logger.debug("Confirmed CPU platform is available"
" because the machine is MacOS.")
logger.debug("Confirmed CPU platform is available on macOS (no MPS).")

except Exception as e:
logger.debug("CPU platform is not available because: {}", str(e))
return None

return "aphrodite.platforms.cpu.CpuPlatform" if is_cpu else None

Expand Down Expand Up @@ -190,6 +200,23 @@ def neuron_platform_plugin() -> Optional[str]:
is_neuron = tnx_installed or nxd_installed
return "aphrodite.platforms.neuron.NeuronPlatform" if is_neuron else None

def mps_platform_plugin() -> Optional[str]:
is_mps = False
logger.debug("Checking if MPS platform is available.")
try:
import torch
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
is_mps = True
logger.debug("Confirmed MPS platform is available.")
return "aphrodite.platforms.mps.MpsPlatform"
else:
logger.debug("MPS platform is not available.")
return None
except Exception as e:
logger.debug("MPS platform is not available because: {}", str(e))
return None
return "aphrodite.platforms.mps.MpsPlatform" if is_mps else None


builtin_platform_plugins = {
'tpu': tpu_platform_plugin,
Expand All @@ -198,6 +225,7 @@ def neuron_platform_plugin() -> Optional[str]:
'xpu': xpu_platform_plugin,
'cpu': cpu_platform_plugin,
'neuron': neuron_platform_plugin,
'mps': mps_platform_plugin,
}


Expand Down
4 changes: 4 additions & 0 deletions aphrodite/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class PlatformEnum(enum.Enum):
XPU = enum.auto()
CPU = enum.auto()
NEURON = enum.auto()
MPS = enum.auto()
OOT = enum.auto()
UNSPECIFIED = enum.auto()

Expand Down Expand Up @@ -159,6 +160,9 @@ def is_cpu(self) -> bool:
def is_neuron(self) -> bool:
return self._enum == PlatformEnum.NEURON

def is_mps(self) -> bool:
return self._enum == PlatformEnum.MPS

def is_out_of_tree(self) -> bool:
return self._enum == PlatformEnum.OOT

Expand Down
105 changes: 105 additions & 0 deletions aphrodite/platforms/mps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import TYPE_CHECKING, Optional

import torch
from loguru import logger

from .interface import Platform, PlatformEnum, _Backend

if TYPE_CHECKING:
from aphrodite.common.config import AphroditeConfig, ModelConfig


class MpsPlatform(Platform):
_enum = PlatformEnum.MPS
device_name: str = "mps"
device_type: str = "mps"
dispatch_key: str = "MPS"

@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool,
use_mla: bool) -> str:
# V1 engine: use the MPS SDPA backend (routes to torch SDPA on MPS and our paged attention ops)
if use_v1:
return "aphrodite.v1.attention.backends.mps_attn.MpsSDPABackend"
# V0 engine: default to XFormers backend (works on CPU; on MPS it will route via PyTorch ops)
return "aphrodite.attention.backends.xformers.XFormersBackend"

@classmethod
def set_device(cls, device: torch.device) -> None:
pass

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return "Apple MPS"

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
import psutil
return psutil.virtual_memory().total

@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
from aphrodite.common import envs
if enforce_eager and not envs.APHRODITE_USE_V1:
logger.warning(
"To see benefits of async output processing, enable MPS "
"graph. Since, enforce-eager is enabled, async output "
"processor cannot be used")
return False
return True

@classmethod
def inference_mode(cls):
"""A device-specific wrapper of `torch.inference_mode`."""
return torch.no_grad()

@classmethod
def check_and_update_config(cls, aphrodite_config: "AphroditeConfig") -> None:
"""Check and update the configuration for MPS platform."""
parallel_config = aphrodite_config.parallel_config
if parallel_config.worker_cls == "auto":
# Use V1 MPS worker when V1 engine is enabled; else use V0 worker
from aphrodite.common import envs as _envs
if _envs.APHRODITE_USE_V1:
parallel_config.worker_cls = "aphrodite.v1.worker.mps_worker.Worker"
else:
parallel_config.worker_cls = "aphrodite.worker.worker.Worker"

if parallel_config.tensor_parallel_size > 1:
raise RuntimeError("MPS backend does not support tensor parallelism")
if parallel_config.pipeline_parallel_size > 1:
raise RuntimeError("MPS backend does not support pipeline parallelism")

cache_config = aphrodite_config.cache_config
if cache_config.block_size is None:
cache_config.block_size = 16

compilation_config = aphrodite_config.compilation_config
# Disable compilation on MPS
from aphrodite.common.config import CompilationLevel
compilation_config.level = CompilationLevel.NO_COMPILATION
compilation_config.use_cudagraph = False
compilation_config.full_cuda_graph = False

@classmethod
def get_current_memory_usage(cls, device: torch.device) -> int:
"""Get current memory usage for MPS device."""
return torch.mps.current_allocated_memory()

@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on MPS.")
return False

@classmethod
def supports_v1(cls, model_config: "ModelConfig") -> bool: # type: ignore
# V1 engine is supported on MPS for standard decoder/encoder models
# given we use torch SDPA and our MPS paged attention kernels.
return True

@classmethod
def default_v1(cls, model_config: "ModelConfig") -> bool: # type: ignore
# Enable V1 by default on MPS.
return True
13 changes: 13 additions & 0 deletions aphrodite/v1/attention/backends/cpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from aphrodite.attention.backends.utils import CommonAttentionState
from aphrodite.common.config import AphroditeConfig
from aphrodite.common.logger import log_once
from aphrodite.platforms import current_platform
from aphrodite.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from aphrodite.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -610,6 +611,14 @@ def forward(
block_tables_arg,
) = decode_meta.get_seq_len_block_table_args(attn_type)

if current_platform.is_mps():
# Ensure auxiliary tensors are on the same device as query (e.g., MPS)
device = query.device
if block_tables_arg is not None and block_tables_arg.device != device:
block_tables_arg = block_tables_arg.to(device)
if seq_lens_arg is not None and seq_lens_arg.device != device:
seq_lens_arg = seq_lens_arg.to(device)

self.paged_attn_impl.forward_decode(
output[attn_metadata.num_prefill_tokens:, :, :],
query[attn_metadata.num_prefill_tokens:, :, :],
Expand Down Expand Up @@ -656,6 +665,10 @@ def _run_sdpa_forward(
else:
seq_lens, _ = attn_metadata.get_seq_lens(attn_type)
attn_masks = [None] * len(seq_lens)
# Ensure masks are placed on the same device as query (e.g., MPS)
attn_masks = [
(m if m is None else m.to(device=query.device)) for m in attn_masks
]
attn_metadata.set_attn_bias(attn_masks, attn_type)

query = query.movedim(0, query.dim() - 2)
Expand Down
Loading
Loading