-
Notifications
You must be signed in to change notification settings - Fork 935
Expand file tree
/
Copy pathops.py
More file actions
4083 lines (3532 loc) · 128 KB
/
ops.py
File metadata and controls
4083 lines (3532 loc) · 128 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
"""
MLX Op Handlers - registered handlers for converting ATen/custom ops to MLX.
This module contains all the op handler functions registered with the MLXOpRegistry.
Each handler converts a specific PyTorch operation to the corresponding MLX graph node.
"""
from __future__ import annotations
import operator
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch
from executorch.backends.mlx.builder.op_helpers import (
emit_lifted_constant,
emit_quantized_biases,
parse_dequant_node,
to_mlx_qparams,
torch_dtype_to_scalar_type,
)
from executorch.backends.mlx.builder.op_registry import REGISTRY
from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder
from executorch.backends.mlx.builder.slot_manager import IdType, Slot
from executorch.backends.mlx.serialization.mlx_graph_schema import (
AbsNode,
AddIntNode,
AddmmNode,
AddNode,
AllNode,
AnyNode,
ARangeNode,
ArccoshNode,
ArccosNode,
ArcsinhNode,
ArcsinNode,
ArctanhNode,
ArctanNode,
ArgmaxNode,
ArgminNode,
ArgPartitionNode,
ArgsortNode,
AsStridedNode,
AsTypeNode,
Atan2Node,
BroadcastToNode,
CeilNode,
ClipNode,
ConcatenateNode,
ContiguousNode,
Conv1DNode,
Conv2DNode,
Conv3DNode,
ConvTranspose1DNode,
ConvTranspose2DNode,
ConvTranspose3DNode,
CoshNode,
CosNode,
CumsumNode,
DequantizeNode,
DivideNode,
EqualNode,
ErfNode,
ExpandDimsNode,
Expm1Node,
ExpNode,
FloatOrVid,
FloorDivideIntNode,
FloorDivideNode,
FloorNode,
FullLikeNode,
FullNode,
GatherNode,
GeluNode,
GreaterEqualNode,
GreaterNode,
IdCopyNode,
IntOrVid,
IntOrVidOrTid,
ItemIntNode,
LayerNormNode,
LessEqualNode,
LessNode,
Log10Node,
Log1pNode,
Log2Node,
LogAddExpNode,
LogicalAndNode,
LogicalNotNode,
LogicalOrNode,
LogNode,
LogSumExpNode,
MaximumNode,
MaxNode,
MeanNode,
MinimumNode,
MinNode,
ModIntNode,
MultiplyIntNode,
MultiplyNode,
NegNode,
NotEqualNode,
PadNode,
PartitionNode,
PowerNode,
ProdNode,
ReciprocalNode,
RemainderNode,
RepeatNode,
ReshapeNode,
RMSNormNode,
RopeNode,
RoundNode,
RsqrtNode,
ScatterAddNode,
SigmoidNode,
SignNode,
SiluNode,
SinhNode,
SinNode,
SliceNode,
SliceUpdateNode,
SoftmaxNode,
SortNode,
SplitNode,
SqrtNode,
SquareNode,
SqueezeNode,
StackNode,
StdNode,
SubtractIntNode,
SubtractNode,
SumNode,
SymSizeNode,
TakeAlongAxisNode,
TakeNode,
TanhNode,
TanNode,
TileNode,
TransposeNode,
TrilNode,
TriuNode,
VarNode,
VidOrTid,
WhereNode,
)
# The coding style is for handlers to register against aten targets
# The corresponding edge ops are automatically registered
# For ops that are not in aten (e.g., dim order ops), directly register on exir_ops
from executorch.exir.dialects._ops import ops as exir_ops
from torch.fx.node import Node
def require_static_int(value: Any, param_name: str, op_name: str) -> None:
"""
Validate that a parameter is a static integer (not a Slot/SymInt).
Raises NotImplementedError if the value is dynamic.
Args:
value: The parameter value to check
param_name: Name of the parameter (for error message)
op_name: Name of the operation (for error message)
"""
if isinstance(value, Slot) or not isinstance(value, int):
raise NotImplementedError(
f"{op_name} with dynamic {param_name} is not supported. "
f"{param_name} requires a static int32 value, but got {value} (type={type(value).__name__})."
)
def require_static_float(value: Any, param_name: str, op_name: str) -> None:
"""
Validate that a parameter is a static float (not a Slot/SymFloat).
Raises NotImplementedError if the value is dynamic.
Args:
value: The parameter value to check
param_name: Name of the parameter (for error message)
op_name: Name of the operation (for error message)
"""
if isinstance(value, Slot) or not isinstance(value, (int, float)):
raise NotImplementedError(
f"{op_name} with dynamic {param_name} is not supported. "
f"{param_name} requires a static float value, but got {value} (type={type(value).__name__})."
)
def require_static_ints(
values: Union[List[Any], Any], param_name: str, op_name: str
) -> None:
"""
Validate that all values in a list are static integers (not Slots/SymInts).
Raises NotImplementedError if any value is dynamic.
Args:
values: List of values to check, or a single value
param_name: Name of the parameter (for error message)
op_name: Name of the operation (for error message)
"""
if not isinstance(values, list):
values = [values]
for v in values:
require_static_int(v, param_name, op_name)
def require_args(
args: List[Any],
min_count: int,
max_count: int,
op_name: str,
) -> None:
"""
Validate that args count is within expected range.
Raises ValueError if the count is outside the expected range.
Args:
args: The handler args list
min_count: Minimum number of args expected
max_count: Maximum number of args expected
op_name: Name of the operation (for error message)
"""
if not (min_count <= len(args) <= max_count):
if min_count == max_count:
raise ValueError(f"{op_name}: expected {min_count} args, got {len(args)}")
raise ValueError(
f"{op_name}: expected {min_count}-{max_count} args, got {len(args)}"
)
def require_kwargs(
kwargs: Dict[str, Any],
allowed: Set[str],
op_name: str,
) -> None:
"""
Validate that only allowed kwargs are present.
Raises ValueError if unexpected kwargs are found.
Args:
kwargs: The handler kwargs dict
allowed: Set of allowed kwarg names
op_name: Name of the operation (for error message)
"""
unexpected = set(kwargs.keys()) - allowed
if unexpected:
raise ValueError(f"{op_name}: unexpected kwargs: {unexpected}")
def require_contiguous_format(
*,
layout=None,
memory_format=None,
dim_order=None,
op_name: str,
) -> None:
"""
Validate that layout/memory_format/dim_order specify contiguous format.
MLX only supports contiguous (strided) tensors. Raises ValueError if
sparse layouts or non-contiguous memory formats are requested.
Args:
layout: The torch layout (e.g., torch.strided, torch.sparse_coo)
memory_format: The torch memory format (e.g., torch.contiguous_format,
torch.channels_last)
dim_order: The dimension order (list of ints, identity = contiguous)
op_name: Name of the operation (for error message)
"""
if layout is not None and layout != torch.strided:
raise ValueError(f"{op_name}: only strided layout supported, got {layout}")
if memory_format is not None and memory_format not in (
torch.contiguous_format,
torch.preserve_format,
):
raise ValueError(
f"{op_name}: only contiguous memory format supported, got {memory_format}"
)
if dim_order is not None:
if list(dim_order) != list(range(len(dim_order))):
raise ValueError(
f"{op_name}: only contiguous dim_order supported, got {dim_order}"
)
def is_static_value(value: Any) -> bool:
"""
Check if a value is static (not a Slot/SymInt).
Returns:
True if the value is a static scalar (int, float, bool), False otherwise
"""
return not isinstance(value, Slot)
def used_getitem_indices(n: Node) -> Set[int]:
"""Return the set of getitem indices actually consumed downstream.
Only includes indices where the getitem node has at least one user.
"""
return {
user.args[1]
for user in n.users
if user.target == operator.getitem and len(user.users) > 0
}
def normalize_reduction_dim(
args: List[Any], start_idx: int = 1
) -> Tuple[Optional[List[int]], bool]:
"""
Normalize dim argument for reduction operations.
Extracts and normalizes the dim argument from handler args, returning a list of axes
and the keepdim flag. Handles both list-based dims (e.g., sum.dim_IntList) and
single int dims (e.g., prod.dim_int).
Args:
args: The handler args list
start_idx: Index where the dim argument starts (default 1, after self)
Returns:
Tuple of (axes, keepdim) where:
- axes: List of dimension indices, or empty list for reduce-all
- keepdim: Boolean keepdim flag (default False)
"""
if len(args) > start_idx and isinstance(args[start_idx], (list, tuple)):
dim = list(args[start_idx])
keepdim = args[start_idx + 1] if len(args) > start_idx + 1 else False
elif len(args) > start_idx and isinstance(args[start_idx], int):
dim = [args[start_idx]]
keepdim = args[start_idx + 1] if len(args) > start_idx + 1 else False
else:
dim = []
keepdim = False
return dim, keepdim
_UNARY_OPS: List[Tuple[Any, Any, str]] = [
# Activations
(torch.ops.aten.silu.default, SiluNode, "aten.silu"),
(torch.ops.aten.sigmoid.default, SigmoidNode, "aten.sigmoid"),
(torch.ops.aten.tanh.default, TanhNode, "aten.tanh"),
# Reciprocal square root
(torch.ops.aten.rsqrt.default, RsqrtNode, "aten.rsqrt"),
# Rounding
(torch.ops.aten.floor.default, FloorNode, "aten.floor"),
(torch.ops.aten.ceil.default, CeilNode, "aten.ceil"),
# Powers / roots
(torch.ops.aten.square.default, SquareNode, "aten.square"),
(torch.ops.aten.exp.default, ExpNode, "aten.exp"),
(torch.ops.aten.sqrt.default, SqrtNode, "aten.sqrt"),
(torch.ops.aten.reciprocal.default, ReciprocalNode, "aten.reciprocal"),
# Trigonometric
(torch.ops.aten.sin.default, SinNode, "aten.sin"),
(torch.ops.aten.cos.default, CosNode, "aten.cos"),
(torch.ops.aten.tan.default, TanNode, "aten.tan"),
(torch.ops.aten.asin.default, ArcsinNode, "aten.asin"),
(torch.ops.aten.acos.default, ArccosNode, "aten.acos"),
(torch.ops.aten.atan.default, ArctanNode, "aten.atan"),
# Hyperbolic
(torch.ops.aten.sinh.default, SinhNode, "aten.sinh"),
(torch.ops.aten.cosh.default, CoshNode, "aten.cosh"),
(torch.ops.aten.asinh.default, ArcsinhNode, "aten.asinh"),
(torch.ops.aten.acosh.default, ArccoshNode, "aten.acosh"),
(torch.ops.aten.atanh.default, ArctanhNode, "aten.atanh"),
# Logarithmic
(torch.ops.aten.log.default, LogNode, "aten.log"),
(torch.ops.aten.log2.default, Log2Node, "aten.log2"),
(torch.ops.aten.log10.default, Log10Node, "aten.log10"),
(torch.ops.aten.log1p.default, Log1pNode, "aten.log1p"),
# Special
(torch.ops.aten.erf.default, ErfNode, "aten.erf"),
(torch.ops.aten.expm1.default, Expm1Node, "aten.expm1"),
# Sign / magnitude
(torch.ops.aten.abs.default, AbsNode, "aten.abs"),
(torch.ops.aten.neg.default, NegNode, "aten.neg"),
(torch.ops.aten.sign.default, SignNode, "aten.sign"),
# Logical
(torch.ops.aten.logical_not.default, LogicalNotNode, "aten.logical_not"),
]
def _make_unary_handler(node_cls: Any, op_name: str):
"""Create a handler for a simple unary op: x → node_cls(x, out)."""
def handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 1, 1, op_name)
require_kwargs(P.kwargs(n), set(), op_name)
x = args[0]
out = P.make_or_get_slot(n)
P.emit(node_cls(x=P.slot_to_tid(x), out=P.slot_to_tid(out)))
return out
handler.__name__ = f"_{op_name.replace('.', '_')}_handler"
handler.__doc__ = f"Handle {op_name} (table-driven unary op)."
return handler
for _target, _node_cls, _op_name in _UNARY_OPS:
REGISTRY.register(target=[_target])(_make_unary_handler(_node_cls, _op_name))
_BINARY_OPS: List[Tuple[List[Any], Any, str, bool]] = [
(
[torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar],
MultiplyNode,
"aten.mul",
True,
),
(
[torch.ops.aten.div.Tensor, torch.ops.aten.div.Scalar],
DivideNode,
"aten.div",
True,
),
(
[torch.ops.aten.remainder.Tensor, torch.ops.aten.remainder.Scalar],
RemainderNode,
"aten.remainder",
True,
),
(
[torch.ops.aten.pow.Tensor_Tensor, torch.ops.aten.pow.Tensor_Scalar],
PowerNode,
"aten.pow",
True,
),
(
[torch.ops.aten.floor_divide.default],
FloorDivideNode,
"aten.floor_divide",
False,
),
([torch.ops.aten.maximum.default], MaximumNode, "aten.maximum", False),
([torch.ops.aten.minimum.default], MinimumNode, "aten.minimum", False),
([torch.ops.aten.atan2.default], Atan2Node, "aten.atan2", False),
([torch.ops.aten.logaddexp.default], LogAddExpNode, "aten.logaddexp", False),
([torch.ops.aten.logical_or.default], LogicalOrNode, "aten.logical_or", False),
(
[torch.ops.aten.lt.Tensor, torch.ops.aten.lt.Scalar],
LessNode,
"aten.lt",
True,
),
(
[torch.ops.aten.le.Tensor, torch.ops.aten.le.Scalar],
LessEqualNode,
"aten.le",
True,
),
(
[torch.ops.aten.gt.Tensor, torch.ops.aten.gt.Scalar],
GreaterNode,
"aten.gt",
True,
),
(
[torch.ops.aten.ge.Tensor, torch.ops.aten.ge.Scalar],
GreaterEqualNode,
"aten.ge",
True,
),
(
[torch.ops.aten.eq.Tensor, torch.ops.aten.eq.Scalar],
EqualNode,
"aten.eq",
True,
),
(
[torch.ops.aten.ne.Tensor, torch.ops.aten.ne.Scalar],
NotEqualNode,
"aten.ne",
True,
),
]
def _make_binary_handler(node_cls: Any, op_name: str, lift_b: bool):
"""Create a handler for a binary op: (a, b) -> node_cls(a, b, out).
When lift_b is True, scalar b values are lifted to 0-D constant tensors
via emit_lifted_constant, using a's dtype.
"""
def handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 2, 2, op_name)
require_kwargs(P.kwargs(n), set(), op_name)
a, b = args[0], args[1]
if lift_b and (not isinstance(b, Slot) or b.id_type != IdType.Tensor):
input_meta = n.args[0].meta.get("val")
dtype = input_meta.dtype if input_meta is not None else torch.float32
b = emit_lifted_constant(P, b, dtype)
out = P.make_or_get_slot(n)
P.emit(node_cls(a=P.slot_to_tid(a), b=P.slot_to_tid(b), out=P.slot_to_tid(out)))
return out
handler.__name__ = f"_{op_name.replace('.', '_')}_handler"
handler.__doc__ = f"Handle {op_name} (table-driven binary op)."
return handler
for _targets, _node_cls, _op_name, _lift_b in _BINARY_OPS:
REGISTRY.register(target=_targets)(
_make_binary_handler(_node_cls, _op_name, _lift_b)
)
_SCALAR_INT_OPS: List[Tuple[Any, Any, str]] = [
(operator.add, AddIntNode, "operator.add"),
(operator.sub, SubtractIntNode, "operator.sub"),
(operator.mul, MultiplyIntNode, "operator.mul"),
(operator.floordiv, FloorDivideIntNode, "operator.floordiv"),
(operator.mod, ModIntNode, "operator.mod"),
]
def _make_scalar_int_handler(node_cls: Any, op_name: str):
"""Create a handler for a scalar int op: (a, b) -> node_cls(a, b, out)."""
def handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 2, 2, op_name)
require_kwargs(P.kwargs(n), set(), op_name)
a, b = args
out = P.make_or_get_slot(n)
P.emit(
node_cls(
a=P.to_int_or_vid(a),
b=P.to_int_or_vid(b),
out=P.slot_to_vid(out),
)
)
return out
handler.__name__ = f"_{op_name.replace('.', '_')}_handler"
handler.__doc__ = f"Handle {op_name} (table-driven scalar int op)."
return handler
for _target, _node_cls, _op_name in _SCALAR_INT_OPS:
REGISTRY.register(target=[_target])(_make_scalar_int_handler(_node_cls, _op_name))
_REDUCTION_OPS: List[Tuple[List[Any], Any, str, int]] = [
(
[torch.ops.aten.sum.dim_IntList, torch.ops.aten.sum.default],
SumNode,
"aten.sum",
4,
),
([torch.ops.aten.mean.dim, torch.ops.aten.mean.default], MeanNode, "aten.mean", 4),
(
[torch.ops.aten.prod.dim_int, torch.ops.aten.prod.default],
ProdNode,
"aten.prod",
4,
),
([torch.ops.aten.amax.default], MaxNode, "aten.amax", 3),
([torch.ops.aten.amin.default], MinNode, "aten.amin", 3),
([torch.ops.aten.any.dim, torch.ops.aten.any.default], AnyNode, "aten.any", 3),
([torch.ops.aten.all.dim, torch.ops.aten.all.default], AllNode, "aten.all", 3),
]
def _make_reduction_handler(node_cls: Any, op_name: str, max_args: int):
"""Create a handler for a reduction op: x -> node_cls(x, out, axes, keepdims)."""
def handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 1, max_args, op_name)
require_kwargs(P.kwargs(n), set(), op_name)
x = args[0]
axes, keepdim = normalize_reduction_dim(args)
out = P.make_or_get_slot(n)
P.emit(
node_cls(
x=P.slot_to_tid(x), out=P.slot_to_tid(out), axes=axes, keepdims=keepdim
)
)
return out
handler.__name__ = f"_{op_name.replace('.', '_')}_handler"
handler.__doc__ = f"Handle {op_name} (table-driven reduction op)."
return handler
for _targets, _node_cls, _op_name, _max_args in _REDUCTION_OPS:
REGISTRY.register(target=_targets)(
_make_reduction_handler(_node_cls, _op_name, _max_args)
)
_FULL_OPS: List[Tuple[List[Any], str, Optional[float]]] = [
([torch.ops.aten.full.default], "aten.full", None),
([torch.ops.aten.zeros.default], "aten.zeros", 0.0),
([torch.ops.aten.ones.default], "aten.ones", 1.0),
]
def _make_full_handler(op_name: str, fixed_fill: Optional[float]):
"""Create a handler for full/zeros/ones: shape -> FullNode(shape, v, dtype)."""
has_fill_arg = fixed_fill is None
n_args = 2 if has_fill_arg else 1
def handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, n_args, n_args, op_name)
kwargs = P.kwargs(n)
require_kwargs(kwargs, {"dtype", "layout", "device", "pin_memory"}, op_name)
require_contiguous_format(layout=kwargs.get("layout"), op_name=op_name)
shape = args[0]
shape_iovs = [P.to_int_or_vid(d) for d in shape]
v = (
P.to_float_or_vid(args[1])
if has_fill_arg
else FloatOrVid.from_literal(fixed_fill)
)
dtype = n.kwargs.get("dtype")
if dtype is None:
dtype = torch.float32
out = P.make_or_get_slot(n)
P.emit(
FullNode(
out=P.slot_to_tid(out),
shape=shape_iovs,
v=v,
scalar_type=torch_dtype_to_scalar_type(dtype),
)
)
return out
handler.__name__ = f"_{op_name.replace('.', '_')}_handler"
handler.__doc__ = f"Handle {op_name} (table-driven full op)."
return handler
for _targets, _op_name, _fixed_fill in _FULL_OPS:
REGISTRY.register(target=_targets)(_make_full_handler(_op_name, _fixed_fill))
_FULL_LIKE_OPS: List[Tuple[List[Any], str, Optional[float]]] = [
([torch.ops.aten.full_like.default], "aten.full_like", None),
([torch.ops.aten.zeros_like.default], "aten.zeros_like", 0.0),
([torch.ops.aten.ones_like.default], "aten.ones_like", 1.0),
]
def _make_full_like_handler(op_name: str, fixed_fill: Optional[float]):
"""Create a handler for full_like/zeros_like/ones_like: x -> FullLikeNode(x, v, dtype)."""
has_fill_arg = fixed_fill is None
n_args = 2 if has_fill_arg else 1
def handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, n_args, n_args, op_name)
kwargs = P.kwargs(n)
require_kwargs(
kwargs,
{"dtype", "layout", "device", "pin_memory", "memory_format"},
op_name,
)
require_contiguous_format(
layout=kwargs.get("layout"),
memory_format=kwargs.get("memory_format"),
op_name=op_name,
)
x = args[0]
v = (
P.to_float_or_vid(args[1])
if has_fill_arg
else FloatOrVid.from_literal(fixed_fill)
)
dtype = n.kwargs.get("dtype")
out = P.make_or_get_slot(n)
P.emit(
FullLikeNode(
x=P.slot_to_tid(x),
out=P.slot_to_tid(out),
v=v,
scalar_type=(
torch_dtype_to_scalar_type(dtype) if dtype is not None else None
),
)
)
return out
handler.__name__ = f"_{op_name.replace('.', '_')}_handler"
handler.__doc__ = f"Handle {op_name} (table-driven full_like op)."
return handler
for _targets, _op_name, _fixed_fill in _FULL_LIKE_OPS:
REGISTRY.register(target=_targets)(_make_full_like_handler(_op_name, _fixed_fill))
@REGISTRY.register(target=["NOOP", torch.ops.aten._assert_scalar.default])
def _noop_handler(P: MLXProgramBuilder, n: Node) -> None:
"""No-op handler for nodes that don't emit any MLX instructions."""
return None
# Handler for auto_functionalized_v2 higher-order op
# This handles mutating ops that have been functionalized
@REGISTRY.register(target=[torch.ops.higher_order.auto_functionalized_v2])
def _auto_functionalized_v2_handler(P: MLXProgramBuilder, n: Node):
"""
Handler for auto_functionalized_v2 higher-order op.
auto_functionalized_v2 wraps mutating ops after functionalization.
It returns a tuple of (token, mutated_values...).
This handler emits the actual lowering instructions and returns a tuple
of slots that getitem can index into.
"""
if len(n.args) < 1:
raise ValueError(
f"auto_functionalized_v2 requires at least 1 arg, got {len(n.args)}"
)
wrapped_op = n.args[0]
# Unknown wrapped op - not supported
raise NotImplementedError(
f"auto_functionalized_v2 wrapping '{wrapped_op}' is not supported."
)
@REGISTRY.register(target=[torch.ops.aten.linear.default])
def _linear_handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 2, 3, "aten.linear")
require_kwargs(P.kwargs(n), set(), "aten.linear")
x, w = args[0], args[1]
b = args[2] if len(args) > 2 else None
out = P.make_or_get_slot(n)
# Transpose weight: linear(x, w) = x @ w.T
_, w_t = P.make_tmp_slot()
P.emit(
TransposeNode(
x=P.slot_to_tid(w),
out=P.slot_to_tid(w_t),
perm=[1, 0],
)
)
P.emit(
AddmmNode(
mat1=P.slot_to_tid(x),
mat2=P.slot_to_tid(w_t),
out=P.slot_to_tid(out),
bias=P.slot_to_tid(b) if b else None,
)
)
return out
@REGISTRY.register(target=[torch.ops.aten.addmm.default])
def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle addmm: self + (mat1 @ mat2).
addmm(self, mat1, mat2, *, beta=1, alpha=1) computes:
beta * self + alpha * (mat1 @ mat2)
This is typically the result of decomposing linear(x, w, b) in Edge IR:
permute(w) -> addmm(b, x, permuted_w)
For the common case where beta=1 and alpha=1, this is equivalent to:
mat1 @ mat2 + self
We use AddmmNode which calls matmul directly (no transposition needed).
"""
args = P.args(n)
kwargs = P.kwargs(n)
require_args(args, 3, 3, "aten.addmm")
require_kwargs(kwargs, {"beta", "alpha"}, "aten.addmm")
bias, mat1, mat2 = args[0], args[1], args[2]
beta = kwargs.get("beta", 1)
alpha = kwargs.get("alpha", 1)
out = P.make_or_get_slot(n)
# Emit AddmmNode with alpha and beta parameters
P.emit(
AddmmNode(
mat1=P.slot_to_tid(mat1),
mat2=P.slot_to_tid(mat2),
out=P.slot_to_tid(out),
bias=P.slot_to_tid(bias),
alpha=float(alpha),
beta=float(beta),
)
)
return out
@REGISTRY.register(
target=[
torch.ops.aten.mm.default,
torch.ops.aten.bmm.default,
torch.ops.aten.matmul.default,
]
)
def _mm_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle mm/bmm/matmul: matrix multiplication without bias.
All three ops compute matrix products with different dimension expectations:
- mm: 2D x 2D
- bmm: 3D x 3D (batched)
- matmul: arbitrary dimensions (NumPy semantics)
MLX's matmul handles all cases, so we emit AddmmNode with bias=None.
"""
args = P.args(n)
require_args(args, 2, 2, "aten.mm/bmm/matmul")
require_kwargs(P.kwargs(n), set(), "aten.mm/bmm/matmul")
mat1, mat2 = args[0], args[1]
out = P.make_or_get_slot(n)
P.emit(
AddmmNode(
mat1=P.slot_to_tid(mat1),
mat2=P.slot_to_tid(mat2),
out=P.slot_to_tid(out),
bias=None,
)
)
return out
@REGISTRY.register(
target=[
torch.ops.aten.view.default,
torch.ops.aten.view_copy.default,
torch.ops.aten.reshape.default,
]
)
def _view_handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 2, 2, "aten.view")
require_kwargs(P.kwargs(n), set(), "aten.view")
x, shape = args
out = P.make_or_get_slot(n)
shape_iovs = [P.to_int_or_vid(s) for s in shape]
P.emit(
ReshapeNode(
x=P.slot_to_tid(x),
out=P.slot_to_tid(out),
shape=shape_iovs,
)
)
return out
@REGISTRY.register(
target=[
torch.ops.aten.clone.default,
torch.ops.aten.alias.default,
torch.ops.aten.alias_copy.default,
]
)
def _clone_handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
kwargs = P.kwargs(n)
require_args(args, 1, 1, "aten.clone")
require_kwargs(kwargs, {"memory_format"}, "aten.clone")
require_contiguous_format(
memory_format=kwargs.get("memory_format"),
op_name="aten.clone",
)
(x,) = args
out = P.make_or_get_slot(n)
P.emit(
ContiguousNode(
x=P.slot_to_tid(x),
out=P.slot_to_tid(out),
)
)
return out
@REGISTRY.register(target=[torch.ops.aten.copy.default])
def _copy_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle aten.copy - copy data from src to self.
Schema: aten::copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor
In functionalized Edge IR, this returns a copy of src (args[1]).
"""
args = P.args(n)
require_args(args, 2, 2, "aten.copy")
require_kwargs(P.kwargs(n), {"non_blocking"}, "aten.copy")
src = args[1]
out = P.make_or_get_slot(n)
P.emit(
ContiguousNode(
x=P.slot_to_tid(src),
out=P.slot_to_tid(out),
)
)
return out
@REGISTRY.register(target=[exir_ops.edge.dim_order_ops._clone_dim_order.default])
def _dim_order_clone_handler(P: MLXProgramBuilder, n: Node) -> Slot:
# dim_order_ops._clone_dim_order(Tensor self, *, bool non_blocking=False, int[]? dim_order=None) -> Tensor
# This is essentially a contiguous/clone operation for memory layout
args = P.args(n)
kwargs = P.kwargs(n)
require_args(args, 1, 1, "dim_order_ops._clone_dim_order")
require_kwargs(
kwargs, {"non_blocking", "dim_order"}, "dim_order_ops._clone_dim_order"
)
require_contiguous_format(
dim_order=kwargs.get("dim_order"),
op_name="dim_order_ops._clone_dim_order",
)
x = args[0]
out = P.make_or_get_slot(n)
P.emit(
ContiguousNode(
x=P.slot_to_tid(x),
out=P.slot_to_tid(out),
)
)
return out
# Handle Edge IR's dim_order_ops._to_dim_order_copy (dtype conversion)
# This is what x.to(dtype) becomes after to_edge() transformation
@REGISTRY.register(target=[exir_ops.edge.dim_order_ops._to_dim_order_copy.default])
def _dim_order_copy_handler(P: MLXProgramBuilder, n: Node) -> Slot:
# dim_order_ops._to_dim_order_copy(Tensor self, *, ScalarType? dtype=None, ...)
# If dtype is specified, this is a dtype conversion (use AsTypeNode)
# If dtype is None/same, this is just a memory layout copy (use ContiguousNode)
args = P.args(n)
kwargs = P.kwargs(n)
require_args(args, 1, 1, "dim_order_ops._to_dim_order_copy")
require_kwargs(
kwargs,
{"dtype", "device", "layout", "non_blocking", "dim_order"},
"dim_order_ops._to_dim_order_copy",
)
require_contiguous_format(
layout=kwargs.get("layout"),
dim_order=kwargs.get("dim_order"),
op_name="dim_order_ops._to_dim_order_copy",
)
x = args[0]
out = P.make_or_get_slot(n)
dtype = kwargs.get("dtype")
if dtype is not None:
# Dtype conversion
P.emit(
AsTypeNode(
x=P.slot_to_tid(x),
out=P.slot_to_tid(out),
scalar_type=torch_dtype_to_scalar_type(dtype),
)
)
else:
# No dtype change, just memory layout (contiguous)
P.emit(
ContiguousNode(
x=P.slot_to_tid(x),
out=P.slot_to_tid(out),
)
)
return out
@REGISTRY.register(target=[torch.ops.aten._to_copy.default])