Skip to content

Commit f8f571f

Browse files
committed
up
1 parent 0c76afa commit f8f571f

22 files changed

Lines changed: 3495 additions & 3109 deletions

backends/mlx/builder/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
#
8+
9+
# Trigger op/pattern handler registration.
10+
# ops.py and patterns.py use @REGISTRY.register() decorators at import time.
11+
# This must happen after REGISTRY is defined (in op_registry.py).
12+
from executorch.backends.mlx import ops, patterns # noqa: F401
13+
from executorch.backends.mlx.builder.op_registry import REGISTRY # noqa: F401
14+
from executorch.backends.mlx.builder.program_builder import ( # noqa: F401
15+
MLXProgramBuilder,
16+
)

backends/mlx/builder/op_helpers.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
#
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
#
8+
9+
from __future__ import annotations
10+
11+
from typing import Dict, Optional, Tuple, TYPE_CHECKING, Union
12+
13+
import torch
14+
from executorch.backends.mlx.builder.slot_manager import Slot
15+
from executorch.exir.scalar_type import ScalarType
16+
from torch.fx.node import Node
17+
18+
if TYPE_CHECKING:
19+
from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder
20+
21+
22+
def get_aten_target(target):
23+
"""
24+
Unwrap EdgeOpOverload to get the underlying ATen op.
25+
26+
In Edge IR, ops are wrapped in EdgeOpOverload. This extracts the
27+
underlying ATen op for consistent comparison.
28+
"""
29+
if hasattr(target, "_op") and "EdgeOpOverload" in type(target).__name__:
30+
return target._op
31+
return target
32+
33+
34+
# Mapping from _copy variants to their non-copy equivalents.
35+
# Edge IR uses _copy variants for certain ops, but for pattern matching
36+
# we want to compare against the semantic operation.
37+
_COPY_TO_NON_COPY = {
38+
torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor,
39+
torch.ops.aten.transpose_copy.int: torch.ops.aten.transpose.int,
40+
torch.ops.aten.view_copy.default: torch.ops.aten.view.default,
41+
torch.ops.aten.permute_copy.default: torch.ops.aten.permute.default,
42+
torch.ops.aten.unsqueeze_copy.default: torch.ops.aten.unsqueeze.default,
43+
torch.ops.aten.squeeze_copy.dim: torch.ops.aten.squeeze.dim,
44+
torch.ops.aten.squeeze_copy.dims: torch.ops.aten.squeeze.dims,
45+
torch.ops.aten.squeeze_copy.default: torch.ops.aten.squeeze.default,
46+
torch.ops.aten.expand_copy.default: torch.ops.aten.expand.default,
47+
torch.ops.aten.alias_copy.default: torch.ops.aten.alias.default,
48+
}
49+
50+
51+
def get_aten_target_normalized(target):
52+
"""
53+
Get ATen target, mapping _copy variants to their non-copy equivalents.
54+
55+
Use this for pattern matching where Edge IR uses _copy variants but
56+
we want to match the semantic operation.
57+
58+
E.g., aten.transpose_copy.int -> aten.transpose.int
59+
"""
60+
target = get_aten_target(target)
61+
return _COPY_TO_NON_COPY.get(target, target)
62+
63+
64+
def emit_stop_position(
65+
P: "MLXProgramBuilder",
66+
start: "Union[int, Slot]",
67+
length_tensor: "Slot",
68+
length_dim: int,
69+
length_meta: "Optional[torch.Tensor]" = None,
70+
) -> "Union[int, Slot]":
71+
"""
72+
Emit nodes to compute stop = start + length for slice operations.
73+
74+
May emit SymSizeNode and/or AddIntNode depending on whether
75+
start and length are static or dynamic.
76+
77+
Args:
78+
P: The program builder
79+
start: Start position (int or Slot)
80+
length_tensor: The tensor slot whose dimension gives the length
81+
length_dim: Which dimension of length_tensor contains the length
82+
length_meta: Optional tensor metadata for static length extraction
83+
84+
Returns:
85+
stop position as int (if fully static) or Slot (if any dynamic)
86+
"""
87+
from executorch.backends.mlx.serialization.mlx_graph_schema import (
88+
AddIntNode,
89+
IntOrVid,
90+
SymSizeNode,
91+
)
92+
93+
# Check if seq_len is symbolic (dynamic)
94+
seq_len_is_symbolic = False
95+
seq_len_concrete = None
96+
97+
if length_meta is not None:
98+
seq_len_dim = length_meta.shape[length_dim]
99+
if hasattr(seq_len_dim, "node"):
100+
seq_len_is_symbolic = True
101+
else:
102+
seq_len_concrete = int(seq_len_dim)
103+
104+
if seq_len_is_symbolic or length_meta is None:
105+
# Dynamic seq_len: emit SymSizeNode to get length at runtime
106+
_, seq_len_slot = P.slot_manager.make_tmp_value_slot()
107+
P.emit(
108+
SymSizeNode(
109+
a=P.slot_to_tid(length_tensor),
110+
dim=length_dim,
111+
out=P.slot_to_vid(seq_len_slot),
112+
)
113+
)
114+
_, stop_slot = P.slot_manager.make_tmp_value_slot()
115+
if isinstance(start, Slot):
116+
start_iov = P.to_int_or_vid(start)
117+
else:
118+
start_iov = IntOrVid.from_literal(int(start))
119+
P.emit(
120+
AddIntNode(
121+
a=start_iov,
122+
b=IntOrVid.from_vid(P.slot_to_vid(seq_len_slot)),
123+
out=P.slot_to_vid(stop_slot),
124+
)
125+
)
126+
return stop_slot
127+
else:
128+
# Static seq_len
129+
if isinstance(start, Slot):
130+
# Dynamic start + static length
131+
_, stop_slot = P.slot_manager.make_tmp_value_slot()
132+
P.emit(
133+
AddIntNode(
134+
a=P.to_int_or_vid(start),
135+
b=IntOrVid.from_literal(seq_len_concrete),
136+
out=P.slot_to_vid(stop_slot),
137+
)
138+
)
139+
return stop_slot
140+
else:
141+
# Both static - just return the sum
142+
return start + seq_len_concrete
143+
144+
145+
def emit_lifted_constant(P: "MLXProgramBuilder", value, dtype: torch.dtype) -> Slot:
146+
"""Lift a scalar to a 0-D tensor.
147+
148+
Concrete scalars (int/float/bool) become deduplicated constants.
149+
Dynamic values (SymInt Slots) emit a FullNode at runtime.
150+
"""
151+
152+
if isinstance(value, (int, float, bool)):
153+
return P.make_or_get_constant(
154+
f"_scalar_{value}", torch.tensor(value, dtype=dtype) # 0-D
155+
)
156+
157+
from executorch.backends.mlx.serialization.mlx_graph_schema import FullNode
158+
159+
_, slot = P.make_tmp_slot()
160+
P.emit(
161+
FullNode(
162+
shape=[],
163+
v=P.to_float_or_vid(value),
164+
scalar_type=torch_dtype_to_scalar_type(dtype),
165+
out=P.slot_to_tid(slot),
166+
)
167+
)
168+
return slot
169+
170+
171+
def to_mlx_qparams(
172+
qdata: torch.Tensor,
173+
scale: torch.Tensor,
174+
zero_point: torch.Tensor,
175+
bits: int,
176+
compute_biases: bool = True,
177+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
178+
"""
179+
Convert TorchAO quantization params to MLX format.
180+
181+
TorchAO uses: s * (q - z), with q signed
182+
MLX uses: S * Q + B, with Q unsigned
183+
184+
s * (q - z)
185+
= s ((q + offset) - (z + offset))
186+
= s Q + B,
187+
where Q = q + offset, B = -s * (z + offset)
188+
189+
Args:
190+
compute_biases: If False, skip bias computation (for scale_only mode).
191+
Returns (Q, None) in this case. This is valid when
192+
zero_point is all zeros, as the C++ runtime will compute
193+
biases = -scales * 2^(bits-1).
194+
"""
195+
assert qdata.dtype == torch.int8
196+
offset = 2 ** (bits - 1)
197+
Q = qdata.to(torch.int32) + offset
198+
199+
# Pack data tightly into uint32
200+
assert 32 % bits == 0
201+
vals_per_uint32 = 32 // bits
202+
assert qdata.shape[1] % vals_per_uint32 == 0
203+
204+
Q = Q.reshape(-1, vals_per_uint32)
205+
shifts = torch.arange(0, 32, bits, dtype=torch.int64)
206+
207+
# Convert to int64 for shift/packing
208+
Q = Q.to(torch.int64)
209+
Q = (Q << shifts).sum(dim=-1)
210+
Q = Q.to(torch.uint32)
211+
Q = Q.reshape(qdata.shape[0], -1)
212+
213+
if compute_biases:
214+
B = -scale * (zero_point.to(scale.dtype) + offset)
215+
return Q, B
216+
else:
217+
return Q, None
218+
219+
220+
def parse_dequant_node(
221+
node: Node,
222+
) -> Optional[Tuple[Node, Node, Node, int, int, Optional[torch.dtype], int]]:
223+
"""Parse a torchao.dequantize_affine node.
224+
225+
Accepts N-dimensional block_size with a single non-1 element identifying
226+
the quantized dimension and group_size. For example:
227+
- Linear weights (2D): block_size=[1, 32] → quantized_dim=1
228+
- Conv2d weights (4D): block_size=[1, 32, 1, 1] → quantized_dim=1
229+
230+
Returns (qdata, scale, zero_point, group_size, bits, out_dtype, quantized_dim)
231+
or None if unsupported.
232+
"""
233+
qdata, block_size, scale, zero_point, dtype, qmin, qmax = node.args[0:7]
234+
out_dtype = (
235+
node.args[7] if len(node.args) > 7 else node.kwargs.get("output_dtype", None)
236+
)
237+
if dtype != torch.int8:
238+
return None
239+
if len(block_size) < 2:
240+
return None
241+
non_one = [(i, d) for i, d in enumerate(block_size) if d != 1]
242+
if len(non_one) != 1:
243+
return None
244+
quantized_dim, group_size = non_one[0]
245+
if group_size not in [32, 64, 128]:
246+
return None
247+
if qmin == -8 and qmax == 7:
248+
bits = 4
249+
elif qmin == -128 and qmax == 127:
250+
bits = 8
251+
else:
252+
return None
253+
return qdata, scale, zero_point, group_size, bits, out_dtype, quantized_dim
254+
255+
256+
# Mapping from torch dtype to ET ScalarType int value
257+
# See executorch/exir/scalar_type.py for ScalarType enum
258+
_TORCH_DTYPE_TO_SCALAR_TYPE: Dict[torch.dtype, int] = {
259+
torch.float16: ScalarType.HALF,
260+
torch.float32: ScalarType.FLOAT,
261+
torch.bfloat16: ScalarType.BFLOAT16,
262+
torch.int32: ScalarType.INT,
263+
torch.int64: ScalarType.LONG,
264+
torch.uint32: ScalarType.UINT32,
265+
torch.uint8: ScalarType.BYTE,
266+
torch.bool: ScalarType.BOOL,
267+
torch.int8: ScalarType.CHAR,
268+
}
269+
270+
271+
def torch_dtype_to_scalar_type(dtype: torch.dtype) -> int:
272+
"""Convert torch dtype to ET ScalarType int value."""
273+
if dtype not in _TORCH_DTYPE_TO_SCALAR_TYPE:
274+
raise ValueError(f"Unsupported dtype: {dtype}")
275+
return int(_TORCH_DTYPE_TO_SCALAR_TYPE[dtype])

0 commit comments

Comments
 (0)