-
Notifications
You must be signed in to change notification settings - Fork 935
Expand file tree
/
Copy pathpreprocess.py
More file actions
168 lines (144 loc) · 6.49 KB
/
preprocess.py
File metadata and controls
168 lines (144 loc) · 6.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
"""
MLX Backend preprocessing - converts EdgeIR to MLX delegate payload.
This module implements the BackendDetails.preprocess() method which:
1. Takes an ExportedProgram (edge dialect)
2. Builds an MLXGraph using MLXProgramBuilder
3. Serializes to FlatBuffer (no embedded constants - those come via named_data_map)
4. Returns PreprocessResult with the binary and data_store_output for constants
"""
from __future__ import annotations
import hashlib
from typing import ClassVar, final, List
from executorch.backends.mlx._logging import logger
from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder
from executorch.backends.mlx.serialization.mlx_graph_serialize import (
HEADER_LENGTH,
MAGIC,
serialize_mlx_graph,
)
from executorch.exir.backend.backend_details import (
BackendDetails,
CompileSpec,
PreprocessResult,
)
from torch.export.exported_program import ExportedProgram
@final
class MLXBackend(BackendDetails):
"""
ExecuTorch backend for MLX (Apple Silicon GPU compute framework).
This backend compiles EdgeIR programs to a custom bytecode format
that can be executed by the MLX C++ runtime.
Constants (weights) are stored in ExecuTorch's named_data_map rather than
embedded in the delegate payload. This allows ExecuTorch to own the constant
data and provide it to the backend at runtime.
"""
MAGIC_IX: ClassVar[slice] = slice(4, 8)
DATA_SEGMENT_OFFSET_IX: ClassVar[slice] = slice(8, 16)
DATA_SEGMENT_SIZE_IX: ClassVar[slice] = slice(16, 24)
EXPECTED_MAGIC: ClassVar[bytes] = MAGIC
EXPECTED_LENGTH: ClassVar[int] = HEADER_LENGTH
@staticmethod
def preprocess(
edge_program: ExportedProgram,
compile_specs: List[CompileSpec],
) -> PreprocessResult:
"""
Convert an ExportedProgram to MLX delegate payload.
Args:
edge_program: The ExportedProgram in edge dialect to compile.
compile_specs: List of compilation options.
Returns:
PreprocessResult containing the serialized MLX program and
data_store_output with constant tensor data.
"""
logger.debug("MLXBackend.preprocess() called")
logger.debug(f"Edge program:\n{edge_program}")
# Build MLXGraph from ExportedProgram
# Use a deterministic 4-hex prefix derived from the edge program to
# namespace named_data keys, avoiding collisions in multi-method
# programs where different methods may have lifted tensor constants
# with the same auto-generated name.
prefix = hashlib.sha256(str(edge_program).encode()).hexdigest()[:4]
builder = MLXProgramBuilder(edge_program, named_data_key_prefix=prefix)
mlx_graph = builder.build()
# Get constant data as NamedDataStore (ET will own this data)
named_data_store = builder.get_named_data_store()
logger.debug(f" named_data_store entries: {len(named_data_store.pte_data)}")
_log_mlx_graph(mlx_graph)
# Serialize to bytes (no constant data embedded)
serialized = serialize_mlx_graph(mlx_graph)
logger.debug(f"MLXBackend.preprocess() complete: {len(serialized)} bytes")
return PreprocessResult(
processed_bytes=serialized,
data_store_output=named_data_store.get_named_data_store_output(),
)
def _format_tensor_meta(meta) -> str:
"""Format a TensorMeta for display."""
shape_parts = []
for dim in meta.shape:
if dim.value == -1:
# Dynamic dim
if dim.max_value == -1:
shape_parts.append(f"dyn(min={dim.min_value})")
else:
shape_parts.append(f"dyn({dim.min_value}..{dim.max_value})")
else:
shape_parts.append(str(dim.value))
shape_str = f"[{', '.join(shape_parts)}]"
dtype_str = f"dtype={meta.scalar_type}" if meta.scalar_type is not None else ""
dim_order_str = f"dim_order={meta.dim_order}" if meta.dim_order is not None else ""
parts = [shape_str]
if dtype_str:
parts.append(dtype_str)
if dim_order_str:
parts.append(dim_order_str)
return ", ".join(parts)
def _log_mlx_graph(mlx_graph) -> None: # noqa: C901
"""Log MLXGraph contents at DEBUG level for debugging."""
logger.debug("MLXGraph:")
logger.debug(f" version: {mlx_graph.version}")
logger.debug(f" num_constant_tensors: {mlx_graph.num_constant_tensors}")
logger.debug(f" num_input_tensors: {mlx_graph.num_input_tensors}")
logger.debug(f" num_output_tensors: {mlx_graph.num_output_tensors}")
logger.debug(
f" num_mutable_buffer_tensors: {mlx_graph.num_mutable_buffer_tensors}"
)
logger.debug(f" num_temp_tensors: {mlx_graph.num_temp_tensors}")
logger.debug(f" num_values: {mlx_graph.num_values}")
logger.debug(f" instruction_chains ({len(mlx_graph.instruction_chains)}):")
for c, chain in enumerate(mlx_graph.instruction_chains):
label = ""
if c == mlx_graph.main_chain_idx:
label = " (main)"
elif c == mlx_graph.init_chain_idx:
label = " (init)"
logger.debug(f" chain {c}{label} ({len(chain.instructions)} instructions):")
for i, instr in enumerate(chain.instructions):
logger.debug(f" [{i}]: {type(instr.op).__name__}")
if mlx_graph.input_map:
logger.debug(f" input_map ({len(mlx_graph.input_map)}):")
for i, slot in enumerate(mlx_graph.input_map):
logger.debug(f" [{i}]: {slot}")
if mlx_graph.output_map:
logger.debug(f" output_map ({len(mlx_graph.output_map)}):")
for i, slot in enumerate(mlx_graph.output_map):
logger.debug(f" [{i}]: {slot}")
if mlx_graph.mutable_buffer_map:
logger.debug(f" mutable_buffer_map ({len(mlx_graph.mutable_buffer_map)}):")
for i, slot in enumerate(mlx_graph.mutable_buffer_map):
logger.debug(f" [{i}]: {slot}")
if mlx_graph.named_slots:
logger.debug(f" named_slots ({len(mlx_graph.named_slots)}):")
for ns in mlx_graph.named_slots:
logger.debug(f" {ns.name}: {ns.slot}")
if mlx_graph.tensor_meta:
logger.debug(f" tensor_meta ({len(mlx_graph.tensor_meta)}):")
for i, meta in enumerate(mlx_graph.tensor_meta):
logger.debug(f" t{i}: {_format_tensor_meta(meta)}")