118118 TakeAlongAxisNode ,
119119 TanhNode ,
120120 TanNode ,
121+ TidOrVid ,
121122 TileNode ,
122123 TransposeNode ,
123124 TrilNode ,
@@ -234,6 +235,44 @@ def require_kwargs(
234235 raise ValueError (f"{ op_name } : unexpected kwargs: { unexpected } " )
235236
236237
238+ def require_contiguous_format (
239+ * ,
240+ layout = None ,
241+ memory_format = None ,
242+ dim_order = None ,
243+ op_name : str ,
244+ ) -> None :
245+ """
246+ Validate that layout/memory_format/dim_order specify contiguous format.
247+
248+ MLX only supports contiguous (strided) tensors. Raises ValueError if
249+ sparse layouts or non-contiguous memory formats are requested.
250+
251+ Args:
252+ layout: The torch layout (e.g., torch.strided, torch.sparse_coo)
253+ memory_format: The torch memory format (e.g., torch.contiguous_format,
254+ torch.channels_last)
255+ dim_order: The dimension order (list of ints, identity = contiguous)
256+ op_name: Name of the operation (for error message)
257+ """
258+ if layout is not None and layout != torch .strided :
259+ raise ValueError (f"{ op_name } : only strided layout supported, got { layout } " )
260+
261+ if memory_format is not None and memory_format not in (
262+ torch .contiguous_format ,
263+ torch .preserve_format ,
264+ ):
265+ raise ValueError (
266+ f"{ op_name } : only contiguous memory format supported, got { memory_format } "
267+ )
268+
269+ if dim_order is not None :
270+ if list (dim_order ) != list (range (len (dim_order ))):
271+ raise ValueError (
272+ f"{ op_name } : only contiguous dim_order supported, got { dim_order } "
273+ )
274+
275+
237276def is_static_value (value : Any ) -> bool :
238277 """
239278 Check if a value is static (not a Slot/SymInt).
@@ -420,7 +459,9 @@ def _emit_update_cache(
420459
421460# Import custom ops to register llama.update_cache
422461try :
423- from executorch .extension .llm .custom_ops import custom_ops as _llama_ops # noqa: F401
462+ from executorch .extension .llm .custom_ops import ( # noqa: F401
463+ custom_ops as _llama_ops ,
464+ )
424465except ImportError :
425466 pass # Custom ops not available
426467
@@ -584,11 +625,22 @@ def _view_handler(P: MLXProgramBuilder, n: Node) -> Slot:
584625 return out
585626
586627
587- @REGISTRY .register (target = [torch .ops .aten .clone .default , torch .ops .aten .alias .default ])
628+ @REGISTRY .register (
629+ target = [
630+ torch .ops .aten .clone .default ,
631+ torch .ops .aten .alias .default ,
632+ torch .ops .aten .alias_copy .default ,
633+ ]
634+ )
588635def _clone_handler (P : MLXProgramBuilder , n : Node ) -> Slot :
589636 args = P .args (n )
637+ kwargs = P .kwargs (n )
590638 require_args (args , 1 , 1 , "aten.clone" )
591- require_kwargs (P .kwargs (n ), set (), "aten.clone" )
639+ require_kwargs (kwargs , {"memory_format" }, "aten.clone" )
640+ require_contiguous_format (
641+ memory_format = kwargs .get ("memory_format" ),
642+ op_name = "aten.clone" ,
643+ )
592644 (x ,) = args
593645 out = P .make_or_get_slot (n )
594646 P .emit (
@@ -612,9 +664,14 @@ def _dim_order_clone_handler(P: MLXProgramBuilder, n: Node) -> Slot:
612664 # dim_order_ops._clone_dim_order(Tensor self, *, bool non_blocking=False, int[]? dim_order=None) -> Tensor
613665 # This is essentially a contiguous/clone operation for memory layout
614666 args = P .args (n )
667+ kwargs = P .kwargs (n )
615668 require_args (args , 1 , 1 , "dim_order_ops._clone_dim_order" )
616669 require_kwargs (
617- P .kwargs (n ), {"non_blocking" , "dim_order" }, "dim_order_ops._clone_dim_order"
670+ kwargs , {"non_blocking" , "dim_order" }, "dim_order_ops._clone_dim_order"
671+ )
672+ require_contiguous_format (
673+ dim_order = kwargs .get ("dim_order" ),
674+ op_name = "dim_order_ops._clone_dim_order" ,
618675 )
619676 x = args [0 ]
620677 out = P .make_or_get_slot (n )
@@ -643,6 +700,11 @@ def _dim_order_copy_handler(P: MLXProgramBuilder, n: Node) -> Slot:
643700 {"dtype" , "device" , "layout" , "non_blocking" , "dim_order" },
644701 "dim_order_ops._to_dim_order_copy" ,
645702 )
703+ require_contiguous_format (
704+ layout = kwargs .get ("layout" ),
705+ dim_order = kwargs .get ("dim_order" ),
706+ op_name = "dim_order_ops._to_dim_order_copy" ,
707+ )
646708 x = args [0 ]
647709 out = P .make_or_get_slot (n )
648710
@@ -681,6 +743,11 @@ def _to_copy_handler(P: MLXProgramBuilder, n: Node) -> Slot:
681743 require_kwargs (
682744 kwargs , {"dtype" , "device" , "layout" , "memory_format" }, "aten._to_copy"
683745 )
746+ require_contiguous_format (
747+ layout = kwargs .get ("layout" ),
748+ memory_format = kwargs .get ("memory_format" ),
749+ op_name = "aten._to_copy" ,
750+ )
684751 x = args [0 ]
685752 out = P .make_or_get_slot (n )
686753
@@ -707,10 +774,15 @@ def _to_copy_handler(P: MLXProgramBuilder, n: Node) -> Slot:
707774
708775@REGISTRY .register (target = [torch .ops .aten .embedding .default ])
709776def _embedding_handler (P : MLXProgramBuilder , n : Node ) -> Slot :
777+ # aten::embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1,
778+ # bool scale_grad_by_freq=False, bool sparse=False) -> Tensor
779+ # padding_idx is only relevant for training (gradient computation)
780+ # scale_grad_by_freq and sparse are also training-only
710781 args = P .args (n )
711- require_args (args , 2 , 2 , "aten.embedding" )
782+ require_args (args , 2 , 3 , "aten.embedding" )
712783 require_kwargs (P .kwargs (n ), set (), "aten.embedding" )
713784 w , x = args [0 ], args [1 ]
785+ # padding_idx (args[2] if present) is ignored - only affects gradients
714786 out = P .make_or_get_slot (n )
715787 P .emit (
716788 GatherNode (
@@ -1505,6 +1577,10 @@ def _arange_handler(P: MLXProgramBuilder, n: Node) -> Slot:
15051577 kwargs = P .kwargs (n )
15061578 require_args (args , 1 , 3 , "aten.arange" )
15071579 require_kwargs (kwargs , {"dtype" , "layout" , "device" , "pin_memory" }, "aten.arange" )
1580+ require_contiguous_format (
1581+ layout = kwargs .get ("layout" ),
1582+ op_name = "aten.arange" ,
1583+ )
15081584 if len (args ) == 1 :
15091585 start = 0
15101586 stop = args [0 ]
@@ -1541,6 +1617,10 @@ def _arange_start_step_handler(P: MLXProgramBuilder, n: Node) -> Slot:
15411617 require_kwargs (
15421618 kwargs , {"dtype" , "layout" , "device" , "pin_memory" }, "aten.arange.start_step"
15431619 )
1620+ require_contiguous_format (
1621+ layout = kwargs .get ("layout" ),
1622+ op_name = "aten.arange.start_step" ,
1623+ )
15441624 start = args [0 ]
15451625 stop = args [1 ]
15461626 step = args [2 ] if len (args ) > 2 else 1
@@ -1586,6 +1666,30 @@ def _rms_norm_handler(P: MLXProgramBuilder, n: Node) -> Slot:
15861666 return out
15871667
15881668
1669+ @REGISTRY .register (target = [torch .ops .aten .rms_norm .default ])
1670+ def _aten_rms_norm_handler (P : MLXProgramBuilder , n : Node ) -> Slot :
1671+ args = P .args (n )
1672+ require_args (args , 2 , 4 , "aten.rms_norm" )
1673+ require_kwargs (P .kwargs (n ), set (), "aten.rms_norm" )
1674+ x , normalized_shape = args [0 ], args [1 ]
1675+ if len (normalized_shape ) > 1 :
1676+ raise ValueError (
1677+ "RMSNorm is only supported when normalizing over the last dimension"
1678+ )
1679+ w = args [2 ] if len (args ) > 2 else None
1680+ eps = args [3 ] if len (args ) > 3 else 1e-5
1681+ out = P .make_or_get_slot (n )
1682+ P .emit (
1683+ RMSNormNode (
1684+ x = P .slot_to_tid (x ),
1685+ weight = P .slot_to_tid (w ) if w else None ,
1686+ out = P .slot_to_tid (out ),
1687+ eps = eps ,
1688+ )
1689+ )
1690+ return out
1691+
1692+
15891693@REGISTRY .register (target = [torch .ops .mlx .rope .default ])
15901694def _rope_handler (P : MLXProgramBuilder , n : Node ) -> Slot :
15911695 args = P .args (n )
@@ -1599,10 +1703,10 @@ def _rope_handler(P: MLXProgramBuilder, n: Node) -> Slot:
15991703 out = P .make_or_get_slot (n )
16001704
16011705 # pos must be a Slot (SymInt) from input_pos.item() during tracing
1602- # The schema only supports Vid for pos, not literal int
1706+ # The schema supports both Vid (scalar) and Tid (tensor) for offset
16031707 if not isinstance (pos , Slot ):
16041708 raise ValueError (
1605- f"RopeNode.pos must be a SymInt (traced via tensor.item()), got { type (pos )} . "
1709+ f"RopeNode.offset must be a SymInt (traced via tensor.item()), got { type (pos )} . "
16061710 "Make sure input_pos is a tensor and you call input_pos.item() to get a SymInt."
16071711 )
16081712
@@ -1611,7 +1715,7 @@ def _rope_handler(P: MLXProgramBuilder, n: Node) -> Slot:
16111715 x = P .slot_to_tid (x ),
16121716 out = P .slot_to_tid (out ),
16131717 head_dim = head_dim ,
1614- pos = P .slot_to_vid (pos ),
1718+ offset = TidOrVid . from_vid ( P .slot_to_vid (pos ) ),
16151719 freqs = P .slot_to_tid (freqs ) if freqs else None ,
16161720 traditional = traditional ,
16171721 base = base ,
@@ -2353,8 +2457,11 @@ def _full_handler(P: MLXProgramBuilder, n: Node) -> Slot:
23532457 # Use P.args to properly convert Nodes to Slots for dynamic shapes
23542458 args = P .args (n )
23552459 require_args (args , 2 , 2 , "aten.full" )
2356- require_kwargs (
2357- P .kwargs (n ), {"dtype" , "layout" , "device" , "pin_memory" }, "aten.full"
2460+ kwargs = P .kwargs (n )
2461+ require_kwargs (kwargs , {"dtype" , "layout" , "device" , "pin_memory" }, "aten.full" )
2462+ require_contiguous_format (
2463+ layout = kwargs .get ("layout" ),
2464+ op_name = "aten.full" ,
23582465 )
23592466 out = P .make_or_get_slot (n )
23602467 shape = args [0 ]
@@ -2384,8 +2491,11 @@ def _zeros_handler(P: MLXProgramBuilder, n: Node) -> Slot:
23842491 """Handle aten.zeros - create tensor filled with zeros."""
23852492 args = P .args (n )
23862493 require_args (args , 1 , 1 , "aten.zeros" )
2387- require_kwargs (
2388- P .kwargs (n ), {"dtype" , "layout" , "device" , "pin_memory" }, "aten.zeros"
2494+ kwargs = P .kwargs (n )
2495+ require_kwargs (kwargs , {"dtype" , "layout" , "device" , "pin_memory" }, "aten.zeros" )
2496+ require_contiguous_format (
2497+ layout = kwargs .get ("layout" ),
2498+ op_name = "aten.zeros" ,
23892499 )
23902500 out = P .make_or_get_slot (n )
23912501
@@ -2416,8 +2526,11 @@ def _ones_handler(P: MLXProgramBuilder, n: Node) -> Slot:
24162526 """Handle aten.ones - create tensor filled with ones."""
24172527 args = P .args (n )
24182528 require_args (args , 1 , 1 , "aten.ones" )
2419- require_kwargs (
2420- P .kwargs (n ), {"dtype" , "layout" , "device" , "pin_memory" }, "aten.ones"
2529+ kwargs = P .kwargs (n )
2530+ require_kwargs (kwargs , {"dtype" , "layout" , "device" , "pin_memory" }, "aten.ones" )
2531+ require_contiguous_format (
2532+ layout = kwargs .get ("layout" ),
2533+ op_name = "aten.ones" ,
24212534 )
24222535 out = P .make_or_get_slot (n )
24232536
@@ -2447,12 +2560,18 @@ def _ones_handler(P: MLXProgramBuilder, n: Node) -> Slot:
24472560def _zeros_like_handler (P : MLXProgramBuilder , n : Node ) -> Slot :
24482561 """Handle aten.zeros_like - create zero-filled tensor with same shape as input."""
24492562 args = P .args (n )
2563+ kwargs = P .kwargs (n )
24502564 require_args (args , 1 , 1 , "aten.zeros_like" )
24512565 require_kwargs (
2452- P . kwargs ( n ) ,
2566+ kwargs ,
24532567 {"dtype" , "layout" , "device" , "pin_memory" , "memory_format" },
24542568 "aten.zeros_like" ,
24552569 )
2570+ require_contiguous_format (
2571+ layout = kwargs .get ("layout" ),
2572+ memory_format = kwargs .get ("memory_format" ),
2573+ op_name = "aten.zeros_like" ,
2574+ )
24562575 x = args [0 ]
24572576 out = P .make_or_get_slot (n )
24582577
@@ -2475,12 +2594,18 @@ def _zeros_like_handler(P: MLXProgramBuilder, n: Node) -> Slot:
24752594def _ones_like_handler (P : MLXProgramBuilder , n : Node ) -> Slot :
24762595 """Handle aten.ones_like - create one-filled tensor with same shape as input."""
24772596 args = P .args (n )
2597+ kwargs = P .kwargs (n )
24782598 require_args (args , 1 , 1 , "aten.ones_like" )
24792599 require_kwargs (
2480- P . kwargs ( n ) ,
2600+ kwargs ,
24812601 {"dtype" , "layout" , "device" , "pin_memory" , "memory_format" },
24822602 "aten.ones_like" ,
24832603 )
2604+ require_contiguous_format (
2605+ layout = kwargs .get ("layout" ),
2606+ memory_format = kwargs .get ("memory_format" ),
2607+ op_name = "aten.ones_like" ,
2608+ )
24842609 x = args [0 ]
24852610 out = P .make_or_get_slot (n )
24862611
@@ -2503,12 +2628,18 @@ def _ones_like_handler(P: MLXProgramBuilder, n: Node) -> Slot:
25032628def _full_like_handler (P : MLXProgramBuilder , n : Node ) -> Slot :
25042629 """Handle aten.full_like - create tensor filled with value with same shape."""
25052630 args = P .args (n )
2631+ kwargs = P .kwargs (n )
25062632 require_args (args , 2 , 2 , "aten.full_like" )
25072633 require_kwargs (
2508- P . kwargs ( n ) ,
2634+ kwargs ,
25092635 {"dtype" , "layout" , "device" , "pin_memory" , "memory_format" },
25102636 "aten.full_like" ,
25112637 )
2638+ require_contiguous_format (
2639+ layout = kwargs .get ("layout" ),
2640+ memory_format = kwargs .get ("memory_format" ),
2641+ op_name = "aten.full_like" ,
2642+ )
25122643 x = args [0 ]
25132644 fill_value = args [1 ]
25142645 out = P .make_or_get_slot (n )
@@ -2724,9 +2855,14 @@ def _scalar_tensor_handler(P: MLXProgramBuilder, n: Node) -> Slot:
27242855 This is equivalent to torch.full([], scalar, dtype=dtype).
27252856 """
27262857 args = P .args (n )
2858+ kwargs = P .kwargs (n )
27272859 require_args (args , 1 , 1 , "aten.scalar_tensor" )
27282860 require_kwargs (
2729- P .kwargs (n ), {"dtype" , "layout" , "device" , "pin_memory" }, "aten.scalar_tensor"
2861+ kwargs , {"dtype" , "layout" , "device" , "pin_memory" }, "aten.scalar_tensor"
2862+ )
2863+ require_contiguous_format (
2864+ layout = kwargs .get ("layout" ),
2865+ op_name = "aten.scalar_tensor" ,
27302866 )
27312867 scalar_value = args [0 ]
27322868
0 commit comments