Skip to content

Commit 58447b2

Browse files
authored
[MLX] Support multiple KV cache sessions, with shared constant data (#20408)
MLX backend already has mutable state in a separate execution context from its constant data. This PR exposes a way to configure that for external callers, and uses this to support serve.py on MLX like CUDA backend.
1 parent 4433d1f commit 58447b2

15 files changed

Lines changed: 888 additions & 61 deletions

.github/workflows/mlx.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,11 @@ jobs:
6666
echo "::endgroup::"
6767
6868
echo "::group::Build test runners"
69-
${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 ))
69+
${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner mlx_mutable_state_test -j$(( $(sysctl -n hw.ncpu) - 1 ))
70+
echo "::endgroup::"
71+
72+
echo "::group::Run mutable-state (multi-session) unit test"
73+
./cmake-out/backends/mlx/test/mlx_mutable_state_test
7074
echo "::endgroup::"
7175
7276
echo "::group::Run op unit tests"

backends/mlx/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,10 @@ option(ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION
255255
ON
256256
)
257257

258-
set(_mlx_backend__srcs ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp
259-
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp
258+
set(_mlx_backend__srcs
259+
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp
260+
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp
261+
${CMAKE_CURRENT_SOURCE_DIR}/runtime/mlx_mutable_state.cpp
260262
)
261263

262264
add_library(mlxdelegate ${_mlx_backend__srcs})

backends/mlx/custom_kernel_ops/gated_delta_rule.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ def gated_delta_rule(
5353
B, T_len, Hk, Dk = q.shape
5454
Hv, Dv = v.shape[-2:]
5555

56+
# The Metal kernel maps each v-head to its k-head group
57+
# (hk_idx = hv_idx / (Hv / Hk)); mirror that here so the eager reference also
58+
# supports Hk != Hv (GQA) instead of relying on broadcasting, which requires
59+
# Hk == Hv. repeat_interleave on the head dim reproduces that index mapping.
60+
if Hk != Hv:
61+
q = q.repeat_interleave(Hv // Hk, dim=2)
62+
k = k.repeat_interleave(Hv // Hk, dim=2)
63+
Hk = Hv
64+
5665
s = state.clone()
5766

5867
ys = []
@@ -101,6 +110,7 @@ def gated_delta_rule_fake(
101110
IntOrVid,
102111
MetalKernelNode,
103112
MultiplyNode,
113+
RepeatNode,
104114
ScanNode,
105115
SubtractNode,
106116
SumNode,
@@ -450,6 +460,33 @@ def _emit_scan(self, P: MLXProgramBuilder, n: Node) -> Slot:
450460
]
451461
)
452462

463+
# GQA: q/k carry Hk heads but the recurrence state/v have Hv heads. Expand
464+
# q/k to Hv (repeat_interleave on the head axis) so the per-step broadcasts
465+
# match, mirroring the Metal kernel's hk_idx = hv_idx / (Hv / Hk).
466+
Hk = int(self.q_node.meta["val"].shape[-2])
467+
Hv = int(self.v_node.meta["val"].shape[-2])
468+
if Hk != Hv:
469+
rep = IntOrVid.from_literal(Hv // Hk)
470+
_, q_exp = P.make_tmp_slot()
471+
P.emit(
472+
RepeatNode(
473+
x=P.slot_to_tid(q_slot),
474+
out=P.slot_to_tid(q_exp),
475+
repeats=rep,
476+
axis=2,
477+
)
478+
)
479+
_, k_exp = P.make_tmp_slot()
480+
P.emit(
481+
RepeatNode(
482+
x=P.slot_to_tid(k_slot),
483+
out=P.slot_to_tid(k_exp),
484+
repeats=rep,
485+
axis=2,
486+
)
487+
)
488+
q_slot, k_slot = q_exp, k_exp
489+
453490
# Carry needs a writable slot. This is node n's persistent output (the
454491
# mutated state), so it must be a node-owned slot — not a temp slot, whose
455492
# id is reclaimed on tmp_scope exit and would be read as dead by a later

backends/mlx/custom_kernel_ops/test/test_gated_delta_rule.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,8 @@ def forward(
9696
g: torch.Tensor, # [B, T, Hv]
9797
beta: torch.Tensor, # [B, T, Hv]
9898
) -> torch.Tensor:
99-
if self.head_repeat > 1:
100-
q = q.repeat_interleave(self.head_repeat, dim=2)
101-
k = k.repeat_interleave(self.head_repeat, dim=2)
99+
# Pass native Hk (no repeat_interleave): the op itself must handle
100+
# GQA head expansion (kernel via hk_idx mapping, scan/eager internally).
102101
return torch.ops.mlx.gated_delta_rule(
103102
q, k, v, g, beta, self.state, use_custom_kernel=self.use_custom_kernel
104103
)

backends/mlx/runtime/MLXBackend.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "MLXExecutor.h"
1010
#include "MLXInterpreter.h"
1111
#include "MLXLoader.h"
12+
#include "mlx_mutable_state.h"
1213

1314
#include <executorch/runtime/backend/interface.h>
1415
#include <executorch/runtime/core/error.h>
@@ -277,6 +278,12 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
277278
eval(handle->constants.tensors);
278279
}
279280

281+
// Register the handle with the per-session mutable-state manager. This is
282+
// a no-op unless a multi-session owner is active for this load (see
283+
// mlx_mutable_state.h); single-session execution is unaffected.
284+
mutable_state_note_handle(
285+
handle, &handle->program, &handle->mutable_buffers);
286+
280287
} catch (const std::exception& e) {
281288
ET_LOG(Error, "Failed to load MLX program: %s", e.what());
282289
handle->~MLXHandle();
@@ -366,6 +373,14 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
366373
}
367374
}
368375

376+
// Select the active session's mutable buffers (KV cache, recurrent/conv
377+
// state) before running. No-op for single-session handles; weights stay
378+
// shared via ExecutionState::constants.
379+
if (Error rebind_err = mutable_state_rebind_for_execute(h, h->state);
380+
rebind_err != Error::Ok) {
381+
return rebind_err;
382+
}
383+
369384
// Run the MLX program (builds lazy computation graph)
370385
h->interpreter.run(program, h->state, h->stream);
371386

@@ -431,6 +446,7 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
431446
void destroy(DelegateHandle* handle) const override {
432447
std::lock_guard<std::mutex> lock(mlx_global_mutex());
433448
if (handle != nullptr) {
449+
mutable_state_forget_handle(handle);
434450
auto* mlx_handle = static_cast<MLXHandle*>(handle);
435451
mlx_handle->~MLXHandle();
436452
}

0 commit comments

Comments
 (0)