Skip to content

Commit 3f98344

Browse files
committed
backends/mlx: reduce index handler lint complexity
1 parent f16101f commit 3f98344

1 file changed

Lines changed: 87 additions & 58 deletions

File tree

backends/mlx/ops.py

Lines changed: 87 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,82 +1702,88 @@ def _index_gather_permutation(
17021702
)
17031703

17041704

1705-
@REGISTRY.register(target=[torch.ops.aten.index.Tensor])
1706-
def _index_handler(P: MLXProgramBuilder, n: Node) -> Slot:
1707-
args = P.args(n)
1708-
require_args(args, 2, 2, "aten.index.Tensor")
1709-
require_kwargs(P.kwargs(n), set(), "aten.index.Tensor")
1710-
x, idx_list = args
1705+
def _non_none_index_tensors(idx_list: Any) -> List[Tuple[int, Slot]]:
17111706
if not isinstance(idx_list, list) or len(idx_list) == 0:
17121707
raise ValueError(
1713-
f"aten.index.Tensor requires a list of index tensors, "
1714-
f"got {type(idx_list)}"
1708+
f"aten.index.Tensor requires a list of index tensors, got {type(idx_list)}"
17151709
)
17161710

1717-
x_meta = n.args[0].meta.get("val")
1718-
x_ndim = len(x_meta.shape) if x_meta is not None else None
1719-
1720-
# Filter out None indices and track which axes they correspond to
17211711
non_none = [(i, idx) for i, idx in enumerate(idx_list) if idx is not None]
1722-
17231712
if len(non_none) == 0:
17241713
raise ValueError("aten.index.Tensor: all indices are None")
1714+
return non_none
17251715

1726-
if len(non_none) == 1:
1727-
axis, idx = non_none[0]
1728-
idx_meta = n.args[1][axis].meta.get("val")
1729-
ndim_match = (
1730-
x_meta is not None
1731-
and idx_meta is not None
1732-
and len(x_meta.shape) == len(idx_meta.shape)
1716+
1717+
def _emit_single_index_handler(
1718+
P: MLXProgramBuilder,
1719+
n: Node,
1720+
x: Slot,
1721+
axis: int,
1722+
idx: Slot,
1723+
x_meta: Any,
1724+
) -> Slot:
1725+
idx_meta = n.args[1][axis].meta.get("val")
1726+
ndim_match = (
1727+
x_meta is not None
1728+
and idx_meta is not None
1729+
and len(x_meta.shape) == len(idx_meta.shape)
1730+
)
1731+
out = P.make_or_get_slot(n)
1732+
if ndim_match:
1733+
# Same ndim: use TakeAlongAxisNode (element-wise gather)
1734+
P.emit(
1735+
TakeAlongAxisNode(
1736+
x=P.slot_to_tid(x),
1737+
indices=P.slot_to_tid(idx),
1738+
out=P.slot_to_tid(out),
1739+
axis=axis,
1740+
)
17331741
)
1734-
out = P.make_or_get_slot(n)
1735-
if ndim_match:
1736-
# Same ndim: use TakeAlongAxisNode (element-wise gather)
1737-
P.emit(
1738-
TakeAlongAxisNode(
1739-
x=P.slot_to_tid(x),
1740-
indices=P.slot_to_tid(idx),
1741-
out=P.slot_to_tid(out),
1742-
axis=axis,
1743-
)
1742+
else:
1743+
# Different ndim (e.g. 1D indices into 3D tensor): use TakeNode
1744+
P.emit(
1745+
TakeNode(
1746+
x=P.slot_to_tid(x),
1747+
index=IntOrVidOrTid.from_tid(P.slot_to_tid(idx)),
1748+
out=P.slot_to_tid(out),
1749+
axis=axis,
17441750
)
1745-
else:
1746-
# Different ndim (e.g. 1D indices into 3D tensor): use TakeNode
1747-
P.emit(
1748-
TakeNode(
1749-
x=P.slot_to_tid(x),
1750-
index=IntOrVidOrTid.from_tid(P.slot_to_tid(idx)),
1751-
out=P.slot_to_tid(out),
1752-
axis=axis,
1753-
)
1751+
)
1752+
return out
1753+
1754+
1755+
def _index_slice_sizes(x_meta: Any, x_ndim: int, indexed_axes: Set[int]) -> List[int]:
1756+
slice_sizes = []
1757+
for dim in range(x_ndim):
1758+
if dim in indexed_axes:
1759+
slice_sizes.append(1)
1760+
continue
1761+
1762+
dim_size = x_meta.shape[dim]
1763+
if not isinstance(dim_size, int):
1764+
raise ValueError(
1765+
f"aten.index.Tensor: non-indexed dimension {dim} has dynamic size "
1766+
f"{dim_size}, which is not supported with multi-index gather"
17541767
)
1755-
return out
1768+
slice_sizes.append(dim_size)
1769+
return slice_sizes
17561770

1757-
# Multi-index: use GatherNode (maps to mlx::gather)
1758-
if x_meta is None or x_ndim is None:
1759-
raise ValueError(
1760-
"aten.index.Tensor with multiple indices requires input shape metadata"
1761-
)
17621771

1772+
def _emit_multi_index_handler(
1773+
P: MLXProgramBuilder,
1774+
n: Node,
1775+
x: Slot,
1776+
x_meta: Any,
1777+
x_ndim: int,
1778+
non_none: List[Tuple[int, Slot]],
1779+
) -> Slot:
17631780
indices = [P.slot_to_tid(idx) for _, idx in non_none]
17641781
axes = [i for i, _ in non_none]
1782+
indexed_axes = set(axes)
17651783

17661784
# slice_sizes: 1 for indexed axes, full dim size for non-indexed axes
17671785
# Use int() to handle SymInt values from dynamic shapes
1768-
indexed_axes = set(axes)
1769-
slice_sizes = []
1770-
for dim in range(x_ndim):
1771-
if dim in indexed_axes:
1772-
slice_sizes.append(1)
1773-
else:
1774-
dim_size = x_meta.shape[dim]
1775-
if not isinstance(dim_size, int):
1776-
raise ValueError(
1777-
f"aten.index.Tensor: non-indexed dimension {dim} has dynamic size "
1778-
f"{dim_size}, which is not supported with multi-index gather"
1779-
)
1780-
slice_sizes.append(dim_size)
1786+
slice_sizes = _index_slice_sizes(x_meta, x_ndim, indexed_axes)
17811787

17821788
# Emit gather — output shape is broadcast(indices).shape + slice_sizes
17831789
_, gather_slot = P.make_tmp_slot()
@@ -1841,6 +1847,29 @@ def _index_handler(P: MLXProgramBuilder, n: Node) -> Slot:
18411847
return out
18421848

18431849

1850+
@REGISTRY.register(target=[torch.ops.aten.index.Tensor])
1851+
def _index_handler(P: MLXProgramBuilder, n: Node) -> Slot:
1852+
args = P.args(n)
1853+
require_args(args, 2, 2, "aten.index.Tensor")
1854+
require_kwargs(P.kwargs(n), set(), "aten.index.Tensor")
1855+
x, idx_list = args
1856+
1857+
x_meta = n.args[0].meta.get("val")
1858+
x_ndim = len(x_meta.shape) if x_meta is not None else None
1859+
non_none = _non_none_index_tensors(idx_list)
1860+
1861+
if len(non_none) == 1:
1862+
axis, idx = non_none[0]
1863+
return _emit_single_index_handler(P, n, x, axis, idx, x_meta)
1864+
1865+
if x_meta is None or x_ndim is None:
1866+
raise ValueError(
1867+
"aten.index.Tensor with multiple indices requires input shape metadata"
1868+
)
1869+
1870+
return _emit_multi_index_handler(P, n, x, x_meta, x_ndim, non_none)
1871+
1872+
18441873
@REGISTRY.register(target=[torch.ops.aten.index_select.default])
18451874
def _index_select_handler(P: MLXProgramBuilder, n: Node) -> Slot:
18461875
"""Handle aten.index_select: select elements along an axis using a 1D index tensor.

0 commit comments

Comments
 (0)