Skip to content

Commit 3332741

Browse files
committed
up
1 parent c952c5a commit 3332741

6 files changed

Lines changed: 987 additions & 45 deletions

File tree

backends/apple/mlx/ops.py

Lines changed: 154 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
TakeAlongAxisNode,
119119
TanhNode,
120120
TanNode,
121+
TidOrVid,
121122
TileNode,
122123
TransposeNode,
123124
TrilNode,
@@ -234,6 +235,44 @@ def require_kwargs(
234235
raise ValueError(f"{op_name}: unexpected kwargs: {unexpected}")
235236

236237

238+
def require_contiguous_format(
239+
*,
240+
layout=None,
241+
memory_format=None,
242+
dim_order=None,
243+
op_name: str,
244+
) -> None:
245+
"""
246+
Validate that layout/memory_format/dim_order specify contiguous format.
247+
248+
MLX only supports contiguous (strided) tensors. Raises ValueError if
249+
sparse layouts or non-contiguous memory formats are requested.
250+
251+
Args:
252+
layout: The torch layout (e.g., torch.strided, torch.sparse_coo)
253+
memory_format: The torch memory format (e.g., torch.contiguous_format,
254+
torch.channels_last)
255+
dim_order: The dimension order (list of ints, identity = contiguous)
256+
op_name: Name of the operation (for error message)
257+
"""
258+
if layout is not None and layout != torch.strided:
259+
raise ValueError(f"{op_name}: only strided layout supported, got {layout}")
260+
261+
if memory_format is not None and memory_format not in (
262+
torch.contiguous_format,
263+
torch.preserve_format,
264+
):
265+
raise ValueError(
266+
f"{op_name}: only contiguous memory format supported, got {memory_format}"
267+
)
268+
269+
if dim_order is not None:
270+
if list(dim_order) != list(range(len(dim_order))):
271+
raise ValueError(
272+
f"{op_name}: only contiguous dim_order supported, got {dim_order}"
273+
)
274+
275+
237276
def is_static_value(value: Any) -> bool:
238277
"""
239278
Check if a value is static (not a Slot/SymInt).
@@ -420,7 +459,9 @@ def _emit_update_cache(
420459

421460
# Import custom ops to register llama.update_cache
422461
try:
423-
from executorch.extension.llm.custom_ops import custom_ops as _llama_ops # noqa: F401
462+
from executorch.extension.llm.custom_ops import ( # noqa: F401
463+
custom_ops as _llama_ops,
464+
)
424465
except ImportError:
425466
pass # Custom ops not available
426467

@@ -584,11 +625,22 @@ def _view_handler(P: MLXProgramBuilder, n: Node) -> Slot:
584625
return out
585626

586627

587-
@REGISTRY.register(target=[torch.ops.aten.clone.default, torch.ops.aten.alias.default])
628+
@REGISTRY.register(
629+
target=[
630+
torch.ops.aten.clone.default,
631+
torch.ops.aten.alias.default,
632+
torch.ops.aten.alias_copy.default,
633+
]
634+
)
588635
def _clone_handler(P: MLXProgramBuilder, n: Node) -> Slot:
589636
args = P.args(n)
637+
kwargs = P.kwargs(n)
590638
require_args(args, 1, 1, "aten.clone")
591-
require_kwargs(P.kwargs(n), set(), "aten.clone")
639+
require_kwargs(kwargs, {"memory_format"}, "aten.clone")
640+
require_contiguous_format(
641+
memory_format=kwargs.get("memory_format"),
642+
op_name="aten.clone",
643+
)
592644
(x,) = args
593645
out = P.make_or_get_slot(n)
594646
P.emit(
@@ -612,9 +664,14 @@ def _dim_order_clone_handler(P: MLXProgramBuilder, n: Node) -> Slot:
612664
# dim_order_ops._clone_dim_order(Tensor self, *, bool non_blocking=False, int[]? dim_order=None) -> Tensor
613665
# This is essentially a contiguous/clone operation for memory layout
614666
args = P.args(n)
667+
kwargs = P.kwargs(n)
615668
require_args(args, 1, 1, "dim_order_ops._clone_dim_order")
616669
require_kwargs(
617-
P.kwargs(n), {"non_blocking", "dim_order"}, "dim_order_ops._clone_dim_order"
670+
kwargs, {"non_blocking", "dim_order"}, "dim_order_ops._clone_dim_order"
671+
)
672+
require_contiguous_format(
673+
dim_order=kwargs.get("dim_order"),
674+
op_name="dim_order_ops._clone_dim_order",
618675
)
619676
x = args[0]
620677
out = P.make_or_get_slot(n)
@@ -643,6 +700,11 @@ def _dim_order_copy_handler(P: MLXProgramBuilder, n: Node) -> Slot:
643700
{"dtype", "device", "layout", "non_blocking", "dim_order"},
644701
"dim_order_ops._to_dim_order_copy",
645702
)
703+
require_contiguous_format(
704+
layout=kwargs.get("layout"),
705+
dim_order=kwargs.get("dim_order"),
706+
op_name="dim_order_ops._to_dim_order_copy",
707+
)
646708
x = args[0]
647709
out = P.make_or_get_slot(n)
648710

@@ -681,6 +743,11 @@ def _to_copy_handler(P: MLXProgramBuilder, n: Node) -> Slot:
681743
require_kwargs(
682744
kwargs, {"dtype", "device", "layout", "memory_format"}, "aten._to_copy"
683745
)
746+
require_contiguous_format(
747+
layout=kwargs.get("layout"),
748+
memory_format=kwargs.get("memory_format"),
749+
op_name="aten._to_copy",
750+
)
684751
x = args[0]
685752
out = P.make_or_get_slot(n)
686753

@@ -707,10 +774,15 @@ def _to_copy_handler(P: MLXProgramBuilder, n: Node) -> Slot:
707774

708775
@REGISTRY.register(target=[torch.ops.aten.embedding.default])
709776
def _embedding_handler(P: MLXProgramBuilder, n: Node) -> Slot:
777+
# aten::embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1,
778+
# bool scale_grad_by_freq=False, bool sparse=False) -> Tensor
779+
# padding_idx is only relevant for training (gradient computation)
780+
# scale_grad_by_freq and sparse are also training-only
710781
args = P.args(n)
711-
require_args(args, 2, 2, "aten.embedding")
782+
require_args(args, 2, 3, "aten.embedding")
712783
require_kwargs(P.kwargs(n), set(), "aten.embedding")
713784
w, x = args[0], args[1]
785+
# padding_idx (args[2] if present) is ignored - only affects gradients
714786
out = P.make_or_get_slot(n)
715787
P.emit(
716788
GatherNode(
@@ -1505,6 +1577,10 @@ def _arange_handler(P: MLXProgramBuilder, n: Node) -> Slot:
15051577
kwargs = P.kwargs(n)
15061578
require_args(args, 1, 3, "aten.arange")
15071579
require_kwargs(kwargs, {"dtype", "layout", "device", "pin_memory"}, "aten.arange")
1580+
require_contiguous_format(
1581+
layout=kwargs.get("layout"),
1582+
op_name="aten.arange",
1583+
)
15081584
if len(args) == 1:
15091585
start = 0
15101586
stop = args[0]
@@ -1541,6 +1617,10 @@ def _arange_start_step_handler(P: MLXProgramBuilder, n: Node) -> Slot:
15411617
require_kwargs(
15421618
kwargs, {"dtype", "layout", "device", "pin_memory"}, "aten.arange.start_step"
15431619
)
1620+
require_contiguous_format(
1621+
layout=kwargs.get("layout"),
1622+
op_name="aten.arange.start_step",
1623+
)
15441624
start = args[0]
15451625
stop = args[1]
15461626
step = args[2] if len(args) > 2 else 1
@@ -1586,6 +1666,30 @@ def _rms_norm_handler(P: MLXProgramBuilder, n: Node) -> Slot:
15861666
return out
15871667

15881668

1669+
@REGISTRY.register(target=[torch.ops.aten.rms_norm.default])
1670+
def _aten_rms_norm_handler(P: MLXProgramBuilder, n: Node) -> Slot:
1671+
args = P.args(n)
1672+
require_args(args, 2, 4, "aten.rms_norm")
1673+
require_kwargs(P.kwargs(n), set(), "aten.rms_norm")
1674+
x, normalized_shape = args[0], args[1]
1675+
if len(normalized_shape) > 1:
1676+
raise ValueError(
1677+
"RMSNorm is only supported when normalizing over the last dimension"
1678+
)
1679+
w = args[2] if len(args) > 2 else None
1680+
eps = args[3] if len(args) > 3 else 1e-5
1681+
out = P.make_or_get_slot(n)
1682+
P.emit(
1683+
RMSNormNode(
1684+
x=P.slot_to_tid(x),
1685+
weight=P.slot_to_tid(w) if w else None,
1686+
out=P.slot_to_tid(out),
1687+
eps=eps,
1688+
)
1689+
)
1690+
return out
1691+
1692+
15891693
@REGISTRY.register(target=[torch.ops.mlx.rope.default])
15901694
def _rope_handler(P: MLXProgramBuilder, n: Node) -> Slot:
15911695
args = P.args(n)
@@ -1599,10 +1703,10 @@ def _rope_handler(P: MLXProgramBuilder, n: Node) -> Slot:
15991703
out = P.make_or_get_slot(n)
16001704

16011705
# pos must be a Slot (SymInt) from input_pos.item() during tracing
1602-
# The schema only supports Vid for pos, not literal int
1706+
# The schema supports both Vid (scalar) and Tid (tensor) for offset
16031707
if not isinstance(pos, Slot):
16041708
raise ValueError(
1605-
f"RopeNode.pos must be a SymInt (traced via tensor.item()), got {type(pos)}. "
1709+
f"RopeNode.offset must be a SymInt (traced via tensor.item()), got {type(pos)}. "
16061710
"Make sure input_pos is a tensor and you call input_pos.item() to get a SymInt."
16071711
)
16081712

@@ -1611,7 +1715,7 @@ def _rope_handler(P: MLXProgramBuilder, n: Node) -> Slot:
16111715
x=P.slot_to_tid(x),
16121716
out=P.slot_to_tid(out),
16131717
head_dim=head_dim,
1614-
pos=P.slot_to_vid(pos),
1718+
offset=TidOrVid.from_vid(P.slot_to_vid(pos)),
16151719
freqs=P.slot_to_tid(freqs) if freqs else None,
16161720
traditional=traditional,
16171721
base=base,
@@ -2353,8 +2457,11 @@ def _full_handler(P: MLXProgramBuilder, n: Node) -> Slot:
23532457
# Use P.args to properly convert Nodes to Slots for dynamic shapes
23542458
args = P.args(n)
23552459
require_args(args, 2, 2, "aten.full")
2356-
require_kwargs(
2357-
P.kwargs(n), {"dtype", "layout", "device", "pin_memory"}, "aten.full"
2460+
kwargs = P.kwargs(n)
2461+
require_kwargs(kwargs, {"dtype", "layout", "device", "pin_memory"}, "aten.full")
2462+
require_contiguous_format(
2463+
layout=kwargs.get("layout"),
2464+
op_name="aten.full",
23582465
)
23592466
out = P.make_or_get_slot(n)
23602467
shape = args[0]
@@ -2384,8 +2491,11 @@ def _zeros_handler(P: MLXProgramBuilder, n: Node) -> Slot:
23842491
"""Handle aten.zeros - create tensor filled with zeros."""
23852492
args = P.args(n)
23862493
require_args(args, 1, 1, "aten.zeros")
2387-
require_kwargs(
2388-
P.kwargs(n), {"dtype", "layout", "device", "pin_memory"}, "aten.zeros"
2494+
kwargs = P.kwargs(n)
2495+
require_kwargs(kwargs, {"dtype", "layout", "device", "pin_memory"}, "aten.zeros")
2496+
require_contiguous_format(
2497+
layout=kwargs.get("layout"),
2498+
op_name="aten.zeros",
23892499
)
23902500
out = P.make_or_get_slot(n)
23912501

@@ -2416,8 +2526,11 @@ def _ones_handler(P: MLXProgramBuilder, n: Node) -> Slot:
24162526
"""Handle aten.ones - create tensor filled with ones."""
24172527
args = P.args(n)
24182528
require_args(args, 1, 1, "aten.ones")
2419-
require_kwargs(
2420-
P.kwargs(n), {"dtype", "layout", "device", "pin_memory"}, "aten.ones"
2529+
kwargs = P.kwargs(n)
2530+
require_kwargs(kwargs, {"dtype", "layout", "device", "pin_memory"}, "aten.ones")
2531+
require_contiguous_format(
2532+
layout=kwargs.get("layout"),
2533+
op_name="aten.ones",
24212534
)
24222535
out = P.make_or_get_slot(n)
24232536

@@ -2447,12 +2560,18 @@ def _ones_handler(P: MLXProgramBuilder, n: Node) -> Slot:
24472560
def _zeros_like_handler(P: MLXProgramBuilder, n: Node) -> Slot:
24482561
"""Handle aten.zeros_like - create zero-filled tensor with same shape as input."""
24492562
args = P.args(n)
2563+
kwargs = P.kwargs(n)
24502564
require_args(args, 1, 1, "aten.zeros_like")
24512565
require_kwargs(
2452-
P.kwargs(n),
2566+
kwargs,
24532567
{"dtype", "layout", "device", "pin_memory", "memory_format"},
24542568
"aten.zeros_like",
24552569
)
2570+
require_contiguous_format(
2571+
layout=kwargs.get("layout"),
2572+
memory_format=kwargs.get("memory_format"),
2573+
op_name="aten.zeros_like",
2574+
)
24562575
x = args[0]
24572576
out = P.make_or_get_slot(n)
24582577

@@ -2475,12 +2594,18 @@ def _zeros_like_handler(P: MLXProgramBuilder, n: Node) -> Slot:
24752594
def _ones_like_handler(P: MLXProgramBuilder, n: Node) -> Slot:
24762595
"""Handle aten.ones_like - create one-filled tensor with same shape as input."""
24772596
args = P.args(n)
2597+
kwargs = P.kwargs(n)
24782598
require_args(args, 1, 1, "aten.ones_like")
24792599
require_kwargs(
2480-
P.kwargs(n),
2600+
kwargs,
24812601
{"dtype", "layout", "device", "pin_memory", "memory_format"},
24822602
"aten.ones_like",
24832603
)
2604+
require_contiguous_format(
2605+
layout=kwargs.get("layout"),
2606+
memory_format=kwargs.get("memory_format"),
2607+
op_name="aten.ones_like",
2608+
)
24842609
x = args[0]
24852610
out = P.make_or_get_slot(n)
24862611

@@ -2503,12 +2628,18 @@ def _ones_like_handler(P: MLXProgramBuilder, n: Node) -> Slot:
25032628
def _full_like_handler(P: MLXProgramBuilder, n: Node) -> Slot:
25042629
"""Handle aten.full_like - create tensor filled with value with same shape."""
25052630
args = P.args(n)
2631+
kwargs = P.kwargs(n)
25062632
require_args(args, 2, 2, "aten.full_like")
25072633
require_kwargs(
2508-
P.kwargs(n),
2634+
kwargs,
25092635
{"dtype", "layout", "device", "pin_memory", "memory_format"},
25102636
"aten.full_like",
25112637
)
2638+
require_contiguous_format(
2639+
layout=kwargs.get("layout"),
2640+
memory_format=kwargs.get("memory_format"),
2641+
op_name="aten.full_like",
2642+
)
25122643
x = args[0]
25132644
fill_value = args[1]
25142645
out = P.make_or_get_slot(n)
@@ -2724,9 +2855,14 @@ def _scalar_tensor_handler(P: MLXProgramBuilder, n: Node) -> Slot:
27242855
This is equivalent to torch.full([], scalar, dtype=dtype).
27252856
"""
27262857
args = P.args(n)
2858+
kwargs = P.kwargs(n)
27272859
require_args(args, 1, 1, "aten.scalar_tensor")
27282860
require_kwargs(
2729-
P.kwargs(n), {"dtype", "layout", "device", "pin_memory"}, "aten.scalar_tensor"
2861+
kwargs, {"dtype", "layout", "device", "pin_memory"}, "aten.scalar_tensor"
2862+
)
2863+
require_contiguous_format(
2864+
layout=kwargs.get("layout"),
2865+
op_name="aten.scalar_tensor",
27302866
)
27312867
scalar_value = args[0]
27322868

0 commit comments

Comments
 (0)