Skip to content

Commit fab5620

Browse files
committed
up
1 parent ede3a6e commit fab5620

3 files changed

Lines changed: 68 additions & 40 deletions

File tree

backends/mlx/CMakeLists.txt

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -212,25 +212,32 @@ set(MLX_METAL_JIT
212212
CACHE BOOL "Use JIT compiled Metal kernels"
213213
)
214214

215-
# Auto-apply json patch so MLX reuses executorch's nlohmann_json instead of
216-
# downloading its own copy via FetchContent. TODO: upstream a patch to MLX
217-
set(MLX_JSON_PATCH "${CMAKE_CURRENT_SOURCE_DIR}/patches/mlx_json.patch")
218-
if(EXISTS "${MLX_JSON_PATCH}" AND EXISTS "${MLX_SOURCE_DIR}")
219-
execute_process(
220-
COMMAND git apply --check ${MLX_JSON_PATCH}
221-
WORKING_DIRECTORY ${MLX_SOURCE_DIR}
222-
RESULT_VARIABLE _mlx_json_patch_check
223-
OUTPUT_QUIET ERROR_QUIET
224-
)
225-
if(_mlx_json_patch_check EQUAL 0)
215+
# Auto-apply patches to MLX submodule. Each patch is applied idempotently: `git
216+
# apply --check` tests whether the patch is still applicable (i.e. not yet
217+
# applied), and only then applies it.
218+
set(_mlx_patches
219+
"${CMAKE_CURRENT_SOURCE_DIR}/patches/mlx_json.patch"
220+
"${CMAKE_CURRENT_SOURCE_DIR}/patches/mlx_metal_device_retain.patch"
221+
)
222+
foreach(_patch IN LISTS _mlx_patches)
223+
if(EXISTS "${_patch}" AND EXISTS "${MLX_SOURCE_DIR}")
224+
get_filename_component(_patch_name "${_patch}" NAME)
226225
execute_process(
227-
COMMAND git apply ${MLX_JSON_PATCH} WORKING_DIRECTORY ${MLX_SOURCE_DIR}
226+
COMMAND git apply --check "${_patch}"
227+
WORKING_DIRECTORY ${MLX_SOURCE_DIR}
228+
RESULT_VARIABLE _patch_check
229+
OUTPUT_QUIET ERROR_QUIET
228230
)
229-
message(STATUS "Applied mlx_json.patch to MLX submodule")
230-
else()
231-
message(STATUS "mlx_json.patch already applied or not applicable")
231+
if(_patch_check EQUAL 0)
232+
execute_process(
233+
COMMAND git apply "${_patch}" WORKING_DIRECTORY ${MLX_SOURCE_DIR}
234+
)
235+
message(STATUS "Applied ${_patch_name} to MLX submodule")
236+
else()
237+
message(STATUS "${_patch_name} already applied or not applicable")
238+
endif()
232239
endif()
233-
endif()
240+
endforeach()
234241

235242
# Add MLX subdirectory
236243
message(STATUS "Adding MLX from submodule: ${MLX_SOURCE_DIR}")

backends/mlx/builder/program_builder.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
)
4545
from executorch.backends.mlx.serialization.mlx_graph_schema import (
4646
FloatOrVid,
47+
IdCopyNode,
4748
Instruction,
4849
InstructionChain,
4950
IntOrVid,
@@ -557,11 +558,51 @@ def check_support_only(self) -> None:
557558
# SymInts and corrupts the shape_env. This method is used for
558559
# ops_to_not_decompose() where we only need support status.
559560

561+
def _emit_buffer_mutation_writebacks(self):
562+
"""Emit copy-back instructions for BUFFER_MUTATION outputs.
563+
564+
When a model mutates a buffer (e.g., via .copy_() or .mul_()),
565+
torch.export functionalizes it: the new value is a computation result,
566+
and the output spec marks it as BUFFER_MUTATION with a target buffer.
567+
568+
This method emits an IdCopyNode for each BUFFER_MUTATION output,
569+
copying the computation result back to the mutable buffer slot so
570+
the updated value persists across execution calls.
571+
"""
572+
from torch.export.graph_signature import InputKind, OutputKind
573+
574+
# Map buffer target name -> input placeholder name
575+
target_to_placeholder = {}
576+
for ispec in self.ep.graph_signature.input_specs:
577+
if ispec.kind == InputKind.BUFFER and ispec.target is not None:
578+
target_to_placeholder[ispec.target] = ispec.arg.name
579+
580+
for ospec in self.ep.graph_signature.output_specs:
581+
if ospec.kind != OutputKind.BUFFER_MUTATION:
582+
continue
583+
584+
result_slot = self.slot_manager.get_slot(ospec.arg.name)
585+
placeholder_name = target_to_placeholder.get(ospec.target)
586+
if result_slot is None or placeholder_name is None:
587+
continue
588+
589+
buffer_slot = self.slot_manager.get_slot(placeholder_name)
590+
if buffer_slot is None or buffer_slot.id_space != IdSpace.MutableBuffer:
591+
continue
592+
593+
self.emit(
594+
IdCopyNode(
595+
x=self.slot_to_tid(result_slot),
596+
out=self.slot_to_tid(buffer_slot),
597+
)
598+
)
599+
560600
def build(self) -> MLXGraph:
561601
if self._mlx_graph is not None:
562602
return self._mlx_graph
563603

564604
self._process_nodes()
605+
self._emit_buffer_mutation_writebacks()
565606
self._verify_build()
566607
self._mlx_graph = self._build_mlx_graph()
567608
return self._mlx_graph

backends/mlx/patterns.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from executorch.backends.mlx.serialization.mlx_graph_schema import (
3737
AddIntNode,
3838
DequantizeNode,
39-
IdCopyNode,
4039
IndexCopyNode,
4140
IntOrVid,
4241
IntOrVidOrTid,
@@ -125,19 +124,8 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot:
125124
)
126125
)
127126

128-
# index_copy returns the updated dst (same buffer, in-place update)
129-
existing_slot = P.slot_manager.get_slot(n)
130-
if existing_slot is not None and existing_slot != dst:
131-
P.emit(
132-
IdCopyNode(
133-
x=P.slot_to_tid(dst),
134-
out=P.slot_to_tid(existing_slot),
135-
)
136-
)
137-
return existing_slot
138-
else:
139-
P.set_slot(n, dst)
140-
return dst
127+
P.set_slot(n, dst)
128+
return dst
141129

142130

143131
@REGISTRY.register_pattern(name="ET_KV_CACHE_UPDATE")
@@ -238,21 +226,13 @@ def __call__(self, P: "MLXProgramBuilder", n: Node) -> Slot:
238226
[self.cache, self.update, self.start_pos]
239227
)
240228

241-
out_slot = P.make_or_get_slot(n)
242-
243229
if self.ring_size > 0:
244230
self._emit_ring_buffer(P, cache_slot, update_slot, start_slot)
245231
else:
246232
self._emit_linear(P, cache_slot, update_slot, start_slot)
247233

248-
P.emit(
249-
IdCopyNode(
250-
x=P.slot_to_tid(cache_slot),
251-
out=P.slot_to_tid(out_slot),
252-
)
253-
)
254-
255-
return out_slot
234+
P.set_slot(n, cache_slot)
235+
return cache_slot
256236

257237
def _emit_linear(self, P: "MLXProgramBuilder", cache_slot, update_slot, start_slot):
258238
"""Emit a single SliceUpdate for linear (non-ring) cache."""

0 commit comments

Comments
 (0)