Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/mlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ jobs:
ASTYPE_COUNT=$(${CONDA_RUN} python -m executorch.backends.mlx.pte_inspector \
/tmp/qwen35_moe_mlx_tiny/model.pte --mlx-instructions 2>&1 | grep -c "AsTypeNode" || true)
echo "AsType nodes: ${ASTYPE_COUNT}"
if [ "$ASTYPE_COUNT" -gt 23 ]; then
echo "Failed: expected no more than 23 AsType nodes, got ${ASTYPE_COUNT}"
if [ "$ASTYPE_COUNT" -gt 24 ]; then
echo "Failed: expected no more than 24 AsType nodes, got ${ASTYPE_COUNT}"
exit 1
fi
echo "::endgroup::"
Expand Down
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
#
# ==============================================================================

.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu qwen3_5_moe-cuda qwen3_5_moe-metal clean help
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda voxtral_tts-mlx whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu qwen3_5_moe-cuda qwen3_5_moe-metal clean help

help:
@echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make <target>\`. Available targets:"
Expand All @@ -105,6 +105,7 @@ help:
@echo " voxtral_realtime-mlx - Build Voxtral Realtime runner with MLX backend"
@echo " voxtral_tts-cpu - Build Voxtral TTS runner (CPU)"
@echo " voxtral_tts-cuda - Build Voxtral TTS runner with CUDA backend"
@echo " voxtral_tts-mlx - Build Voxtral TTS runner with MLX backend (macOS only)"
@echo " whisper-cuda - Build Whisper runner with CUDA backend"
@echo " whisper-cuda-debug - Build Whisper runner with CUDA backend (debug mode)"
@echo " whisper-cpu - Build Whisper runner with CPU backend"
Expand Down Expand Up @@ -416,6 +417,15 @@ voxtral_tts-cuda:
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/voxtral_tts/voxtral_tts_runner"

voxtral_tts-mlx:
@echo "==> Building and installing ExecuTorch with MLX..."
cmake --workflow --preset mlx-release
@echo "==> Building Voxtral TTS runner with MLX..."
cd examples/models/voxtral_tts && cmake --workflow --preset voxtral-tts-mlx
@echo ""
@echo "✓ Build complete!"
@echo " Binary: cmake-out/examples/models/voxtral_tts/voxtral_tts_runner"

qwen3_5_moe-cuda:
@echo "==> Building and installing ExecuTorch with CUDA..."
cmake --workflow --preset llm-release-cuda
Expand Down
208 changes: 147 additions & 61 deletions backends/mlx/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,82 +1678,115 @@ def _repeat_handler(P: MLXProgramBuilder, n: Node) -> Slot:
return out


@REGISTRY.register(target=[torch.ops.aten.index.Tensor])
def _index_handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 2, 2, "aten.index.Tensor")
require_kwargs(P.kwargs(n), set(), "aten.index.Tensor")
x, idx_list = args
def _index_gather_permutation(
Comment thread
seyeong-han marked this conversation as resolved.
indexed_axes: Set[int],
x_ndim: int,
broadcast_ndim: int,
) -> List[int]:
# PyTorch decomposes F.unfold to aten.index.Tensor, and the Voxtral codec
# uses that path for ConvTranspose1d. Match PyTorch's advanced-index
# placement so MLX gather preserves the codec patch/channel order.
indexed_axes_sorted = sorted(indexed_axes)
expected_rank = broadcast_ndim + x_ndim
is_contiguous = indexed_axes_sorted == list(
range(indexed_axes_sorted[0], indexed_axes_sorted[-1] + 1)
)
if not is_contiguous:
return list(range(expected_rank))

non_indexed_axes = [i for i in range(x_ndim) if i not in indexed_axes]
before_axes = [i for i in non_indexed_axes if i < indexed_axes_sorted[0]]
after_axes = [i for i in non_indexed_axes if i > indexed_axes_sorted[-1]]
return (
[broadcast_ndim + i for i in before_axes]
+ list(range(broadcast_ndim))
+ [broadcast_ndim + i for i in after_axes]
+ [broadcast_ndim + i for i in indexed_axes_sorted]
)


def _non_none_index_tensors(idx_list: Any) -> List[Tuple[int, Slot]]:
if not isinstance(idx_list, list) or len(idx_list) == 0:
raise ValueError(
f"aten.index.Tensor requires a list of index tensors, "
f"got {type(idx_list)}"
f"aten.index.Tensor requires a list of index tensors, got {type(idx_list)}"
)

x_meta = n.args[0].meta.get("val")
x_ndim = len(x_meta.shape) if x_meta is not None else None

# Filter out None indices and track which axes they correspond to
non_none = [(i, idx) for i, idx in enumerate(idx_list) if idx is not None]

if len(non_none) == 0:
raise ValueError("aten.index.Tensor: all indices are None")
return non_none

if len(non_none) == 1:
axis, idx = non_none[0]
idx_meta = n.args[1][axis].meta.get("val")
ndim_match = (
x_meta is not None
and idx_meta is not None
and len(x_meta.shape) == len(idx_meta.shape)

def _emit_single_index_handler(
P: MLXProgramBuilder,
n: Node,
x: Slot,
axis: int,
idx: Slot,
x_meta: Any,
) -> Slot:
idx_meta = n.args[1][axis].meta.get("val")
ndim_match = (
x_meta is not None
and idx_meta is not None
and len(x_meta.shape) == len(idx_meta.shape)
)
out = P.make_or_get_slot(n)
if ndim_match:
# Same ndim: use TakeAlongAxisNode (element-wise gather)
P.emit(
TakeAlongAxisNode(
x=P.slot_to_tid(x),
indices=P.slot_to_tid(idx),
out=P.slot_to_tid(out),
axis=axis,
)
)
out = P.make_or_get_slot(n)
if ndim_match:
# Same ndim: use TakeAlongAxisNode (element-wise gather)
P.emit(
TakeAlongAxisNode(
x=P.slot_to_tid(x),
indices=P.slot_to_tid(idx),
out=P.slot_to_tid(out),
axis=axis,
)
else:
# Different ndim (e.g. 1D indices into 3D tensor): use TakeNode
P.emit(
TakeNode(
x=P.slot_to_tid(x),
index=IntOrVidOrTid.from_tid(P.slot_to_tid(idx)),
out=P.slot_to_tid(out),
axis=axis,
)
else:
# Different ndim (e.g. 1D indices into 3D tensor): use TakeNode
P.emit(
TakeNode(
x=P.slot_to_tid(x),
index=IntOrVidOrTid.from_tid(P.slot_to_tid(idx)),
out=P.slot_to_tid(out),
axis=axis,
)
)
return out


def _index_slice_sizes(x_meta: Any, x_ndim: int, indexed_axes: Set[int]) -> List[int]:
slice_sizes = []
for dim in range(x_ndim):
if dim in indexed_axes:
slice_sizes.append(1)
continue

dim_size = x_meta.shape[dim]
if not isinstance(dim_size, int):
raise ValueError(
f"aten.index.Tensor: non-indexed dimension {dim} has dynamic size "
f"{dim_size}, which is not supported with multi-index gather"
)
return out
slice_sizes.append(dim_size)
return slice_sizes

# Multi-index: use GatherNode (maps to mlx::gather)
if x_meta is None or x_ndim is None:
raise ValueError(
"aten.index.Tensor with multiple indices requires input shape metadata"
)

def _emit_multi_index_handler(
P: MLXProgramBuilder,
n: Node,
x: Slot,
x_meta: Any,
x_ndim: int,
non_none: List[Tuple[int, Slot]],
) -> Slot:
indices = [P.slot_to_tid(idx) for _, idx in non_none]
axes = [i for i, _ in non_none]
indexed_axes = set(axes)

# slice_sizes: 1 for indexed axes, full dim size for non-indexed axes
# Use int() to handle SymInt values from dynamic shapes
indexed_axes = set(axes)
slice_sizes = []
for dim in range(x_ndim):
if dim in indexed_axes:
slice_sizes.append(1)
else:
dim_size = x_meta.shape[dim]
if not isinstance(dim_size, int):
raise ValueError(
f"aten.index.Tensor: non-indexed dimension {dim} has dynamic size "
f"{dim_size}, which is not supported with multi-index gather"
)
slice_sizes.append(dim_size)
slice_sizes = _index_slice_sizes(x_meta, x_ndim, indexed_axes)

# Emit gather — output shape is broadcast(indices).shape + slice_sizes
_, gather_slot = P.make_tmp_slot()
Expand All @@ -1767,26 +1800,79 @@ def _index_handler(P: MLXProgramBuilder, n: Node) -> Slot:
)
)

# Reshape to match aten.index.Tensor output shape, which strips the
# trailing dimensions introduced by gather's slice_sizes
out_meta = n.meta.get("val")
if out_meta is None:
raise ValueError(
"aten.index.Tensor: output shape metadata required for reshape after gather"
)
out_shape = [P.to_int_or_vid(int(d)) for d in out_meta.shape]

# MLX gather returns broadcast(indices).shape followed by one slice
# dimension per input dimension. For contiguous advanced-index groups,
# PyTorch keeps the broadcast dims at the indexed position, so reorder
# before stripping the singleton indexed slice dimensions via reshape.
non_indexed_axes = [i for i in range(x_ndim) if i not in indexed_axes]
broadcast_ndim = len(out_meta.shape) - len(non_indexed_axes)
if broadcast_ndim < 0:
raise ValueError(
"aten.index.Tensor: could not infer broadcast rank for multi-index gather"
)

reshape_input = gather_slot
expected_rank = broadcast_ndim + x_ndim
perm = _index_gather_permutation(indexed_axes, x_ndim, broadcast_ndim)
if len(perm) != expected_rank:
raise ValueError(
f"aten.index.Tensor: internal gather permutation has rank {len(perm)}, "
f"expected {expected_rank}"
)
if perm != list(range(expected_rank)):
_, ordered_slot = P.make_tmp_slot()
P.emit(
TransposeNode(
x=P.slot_to_tid(gather_slot),
out=P.slot_to_tid(ordered_slot),
perm=perm,
)
)
reshape_input = ordered_slot

# Reshape to match aten.index.Tensor output shape, stripping the singleton
# dimensions introduced by gather's slice_sizes for indexed axes.
out = P.make_or_get_slot(n)
P.emit(
ReshapeNode(
x=P.slot_to_tid(gather_slot),
x=P.slot_to_tid(reshape_input),
out=P.slot_to_tid(out),
shape=out_shape,
)
)
return out


@REGISTRY.register(target=[torch.ops.aten.index.Tensor])
def _index_handler(P: MLXProgramBuilder, n: Node) -> Slot:
args = P.args(n)
require_args(args, 2, 2, "aten.index.Tensor")
require_kwargs(P.kwargs(n), set(), "aten.index.Tensor")
x, idx_list = args

x_meta = n.args[0].meta.get("val")
x_ndim = len(x_meta.shape) if x_meta is not None else None
non_none = _non_none_index_tensors(idx_list)

if len(non_none) == 1:
axis, idx = non_none[0]
return _emit_single_index_handler(P, n, x, axis, idx, x_meta)

if x_meta is None or x_ndim is None:
raise ValueError(
"aten.index.Tensor with multiple indices requires input shape metadata"
)

return _emit_multi_index_handler(P, n, x, x_meta, x_ndim, non_none)


@REGISTRY.register(target=[torch.ops.aten.index_select.default])
def _index_select_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle aten.index_select: select elements along an axis using a 1D index tensor.
Expand Down
50 changes: 50 additions & 0 deletions backends/mlx/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

# Import custom ops for RoPE and KV cache tests
from executorch.backends.mlx import ( # noqa: F401 - registers mlx ops # noqa: F401 - registers mlx.rope
Expand Down Expand Up @@ -3608,6 +3609,28 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]:
return (x, indices)


class Unfold1DModel(nn.Module):
"""1D unfold decomposes through aten.index.Tensor in the codec conv path."""

def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.unfold(x.unsqueeze(-1), kernel_size=(3, 1), stride=(1, 1))


@register_test
class Unfold1DTest(OpTestCase):
"""Regression test for aten.index.Tensor order in F.unfold."""

name = "unfold_1d"
rtol = 1e-5
atol = 1e-5

def create_model(self) -> nn.Module:
return Unfold1DModel()

def create_inputs(self) -> Tuple[torch.Tensor, ...]:
return (torch.arange(10, dtype=torch.float32).reshape(1, 2, 5),)


class AdvancedIndexModel(nn.Module):
"""Model that performs advanced (multi-index) tensor indexing.

Expand Down Expand Up @@ -3671,6 +3694,33 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]:
return (x, *indices)


class SeparatedAdvancedIndexModel(nn.Module):
"""Advanced indices separated by a basic slice keep broadcast dims in front."""

def forward(
self, x: torch.Tensor, idx0: torch.Tensor, idx2: torch.Tensor
) -> torch.Tensor:
return x[idx0, :, idx2]


@register_test
class SeparatedAdvancedIndexTest(OpTestCase):
"""Regression test for separated advanced-index dimensions."""

name = "advanced_index_separated"
rtol = 1e-5
atol = 1e-5

def create_model(self) -> nn.Module:
return SeparatedAdvancedIndexModel()

def create_inputs(self) -> Tuple[torch.Tensor, ...]:
x = torch.arange(2 * 3 * 5, dtype=torch.float32).reshape(2, 3, 5)
idx0 = torch.tensor([[0], [1]], dtype=torch.long)
idx2 = torch.tensor([[0, 2, 4]], dtype=torch.long)
return (x, idx0, idx2)


class IndexUpdateModel(nn.Module):
"""Model that performs index_copy on a mutable buffer.

Expand Down
Loading
Loading