Skip to content

Commit 83a0942

Browse files
committed
up
1 parent 540ad17 commit 83a0942

12 files changed

Lines changed: 14447 additions & 40 deletions

File tree

.github/workflows/mlx.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
${CONDA_RUN} pip list
3939
4040
echo "::group::Build test runners"
41-
${CONDA_RUN} cmake --build cmake-out --target op_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 ))
41+
${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 ))
4242
echo "::endgroup::"
4343
4444
echo "::group::Run op unit tests"
@@ -53,6 +53,14 @@ jobs:
5353
-v
5454
echo "::endgroup::"
5555
56+
echo "::group::Run multi-thread stress test"
57+
${CONDA_RUN} python backends/mlx/test/export_multi_thread_test_model.py /tmp/multi_thread_test_model.pte
58+
ET_TESTING_MODEL_PATH=/tmp/multi_thread_test_model.pte \
59+
ET_TESTING_NUM_THREADS=50 \
60+
ET_PREDICTIONS_PER_THREAD=100 \
61+
./cmake-out/backends/mlx/test/multi_thread_test_runner
62+
echo "::endgroup::"
63+
5664
backend-tester:
5765
strategy:
5866
fail-fast: false

backends/mlx/custom_ops.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,259 @@
1313
These ops are used during model export to represent operations that MLX
1414
can execute efficiently but may not have direct PyTorch equivalents.
1515
"""
16+
17+
from typing import Optional
18+
19+
import torch
20+
from torch import Tensor
21+
22+
23+
@torch.library.custom_op("mlx::kv_cache_update", mutates_args=("cache",))
24+
def kv_cache_update(
25+
cache: Tensor, # [B, H, S_max, D] - mutated in place
26+
new_values: Tensor, # [B, H, S, D]
27+
start_pos: int,
28+
ring_size: int = 0,
29+
) -> Tensor:
30+
"""
31+
Mutating KV cache update that modifies cache in place.
32+
33+
This op updates the cache at positions [start_pos, start_pos + S) with
34+
new_values. The cache is mutated in place, similar to llama.update_cache.
35+
36+
Args:
37+
cache: Cache tensor of shape [B, H, S_max, D] (BHSD layout) - mutated
38+
new_values: New values to insert of shape [B, H, S, D]
39+
start_pos: Starting position index for insertion
40+
ring_size: If > 0, treat as ring buffer of this size: write position
41+
is start_pos % ring_size and writes wrap around. If 0 (default),
42+
linear update at start_pos with no wrapping.
43+
44+
Returns:
45+
A dummy tensor (1,) - the return value is not semantically meaningful
46+
but is required for slot management during export. This follows the
47+
same pattern as llama.update_cache.
48+
49+
Note:
50+
The BHSD layout matches what torch SDPA expects, avoiding transposition.
51+
"""
52+
seq_len = new_values.size(2)
53+
54+
if ring_size > 0:
55+
write_pos = start_pos % ring_size
56+
end_pos = write_pos + seq_len
57+
if end_pos <= ring_size:
58+
cache[:, :, write_pos:end_pos, :] = new_values
59+
else:
60+
first_part = ring_size - write_pos
61+
cache[:, :, write_pos:ring_size, :] = new_values[:, :, :first_part, :]
62+
cache[:, :, 0 : seq_len - first_part, :] = new_values[:, :, first_part:, :]
63+
else:
64+
end_pos = start_pos + seq_len
65+
assert end_pos <= cache.size(2), (
66+
f"kv_cache_update: write [{start_pos}, {end_pos}) exceeds "
67+
f"cache size {cache.size(2)}. Use ring_size > 0 for wrapping."
68+
)
69+
cache[:, :, start_pos:end_pos, :] = new_values
70+
71+
return torch.empty((1,), dtype=new_values.dtype, device=new_values.device)
72+
73+
74+
@torch.library.register_fake("mlx::kv_cache_update")
75+
def kv_cache_update_fake(
76+
cache: Tensor,
77+
new_values: Tensor,
78+
start_pos: int,
79+
ring_size: int = 0,
80+
) -> Tensor:
81+
"""Fake implementation for tracing - returns dummy tensor like llama.update_cache."""
82+
return torch.empty((1,), dtype=new_values.dtype, device="meta")
83+
84+
85+
@torch.library.custom_op("mlx::custom_sdpa", mutates_args=())
86+
def mlx_custom_sdpa(
87+
query: Tensor, # [B, num_heads, seq_len, head_dim] - BHSD
88+
key: Tensor, # [B, num_kv_heads, kv_len, head_dim] - BHSD (FULL cache)
89+
value: Tensor, # [B, num_kv_heads, kv_len, head_dim] - BHSD (FULL cache)
90+
start_pos: int, # FIRST position in current batch (0-indexed)
91+
attn_mask: Optional[Tensor] = None,
92+
dropout_p: float = 0.0,
93+
is_causal: bool = False,
94+
scale: Optional[float] = None,
95+
) -> Tensor:
96+
"""
97+
MLX custom SDPA with K/V cache slicing.
98+
99+
This op uses BHSD layout (matching PyTorch SDPA and MLX's SdpaNode).
100+
It receives the FULL K/V cache and slices to [0:stop_pos] before computing
101+
attention, where stop_pos = start_pos + query_seq_len.
102+
103+
The semantics follow executorch's llama.custom_sdpa:
104+
- start_pos: FIRST position of the current query batch
105+
- For prefill with 7 tokens at positions [0,1,2,3,4,5,6]: start_pos=0, stop_pos=7
106+
- For decode at position 10: start_pos=10, stop_pos=11
107+
108+
Args:
109+
query: Query tensor [B, num_heads, seq_len, head_dim]
110+
key: Key cache [B, num_kv_heads, kv_len, head_dim] - FULL cache
111+
value: Value cache [B, num_kv_heads, kv_len, head_dim] - FULL cache
112+
start_pos: FIRST position in current batch (SymInt)
113+
attn_mask: Optional attention mask (only used when is_causal=False)
114+
dropout_p: Dropout probability (default 0.0)
115+
is_causal: Whether to apply causal masking (default False)
116+
scale: Attention scale factor (default 1/sqrt(head_dim))
117+
118+
Returns:
119+
Attention output [B, num_heads, seq_len, head_dim] - BHSD
120+
"""
121+
if scale is None:
122+
scale = query.shape[-1] ** -0.5
123+
124+
# Compute stop_pos = start_pos + query_seq_len
125+
# BHSD layout: seq_len is at dim 2
126+
query_seq_len = query.shape[2]
127+
stop_pos = start_pos + query_seq_len
128+
129+
# Constrain symbolic shapes so torch.export can resolve guards.
130+
# start_pos is data-dependent (from input_pos), so the slice
131+
# stop_pos > kv_len comparison is unresolvable without these hints.
132+
torch._check(start_pos >= 0)
133+
torch._check(stop_pos <= key.shape[2])
134+
135+
# Slice K/V to valid cache entries [0:stop_pos]
136+
key_sliced = key[:, :, :stop_pos, :]
137+
value_sliced = value[:, :, :stop_pos, :]
138+
139+
# Handle GQA: expand K/V heads to match query heads
140+
num_heads = query.shape[1]
141+
num_kv_heads = key.shape[1]
142+
if num_kv_heads != num_heads:
143+
num_groups = num_heads // num_kv_heads
144+
key_sliced = key_sliced.repeat_interleave(num_groups, dim=1)
145+
value_sliced = value_sliced.repeat_interleave(num_groups, dim=1)
146+
147+
# Build explicit lower-right aligned causal mask to match MLX's SdpaNode.
148+
# PyTorch's is_causal=True uses upper-left alignment when Q_len != K_len,
149+
# but for KV-cache inference q[i] is at context position (start_pos + i)
150+
# and should attend to all positions 0..start_pos+i (lower-right).
151+
if is_causal:
152+
L, S = query.shape[2], key_sliced.shape[2]
153+
offset = S - L # equals start_pos
154+
mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(
155+
diagonal=offset
156+
)
157+
attn_mask = torch.where(mask, 0.0, float("-inf")).to(query.dtype)
158+
159+
# Compute SDPA - returns BHSD
160+
return torch.nn.functional.scaled_dot_product_attention(
161+
query,
162+
key_sliced,
163+
value_sliced,
164+
attn_mask=attn_mask,
165+
dropout_p=dropout_p,
166+
is_causal=False,
167+
scale=scale,
168+
)
169+
170+
171+
@torch.library.register_fake("mlx::custom_sdpa")
172+
def mlx_custom_sdpa_fake(
173+
query: Tensor,
174+
key: Tensor,
175+
value: Tensor,
176+
start_pos: int,
177+
attn_mask: Optional[Tensor] = None,
178+
dropout_p: float = 0.0,
179+
is_causal: bool = False,
180+
scale: Optional[float] = None,
181+
) -> Tensor:
182+
"""Fake implementation for tracing - returns BHSD shape (same as query)."""
183+
return query.new_empty(query.shape)
184+
185+
186+
@torch.library.custom_op("mlx::rope", mutates_args=())
187+
def rope(
188+
x: Tensor, # (B, H, T, D)
189+
dims: int,
190+
pos: int, # int, not tensor
191+
traditional: bool = False,
192+
base: float = 500000.0,
193+
scale: float = 1.0,
194+
freqs: Optional[Tensor] = None,
195+
) -> Tensor:
196+
"""
197+
Apply Rotary Position Embedding to a single tensor.
198+
199+
Args:
200+
x: Input tensor of shape (B, H, T, D)
201+
dims: Number of feature dimensions to rotate. If less than D,
202+
only the first `dims` dimensions are rotated and the rest
203+
are left unchanged.
204+
pos: Starting position index (int, not tensor)
205+
traditional: Whether to use traditional RoPE formulation
206+
base: Base for frequency computation
207+
scale: Scale factor for frequencies
208+
freqs: Optional precomputed frequencies
209+
210+
Returns:
211+
Rotated tensor of the same shape
212+
"""
213+
Dh = int(dims)
214+
215+
B, H, T, _ = x.shape
216+
half = Dh // 2
217+
218+
if freqs is None:
219+
# [1, 1, 1, half] to broadcast over B,H,T
220+
i = torch.arange(half, device=x.device, dtype=torch.float32)
221+
inv_freq = (base ** (-2.0 * i / Dh)).view(1, 1, 1, half)
222+
223+
# positions: [1, 1, T, 1]
224+
pos_range = torch.arange(
225+
pos, pos + T, device=x.device, dtype=torch.float32
226+
).view(1, 1, T, 1)
227+
228+
# final angles: [1, 1, T, half]
229+
angles = (pos_range * inv_freq) * float(scale)
230+
else:
231+
# assume freqs is already per-position, just reshape to [1,1,T,half]
232+
angles = freqs.to(torch.float32).view(1, 1, T, half)
233+
234+
cos = angles.cos().to(x.dtype) # [1,1,T,half]
235+
sin = angles.sin().to(x.dtype) # [1,1,T,half]
236+
237+
# Split into rotated and unrotated portions
238+
x_rot = x[..., :Dh]
239+
x_pass = x[..., Dh:]
240+
241+
if traditional:
242+
# Interleaved pairs: (x[0],x[1]), (x[2],x[3]), ...
243+
x1 = x_rot[..., 0::2] # even indices
244+
x2 = x_rot[..., 1::2] # odd indices
245+
xr = x1 * cos - x2 * sin
246+
xi = x1 * sin + x2 * cos
247+
rotated = torch.stack([xr, xi], dim=-1).flatten(-2)
248+
else:
249+
# Split-half: first half paired with second half
250+
x1, x2 = x_rot[..., :half], x_rot[..., half:]
251+
xr = x1 * cos - x2 * sin
252+
xi = x1 * sin + x2 * cos
253+
rotated = torch.cat([xr, xi], dim=-1)
254+
255+
if x_pass.shape[-1] > 0:
256+
return torch.cat([rotated, x_pass], dim=-1)
257+
return rotated
258+
259+
260+
@torch.library.register_fake("mlx::rope")
261+
def rope_fake(
262+
x: Tensor,
263+
dims: int,
264+
pos: int,
265+
traditional: bool = False,
266+
base: float = 500000.0,
267+
scale: float = 1.0,
268+
freqs: Optional[Tensor] = None,
269+
) -> Tensor:
270+
"""Fake implementation for tracing."""
271+
return x.new_empty(x.shape)

backends/mlx/llm/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.

0 commit comments

Comments
 (0)