Skip to content

Commit 493d9ea

Browse files
committed
up
1 parent 6a2d455 commit 493d9ea

4 files changed

Lines changed: 65 additions & 10 deletions

File tree

backends/mlx/builder/program_builder.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Union
2828

2929
import torch
30-
3130
from executorch.backends.mlx._logging import logger
3231
from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type
3332
from executorch.backends.mlx.builder.op_registry import (
@@ -132,7 +131,9 @@ class MLXProgramBuilder:
132131

133132
def __init__(self, ep: ExportedProgram, named_data_key_prefix: str = ""):
134133
self.ep: ExportedProgram = ep
135-
self._instrs: List[Instruction] = []
134+
self._chains: List[List[Instruction]] = [[]] # chain 0 = main
135+
self._current_chain: int = 0
136+
self.init_chain_idx: int = -1
136137
self.extra_constants: Dict[str, torch.Tensor] = {}
137138
self.slot_manager = SlotManager()
138139
self.node_info: DefaultDict[Node, NodeInfo] = defaultdict(NodeInfo)
@@ -163,7 +164,13 @@ def _prefix_key(self, name: str) -> str:
163164
return name
164165

165166
def emit(self, op: OpNodeUnion) -> None:
166-
self._instrs.append(Instruction(op=op))
167+
self._chains[self._current_chain].append(Instruction(op=op))
168+
169+
def emit_init(self, op: OpNodeUnion) -> None:
170+
if self.init_chain_idx == -1:
171+
self.init_chain_idx = len(self._chains)
172+
self._chains.append([])
173+
self._chains[self.init_chain_idx].append(Instruction(op=op))
167174

168175
def args(self, node: Node) -> Tuple[Any, ...]:
169176
return self.slot_map(node.args)
@@ -934,9 +941,11 @@ def _build_mlx_graph(self) -> MLXGraph:
934941
num_mutable_buffer_tensors=num_tensors[IdSpace.MutableBuffer],
935942
num_temp_tensors=num_temp_tensors,
936943
num_values=num_values_count,
937-
instruction_chains=[InstructionChain(instructions=self._instrs)],
944+
instruction_chains=[
945+
InstructionChain(instructions=chain) for chain in self._chains
946+
],
938947
main_chain_idx=0,
939-
init_chain_idx=-1,
948+
init_chain_idx=self.init_chain_idx,
940949
input_map=input_map,
941950
output_map=output_map,
942951
mutable_buffer_map=mutable_buffer_map,

backends/mlx/runtime/MLXBackend.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,24 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
219219
static_cast<const uint8_t*>(processed->data()), processed->size());
220220

221221
// Validate schema version
222-
if (handle->program.version != "1") {
222+
int schema_version = 1;
223+
if (!handle->program.version.empty()) {
224+
try {
225+
schema_version = std::stoi(handle->program.version);
226+
} catch (...) {
227+
throw std::runtime_error(
228+
"Invalid MLX schema version '" + handle->program.version +
229+
"' (expected integer)");
230+
}
231+
}
232+
constexpr int kMaxSupportedVersion = 1;
233+
if (schema_version > kMaxSupportedVersion) {
223234
throw std::runtime_error(
224-
"Unsupported MLX schema version '" + handle->program.version +
225-
"' (expected '1'). Rebuild the .pte with a matching SDK version.");
235+
"This .pte requires ExecuTorch MLX runtime version " +
236+
std::to_string(schema_version) +
237+
" but this runtime only supports up to version " +
238+
std::to_string(kMaxSupportedVersion) +
239+
". Upgrade ExecuTorch to a newer version.");
226240
}
227241

228242
// Load constants from named_data_map
@@ -251,11 +265,17 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
251265
// SAFETY: The >= 0 check ensures init_chain_idx is non-negative, so the
252266
// static_cast<uint32_t> cannot produce UINT32_MAX from a -1 sentinel.
253267
if (handle->program.init_chain_idx >= 0) {
268+
handle->state.is_init_chain = true;
254269
handle->interpreter.run_chain(
255270
handle->program,
256271
static_cast<uint32_t>(handle->program.init_chain_idx),
257272
handle->state,
258273
handle->stream);
274+
handle->state.is_init_chain = false;
275+
276+
// Evaluate any constants written by the init chain so the first
277+
// execute() doesn't pay the cost of materializing them.
278+
eval(handle->constants.tensors);
259279
}
260280

261281
} catch (const std::exception& e) {

backends/mlx/runtime/MLXExecutor.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,13 @@ struct ConstantData {
9797
return tensors[id.idx];
9898
}
9999

100+
inline void set(Tid id, Tensor t) {
101+
if (id.idx >= tensors.size()) {
102+
throw std::out_of_range("ConstantData::set: id out of range");
103+
}
104+
tensors[id.idx] = std::move(t);
105+
}
106+
100107
inline void add(Tensor t) {
101108
tensors.push_back(std::move(t));
102109
}
@@ -153,6 +160,9 @@ struct ExecutionState {
153160
// Non-constant values (SymInt, etc.)
154161
std::vector<std::optional<Value>> values;
155162

163+
// Init chain flag: when true, set_tensor allows writing to constants
164+
bool is_init_chain{false};
165+
156166
// Logging context
157167
size_t current_op_idx{0};
158168
const char* current_op_name{nullptr};
@@ -478,7 +488,15 @@ struct ExecutionState {
478488
throw std::runtime_error("set_tensor: Program not bound");
479489
}
480490
if (id.idx < program->num_constant_tensors) {
481-
throw std::runtime_error("set_tensor: cannot write to constant tensor");
491+
if (!is_init_chain) {
492+
throw std::runtime_error("set_tensor: cannot write to constant tensor");
493+
}
494+
// Init chain can write over constants
495+
if (!constants) {
496+
throw std::runtime_error("set_tensor: constants not bound");
497+
}
498+
const_cast<ConstantData*>(constants)->set(id, std::move(arr));
499+
return;
482500
}
483501
// Route to mutable buffers or per-execution tensors
484502
if (is_mutable_buffer(id)) {

backends/mlx/test/test_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pathlib import Path
2424
from typing import Dict, List, Optional, Tuple, Union
2525

26+
import executorch.exir as exir
2627
import numpy as np
2728
import torch
2829

@@ -268,6 +269,7 @@ def export_model_to_pte(
268269
output_path: Union[str, Path],
269270
dynamic_shapes: Optional[Dict] = None,
270271
verbose: bool = False,
272+
edge_compile_config: Optional[exir.EdgeCompileConfig] = None,
271273
) -> None:
272274
"""
273275
Export a PyTorch model to a .pte file using the MLX delegate.
@@ -281,7 +283,6 @@ def export_model_to_pte(
281283
Example: {0: {0: Dim("batch", min=1, max=32)}} for dynamic batch on first input.
282284
verbose: Whether to print the exported program for debugging.
283285
"""
284-
import executorch.exir as exir
285286
from executorch.backends.mlx import MLXPartitioner
286287
from executorch.exir.capture._config import ExecutorchBackendConfig
287288
from torch.export import export
@@ -301,9 +302,11 @@ def export_model_to_pte(
301302
print(exported_program)
302303

303304
# Lower to edge and delegate to MLX
305+
compile_config = edge_compile_config or exir.EdgeCompileConfig()
304306
edge_program = exir.to_edge_transform_and_lower(
305307
exported_program,
306308
partitioner=[MLXPartitioner()],
309+
compile_config=compile_config,
307310
)
308311

309312
# Print edge program if verbose
@@ -865,6 +868,10 @@ def get_dynamic_shapes(self) -> Optional[Dict]:
865868
"""Return dynamic shapes specification for torch.export, or None for static shapes."""
866869
return None
867870

871+
def get_edge_compile_config(self) -> Optional[exir.EdgeCompileConfig]:
872+
"""Return EdgeCompileConfig for export, or None for default."""
873+
return None
874+
868875
def get_test_dir(self) -> Path:
869876
"""Get the directory for this test's files."""
870877
test_dir = Path(__file__).parent / "op_tests" / self.name
@@ -924,6 +931,7 @@ def generate_test_files(self, verbose: bool = False) -> Tuple[Path, Path, Path]:
924931
pte_path,
925932
dynamic_shapes=dynamic_shapes,
926933
verbose=verbose,
934+
edge_compile_config=self.get_edge_compile_config(),
927935
)
928936

929937
# Save test inputs

0 commit comments

Comments
 (0)