Skip to content

Commit 11cf4ad

Browse files
metascroyjpiat
authored andcommitted
Mlx delegate part2 (pytorch#17828)
This is MLX delegate part2, which focusses on adding ops.
1 parent 163dea0 commit 11cf4ad

15 files changed

Lines changed: 14972 additions & 56 deletions

File tree

.github/workflows/mlx.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ on:
99
paths:
1010
- .github/workflows/mlx.yml
1111
- backends/mlx/**
12+
- extension/llm/export/**
1213
workflow_dispatch:
1314

1415
permissions: {}
@@ -36,7 +37,7 @@ jobs:
3637
${CONDA_RUN} pip list
3738
3839
echo "::group::Build test runners"
39-
${CONDA_RUN} cmake --build cmake-out --target op_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 ))
40+
${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 ))
4041
echo "::endgroup::"
4142
4243
echo "::group::Run op unit tests"
@@ -51,6 +52,14 @@ jobs:
5152
-v
5253
echo "::endgroup::"
5354
55+
echo "::group::Run multi-thread stress test"
56+
${CONDA_RUN} python backends/mlx/test/export_multi_thread_test_model.py /tmp/multi_thread_test_model.pte
57+
ET_TESTING_MODEL_PATH=/tmp/multi_thread_test_model.pte \
58+
ET_TESTING_NUM_THREADS=50 \
59+
ET_PREDICTIONS_PER_THREAD=100 \
60+
./cmake-out/backends/mlx/test/multi_thread_test_runner
61+
echo "::endgroup::"
62+
5463
backend-tester:
5564
strategy:
5665
fail-fast: false

backends/mlx/builder/op_helpers.py

Lines changed: 107 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
if TYPE_CHECKING:
1919
from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder
2020

21+
# When True, always serialize the biases tensor for quantized ops.
22+
# When False, use init-time computation when zero_point is all zeros,
23+
# computing biases = -scales * 2^(bits-1) during the init chain.
24+
QUANTIZED_SERIALIZE_BIASES = False
25+
2126

2227
def get_aten_target(target):
2328
"""
@@ -168,6 +173,50 @@ def emit_lifted_constant(P: "MLXProgramBuilder", value, dtype: torch.dtype) -> S
168173
return slot
169174

170175

176+
def emit_quantized_biases(
177+
P: "MLXProgramBuilder",
178+
zero_point_key: str,
179+
scale: torch.Tensor,
180+
zero_point: torch.Tensor,
181+
bits: int,
182+
B: torch.Tensor,
183+
scale_slot: "Slot",
184+
) -> "Slot":
185+
"""Emit biases for quantized ops, computing at init time when possible.
186+
187+
When zero_point is all zeros and QUANTIZED_SERIALIZE_BIASES is False,
188+
avoids serializing the biases tensor by computing biases = scales * -offset
189+
during the init chain instead.
190+
191+
Returns the biases Slot.
192+
"""
193+
from executorch.backends.mlx.serialization.mlx_graph_schema import MultiplyNode
194+
from torch._subclasses.fake_tensor import FakeTensor
195+
196+
is_scale_only = False
197+
if not isinstance(zero_point, FakeTensor):
198+
if torch.sum(torch.abs(zero_point)).item() == 0:
199+
is_scale_only = True
200+
201+
if QUANTIZED_SERIALIZE_BIASES or not is_scale_only:
202+
return P.make_or_get_constant(f"{zero_point_key}_to_biases", B)
203+
204+
scale_dtype = scale.dtype
205+
offset = 1 << (bits - 1)
206+
neg_offset = emit_lifted_constant(P, -offset, scale_dtype)
207+
biases = P.make_or_get_constant(
208+
f"{zero_point_key}_to_biases_dummy", torch.tensor(0.0, dtype=B.dtype)
209+
)
210+
P.emit_init(
211+
MultiplyNode(
212+
a=P.slot_to_tid(scale_slot),
213+
b=P.slot_to_tid(neg_offset),
214+
out=P.slot_to_tid(biases),
215+
)
216+
)
217+
return biases
218+
219+
171220
def to_mlx_qparams(
172221
qdata: torch.Tensor,
173222
scale: torch.Tensor,
@@ -194,21 +243,36 @@ def to_mlx_qparams(
194243
"""
195244
assert qdata.dtype == torch.int8
196245
offset = 2 ** (bits - 1)
197-
Q = qdata.to(torch.int32) + offset
198246

199247
# Pack data tightly into uint32
200248
assert 32 % bits == 0
201249
vals_per_uint32 = 32 // bits
202250
assert qdata.shape[1] % vals_per_uint32 == 0
203-
204-
Q = Q.reshape(-1, vals_per_uint32)
205-
shifts = torch.arange(0, 32, bits, dtype=torch.int64)
206-
207-
# Convert to int64 for shift/packing
208-
Q = Q.to(torch.int64)
209-
Q = (Q << shifts).sum(dim=-1)
210-
Q = Q.to(torch.uint32)
211-
Q = Q.reshape(qdata.shape[0], -1)
251+
rows, cols = qdata.shape
252+
253+
if bits == 4:
254+
# 4-bit: view(uint8) + wrapping add + pack 2 nibbles per byte → view as uint32
255+
q = qdata.view(torch.uint8) + offset
256+
q3 = q.reshape(rows, cols // 2, 2)
257+
Q = (q3[:, :, 0] | (q3[:, :, 1] << 4)).view(torch.uint32)
258+
elif bits == 2:
259+
# 2-bit: pack 4×2-bit values per byte in uint8, then view as uint32
260+
Q = ((qdata.view(torch.uint8) + offset) & 0x3).reshape(rows, cols // 4, 4)
261+
packed = Q[:, :, 0] | (Q[:, :, 1] << 2) | (Q[:, :, 2] << 4) | (Q[:, :, 3] << 6)
262+
Q = packed.contiguous().view(torch.uint32)
263+
elif bits == 8:
264+
# 8-bit: each byte maps 1:1 to a uint32 slot — no shifting needed
265+
q = qdata.view(torch.uint8) + offset
266+
Q = q.contiguous().view(torch.uint32).reshape(rows, -1)
267+
else:
268+
# General fallback for other bit widths
269+
Q = (qdata.to(torch.int32) + offset).reshape(-1, vals_per_uint32)
270+
shifts = torch.arange(0, 32, bits, dtype=torch.int32)
271+
shifted = Q << shifts
272+
packed = shifted[:, 0]
273+
for i in range(1, vals_per_uint32):
274+
packed = packed | shifted[:, i]
275+
Q = packed.view(torch.uint32).reshape(rows, -1)
212276

213277
if compute_biases:
214278
B = -scale * (zero_point.to(scale.dtype) + offset)
@@ -217,6 +281,34 @@ def to_mlx_qparams(
217281
return Q, None
218282

219283

284+
def parse_dequant_nvfp4_node(
285+
node: Node,
286+
) -> Optional[Tuple[Node, Node, Node, torch.dtype]]:
287+
"""Parse a torchao.dequantize_nvfp4 node.
288+
289+
Returns (qdata, scale, per_tensor_scale, output_dtype) or None if not a
290+
dequantize_nvfp4 node or the custom op is not registered.
291+
"""
292+
target = get_aten_target(node.target)
293+
try:
294+
import executorch.extension.llm.export.nvfp4 # noqa: F401
295+
except ImportError:
296+
return None
297+
298+
if target is not torch.ops.torchao.dequantize_nvfp4.default:
299+
return None
300+
301+
qdata, scale, per_tensor_scale = node.args[0:3]
302+
303+
output_dtype = torch.float32
304+
if len(node.args) > 4:
305+
output_dtype = node.args[4]
306+
elif "output_dtype" in node.kwargs:
307+
output_dtype = node.kwargs["output_dtype"]
308+
309+
return qdata, scale, per_tensor_scale, output_dtype
310+
311+
220312
def parse_dequant_node(
221313
node: Node,
222314
) -> Optional[Tuple[Node, Node, Node, int, int, Optional[torch.dtype], int]]:
@@ -244,11 +336,11 @@ def parse_dequant_node(
244336
quantized_dim, group_size = non_one[0]
245337
if group_size not in [32, 64, 128]:
246338
return None
247-
if qmin == -8 and qmax == 7:
248-
bits = 4
249-
elif qmin == -128 and qmax == 127:
250-
bits = 8
251-
else:
339+
340+
# TODO: MLX supports 3, 5, and 7, but we need to figure out the
341+
# packing story in to_mlx_qparams to use them
342+
bits = (qmax - qmin + 1).bit_length() - 1
343+
if bits not in [2, 4, 8]:
252344
return None
253345
return qdata, scale, zero_point, group_size, bits, out_dtype, quantized_dim
254346

0 commit comments

Comments
 (0)