@@ -1702,82 +1702,88 @@ def _index_gather_permutation(
17021702 )
17031703
17041704
1705- @REGISTRY .register (target = [torch .ops .aten .index .Tensor ])
1706- def _index_handler (P : MLXProgramBuilder , n : Node ) -> Slot :
1707- args = P .args (n )
1708- require_args (args , 2 , 2 , "aten.index.Tensor" )
1709- require_kwargs (P .kwargs (n ), set (), "aten.index.Tensor" )
1710- x , idx_list = args
1705+ def _non_none_index_tensors (idx_list : Any ) -> List [Tuple [int , Slot ]]:
17111706 if not isinstance (idx_list , list ) or len (idx_list ) == 0 :
17121707 raise ValueError (
1713- f"aten.index.Tensor requires a list of index tensors, "
1714- f"got { type (idx_list )} "
1708+ f"aten.index.Tensor requires a list of index tensors, got { type (idx_list )} "
17151709 )
17161710
1717- x_meta = n .args [0 ].meta .get ("val" )
1718- x_ndim = len (x_meta .shape ) if x_meta is not None else None
1719-
1720- # Filter out None indices and track which axes they correspond to
17211711 non_none = [(i , idx ) for i , idx in enumerate (idx_list ) if idx is not None ]
1722-
17231712 if len (non_none ) == 0 :
17241713 raise ValueError ("aten.index.Tensor: all indices are None" )
1714+ return non_none
17251715
1726- if len (non_none ) == 1 :
1727- axis , idx = non_none [0 ]
1728- idx_meta = n .args [1 ][axis ].meta .get ("val" )
1729- ndim_match = (
1730- x_meta is not None
1731- and idx_meta is not None
1732- and len (x_meta .shape ) == len (idx_meta .shape )
1716+
1717+ def _emit_single_index_handler (
1718+ P : MLXProgramBuilder ,
1719+ n : Node ,
1720+ x : Slot ,
1721+ axis : int ,
1722+ idx : Slot ,
1723+ x_meta : Any ,
1724+ ) -> Slot :
1725+ idx_meta = n .args [1 ][axis ].meta .get ("val" )
1726+ ndim_match = (
1727+ x_meta is not None
1728+ and idx_meta is not None
1729+ and len (x_meta .shape ) == len (idx_meta .shape )
1730+ )
1731+ out = P .make_or_get_slot (n )
1732+ if ndim_match :
1733+ # Same ndim: use TakeAlongAxisNode (element-wise gather)
1734+ P .emit (
1735+ TakeAlongAxisNode (
1736+ x = P .slot_to_tid (x ),
1737+ indices = P .slot_to_tid (idx ),
1738+ out = P .slot_to_tid (out ),
1739+ axis = axis ,
1740+ )
17331741 )
1734- out = P .make_or_get_slot (n )
1735- if ndim_match :
1736- # Same ndim: use TakeAlongAxisNode (element-wise gather)
1737- P .emit (
1738- TakeAlongAxisNode (
1739- x = P .slot_to_tid (x ),
1740- indices = P .slot_to_tid (idx ),
1741- out = P .slot_to_tid (out ),
1742- axis = axis ,
1743- )
1742+ else :
1743+ # Different ndim (e.g. 1D indices into 3D tensor): use TakeNode
1744+ P .emit (
1745+ TakeNode (
1746+ x = P .slot_to_tid (x ),
1747+ index = IntOrVidOrTid .from_tid (P .slot_to_tid (idx )),
1748+ out = P .slot_to_tid (out ),
1749+ axis = axis ,
17441750 )
1745- else :
1746- # Different ndim (e.g. 1D indices into 3D tensor): use TakeNode
1747- P .emit (
1748- TakeNode (
1749- x = P .slot_to_tid (x ),
1750- index = IntOrVidOrTid .from_tid (P .slot_to_tid (idx )),
1751- out = P .slot_to_tid (out ),
1752- axis = axis ,
1753- )
1751+ )
1752+ return out
1753+
1754+
1755+ def _index_slice_sizes (x_meta : Any , x_ndim : int , indexed_axes : Set [int ]) -> List [int ]:
1756+ slice_sizes = []
1757+ for dim in range (x_ndim ):
1758+ if dim in indexed_axes :
1759+ slice_sizes .append (1 )
1760+ continue
1761+
1762+ dim_size = x_meta .shape [dim ]
1763+ if not isinstance (dim_size , int ):
1764+ raise ValueError (
1765+ f"aten.index.Tensor: non-indexed dimension { dim } has dynamic size "
1766+ f"{ dim_size } , which is not supported with multi-index gather"
17541767 )
1755- return out
1768+ slice_sizes .append (dim_size )
1769+ return slice_sizes
17561770
1757- # Multi-index: use GatherNode (maps to mlx::gather)
1758- if x_meta is None or x_ndim is None :
1759- raise ValueError (
1760- "aten.index.Tensor with multiple indices requires input shape metadata"
1761- )
17621771
1772+ def _emit_multi_index_handler (
1773+ P : MLXProgramBuilder ,
1774+ n : Node ,
1775+ x : Slot ,
1776+ x_meta : Any ,
1777+ x_ndim : int ,
1778+ non_none : List [Tuple [int , Slot ]],
1779+ ) -> Slot :
17631780 indices = [P .slot_to_tid (idx ) for _ , idx in non_none ]
17641781 axes = [i for i , _ in non_none ]
1782+ indexed_axes = set (axes )
17651783
17661784 # slice_sizes: 1 for indexed axes, full dim size for non-indexed axes
17671785 # Use int() to handle SymInt values from dynamic shapes
1768- indexed_axes = set (axes )
1769- slice_sizes = []
1770- for dim in range (x_ndim ):
1771- if dim in indexed_axes :
1772- slice_sizes .append (1 )
1773- else :
1774- dim_size = x_meta .shape [dim ]
1775- if not isinstance (dim_size , int ):
1776- raise ValueError (
1777- f"aten.index.Tensor: non-indexed dimension { dim } has dynamic size "
1778- f"{ dim_size } , which is not supported with multi-index gather"
1779- )
1780- slice_sizes .append (dim_size )
1786+ slice_sizes = _index_slice_sizes (x_meta , x_ndim , indexed_axes )
17811787
17821788 # Emit gather — output shape is broadcast(indices).shape + slice_sizes
17831789 _ , gather_slot = P .make_tmp_slot ()
@@ -1841,6 +1847,29 @@ def _index_handler(P: MLXProgramBuilder, n: Node) -> Slot:
18411847 return out
18421848
18431849
1850+ @REGISTRY .register (target = [torch .ops .aten .index .Tensor ])
1851+ def _index_handler (P : MLXProgramBuilder , n : Node ) -> Slot :
1852+ args = P .args (n )
1853+ require_args (args , 2 , 2 , "aten.index.Tensor" )
1854+ require_kwargs (P .kwargs (n ), set (), "aten.index.Tensor" )
1855+ x , idx_list = args
1856+
1857+ x_meta = n .args [0 ].meta .get ("val" )
1858+ x_ndim = len (x_meta .shape ) if x_meta is not None else None
1859+ non_none = _non_none_index_tensors (idx_list )
1860+
1861+ if len (non_none ) == 1 :
1862+ axis , idx = non_none [0 ]
1863+ return _emit_single_index_handler (P , n , x , axis , idx , x_meta )
1864+
1865+ if x_meta is None or x_ndim is None :
1866+ raise ValueError (
1867+ "aten.index.Tensor with multiple indices requires input shape metadata"
1868+ )
1869+
1870+ return _emit_multi_index_handler (P , n , x , x_meta , x_ndim , non_none )
1871+
1872+
18441873@REGISTRY .register (target = [torch .ops .aten .index_select .default ])
18451874def _index_select_handler (P : MLXProgramBuilder , n : Node ) -> Slot :
18461875 """Handle aten.index_select: select elements along an axis using a 1D index tensor.
0 commit comments