Skip to content

Commit 643efa6

Browse files
authored
chore: optimize metal backend performance (#1669)
1 parent c19f1fe commit 643efa6

8 files changed

Lines changed: 494 additions & 46 deletions

File tree

aphrodite/metal/config.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,33 @@ def reset_config() -> None:
160160
"""Reset the global config (useful for testing)."""
161161
global _config
162162
_config = None
163+
164+
165+
def should_use_contiguous_kv_fast_path(
166+
config: MetalConfig,
167+
*,
168+
model_config: object | None,
169+
scheduler_config: object,
170+
) -> bool:
171+
"""Return whether Metal should prefer MLX's contiguous KV cache.
172+
173+
Paged attention is still the default for higher concurrency and features
174+
that need block-managed KV state. For dense, low-concurrency text serving,
175+
MLX's contiguous cache is currently much faster on decode and does not
176+
require an environment variable from the user.
177+
"""
178+
return (
179+
"APHRODITE_METAL_USE_PAGED_ATTENTION" not in os.environ
180+
and config.use_paged_attention
181+
and config.is_auto_memory
182+
and not config.turboquant
183+
and model_config is not None
184+
and not getattr(model_config, "is_hybrid", False)
185+
and getattr(scheduler_config, "max_num_seqs") <= 2
186+
)
187+
188+
189+
def enable_contiguous_kv_fast_path(config: MetalConfig) -> None:
190+
"""Switch a Metal config to the contiguous MLX KV cache path."""
191+
config.use_paged_attention = False
192+
config.kv_sharing_fast_prefill = False

aphrodite/metal/metal_kernel_backend/attention_sdpa.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _pick_kernel_block_size(cache_block_size: int) -> int:
104104

105105

106106
def _build_block_tables(
107-
raw_block_tables: list[list[int]],
107+
ctx: PagedAttentionContext,
108108
cache_block_size: int,
109109
) -> tuple[mx.array, int]:
110110
"""Build kernel-compatible block tables, translating if necessary.
@@ -117,14 +117,23 @@ def _build_block_tables(
117117
Returns:
118118
(block_tables, kernel_block_size)
119119
"""
120+
cached = ctx.block_tables_cache.get(cache_block_size)
121+
if cached is not None:
122+
return cached
123+
124+
raw_block_tables = ctx.block_tables
120125
if not raw_block_tables:
121-
return mx.zeros((0, 0), dtype=mx.int32), cache_block_size
126+
result = (mx.zeros((0, 0), dtype=mx.int32), cache_block_size)
127+
ctx.block_tables_cache[cache_block_size] = result
128+
return result
122129

123130
if cache_block_size in _KERNEL_BLOCK_SIZES:
124131
# Fast path — no translation needed.
125132
max_blocks = max(len(bt) for bt in raw_block_tables)
126133
padded = [bt + [0] * (max_blocks - len(bt)) for bt in raw_block_tables]
127-
return mx.array(padded, dtype=mx.int32), cache_block_size
134+
result = (mx.array(padded, dtype=mx.int32), cache_block_size)
135+
ctx.block_tables_cache[cache_block_size] = result
136+
return result
128137

129138
# Hybrid path — translate large block_size to a kernel-compatible one.
130139
# Vectorized: each vLLM block b → [b*ratio, b*ratio+1, …, b*ratio+ratio-1].
@@ -139,7 +148,9 @@ def _build_block_tables(
139148
expanded = (bt_arr[:, :, None] * ratio + offsets[None, None, :]).reshape(
140149
bt_arr.shape[0], -1
141150
)
142-
return expanded, kernel_bs
151+
result = (expanded, kernel_bs)
152+
ctx.block_tables_cache[cache_block_size] = result
153+
return result
143154

144155

145156
# === Q/K/V preparation (YOCO, K-eq-V, v_norm variants) ===
@@ -424,20 +435,24 @@ def sdpa_forward(
424435
k_3d = mx.contiguous(keys[0].transpose(1, 0, 2).astype(kv_cache.dtype))
425436
v_3d = mx.contiguous(values[0].transpose(1, 0, 2).astype(kv_cache.dtype))
426437

427-
slot_mapping = mx.array(ctx.slot_mapping, dtype=mx.int64)
428-
seq_lens = mx.array(ctx.context_lens, dtype=mx.int32)
429-
cu_seqlens_q = mx.array(ctx.cu_seqlens, dtype=mx.int32)
430-
max_seq_len = max(ctx.context_lens)
438+
slot_mapping = ctx.slot_mapping_mx
439+
if slot_mapping is None:
440+
slot_mapping = mx.array(ctx.slot_mapping, dtype=mx.int64)
441+
seq_lens = ctx.context_lens_mx
442+
if seq_lens is None:
443+
seq_lens = mx.array(ctx.context_lens, dtype=mx.int32)
444+
cu_seqlens_q = ctx.cu_seqlens_mx
445+
if cu_seqlens_q is None:
446+
cu_seqlens_q = mx.array(ctx.cu_seqlens, dtype=mx.int32)
447+
max_seq_len = ctx.max_context_len or max(ctx.context_lens)
431448

432449
# --- Block tables (with hybrid block-size translation) ---
433450
# vLLM may inflate block_size (e.g. 544) to align attention pages with
434451
# mamba pages in hybrid models. The Metal kernel only supports small
435452
# block sizes (8, 16, 32). _build_block_tables handles the translation:
436453
# it expands each vLLM block into multiple kernel blocks and returns the
437454
# kernel-compatible block_size. The cache is reshaped to match (zero-copy).
438-
block_tables, kernel_block_size = _build_block_tables(
439-
ctx.block_tables, kv_cache.block_size
440-
)
455+
block_tables, kernel_block_size = _build_block_tables(ctx, kv_cache.block_size)
441456

442457
if shared_kv is not None:
443458
# YOCO shared layer: the reference layer already scattered the

aphrodite/metal/paged_attention_common.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from dataclasses import dataclass, field
1818
from typing import Any
1919

20+
import mlx.core as mx
2021
from mlx_lm.models.base import create_causal_mask
2122

2223
# ---------------------------------------------------------------------------
@@ -52,6 +53,13 @@ class PagedAttentionContext:
5253
# GDN state pool slot mapping: request batch position → stable slot ID.
5354
# Populated by model_runner for hybrid models; None for non-hybrid.
5455
gdn_slot_mapping: list[int] | None = None
56+
# MLX forms of per-step metadata. These are shared by all layers in the
57+
# same forward pass to avoid rebuilding identical arrays per layer.
58+
slot_mapping_mx: mx.array | None = None
59+
context_lens_mx: mx.array | None = None
60+
cu_seqlens_mx: mx.array | None = None
61+
max_context_len: int = 0
62+
block_tables_cache: dict[int, tuple[mx.array, int]] = field(default_factory=dict)
5563

5664

5765
def set_context(ctx: PagedAttentionContext) -> None:
@@ -200,12 +208,15 @@ def prepare_unified(
200208
context_lens.append(start_pos + num_tokens)
201209
offsets.append(start_pos)
202210

203-
set_context(
204-
PagedAttentionContext(
205-
slot_mapping=slot_mapping,
206-
block_tables=block_tables,
207-
context_lens=context_lens,
208-
cu_seqlens=cu_seqlens,
209-
offsets=offsets,
210-
)
211+
ctx = PagedAttentionContext(
212+
slot_mapping=slot_mapping,
213+
block_tables=block_tables,
214+
context_lens=context_lens,
215+
cu_seqlens=cu_seqlens,
216+
offsets=offsets,
217+
slot_mapping_mx=mx.array(slot_mapping, dtype=mx.int64),
218+
context_lens_mx=mx.array(context_lens, dtype=mx.int32),
219+
cu_seqlens_mx=mx.array(cu_seqlens, dtype=mx.int32),
220+
max_context_len=max(context_lens, default=0),
211221
)
222+
set_context(ctx)

aphrodite/metal/platform.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
import torch
1010
from aphrodite.platforms.interface import DeviceCapability, Platform, PlatformEnum
1111

12-
from aphrodite.metal.config import get_config
12+
from aphrodite.metal.config import (
13+
enable_contiguous_kv_fast_path,
14+
get_config,
15+
should_use_contiguous_kv_fast_path,
16+
)
1317

1418
if TYPE_CHECKING:
1519
from aphrodite.config import AphroditeConfig
@@ -253,6 +257,20 @@ def check_and_update_config(cls, aphrodite_config: "AphroditeConfig") -> None:
253257
f"k_quant={config.k_quant}, v_quant={config.v_quant}"
254258
)
255259

260+
scheduler_config = aphrodite_config.scheduler_config
261+
if should_use_contiguous_kv_fast_path(
262+
config,
263+
model_config=model_config,
264+
scheduler_config=scheduler_config,
265+
):
266+
enable_contiguous_kv_fast_path(config)
267+
logger.info(
268+
"Metal: using contiguous MLX KV cache for low-concurrency "
269+
"dense serving (max_num_seqs=%d). Set "
270+
"APHRODITE_METAL_USE_PAGED_ATTENTION=1 to force paged attention.",
271+
scheduler_config.max_num_seqs,
272+
)
273+
256274
if config.debug:
257275
logger.info(f"Metal config: {config}")
258276

@@ -267,7 +285,6 @@ def check_and_update_config(cls, aphrodite_config: "AphroditeConfig") -> None:
267285
# Disable features not supported on Metal
268286
parallel_config.disable_custom_all_reduce = True
269287

270-
scheduler_config = aphrodite_config.scheduler_config
271288
if getattr(scheduler_config, "enable_chunked_prefill", False):
272289
if config.use_paged_attention:
273290
# The paged path uses a unified varlen Metal kernel that

aphrodite/metal/v1/cache_policy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,11 +646,15 @@ def determine_available_memory(self) -> int:
646646
)
647647
return available
648648

649-
available = self._worker._one_sequence_kv_bytes()
649+
one_sequence_bytes = self._worker._one_sequence_kv_bytes()
650+
max_num_seqs = self._worker.model_runner.scheduler_config.max_num_seqs
651+
available = one_sequence_bytes * max_num_seqs
650652
logger.info(
651653
"MLX path: reporting %.2f GB for scheduler admission control "
652-
"(one max-length sequence, max_model_len=%d)",
654+
"(%d max-length sequence%s, max_model_len=%d)",
653655
available / 1e9,
656+
max_num_seqs,
657+
"" if max_num_seqs == 1 else "s",
654658
self._worker.model_config.max_model_len,
655659
)
656660
return available

0 commit comments

Comments
 (0)