diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index cc83c90e23e..53c7b9360cd 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -9,6 +9,7 @@ on: paths: - .github/workflows/mlx.yml - backends/mlx/** + - extension/llm/export/** workflow_dispatch: permissions: {} @@ -36,7 +37,7 @@ jobs: ${CONDA_RUN} pip list echo "::group::Build test runners" - ${CONDA_RUN} cmake --build cmake-out --target op_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 )) + ${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 )) echo "::endgroup::" echo "::group::Run op unit tests" @@ -51,6 +52,14 @@ jobs: -v echo "::endgroup::" + echo "::group::Run multi-thread stress test" + ${CONDA_RUN} python backends/mlx/test/export_multi_thread_test_model.py /tmp/multi_thread_test_model.pte + ET_TESTING_MODEL_PATH=/tmp/multi_thread_test_model.pte \ + ET_TESTING_NUM_THREADS=50 \ + ET_PREDICTIONS_PER_THREAD=100 \ + ./cmake-out/backends/mlx/test/multi_thread_test_runner + echo "::endgroup::" + backend-tester: strategy: fail-fast: false diff --git a/backends/mlx/builder/op_helpers.py b/backends/mlx/builder/op_helpers.py index 5e082cdf386..40e71e0bdab 100644 --- a/backends/mlx/builder/op_helpers.py +++ b/backends/mlx/builder/op_helpers.py @@ -18,6 +18,11 @@ if TYPE_CHECKING: from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +# When True, always serialize the biases tensor for quantized ops. +# When False, use init-time computation when zero_point is all zeros, +# computing biases = -scales * 2^(bits-1) during the init chain. +QUANTIZED_SERIALIZE_BIASES = False + def get_aten_target(target): """ @@ -168,6 +173,50 @@ def emit_lifted_constant(P: "MLXProgramBuilder", value, dtype: torch.dtype) -> S return slot +def emit_quantized_biases( + P: "MLXProgramBuilder", + zero_point_key: str, + scale: torch.Tensor, + zero_point: torch.Tensor, + bits: int, + B: torch.Tensor, + scale_slot: "Slot", +) -> "Slot": + """Emit biases for quantized ops, computing at init time when possible. + + When zero_point is all zeros and QUANTIZED_SERIALIZE_BIASES is False, + avoids serializing the biases tensor by computing biases = scales * -offset + during the init chain instead. + + Returns the biases Slot. + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import MultiplyNode + from torch._subclasses.fake_tensor import FakeTensor + + is_scale_only = False + if not isinstance(zero_point, FakeTensor): + if torch.sum(torch.abs(zero_point)).item() == 0: + is_scale_only = True + + if QUANTIZED_SERIALIZE_BIASES or not is_scale_only: + return P.make_or_get_constant(f"{zero_point_key}_to_biases", B) + + scale_dtype = scale.dtype + offset = 1 << (bits - 1) + neg_offset = emit_lifted_constant(P, -offset, scale_dtype) + biases = P.make_or_get_constant( + f"{zero_point_key}_to_biases_dummy", torch.tensor(0.0, dtype=B.dtype) + ) + P.emit_init( + MultiplyNode( + a=P.slot_to_tid(scale_slot), + b=P.slot_to_tid(neg_offset), + out=P.slot_to_tid(biases), + ) + ) + return biases + + def to_mlx_qparams( qdata: torch.Tensor, scale: torch.Tensor, @@ -194,21 +243,36 @@ def to_mlx_qparams( """ assert qdata.dtype == torch.int8 offset = 2 ** (bits - 1) - Q = qdata.to(torch.int32) + offset # Pack data tightly into uint32 assert 32 % bits == 0 vals_per_uint32 = 32 // bits assert qdata.shape[1] % vals_per_uint32 == 0 - - Q = Q.reshape(-1, vals_per_uint32) - shifts = torch.arange(0, 32, bits, dtype=torch.int64) - - # Convert to int64 for shift/packing - Q = Q.to(torch.int64) - Q = (Q << shifts).sum(dim=-1) - Q = Q.to(torch.uint32) - Q = Q.reshape(qdata.shape[0], -1) + rows, cols = qdata.shape + + if bits == 4: + # 4-bit: view(uint8) + wrapping add + pack 2 nibbles per byte → view as uint32 + q = qdata.view(torch.uint8) + offset + q3 = q.reshape(rows, cols // 2, 2) + Q = (q3[:, :, 0] | (q3[:, :, 1] << 4)).view(torch.uint32) + elif bits == 2: + # 2-bit: pack 4×2-bit values per byte in uint8, then view as uint32 + Q = ((qdata.view(torch.uint8) + offset) & 0x3).reshape(rows, cols // 4, 4) + packed = Q[:, :, 0] | (Q[:, :, 1] << 2) | (Q[:, :, 2] << 4) | (Q[:, :, 3] << 6) + Q = packed.contiguous().view(torch.uint32) + elif bits == 8: + # 8-bit: each byte maps 1:1 to a uint32 slot — no shifting needed + q = qdata.view(torch.uint8) + offset + Q = q.contiguous().view(torch.uint32).reshape(rows, -1) + else: + # General fallback for other bit widths + Q = (qdata.to(torch.int32) + offset).reshape(-1, vals_per_uint32) + shifts = torch.arange(0, 32, bits, dtype=torch.int32) + shifted = Q << shifts + packed = shifted[:, 0] + for i in range(1, vals_per_uint32): + packed = packed | shifted[:, i] + Q = packed.view(torch.uint32).reshape(rows, -1) if compute_biases: B = -scale * (zero_point.to(scale.dtype) + offset) @@ -217,6 +281,34 @@ def to_mlx_qparams( return Q, None +def parse_dequant_nvfp4_node( + node: Node, +) -> Optional[Tuple[Node, Node, Node, torch.dtype]]: + """Parse a torchao.dequantize_nvfp4 node. + + Returns (qdata, scale, per_tensor_scale, output_dtype) or None if not a + dequantize_nvfp4 node or the custom op is not registered. + """ + target = get_aten_target(node.target) + try: + import executorch.extension.llm.export.nvfp4 # noqa: F401 + except ImportError: + return None + + if target is not torch.ops.torchao.dequantize_nvfp4.default: + return None + + qdata, scale, per_tensor_scale = node.args[0:3] + + output_dtype = torch.float32 + if len(node.args) > 4: + output_dtype = node.args[4] + elif "output_dtype" in node.kwargs: + output_dtype = node.kwargs["output_dtype"] + + return qdata, scale, per_tensor_scale, output_dtype + + def parse_dequant_node( node: Node, ) -> Optional[Tuple[Node, Node, Node, int, int, Optional[torch.dtype], int]]: @@ -244,11 +336,11 @@ def parse_dequant_node( quantized_dim, group_size = non_one[0] if group_size not in [32, 64, 128]: return None - if qmin == -8 and qmax == 7: - bits = 4 - elif qmin == -128 and qmax == 127: - bits = 8 - else: + + # TODO: MLX supports 3, 5, and 7, but we need to figure out the + # packing story in to_mlx_qparams to use them + bits = (qmax - qmin + 1).bit_length() - 1 + if bits not in [2, 4, 8]: return None return qdata, scale, zero_point, group_size, bits, out_dtype, quantized_dim diff --git a/backends/mlx/custom_ops.py b/backends/mlx/custom_ops.py index 81853adbd6d..8ad891e3568 100644 --- a/backends/mlx/custom_ops.py +++ b/backends/mlx/custom_ops.py @@ -13,3 +13,259 @@ These ops are used during model export to represent operations that MLX can execute efficiently but may not have direct PyTorch equivalents. """ + +from typing import Optional + +import torch +from torch import Tensor + + +@torch.library.custom_op("mlx::kv_cache_update", mutates_args=("cache",)) +def kv_cache_update( + cache: Tensor, # [B, H, S_max, D] - mutated in place + new_values: Tensor, # [B, H, S, D] + start_pos: int, + ring_size: int = 0, +) -> Tensor: + """ + Mutating KV cache update that modifies cache in place. + + This op updates the cache at positions [start_pos, start_pos + S) with + new_values. The cache is mutated in place, similar to llama.update_cache. + + Args: + cache: Cache tensor of shape [B, H, S_max, D] (BHSD layout) - mutated + new_values: New values to insert of shape [B, H, S, D] + start_pos: Starting position index for insertion + ring_size: If > 0, treat as ring buffer of this size: write position + is start_pos % ring_size and writes wrap around. If 0 (default), + linear update at start_pos with no wrapping. + + Returns: + A dummy tensor (1,) - the return value is not semantically meaningful + but is required for slot management during export. This follows the + same pattern as llama.update_cache. + + Note: + The BHSD layout matches what torch SDPA expects, avoiding transposition. + """ + seq_len = new_values.size(2) + + if ring_size > 0: + write_pos = start_pos % ring_size + end_pos = write_pos + seq_len + if end_pos <= ring_size: + cache[:, :, write_pos:end_pos, :] = new_values + else: + first_part = ring_size - write_pos + cache[:, :, write_pos:ring_size, :] = new_values[:, :, :first_part, :] + cache[:, :, 0 : seq_len - first_part, :] = new_values[:, :, first_part:, :] + else: + end_pos = start_pos + seq_len + assert end_pos <= cache.size(2), ( + f"kv_cache_update: write [{start_pos}, {end_pos}) exceeds " + f"cache size {cache.size(2)}. Use ring_size > 0 for wrapping." + ) + cache[:, :, start_pos:end_pos, :] = new_values + + return torch.empty((1,), dtype=new_values.dtype, device=new_values.device) + + +@torch.library.register_fake("mlx::kv_cache_update") +def kv_cache_update_fake( + cache: Tensor, + new_values: Tensor, + start_pos: int, + ring_size: int = 0, +) -> Tensor: + """Fake implementation for tracing - returns dummy tensor like llama.update_cache.""" + return torch.empty((1,), dtype=new_values.dtype, device="meta") + + +@torch.library.custom_op("mlx::custom_sdpa", mutates_args=()) +def mlx_custom_sdpa( + query: Tensor, # [B, num_heads, seq_len, head_dim] - BHSD + key: Tensor, # [B, num_kv_heads, kv_len, head_dim] - BHSD (FULL cache) + value: Tensor, # [B, num_kv_heads, kv_len, head_dim] - BHSD (FULL cache) + start_pos: int, # FIRST position in current batch (0-indexed) + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, +) -> Tensor: + """ + MLX custom SDPA with K/V cache slicing. + + This op uses BHSD layout (matching PyTorch SDPA and MLX's SdpaNode). + It receives the FULL K/V cache and slices to [0:stop_pos] before computing + attention, where stop_pos = start_pos + query_seq_len. + + The semantics follow executorch's llama.custom_sdpa: + - start_pos: FIRST position of the current query batch + - For prefill with 7 tokens at positions [0,1,2,3,4,5,6]: start_pos=0, stop_pos=7 + - For decode at position 10: start_pos=10, stop_pos=11 + + Args: + query: Query tensor [B, num_heads, seq_len, head_dim] + key: Key cache [B, num_kv_heads, kv_len, head_dim] - FULL cache + value: Value cache [B, num_kv_heads, kv_len, head_dim] - FULL cache + start_pos: FIRST position in current batch (SymInt) + attn_mask: Optional attention mask (only used when is_causal=False) + dropout_p: Dropout probability (default 0.0) + is_causal: Whether to apply causal masking (default False) + scale: Attention scale factor (default 1/sqrt(head_dim)) + + Returns: + Attention output [B, num_heads, seq_len, head_dim] - BHSD + """ + if scale is None: + scale = query.shape[-1] ** -0.5 + + # Compute stop_pos = start_pos + query_seq_len + # BHSD layout: seq_len is at dim 2 + query_seq_len = query.shape[2] + stop_pos = start_pos + query_seq_len + + # Constrain symbolic shapes so torch.export can resolve guards. + # start_pos is data-dependent (from input_pos), so the slice + # stop_pos > kv_len comparison is unresolvable without these hints. + torch._check(start_pos >= 0) + torch._check(stop_pos <= key.shape[2]) + + # Slice K/V to valid cache entries [0:stop_pos] + key_sliced = key[:, :, :stop_pos, :] + value_sliced = value[:, :, :stop_pos, :] + + # Handle GQA: expand K/V heads to match query heads + num_heads = query.shape[1] + num_kv_heads = key.shape[1] + if num_kv_heads != num_heads: + num_groups = num_heads // num_kv_heads + key_sliced = key_sliced.repeat_interleave(num_groups, dim=1) + value_sliced = value_sliced.repeat_interleave(num_groups, dim=1) + + # Build explicit lower-right aligned causal mask to match MLX's SdpaNode. + # PyTorch's is_causal=True uses upper-left alignment when Q_len != K_len, + # but for KV-cache inference q[i] is at context position (start_pos + i) + # and should attend to all positions 0..start_pos+i (lower-right). + if is_causal: + L, S = query.shape[2], key_sliced.shape[2] + offset = S - L # equals start_pos + mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril( + diagonal=offset + ) + attn_mask = torch.where(mask, 0.0, float("-inf")).to(query.dtype) + + # Compute SDPA - returns BHSD + return torch.nn.functional.scaled_dot_product_attention( + query, + key_sliced, + value_sliced, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=False, + scale=scale, + ) + + +@torch.library.register_fake("mlx::custom_sdpa") +def mlx_custom_sdpa_fake( + query: Tensor, + key: Tensor, + value: Tensor, + start_pos: int, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, +) -> Tensor: + """Fake implementation for tracing - returns BHSD shape (same as query).""" + return query.new_empty(query.shape) + + +@torch.library.custom_op("mlx::rope", mutates_args=()) +def rope( + x: Tensor, # (B, H, T, D) + dims: int, + pos: int, # int, not tensor + traditional: bool = False, + base: float = 500000.0, + scale: float = 1.0, + freqs: Optional[Tensor] = None, +) -> Tensor: + """ + Apply Rotary Position Embedding to a single tensor. + + Args: + x: Input tensor of shape (B, H, T, D) + dims: Number of feature dimensions to rotate. If less than D, + only the first `dims` dimensions are rotated and the rest + are left unchanged. + pos: Starting position index (int, not tensor) + traditional: Whether to use traditional RoPE formulation + base: Base for frequency computation + scale: Scale factor for frequencies + freqs: Optional precomputed frequencies + + Returns: + Rotated tensor of the same shape + """ + Dh = int(dims) + + B, H, T, _ = x.shape + half = Dh // 2 + + if freqs is None: + # [1, 1, 1, half] to broadcast over B,H,T + i = torch.arange(half, device=x.device, dtype=torch.float32) + inv_freq = (base ** (-2.0 * i / Dh)).view(1, 1, 1, half) + + # positions: [1, 1, T, 1] + pos_range = torch.arange( + pos, pos + T, device=x.device, dtype=torch.float32 + ).view(1, 1, T, 1) + + # final angles: [1, 1, T, half] + angles = (pos_range * inv_freq) * float(scale) + else: + # assume freqs is already per-position, just reshape to [1,1,T,half] + angles = freqs.to(torch.float32).view(1, 1, T, half) + + cos = angles.cos().to(x.dtype) # [1,1,T,half] + sin = angles.sin().to(x.dtype) # [1,1,T,half] + + # Split into rotated and unrotated portions + x_rot = x[..., :Dh] + x_pass = x[..., Dh:] + + if traditional: + # Interleaved pairs: (x[0],x[1]), (x[2],x[3]), ... + x1 = x_rot[..., 0::2] # even indices + x2 = x_rot[..., 1::2] # odd indices + xr = x1 * cos - x2 * sin + xi = x1 * sin + x2 * cos + rotated = torch.stack([xr, xi], dim=-1).flatten(-2) + else: + # Split-half: first half paired with second half + x1, x2 = x_rot[..., :half], x_rot[..., half:] + xr = x1 * cos - x2 * sin + xi = x1 * sin + x2 * cos + rotated = torch.cat([xr, xi], dim=-1) + + if x_pass.shape[-1] > 0: + return torch.cat([rotated, x_pass], dim=-1) + return rotated + + +@torch.library.register_fake("mlx::rope") +def rope_fake( + x: Tensor, + dims: int, + pos: int, + traditional: bool = False, + base: float = 500000.0, + scale: float = 1.0, + freqs: Optional[Tensor] = None, +) -> Tensor: + """Fake implementation for tracing.""" + return x.new_empty(x.shape) diff --git a/backends/mlx/llm/__init__.py b/backends/mlx/llm/__init__.py new file mode 100644 index 00000000000..f557ef26c5b --- /dev/null +++ b/backends/mlx/llm/__init__.py @@ -0,0 +1,6 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/llm/cache.py b/backends/mlx/llm/cache.py new file mode 100644 index 00000000000..9709980689b --- /dev/null +++ b/backends/mlx/llm/cache.py @@ -0,0 +1,429 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Shared KV cache utilities for MLX delegate examples. + +Provides reusable KV cache implementations optimized for the MLX backend: +""" + +from typing import Tuple + +import torch +import torch.nn as nn + +# Import MLX custom ops to register mlx::kv_cache_update +from executorch.backends.mlx import custom_ops as _mlx_custom_ops # noqa: F401 + + +class KVCache(nn.Module): + """ + MLX-optimized KV cache with ExecutorTorch llama KVCache interface. + + This class follows the same interface as examples/models/llama/attention.py KVCache, + making it a drop-in replacement, but uses the mlx::kv_cache_update op internally + which is optimized for the MLX delegate. + + The cache uses BHSD layout [B, H, S, D] which matches what torch SDPA expects. + + The ``update`` method accepts ``input_pos`` as either a ``torch.Tensor`` or a + plain ``int`` / SymInt. When a tensor is passed, ``item()`` is called internally + to extract the start position, which introduces an unbacked SymInt during + ``torch.export``. Extracting a SymInt has a cost because it creates a new + symbolic variable and associated constraints in the exported program. In a + multi-layer model, prefer extracting the SymInt once and passing the resulting + int/SymInt to every layer's ``update`` call rather than passing the tensor + repeatedly: + + .. code-block:: python + + # Preferred: extract once, pass to all layers + start_pos = input_pos[0].item() + for layer_cache in caches: + layer_cache.update(start_pos, k_val, v_val) + + # Avoid: each layer re-extracts from the tensor + for layer_cache in caches: + layer_cache.update(input_pos, k_val, v_val) + + Example: + >>> cache = KVCache( + ... max_batch_size=1, + ... max_context_length=4096, + ... n_heads=32, + ... head_dim=128, + ... enable_dynamic_shape=True, + ... ) + >>> # With tensor input_pos + >>> input_pos = torch.tensor([0]) + >>> k_val = torch.randn(1, 32, 10, 128) # [B, H, S, D] + >>> v_val = torch.randn(1, 32, 10, 128) # [B, H, S, D] + >>> k_cache, v_cache = cache.update(input_pos, k_val, v_val) + >>> + >>> # With int/SymInt input_pos (preferred in multi-layer loops) + >>> start_pos = input_pos[0].item() + >>> k_cache, v_cache = cache.update(start_pos, k_val, v_val) + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool, + dtype: torch.dtype = torch.float32, + ): + """ + Initialize KV cache buffers. + + Args: + max_batch_size: Maximum batch size + max_context_length: Maximum sequence length the cache can hold + n_heads: Number of attention heads (key/value heads for GQA) + head_dim: Dimension per head + enable_dynamic_shape: Whether dynamic shapes are enabled (kept for interface + compatibility, but MLX always uses dynamic-style update) + dtype: Data type for cache buffers + """ + super().__init__() + assert ( + max_batch_size == 1 + ), f"Only max_batch_size=1 is supported, but got {max_batch_size}" + self.max_batch_size = max_batch_size + self.max_context_length = max_context_length + self.n_heads = n_heads + self.head_dim = head_dim + self.enable_dynamic_shape = enable_dynamic_shape + + # Initialize cache buffers [B, H, T_max, D] - BHSD layout + cache_shape = (max_batch_size, n_heads, max_context_length, head_dim) + self.register_buffer( + "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + ) + self.register_buffer( + "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + ) + + def update( + self, input_pos: torch.Tensor | int, k_val: torch.Tensor, v_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update cache with new K/V states and return FULL cache. + + This method follows the same signature as examples/models/llama/attention.py KVCache. + + Args: + input_pos: Start position — either a position tensor [S] or an int/SymInt + k_val: New key states [B, H, S, D] + v_val: New value states [B, H, S, D] + + Returns: + Tuple of (k_cache, v_cache) - slices of the FULL cache buffers + """ + + if isinstance(input_pos, torch.Tensor): + start_pos = input_pos[0].item() + seq_len = k_val.size(2) + torch._check(seq_len == v_val.size(2)) + torch._check(start_pos >= 0) + torch._check(start_pos + seq_len <= self.max_context_length) + else: + start_pos = input_pos + + torch.ops.mlx.kv_cache_update(self.k_cache, k_val, start_pos) + torch.ops.mlx.kv_cache_update(self.v_cache, v_val, start_pos) + + # Return full slices of the cache (creates new tensor nodes in the graph) + # This avoids the issue where the same tensor is both BUFFER_MUTATION and USER_OUTPUT + return self.k_cache[:, :, :, :], self.v_cache[:, :, :, :] + + +class RingBufferKVCache(nn.Module): + """ + Ring buffer KV cache for sliding window attention. + + Instead of a linear cache that fills up and stops, this cache wraps around: + write_pos = start_pos % window_size. When the cache is full, new tokens + overwrite the oldest ones, enabling infinite-length generation. + + The attention mask is computed branchlessly from ``start_pos`` and + ``window_size`` alone using ``torch.where`` — no mutable position-tracking + buffers and no Python if/else that would create torch.export guards. + + Mask creation is NOT done here — following optimum-executorch's pattern, + the attention function creates the mask lazily by accessing the cache + via a closure. This avoids tracing issues with torch.export. + + Layout: BHSD [batch_size, num_heads, window_size, head_dim] + + Example: + >>> cache = RingBufferKVCache( + ... max_batch_size=1, + ... max_context_length=512, + ... n_heads=4, + ... head_dim=256, + ... dtype=torch.bfloat16, + ... ) + >>> k_val = torch.randn(1, 4, 1, 256) + >>> v_val = torch.randn(1, 4, 1, 256) + >>> k_cache, v_cache = cache.update(start_pos=0, k_val=k_val, v_val=v_val) + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + dtype: torch.dtype = torch.float32, + ): + super().__init__() + assert ( + max_batch_size == 1 + ), f"Only max_batch_size=1 is supported, but got {max_batch_size}" + self.max_batch_size = max_batch_size + self.max_context_length = max_context_length + self.window_size = max_context_length + self.buffer_size = 2 * max_context_length + self.n_heads = n_heads + self.head_dim = head_dim + + # Cache buffers [B, H, 2*window_size, D] + # 2× buffer ensures multi-token writes never overwrite data that + # earlier queries in the same batch still need (matches ET behavior). + cache_shape = (max_batch_size, n_heads, self.buffer_size, head_dim) + self.register_buffer( + "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + ) + self.register_buffer( + "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") + ) + + def update( + self, input_pos: torch.Tensor | int, k_val: torch.Tensor, v_val: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update cache with new K/V states using ring buffer semantics. + + Args: + input_pos: Start position — either a position tensor [S] or an int/SymInt + k_val: New key states [B, H, S, D] + v_val: New value states [B, H, S, D] + + Returns: + Tuple of (k_cache, v_cache) — full ring buffer slices + """ + if isinstance(input_pos, torch.Tensor): + start_pos = input_pos[0].item() + seq_len = k_val.size(2) + torch._check(seq_len == v_val.size(2)) + torch._check(start_pos >= 0) + torch._check(seq_len <= self.window_size) + else: + start_pos = input_pos + + torch.ops.mlx.kv_cache_update( + self.k_cache, k_val, start_pos, ring_size=self.buffer_size + ) + torch.ops.mlx.kv_cache_update( + self.v_cache, v_val, start_pos, ring_size=self.buffer_size + ) + + return self.k_cache[:, :, :, :], self.v_cache[:, :, :, :] + + def create_sliding_window_mask(self, start_pos: int, seq_len: int) -> torch.Tensor: + """ + Build attention mask for the ring buffer — branchless, no mutable state. + + Reconstructs the slot→position mapping from ``start_pos`` and + ``buffer_size`` alone using ``torch.where``, avoiding both Python + if/else (which creates torch.export guards) and mutable position- + tracking buffers (which require extra kv_cache_update calls and + complicate partitioning). + + Returns: + Additive mask [1, 1, seq_len, buffer_size] in the cache's dtype, + where 0 = attend, -inf = block. + """ + w = self.window_size + b = self.buffer_size + end_pos = start_pos + seq_len + + # Slot indices [buffer_size] + slots = torch.arange(b, dtype=torch.long) + + last_write_slot = (end_pos - 1) % b + current_cycle_base = end_pos - 1 - last_write_slot + pos_current = current_cycle_base + slots + pos_previous = current_cycle_base - b + slots + + cache_pos = torch.where(slots <= last_write_slot, pos_current, pos_previous) + + # Query positions [seq_len, 1] + pos_q = (start_pos + torch.arange(seq_len, dtype=torch.long)).view(-1, 1) + + # Delta from query to each cached position [seq_len, buffer_size] + delta = pos_q - cache_pos + + # A slot is attendable if: filled (pos >= 0), causal (delta >= 0), + # and within the sliding window (delta < w) + attn_mask = (cache_pos >= 0) & (delta >= 0) & (delta < w) + + # Use cache dtype (e.g. bf16) to avoid float32 AsTypeNode casts in SDPA + dtype = self.k_cache.dtype + zero = torch.zeros(1, dtype=dtype) + neg_inf = torch.full((1,), float("-inf"), dtype=dtype) + return torch.where(attn_mask, zero, neg_inf).unsqueeze(0).unsqueeze(0) + + +from transformers.cache_utils import StaticCache + + +class HFStaticCache(StaticCache): + """ + MLX-optimized Static KV Cache that follows HuggingFace's StaticCache interface. + + This cache is designed to be a drop-in replacement for HuggingFace's StaticCache + when exporting models for the MLX backend. It uses mlx::kv_cache_update internally + which is optimized for the MLX delegate. + + The cache supports multi-layer models by maintaining separate K/V buffers per layer, + matching the HF StaticCache behavior where `update()` takes a `layer_idx` argument. + + Layout: BHSD [batch_size, num_heads, max_cache_len, head_dim] + + Example: + >>> from transformers import AutoConfig + >>> config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> cache = HFStaticCache(config, max_batch_size=1, max_cache_len=4096) + >>> # In attention layer: + >>> k_out, v_out = cache.update(k_states, v_states, layer_idx=0, + ... cache_kwargs={"cache_position": pos_tensor}) + """ + + def __init__( + self, + config, + max_batch_size: int = 1, + max_cache_len: int | None = None, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.float32, + ): + """ + Initialize MLX Static Cache. + + Args: + config: HuggingFace model config with num_hidden_layers, num_key_value_heads, + num_attention_heads, hidden_size, and optionally head_dim + max_batch_size: Maximum batch size (default: 1) + max_cache_len: Maximum cache length. If None, uses config.max_position_embeddings + device: Device for cache tensors (default: None = CPU) + dtype: Data type for cache tensors (default: torch.float32) + """ + # Resolve dimensions from config BEFORE calling parent + num_layers = config.num_hidden_layers + num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + actual_max_cache_len = max_cache_len or getattr( + config, "max_position_embeddings", 2048 + ) + + # Initialize parent StaticCache with required arguments + super().__init__( + config=config, + max_batch_size=max_batch_size, + max_cache_len=actual_max_cache_len, + device=device, + dtype=dtype, + ) + # Call early_initialization to ensure parent's layers are fully initialized + self.early_initialization( + batch_size=max_batch_size, + num_heads=num_heads, + head_dim=head_dim, + dtype=dtype, + device=device, + ) + + # Store dimensions as instance attributes + self.num_layers = num_layers + self.num_heads = num_heads + self.head_dim = head_dim + + # Create KVCache wrappers for each layer - these use mlx::kv_cache_update + # Named 'kv_cache' to match optimum-executorch's ETCustomStaticCache pattern + self.kv_cache = nn.ModuleList( + [ + KVCache( + max_batch_size=max_batch_size, + max_context_length=actual_max_cache_len, + n_heads=num_heads, + head_dim=head_dim, + enable_dynamic_shape=True, + dtype=dtype, + ) + for _ in range(num_layers) + ] + ) + + # Move to device if specified + if device is not None: + self.to(device) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update the cache with new key/value states for a specific layer. + + This method follows HuggingFace's StaticCache.update() signature. + + Args: + key_states: New key states [batch_size, num_heads, seq_len, head_dim] + value_states: New value states [batch_size, num_heads, seq_len, head_dim] + layer_idx: Index of the layer to update + cache_kwargs: Dictionary containing 'cache_position' tensor with start position + + Returns: + Tuple of (key_cache, value_cache) for the full cache after update + """ + assert ( + cache_kwargs is not None + ), "cache_kwargs must be provided with 'cache_position'" + cache_position = cache_kwargs.get("cache_position") + assert ( + cache_position is not None + ), "cache_position must be provided in cache_kwargs" + assert isinstance( + cache_position, torch.Tensor + ), "cache_position must be a tensor" + + # Pass cache_position tensor directly to KVCache.update() + # KVCache extracts start_pos internally via input_pos[0].item() + return self.kv_cache[layer_idx].update(cache_position, key_states, value_states) + + def get_seq_length(self, layer_idx: int = 0) -> int: + """Approximate sequence length (counts non-zero cache positions).""" + k_cache = self.kv_cache[layer_idx].k_cache + # Check if any value in the head_dim is non-zero for each position + return (k_cache[0, 0, :, 0] != 0).sum().item() + + def get_max_cache_shape(self, layer_idx: int = 0) -> int: + return self.max_cache_len + + def reset(self): + for layer_cache in self.kv_cache: + layer_cache.k_cache.zero_() + layer_cache.v_cache.zero_() diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 4c9e0d6f796..439d4569313 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -19,10 +19,142 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch +from executorch.backends.mlx.builder.op_helpers import ( + emit_lifted_constant, + emit_quantized_biases, + parse_dequant_node, + to_mlx_qparams, + torch_dtype_to_scalar_type, +) from executorch.backends.mlx.builder.op_registry import REGISTRY from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder -from executorch.backends.mlx.builder.slot_manager import Slot -from executorch.backends.mlx.serialization.mlx_graph_schema import AddmmNode +from executorch.backends.mlx.builder.slot_manager import IdType, Slot +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + AbsNode, + AddIntNode, + AddmmNode, + AddNode, + AllNode, + AnyNode, + ARangeNode, + ArccoshNode, + ArccosNode, + ArcsinhNode, + ArcsinNode, + ArctanhNode, + ArctanNode, + ArgmaxNode, + ArgminNode, + ArgPartitionNode, + ArgsortNode, + AsStridedNode, + AsTypeNode, + Atan2Node, + BroadcastToNode, + CeilNode, + ClipNode, + ConcatenateNode, + ContiguousNode, + Conv1DNode, + Conv2DNode, + Conv3DNode, + ConvTranspose1DNode, + ConvTranspose2DNode, + ConvTranspose3DNode, + CoshNode, + CosNode, + CumsumNode, + DequantizeNode, + DivideNode, + EqualNode, + ErfNode, + ExpandDimsNode, + Expm1Node, + ExpNode, + FloatOrVid, + FloorDivideIntNode, + FloorDivideNode, + FloorNode, + FullLikeNode, + FullNode, + GatherNode, + GeluNode, + GreaterEqualNode, + GreaterNode, + IdCopyNode, + IntOrVid, + IntOrVidOrTid, + ItemIntNode, + LayerNormNode, + LessEqualNode, + LessNode, + Log10Node, + Log1pNode, + Log2Node, + LogAddExpNode, + LogicalAndNode, + LogicalNotNode, + LogicalOrNode, + LogNode, + LogSumExpNode, + MaximumNode, + MaxNode, + MeanNode, + MinimumNode, + MinNode, + ModIntNode, + MultiplyIntNode, + MultiplyNode, + NegNode, + NotEqualNode, + PadNode, + PartitionNode, + PowerNode, + ProdNode, + ReciprocalNode, + RemainderNode, + RepeatNode, + ReshapeNode, + RMSNormNode, + RopeNode, + RoundNode, + RsqrtNode, + SigmoidNode, + SignNode, + SiluNode, + SinhNode, + SinNode, + SliceNode, + SliceUpdateNode, + SoftmaxNode, + SortNode, + SplitNode, + SqrtNode, + SquareNode, + SqueezeNode, + StackNode, + StdNode, + SubtractIntNode, + SubtractNode, + SumNode, + SymSizeNode, + TakeAlongAxisNode, + TakeNode, + TanhNode, + TanNode, + TileNode, + TransposeNode, + TrilNode, + TriuNode, + VarNode, + VidOrTid, + WhereNode, +) + +# The coding style is for handlers to register against aten targets +# The corresponding edge ops are automatically registered +# For ops that are not in aten (e.g., dim order ops), directly register on exir_ops +from executorch.exir.dialects._ops import ops as exir_ops from torch.fx.node import Node @@ -219,12 +351,434 @@ def normalize_reduction_dim( return dim, keepdim +_UNARY_OPS: List[Tuple[Any, Any, str]] = [ + # Activations + (torch.ops.aten.silu.default, SiluNode, "aten.silu"), + (torch.ops.aten.sigmoid.default, SigmoidNode, "aten.sigmoid"), + (torch.ops.aten.tanh.default, TanhNode, "aten.tanh"), + # Reciprocal square root + (torch.ops.aten.rsqrt.default, RsqrtNode, "aten.rsqrt"), + # Rounding + (torch.ops.aten.floor.default, FloorNode, "aten.floor"), + (torch.ops.aten.ceil.default, CeilNode, "aten.ceil"), + # Powers / roots + (torch.ops.aten.square.default, SquareNode, "aten.square"), + (torch.ops.aten.exp.default, ExpNode, "aten.exp"), + (torch.ops.aten.sqrt.default, SqrtNode, "aten.sqrt"), + (torch.ops.aten.reciprocal.default, ReciprocalNode, "aten.reciprocal"), + # Trigonometric + (torch.ops.aten.sin.default, SinNode, "aten.sin"), + (torch.ops.aten.cos.default, CosNode, "aten.cos"), + (torch.ops.aten.tan.default, TanNode, "aten.tan"), + (torch.ops.aten.asin.default, ArcsinNode, "aten.asin"), + (torch.ops.aten.acos.default, ArccosNode, "aten.acos"), + (torch.ops.aten.atan.default, ArctanNode, "aten.atan"), + # Hyperbolic + (torch.ops.aten.sinh.default, SinhNode, "aten.sinh"), + (torch.ops.aten.cosh.default, CoshNode, "aten.cosh"), + (torch.ops.aten.asinh.default, ArcsinhNode, "aten.asinh"), + (torch.ops.aten.acosh.default, ArccoshNode, "aten.acosh"), + (torch.ops.aten.atanh.default, ArctanhNode, "aten.atanh"), + # Logarithmic + (torch.ops.aten.log.default, LogNode, "aten.log"), + (torch.ops.aten.log2.default, Log2Node, "aten.log2"), + (torch.ops.aten.log10.default, Log10Node, "aten.log10"), + (torch.ops.aten.log1p.default, Log1pNode, "aten.log1p"), + # Special + (torch.ops.aten.erf.default, ErfNode, "aten.erf"), + (torch.ops.aten.expm1.default, Expm1Node, "aten.expm1"), + # Sign / magnitude + (torch.ops.aten.abs.default, AbsNode, "aten.abs"), + (torch.ops.aten.neg.default, NegNode, "aten.neg"), + (torch.ops.aten.sign.default, SignNode, "aten.sign"), + # Logical + (torch.ops.aten.logical_not.default, LogicalNotNode, "aten.logical_not"), +] + + +def _make_unary_handler(node_cls: Any, op_name: str): + """Create a handler for a simple unary op: x → node_cls(x, out).""" + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 1, 1, op_name) + require_kwargs(P.kwargs(n), set(), op_name) + x = args[0] + out = P.make_or_get_slot(n) + P.emit(node_cls(x=P.slot_to_tid(x), out=P.slot_to_tid(out))) + return out + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (table-driven unary op)." + return handler + + +for _target, _node_cls, _op_name in _UNARY_OPS: + REGISTRY.register(target=[_target])(_make_unary_handler(_node_cls, _op_name)) + + +_BINARY_OPS: List[Tuple[List[Any], Any, str, bool]] = [ + ( + [torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar], + MultiplyNode, + "aten.mul", + True, + ), + ( + [torch.ops.aten.div.Tensor, torch.ops.aten.div.Scalar], + DivideNode, + "aten.div", + True, + ), + ( + [torch.ops.aten.remainder.Tensor, torch.ops.aten.remainder.Scalar], + RemainderNode, + "aten.remainder", + True, + ), + ( + [torch.ops.aten.pow.Tensor_Tensor, torch.ops.aten.pow.Tensor_Scalar], + PowerNode, + "aten.pow", + True, + ), + ( + [torch.ops.aten.floor_divide.default], + FloorDivideNode, + "aten.floor_divide", + False, + ), + ([torch.ops.aten.maximum.default], MaximumNode, "aten.maximum", False), + ([torch.ops.aten.minimum.default], MinimumNode, "aten.minimum", False), + ([torch.ops.aten.atan2.default], Atan2Node, "aten.atan2", False), + ([torch.ops.aten.logaddexp.default], LogAddExpNode, "aten.logaddexp", False), + ([torch.ops.aten.logical_or.default], LogicalOrNode, "aten.logical_or", False), + ( + [torch.ops.aten.lt.Tensor, torch.ops.aten.lt.Scalar], + LessNode, + "aten.lt", + True, + ), + ( + [torch.ops.aten.le.Tensor, torch.ops.aten.le.Scalar], + LessEqualNode, + "aten.le", + True, + ), + ( + [torch.ops.aten.gt.Tensor, torch.ops.aten.gt.Scalar], + GreaterNode, + "aten.gt", + True, + ), + ( + [torch.ops.aten.ge.Tensor, torch.ops.aten.ge.Scalar], + GreaterEqualNode, + "aten.ge", + True, + ), + ( + [torch.ops.aten.eq.Tensor, torch.ops.aten.eq.Scalar], + EqualNode, + "aten.eq", + True, + ), + ( + [torch.ops.aten.ne.Tensor, torch.ops.aten.ne.Scalar], + NotEqualNode, + "aten.ne", + True, + ), +] + + +def _make_binary_handler(node_cls: Any, op_name: str, lift_b: bool): + """Create a handler for a binary op: (a, b) -> node_cls(a, b, out). + + When lift_b is True, scalar b values are lifted to 0-D constant tensors + via emit_lifted_constant, using a's dtype. + """ + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, op_name) + require_kwargs(P.kwargs(n), set(), op_name) + a, b = args[0], args[1] + if lift_b and (not isinstance(b, Slot) or b.id_type != IdType.Tensor): + input_meta = n.args[0].meta.get("val") + dtype = input_meta.dtype if input_meta is not None else torch.float32 + b = emit_lifted_constant(P, b, dtype) + out = P.make_or_get_slot(n) + P.emit(node_cls(a=P.slot_to_tid(a), b=P.slot_to_tid(b), out=P.slot_to_tid(out))) + return out + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (table-driven binary op)." + return handler + + +for _targets, _node_cls, _op_name, _lift_b in _BINARY_OPS: + REGISTRY.register(target=_targets)( + _make_binary_handler(_node_cls, _op_name, _lift_b) + ) + + +_SCALAR_INT_OPS: List[Tuple[Any, Any, str]] = [ + (operator.add, AddIntNode, "operator.add"), + (operator.sub, SubtractIntNode, "operator.sub"), + (operator.mul, MultiplyIntNode, "operator.mul"), + (operator.floordiv, FloorDivideIntNode, "operator.floordiv"), + (operator.mod, ModIntNode, "operator.mod"), +] + + +def _make_scalar_int_handler(node_cls: Any, op_name: str): + """Create a handler for a scalar int op: (a, b) -> node_cls(a, b, out).""" + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, op_name) + require_kwargs(P.kwargs(n), set(), op_name) + a, b = args + out = P.make_or_get_slot(n) + P.emit( + node_cls( + a=P.to_int_or_vid(a), + b=P.to_int_or_vid(b), + out=P.slot_to_vid(out), + ) + ) + return out + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (table-driven scalar int op)." + return handler + + +for _target, _node_cls, _op_name in _SCALAR_INT_OPS: + REGISTRY.register(target=[_target])(_make_scalar_int_handler(_node_cls, _op_name)) + + +_REDUCTION_OPS: List[Tuple[List[Any], Any, str, int]] = [ + ( + [torch.ops.aten.sum.dim_IntList, torch.ops.aten.sum.default], + SumNode, + "aten.sum", + 4, + ), + ([torch.ops.aten.mean.dim, torch.ops.aten.mean.default], MeanNode, "aten.mean", 4), + ( + [torch.ops.aten.prod.dim_int, torch.ops.aten.prod.default], + ProdNode, + "aten.prod", + 4, + ), + ([torch.ops.aten.amax.default], MaxNode, "aten.amax", 3), + ([torch.ops.aten.amin.default], MinNode, "aten.amin", 3), + ([torch.ops.aten.any.dim, torch.ops.aten.any.default], AnyNode, "aten.any", 3), + ([torch.ops.aten.all.dim, torch.ops.aten.all.default], AllNode, "aten.all", 3), +] + + +def _make_reduction_handler(node_cls: Any, op_name: str, max_args: int): + """Create a handler for a reduction op: x -> node_cls(x, out, axes, keepdims).""" + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 1, max_args, op_name) + require_kwargs(P.kwargs(n), set(), op_name) + x = args[0] + axes, keepdim = normalize_reduction_dim(args) + out = P.make_or_get_slot(n) + P.emit( + node_cls( + x=P.slot_to_tid(x), out=P.slot_to_tid(out), axes=axes, keepdims=keepdim + ) + ) + return out + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (table-driven reduction op)." + return handler + + +for _targets, _node_cls, _op_name, _max_args in _REDUCTION_OPS: + REGISTRY.register(target=_targets)( + _make_reduction_handler(_node_cls, _op_name, _max_args) + ) + + +_FULL_OPS: List[Tuple[List[Any], str, Optional[float]]] = [ + ([torch.ops.aten.full.default], "aten.full", None), + ([torch.ops.aten.zeros.default], "aten.zeros", 0.0), + ([torch.ops.aten.ones.default], "aten.ones", 1.0), +] + + +def _make_full_handler(op_name: str, fixed_fill: Optional[float]): + """Create a handler for full/zeros/ones: shape -> FullNode(shape, v, dtype).""" + + has_fill_arg = fixed_fill is None + n_args = 2 if has_fill_arg else 1 + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, n_args, n_args, op_name) + kwargs = P.kwargs(n) + require_kwargs(kwargs, {"dtype", "layout", "device", "pin_memory"}, op_name) + require_contiguous_format(layout=kwargs.get("layout"), op_name=op_name) + + shape = args[0] + shape_iovs = [P.to_int_or_vid(d) for d in shape] + v = ( + P.to_float_or_vid(args[1]) + if has_fill_arg + else FloatOrVid.from_literal(fixed_fill) + ) + dtype = n.kwargs.get("dtype") + if dtype is None: + dtype = torch.float32 + + out = P.make_or_get_slot(n) + P.emit( + FullNode( + out=P.slot_to_tid(out), + shape=shape_iovs, + v=v, + scalar_type=torch_dtype_to_scalar_type(dtype), + ) + ) + return out + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (table-driven full op)." + return handler + + +for _targets, _op_name, _fixed_fill in _FULL_OPS: + REGISTRY.register(target=_targets)(_make_full_handler(_op_name, _fixed_fill)) + + +_FULL_LIKE_OPS: List[Tuple[List[Any], str, Optional[float]]] = [ + ([torch.ops.aten.full_like.default], "aten.full_like", None), + ([torch.ops.aten.zeros_like.default], "aten.zeros_like", 0.0), + ([torch.ops.aten.ones_like.default], "aten.ones_like", 1.0), +] + + +def _make_full_like_handler(op_name: str, fixed_fill: Optional[float]): + """Create a handler for full_like/zeros_like/ones_like: x -> FullLikeNode(x, v, dtype).""" + + has_fill_arg = fixed_fill is None + n_args = 2 if has_fill_arg else 1 + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, n_args, n_args, op_name) + kwargs = P.kwargs(n) + require_kwargs( + kwargs, + {"dtype", "layout", "device", "pin_memory", "memory_format"}, + op_name, + ) + require_contiguous_format( + layout=kwargs.get("layout"), + memory_format=kwargs.get("memory_format"), + op_name=op_name, + ) + + x = args[0] + v = ( + P.to_float_or_vid(args[1]) + if has_fill_arg + else FloatOrVid.from_literal(fixed_fill) + ) + dtype = n.kwargs.get("dtype") + + out = P.make_or_get_slot(n) + P.emit( + FullLikeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + v=v, + scalar_type=( + torch_dtype_to_scalar_type(dtype) if dtype is not None else None + ), + ) + ) + return out + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (table-driven full_like op)." + return handler + + +for _targets, _op_name, _fixed_fill in _FULL_LIKE_OPS: + REGISTRY.register(target=_targets)(_make_full_like_handler(_op_name, _fixed_fill)) + + @REGISTRY.register(target=["NOOP", torch.ops.aten._assert_scalar.default]) def _noop_handler(P: MLXProgramBuilder, n: Node) -> None: """No-op handler for nodes that don't emit any MLX instructions.""" return None +# Handler for auto_functionalized_v2 higher-order op +# This handles mutating ops that have been functionalized +@REGISTRY.register(target=[torch.ops.higher_order.auto_functionalized_v2]) +def _auto_functionalized_v2_handler(P: MLXProgramBuilder, n: Node): + """ + Handler for auto_functionalized_v2 higher-order op. + + auto_functionalized_v2 wraps mutating ops after functionalization. + It returns a tuple of (token, mutated_values...). + + This handler emits the actual lowering instructions and returns a tuple + of slots that getitem can index into. + """ + if len(n.args) < 1: + raise ValueError( + f"auto_functionalized_v2 requires at least 1 arg, got {len(n.args)}" + ) + + wrapped_op = n.args[0] + + # Unknown wrapped op - not supported + raise NotImplementedError( + f"auto_functionalized_v2 wrapping '{wrapped_op}' is not supported." + ) + + +@REGISTRY.register(target=[torch.ops.aten.linear.default]) +def _linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 3, "aten.linear") + require_kwargs(P.kwargs(n), set(), "aten.linear") + x, w = args[0], args[1] + b = args[2] if len(args) > 2 else None + out = P.make_or_get_slot(n) + + # Transpose weight: linear(x, w) = x @ w.T + _, w_t = P.make_tmp_slot() + P.emit( + TransposeNode( + x=P.slot_to_tid(w), + out=P.slot_to_tid(w_t), + perm=[1, 0], + ) + ) + + P.emit( + AddmmNode( + mat1=P.slot_to_tid(x), + mat2=P.slot_to_tid(w_t), + out=P.slot_to_tid(out), + bias=P.slot_to_tid(b) if b else None, + ) + ) + return out + + @REGISTRY.register(target=[torch.ops.aten.addmm.default]) def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: """Handle addmm: self + (mat1 @ mat2). @@ -298,3 +852,3107 @@ def _mm_handler(P: MLXProgramBuilder, n: Node) -> Slot: ) ) return out + + +@REGISTRY.register( + target=[ + torch.ops.aten.view.default, + torch.ops.aten.view_copy.default, + torch.ops.aten.reshape.default, + ] +) +def _view_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, "aten.view") + require_kwargs(P.kwargs(n), set(), "aten.view") + x, shape = args + out = P.make_or_get_slot(n) + + shape_iovs = [P.to_int_or_vid(s) for s in shape] + P.emit( + ReshapeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + shape=shape_iovs, + ) + ) + return out + + +@REGISTRY.register( + target=[ + torch.ops.aten.clone.default, + torch.ops.aten.alias.default, + torch.ops.aten.alias_copy.default, + ] +) +def _clone_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 1, 1, "aten.clone") + require_kwargs(kwargs, {"memory_format"}, "aten.clone") + require_contiguous_format( + memory_format=kwargs.get("memory_format"), + op_name="aten.clone", + ) + (x,) = args + out = P.make_or_get_slot(n) + P.emit( + ContiguousNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.copy.default]) +def _copy_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.copy - copy data from src to self. + + Schema: aten::copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor + In functionalized Edge IR, this returns a copy of src (args[1]). + """ + args = P.args(n) + require_args(args, 2, 2, "aten.copy") + require_kwargs(P.kwargs(n), {"non_blocking"}, "aten.copy") + src = args[1] + out = P.make_or_get_slot(n) + P.emit( + ContiguousNode( + x=P.slot_to_tid(src), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[exir_ops.edge.dim_order_ops._clone_dim_order.default]) +def _dim_order_clone_handler(P: MLXProgramBuilder, n: Node) -> Slot: + # dim_order_ops._clone_dim_order(Tensor self, *, bool non_blocking=False, int[]? dim_order=None) -> Tensor + # This is essentially a contiguous/clone operation for memory layout + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 1, 1, "dim_order_ops._clone_dim_order") + require_kwargs( + kwargs, {"non_blocking", "dim_order"}, "dim_order_ops._clone_dim_order" + ) + require_contiguous_format( + dim_order=kwargs.get("dim_order"), + op_name="dim_order_ops._clone_dim_order", + ) + x = args[0] + out = P.make_or_get_slot(n) + P.emit( + ContiguousNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + ) + ) + return out + + +# Handle Edge IR's dim_order_ops._to_dim_order_copy (dtype conversion) +# This is what x.to(dtype) becomes after to_edge() transformation +@REGISTRY.register(target=[exir_ops.edge.dim_order_ops._to_dim_order_copy.default]) +def _dim_order_copy_handler(P: MLXProgramBuilder, n: Node) -> Slot: + # dim_order_ops._to_dim_order_copy(Tensor self, *, ScalarType? dtype=None, ...) + # If dtype is specified, this is a dtype conversion (use AsTypeNode) + # If dtype is None/same, this is just a memory layout copy (use ContiguousNode) + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 1, 1, "dim_order_ops._to_dim_order_copy") + require_kwargs( + kwargs, + {"dtype", "device", "layout", "non_blocking", "dim_order"}, + "dim_order_ops._to_dim_order_copy", + ) + require_contiguous_format( + layout=kwargs.get("layout"), + dim_order=kwargs.get("dim_order"), + op_name="dim_order_ops._to_dim_order_copy", + ) + x = args[0] + out = P.make_or_get_slot(n) + + dtype = kwargs.get("dtype") + if dtype is not None: + # Dtype conversion + P.emit( + AsTypeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + scalar_type=torch_dtype_to_scalar_type(dtype), + ) + ) + else: + # No dtype change, just memory layout (contiguous) + P.emit( + ContiguousNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten._to_copy.default]) +def _to_copy_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten._to_copy - lower-level dtype/device conversion.""" + # aten._to_copy(Tensor self, *, ScalarType? dtype=None, ...) + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 1, 1, "aten._to_copy") + require_kwargs( + kwargs, {"dtype", "device", "layout", "memory_format"}, "aten._to_copy" + ) + require_contiguous_format( + layout=kwargs.get("layout"), + memory_format=kwargs.get("memory_format"), + op_name="aten._to_copy", + ) + x = args[0] + out = P.make_or_get_slot(n) + + dtype = kwargs.get("dtype") + if dtype is not None: + # Dtype conversion + P.emit( + AsTypeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + scalar_type=torch_dtype_to_scalar_type(dtype), + ) + ) + else: + # No dtype change, just copy (use contiguous) + P.emit( + ContiguousNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.embedding.default]) +def _embedding_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 3, "aten.embedding") + # "padding_idx", "scale_grad_by_freq", "sparse" are training only args + # and ignored + require_kwargs( + P.kwargs(n), {"padding_idx", "scale_grad_by_freq", "sparse"}, "aten.embedding" + ) + w, x = args[0], args[1] + # padding_idx (args[2] if present) is ignored - only affects gradients + out = P.make_or_get_slot(n) + P.emit( + TakeNode( + x=P.slot_to_tid(w), + index=IntOrVidOrTid.from_tid(P.slot_to_tid(x)), + out=P.slot_to_tid(out), + axis=0, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.add.Tensor, torch.ops.aten.add.Scalar]) +def _add_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.add.Tensor: a + alpha * b.""" + args = P.args(n) + require_args(args, 2, 2, "aten.add.Tensor") + require_kwargs(P.kwargs(n), {"alpha"}, "aten.add.Tensor") + a, b = args + input_meta = n.args[0].meta.get("val") + dtype = input_meta.dtype if input_meta is not None else torch.float32 + if not isinstance(b, Slot): + b = emit_lifted_constant(P, b, dtype) + alpha = P.kwargs(n).get("alpha", 1) + if alpha != 1: + alpha_slot = emit_lifted_constant(P, alpha, dtype) + _, tmp = P.make_tmp_slot() + P.emit( + MultiplyNode( + a=P.slot_to_tid(b), + b=P.slot_to_tid(alpha_slot), + out=P.slot_to_tid(tmp), + ) + ) + b = tmp + out = P.make_or_get_slot(n) + P.emit( + AddNode( + a=P.slot_to_tid(a), + b=P.slot_to_tid(b), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.div.Tensor_mode]) +def _div_tensor_mode_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.div.Tensor_mode with rounding mode.""" + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 2, 2, "aten.div.Tensor_mode") + require_kwargs(kwargs, {"rounding_mode"}, "aten.div.Tensor_mode") + out = P.make_or_get_slot(n) + a = args[0] + b = args[1] + rounding_mode = kwargs.get("rounding_mode", None) + + # Handle scalar b by creating a constant tensor + if not isinstance(b, Slot): + b = P.make_or_get_constant( + f"_scalar_{b}", torch.tensor([b], dtype=n.meta["val"].dtype) + ) + + # Handle scalar a + if not isinstance(a, Slot): + a = P.make_or_get_constant( + f"_scalar_{a}", torch.tensor([a], dtype=n.meta["val"].dtype) + ) + + if rounding_mode == "trunc": + raise NotImplementedError( + "aten.div.Tensor_mode with rounding_mode='trunc' is not supported. " + "MLX does not have a truncate operation." + ) + elif rounding_mode == "floor": + P.emit( + FloorDivideNode( + a=P.slot_to_tid(a), + b=P.slot_to_tid(b), + out=P.slot_to_tid(out), + ) + ) + else: + # rounding_mode is None - true division + P.emit( + DivideNode( + a=P.slot_to_tid(a), + b=P.slot_to_tid(b), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten._softmax.default]) +def _softmax_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle softmax: computes softmax along the specified dimension. + + aten._softmax(self, dim, half_to_float) computes: + softmax(self, axis=dim) + + The half_to_float parameter is for type conversion and is ignored for MLX. + """ + args = P.args(n) + require_args(args, 3, 3, "aten._softmax") + require_kwargs(P.kwargs(n), set(), "aten._softmax") + x, dim, _ = args[0], args[1], args[2] # half_to_float is unused for MLX + + out = P.make_or_get_slot(n) + + # Emit SoftmaxNode with the specified axis + P.emit( + SoftmaxNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + axis=dim, + precise=False, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.gelu.default]) +def _gelu_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 1, 1, "aten.gelu") + require_kwargs(kwargs, {"approximate"}, "aten.gelu") + (x,) = args + # GELU approximate mode: 'none' (default) or 'tanh' + approximate = kwargs.get("approximate", "none") + out = P.make_or_get_slot(n) + P.emit( + GeluNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + approximate=approximate, + ) + ) + return out + + +@REGISTRY.register( + target=[torch.ops.aten.permute.default, torch.ops.aten.permute_copy.default] +) +def _permute_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, "aten.permute") + require_kwargs(P.kwargs(n), set(), "aten.permute") + x, dims = args + out = P.make_or_get_slot(n) + P.emit( + TransposeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + perm=list(dims), + ) + ) + return out + + +@REGISTRY.register( + target=[torch.ops.aten.transpose.int, torch.ops.aten.transpose_copy.int] +) +def _transpose_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 3, 3, "aten.transpose") + require_kwargs(P.kwargs(n), set(), "aten.transpose") + x, dim0, dim1 = args + perm = list(range(len(n.meta["val"].shape))) + perm[dim0], perm[dim1] = perm[dim1], perm[dim0] + out = P.make_or_get_slot(n) + P.emit( + TransposeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + perm=perm, + ) + ) + return out + + +@REGISTRY.register( + target=[torch.ops.aten.slice.Tensor, torch.ops.aten.slice_copy.Tensor] +) +def _slice_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 4, 5, "aten.slice") + require_kwargs(P.kwargs(n), set(), "aten.slice") + x, dim, start, stop = args[0], args[1], args[2], args[3] + step = args[4] if len(args) > 4 else 1 + if start is None: + start = 0 + require_static_int(step, "step", "aten.slice") + assert step >= 1, f"aten.slice: step must be >= 1, got {step}" + out = P.make_or_get_slot(n) + P.emit( + SliceNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + axis=P.to_int_or_vid(dim), + start=P.to_int_or_vid(start), + stop=P.to_int_or_vid(stop), + step=step, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.narrow.default]) +def _narrow_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """ + Handle narrow(input, dim, start, length) -> slice(input, dim, start, start+length). + + This is needed for KV cache updates with dynamic positions where narrow + is preferred over slice syntax for better torch.export compatibility. + """ + args = P.args(n) + require_args(args, 4, 4, "aten.narrow") + require_kwargs(P.kwargs(n), set(), "aten.narrow") + x, dim, start, length = args + out = P.make_or_get_slot(n) + + # Convert narrow (start, length) to slice (start, end) + # The end is start + length + start_iov = P.to_int_or_vid(start) + length_iov = P.to_int_or_vid(length) + + # For stop = start + length, we need to emit an ADD_SCALAR if either is a Vid + if isinstance(start_iov, IntOrVid) and start_iov.vid is not None: + # start is a Vid, need to add at runtime + if isinstance(length_iov, IntOrVid) and length_iov.vid is not None: + # Both are Vids - emit add to compute stop + _, stop_slot = P.make_tmp_value_slot() + stop_vid = P.slot_to_vid(stop_slot) + P.emit( + AddIntNode( + a=start_iov.vid, + b=length_iov.vid, + out=stop_vid, + ) + ) + stop_iov = IntOrVid(int64=None, vid=stop_vid) + else: + # start is Vid, length is int - emit add scalar + _, stop_slot = P.make_tmp_value_slot() + stop_vid = P.slot_to_vid(stop_slot) + P.emit( + AddIntNode( + a=start_iov.vid, + b=( + length_iov.int64 + if isinstance(length_iov, IntOrVid) + else length_iov + ), + out=stop_vid, + ) + ) + stop_iov = IntOrVid(int64=None, vid=stop_vid) + elif isinstance(length_iov, IntOrVid) and length_iov.vid is not None: + # length is Vid, start is int - emit add scalar + start_val = start_iov.int64 if isinstance(start_iov, IntOrVid) else start_iov + _, stop_slot = P.make_tmp_value_slot() + stop_vid = P.slot_to_vid(stop_slot) + P.emit( + AddIntNode( + a=length_iov.vid, + b=start_val, + out=stop_vid, + ) + ) + stop_iov = IntOrVid(int64=None, vid=stop_vid) + else: + # Both are concrete ints + start_val = start_iov.int64 if isinstance(start_iov, IntOrVid) else start_iov + length_val = ( + length_iov.int64 if isinstance(length_iov, IntOrVid) else length_iov + ) + stop_iov = IntOrVid(int64=start_val + length_val, vid=None) + + P.emit( + SliceNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + axis=P.to_int_or_vid(dim), + start=start_iov, + stop=stop_iov, + ) + ) + return out + + +@REGISTRY.register( + target=[torch.ops.aten.unsqueeze.default, torch.ops.aten.unsqueeze_copy.default] +) +def _unsqueeze_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, "aten.unsqueeze") + require_kwargs(P.kwargs(n), set(), "aten.unsqueeze") + x, dim = args + out = P.make_or_get_slot(n) + P.emit( + ExpandDimsNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + axis=dim, + ) + ) + return out + + +@REGISTRY.register( + target=[torch.ops.aten.squeeze.dims, torch.ops.aten.squeeze_copy.dims] +) +def _squeeze_dims_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle squeeze operation for specific dimensions. + + Removes dimensions of size 1 from the tensor at specified positions. + """ + args = P.args(n) + require_args(args, 2, 2, "aten.squeeze.dims") + require_kwargs(P.kwargs(n), set(), "aten.squeeze.dims") + x, dims = args + out = P.make_or_get_slot(n) + + dims_list = list(dims) if dims is not None else None + + P.emit( + SqueezeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + dims=dims_list, + ) + ) + return out + + +@REGISTRY.register( + target=[torch.ops.aten.squeeze.default, torch.ops.aten.squeeze_copy.default] +) +def _squeeze_default_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle squeeze operation without specified dimensions. + + Removes all dimensions of size 1 from the tensor. + """ + args = P.args(n) + require_args(args, 1, 1, "aten.squeeze.default") + require_kwargs(P.kwargs(n), set(), "aten.squeeze.default") + (x,) = args + out = P.make_or_get_slot(n) + + P.emit( + SqueezeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + dims=None, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.cat.default]) +def _cat_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle concatenation of a list of tensors. + + Concatenates tensors along a specified dimension. + All tensors must have the same shape except in the concatenating dimension. + """ + args = P.args(n) + require_args(args, 1, 2, "aten.cat") + require_kwargs(P.kwargs(n), set(), "aten.cat") + # aten.cat.default signature: cat(Tensor[] tensors, int dim=0) -> Tensor + # args can be (tensors_list,) or (tensors_list, dim) + tensors_list = args[0] + dim = args[1] if len(args) > 1 else 0 + + out = P.make_or_get_slot(n) + + # Convert list of tensor slots to list of Tids + tensor_tids = [P.slot_to_tid(t) for t in tensors_list] + + # dim is typically an int + axis = dim if dim is not None else 0 + + P.emit( + ConcatenateNode( + tensors=tensor_tids, + out=P.slot_to_tid(out), + axis=axis, + ) + ) + return out + + +@REGISTRY.register( + target=[ + torch.ops.aten.split_with_sizes.default, + torch.ops.aten.split_with_sizes_copy.default, + ] +) +def _split_with_sizes_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle split_with_sizes operation. + + Splits a tensor into chunks with specified sizes along a dimension. + Returns a tuple of output slots that getitem can extract from. + + PyTorch: split_with_sizes(x, [2, 3, 4], dim=1) + MLX: split(x, indices=[2, 5], axis=1) # indices are cumulative positions + """ + args = P.args(n) + require_args(args, 2, 3, "aten.split_with_sizes") + require_kwargs(P.kwargs(n), set(), "aten.split_with_sizes") + x = args[0] + sizes = args[1] + dim = args[2] if len(args) > 2 else 0 # dim has default value of 0 + + # Convert sizes to IntOrVid (supports both static ints and dynamic Vids) + sizes_int_or_vid = [P.to_int_or_vid(s) for s in sizes] + + axis = dim if dim is not None else 0 + + # Create output slots for multi-output operation + # make_or_get_slots automatically creates slots based on node.meta["val"] + output_slots = P.make_or_get_slots(n) + + # Emit SplitNode with all output slots + P.emit( + SplitNode( + x=P.slot_to_tid(x), + outs=[P.slot_to_tid(s) for s in output_slots], + sizes=sizes_int_or_vid, + axis=axis, + ) + ) + + # Return tuple of slots - getitem will extract individual elements + return output_slots + + +@REGISTRY.register( + target=[torch.ops.aten.split.Tensor, torch.ops.aten.split_copy.Tensor] +) +def _split_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle split operation with uniform chunk size. + + Splits a tensor into chunks of a given size along a dimension. + The last chunk may be smaller if the dimension does not divide evenly. + + PyTorch: split(x, split_size, dim=0) + + We pass [split_size] to the interpreter, which computes the actual + chunk sizes based on the tensor dimension. + """ + args = P.args(n) + require_args(args, 2, 3, "aten.split") + require_kwargs(P.kwargs(n), set(), "aten.split") + x = args[0] + split_size = args[1] + dim = args[2] if len(args) > 2 else 0 + + axis = dim if dim is not None else 0 + if axis < 0: + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise RuntimeError("split: missing tensor metadata for negative axis") + axis += len(x_meta.shape) + + # Create output slots for multi-output operation + output_slots = P.make_or_get_slots(n) + + # Emit SplitNode - interpreter computes actual chunk sizes from split_size + P.emit( + SplitNode( + x=P.slot_to_tid(x), + outs=[P.slot_to_tid(s) for s in output_slots], + sizes=[P.to_int_or_vid(split_size)], + axis=axis, + ) + ) + + return output_slots + + +@REGISTRY.register(target=[torch.ops.aten.repeat.default]) +def _repeat_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, "aten.repeat") + require_kwargs(P.kwargs(n), set(), "aten.repeat") + x, reps = args + + # Convert reps to IntOrVid (supports both static ints and dynamic Vids) + reps_int_or_vid = [P.to_int_or_vid(r) for r in reps] + + out = P.make_or_get_slot(n) + P.emit( + TileNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + reps=reps_int_or_vid, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.index.Tensor]) +def _index_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, "aten.index.Tensor") + require_kwargs(P.kwargs(n), set(), "aten.index.Tensor") + x, idx_list = args + if not isinstance(idx_list, list) or len(idx_list) == 0: + raise ValueError( + f"aten.index.Tensor requires a list of index tensors, " + f"got {type(idx_list)}" + ) + + x_meta = n.args[0].meta.get("val") + x_ndim = len(x_meta.shape) if x_meta is not None else None + + # Filter out None indices and track which axes they correspond to + non_none = [(i, idx) for i, idx in enumerate(idx_list) if idx is not None] + + if len(non_none) == 0: + raise ValueError("aten.index.Tensor: all indices are None") + + if len(non_none) == 1: + axis, idx = non_none[0] + idx_meta = n.args[1][axis].meta.get("val") + ndim_match = ( + x_meta is not None + and idx_meta is not None + and len(x_meta.shape) == len(idx_meta.shape) + ) + out = P.make_or_get_slot(n) + if ndim_match: + # Same ndim: use TakeAlongAxisNode (element-wise gather) + P.emit( + TakeAlongAxisNode( + x=P.slot_to_tid(x), + indices=P.slot_to_tid(idx), + out=P.slot_to_tid(out), + axis=axis, + ) + ) + else: + # Different ndim (e.g. 1D indices into 3D tensor): use TakeNode + P.emit( + TakeNode( + x=P.slot_to_tid(x), + index=IntOrVidOrTid.from_tid(P.slot_to_tid(idx)), + out=P.slot_to_tid(out), + axis=axis, + ) + ) + return out + + # Multi-index: use GatherNode (maps to mlx::gather) + if x_meta is None or x_ndim is None: + raise ValueError( + "aten.index.Tensor with multiple indices requires input shape metadata" + ) + + indices = [P.slot_to_tid(idx) for _, idx in non_none] + axes = [i for i, _ in non_none] + + # slice_sizes: 1 for indexed axes, full dim size for non-indexed axes + # Use int() to handle SymInt values from dynamic shapes + indexed_axes = set(axes) + slice_sizes = [] + for dim in range(x_ndim): + if dim in indexed_axes: + slice_sizes.append(1) + else: + dim_size = x_meta.shape[dim] + if not isinstance(dim_size, int): + raise ValueError( + f"aten.index.Tensor: non-indexed dimension {dim} has dynamic size " + f"{dim_size}, which is not supported with multi-index gather" + ) + slice_sizes.append(dim_size) + + # Emit gather — output shape is broadcast(indices).shape + slice_sizes + _, gather_slot = P.make_tmp_slot() + P.emit( + GatherNode( + x=P.slot_to_tid(x), + indices=indices, + out=P.slot_to_tid(gather_slot), + axes=axes, + slice_sizes=slice_sizes, + ) + ) + + # Reshape to match aten.index.Tensor output shape, which strips the + # trailing dimensions introduced by gather's slice_sizes + out_meta = n.meta.get("val") + if out_meta is None: + raise ValueError( + "aten.index.Tensor: output shape metadata required for reshape after gather" + ) + out_shape = [P.to_int_or_vid(int(d)) for d in out_meta.shape] + + out = P.make_or_get_slot(n) + P.emit( + ReshapeNode( + x=P.slot_to_tid(gather_slot), + out=P.slot_to_tid(out), + shape=out_shape, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.index_select.default]) +def _index_select_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.index_select: select elements along an axis using a 1D index tensor. + + index_select(input, dim, index) returns input.take(index, axis=dim). + Unlike select (which takes a scalar index and removes the dim), + index_select takes a tensor of indices and preserves the dim. + """ + args = P.args(n) + require_args(args, 3, 3, "aten.index_select") + require_kwargs(P.kwargs(n), set(), "aten.index_select") + x, dim, indices = args + out = P.make_or_get_slot(n) + P.emit( + TakeNode( + x=P.slot_to_tid(x), + index=IntOrVidOrTid.from_tid(P.slot_to_tid(indices)), + out=P.slot_to_tid(out), + axis=dim, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.slice_scatter.default]) +def _slice_scatter_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.slice_scatter: return a copy of self with self[dim][start:end:step] = src.""" + args = P.args(n) + require_args(args, 2, 6, "aten.slice_scatter") + require_kwargs(P.kwargs(n), set(), "aten.slice_scatter") + self_tensor = args[0] + src = args[1] + dim = args[2] if len(args) > 2 else 0 + start = args[3] if len(args) > 3 else 0 + end = args[4] if len(args) > 4 else None + step = args[5] if len(args) > 5 else 1 + + # If end is None, default to dim size + if end is None: + input_meta = n.args[0].meta.get("val") + if input_meta is not None: + end = input_meta.shape[dim] + else: + raise ValueError( + "aten.slice_scatter: end=None requires input shape metadata" + ) + + require_static_int(step, "step", "aten.slice_scatter") + assert step >= 1, f"aten.slice_scatter: step must be >= 1, got {step}" + + out = P.make_or_get_slot(n) + P.emit( + SliceUpdateNode( + dst=P.slot_to_tid(self_tensor), + update=P.slot_to_tid(src), + out=P.slot_to_tid(out), + axis=P.to_int_or_vid(dim), + start=P.to_int_or_vid(start), + stop=P.to_int_or_vid(end), + step=step, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.select.int, torch.ops.aten.select_copy.int]) +def _select_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """ + Handle aten.select_copy.int - select a single index along a dimension. + + select_copy(input, dim, index) returns input[..., index, ...] where the + indexing happens at dimension `dim`. The selected dimension is removed. + + Maps to MLX's take(array, int index, axis) which also removes the dimension. + """ + args = P.args(n) + require_args(args, 3, 3, "aten.select_copy.int") + require_kwargs(P.kwargs(n), set(), "aten.select_copy.int") + x, dim, index = args + out = P.make_or_get_slot(n) + P.emit( + TakeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + index=P.to_int_or_vid_or_tid(index), + axis=dim, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.sym_size.int]) +def _sym_size_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 2, "aten.sym_size.int") + require_kwargs(P.kwargs(n), set(), "aten.sym_size.int") + a, dim = args + out = P.make_or_get_slot(n) + P.emit( + SymSizeNode( + a=P.slot_to_tid(a), + dim=dim, + out=P.slot_to_vid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.item.default]) +def _item_handler(P: MLXProgramBuilder, n: Node) -> Slot: + if not isinstance(n.meta["val"], torch.SymInt): + raise ValueError("item only supported if it returns a SymInt") + args = P.args(n) + require_args(args, 1, 1, "aten.item") + require_kwargs(P.kwargs(n), set(), "aten.item") + (x,) = args + out = P.make_or_get_slot(n) + P.emit( + ItemIntNode( + x=P.slot_to_tid(x), + out=P.slot_to_vid(out), + ) + ) + return out + + +@REGISTRY.register(target=[operator.getitem]) +def _getitem_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """ + Handle getitem(tuple, idx) - extracts element from a tuple of slots. + + The source tuple comes from ops that return multiple values (like + auto_functionalized_v2). Those handlers return tuples of slots, + and we just ID_COPY the selected element to a new output slot. + """ + args = P.args(n) + require_args(args, 2, 2, "operator.getitem") + require_kwargs(P.kwargs(n), set(), "operator.getitem") + a, idx = args + out = P.make_or_get_slot(n) + P.emit( + IdCopyNode( + x=P.slot_to_tid(a[idx]), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.layer_norm.default]) +def _layer_norm_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 5, "aten.layer_norm") + require_kwargs(P.kwargs(n), set(), "aten.layer_norm") + x, shape = args[0:2] + if len(shape) > 1: + raise ValueError( + "LayerNorm is only supported when normalizing over the last dimension" + ) + w = args[2] if len(args) > 2 else None + bias = args[3] if len(args) > 3 else None + eps = args[4] if len(args) > 4 else 1e-5 + + out = P.make_or_get_slot(n) + P.emit( + LayerNormNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + weight=P.slot_to_tid(w) if w else None, + bias=P.slot_to_tid(bias) if bias else None, + eps=eps, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.native_layer_norm.default]) +def _native_layer_norm_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle native_layer_norm which returns (output, mean, rstd). + + Only the normalized output (index 0) is computed via fast::layer_norm; + mean and rstd (indices 1 and 2) are needed only for backward. + """ + # Verify mean/rstd outputs are unused — we only compute the normalized output. + unsupported = used_getitem_indices(n) & {1, 2} + if unsupported: + raise ValueError( + f"native_layer_norm outputs {unsupported} (mean/rstd) are used, " + "but only the normalized output (index 0) is supported" + ) + + args = P.args(n) + require_args(args, 2, 5, "aten.native_layer_norm") + require_kwargs(P.kwargs(n), set(), "aten.native_layer_norm") + x, shape = args[0:2] + if len(shape) > 1: + raise ValueError( + "LayerNorm is only supported when normalizing over the last dimension" + ) + w = args[2] if len(args) > 2 else None + bias = args[3] if len(args) > 3 else None + eps = args[4] if len(args) > 4 else 1e-5 + + # native_layer_norm returns (output, mean, rstd) — allocate all 3 slots + output_slots = P.make_or_get_slots(n) + + P.emit( + LayerNormNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(output_slots[0]), + weight=P.slot_to_tid(w) if w else None, + bias=P.slot_to_tid(bias) if bias else None, + eps=eps, + ) + ) + return output_slots + + +@REGISTRY.register(target=[torch.ops.aten.arange.default]) +def _arange_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle arange with just stop, or (start, stop) or (start, stop, step). + + Supports both static (literal int) and dynamic (Slot from item()) values. + """ + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 1, 3, "aten.arange") + require_kwargs(kwargs, {"dtype", "layout", "device", "pin_memory"}, "aten.arange") + require_contiguous_format( + layout=kwargs.get("layout"), + op_name="aten.arange", + ) + if len(args) == 1: + start = 0 + stop = args[0] + else: + start, stop = args[0:2] + step = args[2] if len(args) > 2 else 1 + + # arange defaults to int64 when dtype is not specified (like torch.arange) + dtype = kwargs.get("dtype", torch.int64) + scalar_type_val = torch_dtype_to_scalar_type(dtype) + + out = P.make_or_get_slot(n) + P.emit( + ARangeNode( + out=P.slot_to_tid(out), + start=P.to_int_or_vid(start), + stop=P.to_int_or_vid(stop), + step=P.to_int_or_vid(step), + scalar_type=scalar_type_val, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.arange.start_step]) +def _arange_start_step_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle arange with start, end, and step arguments. + + Supports both static (literal int) and dynamic (Slot from item()) start/stop/step. + """ + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 2, 3, "aten.arange.start_step") + require_kwargs( + kwargs, {"dtype", "layout", "device", "pin_memory"}, "aten.arange.start_step" + ) + require_contiguous_format( + layout=kwargs.get("layout"), + op_name="aten.arange.start_step", + ) + start = args[0] + stop = args[1] + step = args[2] if len(args) > 2 else 1 + + # arange defaults to int64 when dtype is not specified (like torch.arange) + dtype = kwargs.get("dtype", torch.int64) + scalar_type_val = torch_dtype_to_scalar_type(dtype) + + out = P.make_or_get_slot(n) + P.emit( + ARangeNode( + out=P.slot_to_tid(out), + start=P.to_int_or_vid(start), + stop=P.to_int_or_vid(stop), + step=P.to_int_or_vid(step), + scalar_type=scalar_type_val, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.rms_norm.default]) +def _aten_rms_norm_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 2, 4, "aten.rms_norm") + require_kwargs(P.kwargs(n), set(), "aten.rms_norm") + x, normalized_shape = args[0], args[1] + if len(normalized_shape) > 1: + raise ValueError( + "RMSNorm is only supported when normalizing over the last dimension" + ) + w = args[2] if len(args) > 2 else None + eps = args[3] if len(args) > 3 else 1e-5 + out = P.make_or_get_slot(n) + P.emit( + RMSNormNode( + x=P.slot_to_tid(x), + weight=P.slot_to_tid(w) if w else None, + out=P.slot_to_tid(out), + eps=eps, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.mlx.rope.default]) +def _rope_handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + require_args(args, 3, 7, "mlx.rope") + require_kwargs(P.kwargs(n), set(), "mlx.rope") + x, dims, pos = args[0], args[1], args[2] + traditional = args[3] if len(args) > 3 else False + base = args[4] if len(args) > 4 else 500000.0 + scale = args[5] if len(args) > 5 else 1.0 + freqs = args[6] if len(args) > 6 else None + out = P.make_or_get_slot(n) + + # pos must be a Slot (SymInt) from input_pos.item() during tracing + # The schema supports both Vid (scalar) and Tid (tensor) for offset + if not isinstance(pos, Slot): + raise ValueError( + f"RopeNode.offset must be a SymInt (traced via tensor.item()), got {type(pos)}. " + "Make sure input_pos is a tensor and you call input_pos.item() to get a SymInt." + ) + + P.emit( + RopeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + dims=dims, + offset=VidOrTid.from_vid(P.slot_to_vid(pos)), + freqs=P.slot_to_tid(freqs) if freqs else None, + traditional=traditional, + base=base, + scale=scale, + ) + ) + + return out + + +def _emit_channel_last_weight(P: MLXProgramBuilder, w_node: Node, perm: list) -> Slot: + """Get convolution weight in channel-last format. + + If the weight is a placeholder (static parameter), permute at compile time + and store as a constant. If it comes from another node (e.g. dequantize + output), emit a runtime TransposeNode instead. + """ + if w_node.op == "placeholder": + w_target, w_tensor = P.get_placeholder_target_and_tensor(w_node) + return P.make_or_get_constant( + f"{w_target}_channel_last", w_tensor.permute(perm).contiguous() + ) + else: + w_slot = P.slot_map([w_node])[0] + _, w = P.make_tmp_slot() + P.emit( + TransposeNode( + x=P.slot_to_tid(w_slot), + out=P.slot_to_tid(w), + perm=perm, + ) + ) + return w + + +def _emit_conv_transpose_weight( + P: MLXProgramBuilder, w_node: Node, groups: int, ndim: int +) -> Slot: + """Get conv_transpose weight in MLX format, handling grouped convolutions. + + PyTorch conv_transpose weight shape: [C_in, C_out/G, *K] + MLX expects: [C_out, *K, C_in/G] + + For groups=1, a simple permute suffices (C_in==C_in/G, C_out/G==C_out). + For groups>1, we need reshape-permute-reshape to rearrange the group dim: + [C_in, C_out/G, *K] -> [G, C_in/G, C_out/G, *K] + -> [G, C_out/G, *K, C_in/G] + -> [C_out, *K, C_in/G] + """ + if groups == 1: + # Simple permute: [C_in, C_out, *K] -> [C_out, *K, C_in] + # e.g. 1D: [1, 2, 0], 2D: [1, 2, 3, 0], 3D: [1, 2, 3, 4, 0] + perm = list(range(1, ndim + 2)) + [0] + return _emit_channel_last_weight(P, w_node, perm) + + # Grouped: need reshape-permute-reshape at compile time + if w_node.op != "placeholder": + raise ValueError( + f"conv_transpose with groups > 1 requires static weights, " + f"got dynamic weight from {w_node.op}" + ) + + w_target, w_tensor = P.get_placeholder_target_and_tensor(w_node) + c_in = w_tensor.shape[0] + c_out_per_g = w_tensor.shape[1] + kernel_shape = list(w_tensor.shape[2:]) + c_in_per_g = c_in // groups + + # [C_in, C_out/G, *K] -> [G, C_in/G, C_out/G, *K] + w = w_tensor.reshape([groups, c_in_per_g, c_out_per_g] + kernel_shape) + # [G, C_in/G, C_out/G, *K] -> [G, C_out/G, *K, C_in/G] + # perm: [0, 2, 3, ..., ndim+1, 1] + perm = [0, 2] + list(range(3, ndim + 3)) + [1] + w = w.permute(perm).contiguous() + # [G, C_out/G, *K, C_in/G] -> [C_out, *K, C_in/G] + c_out = groups * c_out_per_g + w = w.reshape([c_out] + kernel_shape + [c_in_per_g]) + + return P.make_or_get_constant(f"{w_target}_channel_last", w) + + +def _emit_conv_bias( + P: MLXProgramBuilder, bias: Optional[Slot], tmp: Slot, ndim: int +) -> None: + """Reshape conv bias to channel-last broadcast shape and add to tmp in-place. + + After the convolution the activation is in channel-last layout, so the bias + (shape ``[C_out]``) must be reshaped to ``[1, …, 1, -1]`` with *ndim* + dimensions before being added. Does nothing when *bias* is ``None``. + """ + if bias is None: + return + _, tmp2 = P.make_tmp_slot() + shape = [IntOrVid.from_literal(1)] * (ndim - 1) + [IntOrVid.from_literal(-1)] + P.emit( + ReshapeNode( + x=P.slot_to_tid(bias), + out=P.slot_to_tid(tmp2), + shape=shape, + ) + ) + P.emit( + AddNode( + a=P.slot_to_tid(tmp), + b=P.slot_to_tid(tmp2), + out=P.slot_to_tid(tmp), + ) + ) + + +def _emit_conv( + P: MLXProgramBuilder, + n: Node, + x_node: Node, + w_node: Node, + bias_node, + stride: list, + padding: list, + dilation: list, + groups: int, + ndim: int, +) -> Slot: + """Shared logic for regular convolution emission. + + Handles weight transform, input/output transposition, bias, and node emission + for all spatial dimensions (1D, 2D, 3D). + + Weight: [C_out, C_in/G, *K] -> [C_out, *K, C_in/G] + Input: (N, C, *spatial) -> (N, *spatial, C) + Output: (N, *spatial, C) -> (N, C, *spatial) + """ + if ndim == 3 and groups != 1: + raise ValueError( + "conv3d with groups != 1 is not supported by MLX. " f"Got groups={groups}." + ) + + # Permutation: channels-first [N, C, *spatial] <-> channels-last [N, *spatial, C] + ch_first_to_last = [0] + list(range(2, ndim + 2)) + [1] + ch_last_to_first = [0, ndim + 1] + list(range(1, ndim + 1)) + + # Weight: [C_out, C_in/G, *K] -> [C_out, *K, C_in/G] (same permutation) + w = _emit_channel_last_weight(P, w_node, ch_first_to_last) + + x, bias = P.slot_map([x_node, bias_node]) + + _, tmp = P.make_tmp_slot() + P.emit( + TransposeNode(x=P.slot_to_tid(x), out=P.slot_to_tid(tmp), perm=ch_first_to_last) + ) + + if ndim == 1: + P.emit( + Conv1DNode( + x=P.slot_to_tid(tmp), + w=P.slot_to_tid(w), + out=P.slot_to_tid(tmp), + stride=stride[0], + padding=padding[0], + dilation=dilation[0], + groups=groups, + ) + ) + elif ndim == 2: + P.emit( + Conv2DNode( + x=P.slot_to_tid(tmp), + w=P.slot_to_tid(w), + out=P.slot_to_tid(tmp), + stride_h=stride[0], + stride_w=stride[1], + padding_h=padding[0], + padding_w=padding[1], + dilation_h=dilation[0], + dilation_w=dilation[1], + groups=groups, + ) + ) + elif ndim == 3: + P.emit( + Conv3DNode( + x=P.slot_to_tid(tmp), + w=P.slot_to_tid(w), + out=P.slot_to_tid(tmp), + stride_d=stride[0], + stride_h=stride[1], + stride_w=stride[2], + padding_d=padding[0], + padding_h=padding[1], + padding_w=padding[2], + dilation_d=dilation[0], + dilation_h=dilation[1], + dilation_w=dilation[2], + groups=groups, + ) + ) + + _emit_conv_bias(P, bias, tmp, ndim=ndim + 2) + + out = P.make_or_get_slot(n) + P.emit( + TransposeNode( + x=P.slot_to_tid(tmp), out=P.slot_to_tid(out), perm=ch_last_to_first + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.conv1d.default]) +def _conv1d_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.conv1d: (input, weight, bias, stride, padding, dilation, groups).""" + require_args(n.args, 2, 7, "aten.conv1d") + require_kwargs(P.kwargs(n), set(), "aten.conv1d") + x_node, w_node = n.args[0:2] + bias_node = n.args[2] if len(n.args) > 2 else None + groups = n.args[6] if len(n.args) > 6 else 1 + stride = _normalize_conv_param(n.args[3] if len(n.args) > 3 else 1, 1, 1) + padding = _normalize_conv_param(n.args[4] if len(n.args) > 4 else 0, 1, 0) + dilation = _normalize_conv_param(n.args[5] if len(n.args) > 5 else 1, 1, 1) + return _emit_conv( + P, n, x_node, w_node, bias_node, stride, padding, dilation, groups, ndim=1 + ) + + +@REGISTRY.register(target=[torch.ops.aten.conv2d.default]) +def _conv2d_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.conv2d: (input, weight, bias, stride, padding, dilation, groups).""" + require_args(n.args, 2, 7, "aten.conv2d") + require_kwargs(P.kwargs(n), set(), "aten.conv2d") + x_node, w_node = n.args[0:2] + bias_node = n.args[2] if len(n.args) > 2 else None + groups = n.args[6] if len(n.args) > 6 else 1 + stride = _normalize_conv_param(n.args[3] if len(n.args) > 3 else [1, 1], 2, 1) + padding = _normalize_conv_param(n.args[4] if len(n.args) > 4 else [0, 0], 2, 0) + dilation = _normalize_conv_param(n.args[5] if len(n.args) > 5 else [1, 1], 2, 1) + return _emit_conv( + P, n, x_node, w_node, bias_node, stride, padding, dilation, groups, ndim=2 + ) + + +@REGISTRY.register(target=[torch.ops.aten.conv3d.default]) +def _conv3d_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.conv3d: (input, weight, bias, stride, padding, dilation, groups).""" + require_args(n.args, 2, 7, "aten.conv3d") + require_kwargs(P.kwargs(n), set(), "aten.conv3d") + x_node, w_node = n.args[0:2] + bias_node = n.args[2] if len(n.args) > 2 else None + groups = n.args[6] if len(n.args) > 6 else 1 + stride = _normalize_conv_param(n.args[3] if len(n.args) > 3 else [1, 1, 1], 3, 1) + padding = _normalize_conv_param(n.args[4] if len(n.args) > 4 else [0, 0, 0], 3, 0) + dilation = _normalize_conv_param(n.args[5] if len(n.args) > 5 else [1, 1, 1], 3, 1) + return _emit_conv( + P, n, x_node, w_node, bias_node, stride, padding, dilation, groups, ndim=3 + ) + + +def _emit_conv_transpose( + P: MLXProgramBuilder, + n: Node, + x_node: Node, + w_node: Node, + bias_node, + stride: list, + padding: list, + dilation: list, + output_padding: list, + groups: int, + ndim: int, +) -> Slot: + """Shared logic for transposed convolution emission. + + Handles weight transform, input/output transposition, bias, and node emission + for all spatial dimensions. Called by both the specific conv_transpose handlers + and the unified aten.convolution.default handler. + """ + if ndim == 3 and groups != 1: + raise ValueError( + "conv_transpose with groups != 1 is not supported for 3D by MLX" + ) + + w = _emit_conv_transpose_weight(P, w_node, groups, ndim=ndim) + x, bias = P.slot_map([x_node, bias_node]) + + # Transpose input: channels-first -> channels-last + ch_first_to_last = list(range(ndim + 2)) + ch_first_to_last = [0] + list(range(2, ndim + 2)) + [1] + ch_last_to_first = [0, ndim + 1] + list(range(1, ndim + 1)) + + _, tmp = P.make_tmp_slot() + P.emit( + TransposeNode(x=P.slot_to_tid(x), out=P.slot_to_tid(tmp), perm=ch_first_to_last) + ) + + if ndim == 1: + P.emit( + ConvTranspose1DNode( + x=P.slot_to_tid(tmp), + w=P.slot_to_tid(w), + out=P.slot_to_tid(tmp), + stride=stride[0], + padding=padding[0], + dilation=dilation[0], + output_padding=output_padding[0], + groups=groups, + ) + ) + elif ndim == 2: + P.emit( + ConvTranspose2DNode( + x=P.slot_to_tid(tmp), + w=P.slot_to_tid(w), + out=P.slot_to_tid(tmp), + stride_h=stride[0], + stride_w=stride[1], + padding_h=padding[0], + padding_w=padding[1], + dilation_h=dilation[0], + dilation_w=dilation[1], + output_padding_h=output_padding[0], + output_padding_w=output_padding[1], + groups=groups, + ) + ) + elif ndim == 3: + P.emit( + ConvTranspose3DNode( + x=P.slot_to_tid(tmp), + w=P.slot_to_tid(w), + out=P.slot_to_tid(tmp), + stride_d=stride[0], + stride_h=stride[1], + stride_w=stride[2], + padding_d=padding[0], + padding_h=padding[1], + padding_w=padding[2], + dilation_d=dilation[0], + dilation_h=dilation[1], + dilation_w=dilation[2], + output_padding_d=output_padding[0], + output_padding_h=output_padding[1], + output_padding_w=output_padding[2], + groups=groups, + ) + ) + + _emit_conv_bias(P, bias, tmp, ndim=ndim + 2) + + out = P.make_or_get_slot(n) + P.emit( + TransposeNode( + x=P.slot_to_tid(tmp), out=P.slot_to_tid(out), perm=ch_last_to_first + ) + ) + return out + + +def _normalize_conv_param(val, ndim, default=0): + """Normalize a conv parameter (stride/padding/etc.) to a list of length ndim.""" + if isinstance(val, int): + return [val] * ndim + if isinstance(val, list): + if len(val) == 1: + return val * ndim + return val + return [default] * ndim + + +@REGISTRY.register(target=[torch.ops.aten.convolution.default]) +def _convolution_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.convolution.default — the unified convolution op. + + Args layout: convolution(input, weight, bias, stride, padding, dilation, + transposed, output_padding, groups) + + This op appears when PyTorch doesn't decompose to specific conv ops + (e.g. grouped conv_transpose). + """ + raw_args = n.args + x_node, w_node = raw_args[0], raw_args[1] + bias_node = raw_args[2] if len(raw_args) > 2 else None + transposed = raw_args[6] if len(raw_args) > 6 else False + groups = raw_args[8] if len(raw_args) > 8 else 1 + + if not transposed: + raise ValueError( + "aten.convolution with transposed=False: use aten.conv{1,2,3}d instead" + ) + + x_meta = x_node.meta.get("val") + if x_meta is None: + raise ValueError("aten.convolution: input shape metadata required") + ndim = len(x_meta.shape) - 2 + + stride = _normalize_conv_param(raw_args[3] if len(raw_args) > 3 else 1, ndim, 1) + padding = _normalize_conv_param(raw_args[4] if len(raw_args) > 4 else 0, ndim, 0) + dilation = _normalize_conv_param(raw_args[5] if len(raw_args) > 5 else 1, ndim, 1) + output_padding = _normalize_conv_param( + raw_args[7] if len(raw_args) > 7 else 0, ndim, 0 + ) + + return _emit_conv_transpose( + P, + n, + x_node, + w_node, + bias_node, + stride, + padding, + dilation, + output_padding, + groups, + ndim, + ) + + +@REGISTRY.register(target=[torch.ops.aten.conv_transpose1d.default]) +def _conv_transpose1d_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.conv_transpose1d: (input, weight, bias, stride, padding, output_padding, groups, dilation).""" + require_args(n.args, 2, 8, "aten.conv_transpose1d") + require_kwargs(P.kwargs(n), set(), "aten.conv_transpose1d") + x_node, w_node = n.args[0:2] + bias_node = n.args[2] if len(n.args) > 2 else None + groups = n.args[6] if len(n.args) > 6 else 1 + + stride = _normalize_conv_param(n.args[3] if len(n.args) > 3 else 1, 1, 1) + padding = _normalize_conv_param(n.args[4] if len(n.args) > 4 else 0, 1, 0) + output_padding = _normalize_conv_param(n.args[5] if len(n.args) > 5 else 0, 1, 0) + dilation = _normalize_conv_param(n.args[7] if len(n.args) > 7 else 1, 1, 1) + + return _emit_conv_transpose( + P, + n, + x_node, + w_node, + bias_node, + stride, + padding, + dilation, + output_padding, + groups, + ndim=1, + ) + + +@REGISTRY.register(target=[torch.ops.aten.conv_transpose2d.input]) +def _conv_transpose2d_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.conv_transpose2d: (input, weight, bias, stride, padding, output_padding, groups, dilation).""" + require_args(n.args, 2, 8, "aten.conv_transpose2d") + require_kwargs(P.kwargs(n), set(), "aten.conv_transpose2d") + x_node, w_node = n.args[0:2] + bias_node = n.args[2] if len(n.args) > 2 else None + groups = n.args[6] if len(n.args) > 6 else 1 + + stride = _normalize_conv_param(n.args[3] if len(n.args) > 3 else [1, 1], 2, 1) + padding = _normalize_conv_param(n.args[4] if len(n.args) > 4 else [0, 0], 2, 0) + output_padding = _normalize_conv_param( + n.args[5] if len(n.args) > 5 else [0, 0], 2, 0 + ) + dilation = _normalize_conv_param(n.args[7] if len(n.args) > 7 else [1, 1], 2, 1) + + return _emit_conv_transpose( + P, + n, + x_node, + w_node, + bias_node, + stride, + padding, + dilation, + output_padding, + groups, + ndim=2, + ) + + +@REGISTRY.register(target=[torch.ops.aten.conv_transpose3d.input]) +def _conv_transpose3d_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.conv_transpose3d: (input, weight, bias, stride, padding, output_padding, groups, dilation).""" + require_args(n.args, 2, 8, "aten.conv_transpose3d") + require_kwargs(P.kwargs(n), set(), "aten.conv_transpose3d") + x_node, w_node = n.args[0:2] + bias_node = n.args[2] if len(n.args) > 2 else None + groups = n.args[6] if len(n.args) > 6 else 1 + + stride = _normalize_conv_param(n.args[3] if len(n.args) > 3 else [1, 1, 1], 3, 1) + padding = _normalize_conv_param(n.args[4] if len(n.args) > 4 else [0, 0, 0], 3, 0) + output_padding = _normalize_conv_param( + n.args[5] if len(n.args) > 5 else [0, 0, 0], 3, 0 + ) + dilation = _normalize_conv_param(n.args[7] if len(n.args) > 7 else [1, 1, 1], 3, 1) + + return _emit_conv_transpose( + P, + n, + x_node, + w_node, + bias_node, + stride, + padding, + dilation, + output_padding, + groups, + ndim=3, + ) + + +@REGISTRY.register(target=[torch.ops.aten.sub.Tensor, torch.ops.aten.sub.Scalar]) +def _sub_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.sub.Tensor: a - alpha * b.""" + args = P.args(n) + require_args(args, 2, 2, "aten.sub.Tensor") + require_kwargs(P.kwargs(n), {"alpha"}, "aten.sub.Tensor") + a, b = args + input_meta = n.args[0].meta.get("val") + dtype = input_meta.dtype if input_meta is not None else torch.float32 + if not isinstance(b, Slot): + b = emit_lifted_constant(P, b, dtype) + alpha = P.kwargs(n).get("alpha", 1) + if alpha != 1: + alpha_slot = emit_lifted_constant(P, alpha, dtype) + _, tmp = P.make_tmp_slot() + P.emit( + MultiplyNode( + a=P.slot_to_tid(b), + b=P.slot_to_tid(alpha_slot), + out=P.slot_to_tid(tmp), + ) + ) + b = tmp + out = P.make_or_get_slot(n) + P.emit( + SubtractNode( + a=P.slot_to_tid(a), + b=P.slot_to_tid(b), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.relu.default]) +def _relu_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.relu.default - rectified linear unit. + + ReLU(x) = max(x, 0), implemented using MaximumNode with a scalar zero. + Uses broadcasting in maximum operation for efficiency. + """ + args = P.args(n) + require_args(args, 1, 1, "aten.relu") + require_kwargs(P.kwargs(n), set(), "aten.relu") + (x,) = args # x is already a Slot + + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise ValueError("Input tensor metadata not found for relu") + dtype = x_meta.dtype + + zero_slot = emit_lifted_constant(P, 0.0, dtype) + + out = P.make_or_get_slot(n) + P.emit( + MaximumNode( + a=P.slot_to_tid(x), + b=P.slot_to_tid(zero_slot), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten._log_softmax.default]) +def _log_softmax_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten._log_softmax.default - log of softmax. + + LogSoftmax(x, dim) = x - logsumexp(x, dim, keepdims=True) + + This is numerically stable because it avoids computing softmax + (which can underflow to 0) followed by log (which gives -inf for 0). + """ + args = P.args(n) + require_args(args, 3, 3, "aten._log_softmax") + require_kwargs(P.kwargs(n), set(), "aten._log_softmax") + x, dim, _half_to_float = args # x is already a Slot + + # Create temporary slot for logsumexp output + _, logsumexp_slot = P.make_tmp_slot() + + # Emit LogSumExpNode with keepdims=True + P.emit( + LogSumExpNode( + x=P.slot_to_tid(x), + axes=[dim], + keepdims=True, + out=P.slot_to_tid(logsumexp_slot), + ) + ) + + # Emit SubtractNode: x - logsumexp(x) + out = P.make_or_get_slot(n) + P.emit( + SubtractNode( + a=P.slot_to_tid(x), + b=P.slot_to_tid(logsumexp_slot), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.constant_pad_nd.default]) +def _constant_pad_nd_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.constant_pad_nd - pad with a constant value. + + PyTorch pad format: [left_0, right_0, left_1, right_1, ...] + MLX pad_width format: [(before_0, after_0), (before_1, after_1), ...] + + Note: PyTorch pads in reverse order (last dimensions first). + """ + args = P.args(n) + require_args(args, 2, 3, "aten.constant_pad_nd") + require_kwargs(P.kwargs(n), set(), "aten.constant_pad_nd") + x_node, pad = args[0], args[1] + value = args[2] if len(args) > 2 else 0 + + if not isinstance(value, (int, float)): + raise ValueError( + f"aten.constant_pad_nd: constant value must be a scalar, got {type(value)}" + ) + + # Convert PyTorch pad format to MLX pad_width format + # PyTorch: [left_D, right_D, left_D-1, right_D-1, ...] + # MLX: [(before_0, after_0), (before_1, after_1), ..., (before_D, after_D)] + if len(pad) % 2 != 0: + raise ValueError( + f"aten.constant_pad_nd: pad length must be even, got {len(pad)}" + ) + + x = P.slot_map([x_node])[0] + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise ValueError("Input tensor metadata not found for constant_pad_nd") + + ndim = len(x_meta.shape) + num_pad_dims = len(pad) // 2 + + if num_pad_dims > ndim: + raise ValueError( + f"aten.constant_pad_nd: trying to pad {num_pad_dims} dimensions " + f"but input has only {ndim} dimensions" + ) + + # Build MLX pad_width: start with zeros for non-padded dims + pad_width = [] + for _ in range(ndim - num_pad_dims): + pad_width.extend([0, 0]) # No padding for these dimensions + + # Add padding for the padded dimensions (reverse order) + for i in range(num_pad_dims - 1, -1, -1): + left = pad[i * 2] + right = pad[i * 2 + 1] + pad_width.extend([left, right]) + + out = P.make_or_get_slot(n) + P.emit( + PadNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + pad_width=[P.to_int_or_vid(v) for v in pad_width], + mode="constant", + constant_value=float(value), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.clamp.default, torch.ops.aten.clamp.Tensor]) +def _clamp_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.clamp - clamp values to [min, max] range. + + clamp(input, min=None, max=None) -> Tensor + + Clamps all elements in input into the range [min, max]. + If min is None, there is no lower bound. If max is None, there is no upper bound. + """ + args = P.args(n) + require_args(args, 1, 3, "aten.clamp") + require_kwargs(P.kwargs(n), set(), "aten.clamp") + + x = args[0] + min_val = args[1] if len(args) > 1 else None + max_val = args[2] if len(args) > 2 else None + + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise ValueError("Input tensor metadata not found for clamp") + dtype = x_meta.dtype + + out = P.make_or_get_slot(n) + + # Lift scalar bounds to 0-D constant tensors + a_min_tid = None + a_max_tid = None + if min_val is not None: + if isinstance(min_val, Slot) and min_val.id_type == IdType.Tensor: + a_min_tid = P.slot_to_tid(min_val) + else: + a_min_tid = P.slot_to_tid(emit_lifted_constant(P, float(min_val), dtype)) + if max_val is not None: + if isinstance(max_val, Slot) and max_val.id_type == IdType.Tensor: + a_max_tid = P.slot_to_tid(max_val) + else: + a_max_tid = P.slot_to_tid(emit_lifted_constant(P, float(max_val), dtype)) + + P.emit( + ClipNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + a_min=a_min_tid, + a_max=a_max_tid, + ) + ) + return out + + +@REGISTRY.register( + target=[torch.ops.aten.expand.default, torch.ops.aten.expand_copy.default] +) +def _expand_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle expand: broadcasts dimensions of size 1 to larger sizes.""" + args = P.args(n) + require_args(args, 2, 2, "aten.expand") + require_kwargs(P.kwargs(n), set(), "aten.expand") + x, size = args + out = P.make_or_get_slot(n) + + shape_iovs = [P.to_int_or_vid(s) for s in size] + P.emit( + BroadcastToNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + shape=shape_iovs, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten._native_batch_norm_legit_no_training.default]) +def _native_batch_norm_legit_no_training_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle batch norm inference (no training). + + Formula: output = (input - mean) / sqrt(var + eps) * weight + bias + + Args: + input: [N, C, ...] tensor + weight: [C] gamma parameter + bias: [C] beta parameter + running_mean: [C] + running_var: [C] + momentum: float (unused in inference) + eps: float + + Returns: + Tuple of (output, empty, empty) - save_mean and save_invstd are empty for no_training + """ + args = P.args(n) + require_args(args, 7, 7, "aten._native_batch_norm_legit_no_training") + require_kwargs(P.kwargs(n), set(), "aten._native_batch_norm_legit_no_training") + x = args[0] + weight = args[1] # gamma [C] - optional (None if affine=False) + bias = args[2] # beta [C] - optional (None if affine=False) + mean = args[3] # running_mean [C] + var = args[4] # running_var [C] + # momentum = args[5] - not used in inference + eps = args[6] # epsilon + + # Get output slots (3 outputs: normalized, save_mean, save_invstd) + output_slots = P.make_or_get_slots(n) + out = output_slots[0] # Main output + + # Get input ndim to determine reshape dimensions + # For BatchNorm1d: input is [N, C, L] -> reshape params to [1, C, 1] + # For BatchNorm2d: input is [N, C, H, W] -> reshape params to [1, C, 1, 1] + input_node = n.args[0] + input_ndim = len(input_node.meta["val"].shape) + + # Validate input dimensionality (only 3D and 4D supported) + if input_ndim not in (3, 4): + raise NotImplementedError( + f"MLX batch norm handler only supports 3D (BatchNorm1d) and 4D (BatchNorm2d) inputs. " + f"Got {input_ndim}D input." + ) + + def reshape_for_broadcast(slot, name_suffix): + """Reshape a [C] tensor for broadcasting with input.""" + _, reshaped = P.make_tmp_slot() + # Build shape: [1, -1] + [1] * (ndim - 2) + shape = [P.to_int_or_vid(1), P.to_int_or_vid(-1)] + for _ in range(input_ndim - 2): + shape.append(P.to_int_or_vid(1)) + P.emit( + ReshapeNode( + x=P.slot_to_tid(slot), + shape=shape, + out=P.slot_to_tid(reshaped), + ) + ) + return reshaped + + mean_reshaped = reshape_for_broadcast(mean, "mean") + var_reshaped = reshape_for_broadcast(var, "var") + + # Step 1: x_centered = x - mean + _, tmp_centered = P.make_tmp_slot() + P.emit( + SubtractNode( + a=P.slot_to_tid(x), + b=P.slot_to_tid(mean_reshaped), + out=P.slot_to_tid(tmp_centered), + ) + ) + + # Step 2: var_eps = var + eps + eps_slot = emit_lifted_constant(P, float(eps), torch.float32) + _, tmp_var_eps = P.make_tmp_slot() + P.emit( + AddNode( + a=P.slot_to_tid(var_reshaped), + b=P.slot_to_tid(eps_slot), + out=P.slot_to_tid(tmp_var_eps), + ) + ) + + # Step 3: inv_std = rsqrt(var_eps) + _, tmp_inv_std = P.make_tmp_slot() + P.emit(RsqrtNode(x=P.slot_to_tid(tmp_var_eps), out=P.slot_to_tid(tmp_inv_std))) + + # Step 4: x_normalized = x_centered * inv_std + _, tmp_normalized = P.make_tmp_slot() + P.emit( + MultiplyNode( + a=P.slot_to_tid(tmp_centered), + b=P.slot_to_tid(tmp_inv_std), + out=P.slot_to_tid(tmp_normalized), + ) + ) + + # Step 5: x_scaled = x_normalized * weight (skip if weight is None, i.e. affine=False) + if weight is not None: + weight_reshaped = reshape_for_broadcast(weight, "weight") + _, tmp_scaled = P.make_tmp_slot() + P.emit( + MultiplyNode( + a=P.slot_to_tid(tmp_normalized), + b=P.slot_to_tid(weight_reshaped), + out=P.slot_to_tid(tmp_scaled), + ) + ) + current_result = tmp_scaled + else: + current_result = tmp_normalized + + # Step 6: out = current_result + bias (skip if bias is None, i.e. affine=False) + if bias is not None: + bias_reshaped = reshape_for_broadcast(bias, "bias") + P.emit( + AddNode( + a=P.slot_to_tid(current_result), + b=P.slot_to_tid(bias_reshaped), + out=P.slot_to_tid(out), + ) + ) + else: + # No bias - just copy the result to output + P.emit( + IdCopyNode( + x=P.slot_to_tid(current_result), + out=P.slot_to_tid(out), + ) + ) + + # For no_training mode, outputs 1 and 2 (save_mean, save_invstd) are empty + # They should already be allocated by make_or_get_slots but we don't write to them + # PyTorch returns empty tensors for these in no_training mode + + return output_slots + + +@REGISTRY.register(target=[torch.ops.aten.where.self]) +def _where_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle where: select from x or y according to condition. + + where(condition, x, y) returns elements from x where condition is True, + and elements from y where condition is False. + """ + args = P.args(n) + require_args(args, 3, 3, "aten.where") + require_kwargs(P.kwargs(n), set(), "aten.where") + condition, x, y = args + out = P.make_or_get_slot(n) + + P.emit( + WhereNode( + condition=P.slot_to_tid(condition), + x=P.slot_to_tid(x), + y=P.slot_to_tid(y), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.bitwise_not.default]) +def _bitwise_not_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.bitwise_not - for boolean tensors, dispatch to logical_not.""" + args = P.args(n) + require_args(args, 1, 1, "aten.bitwise_not") + require_kwargs(P.kwargs(n), set(), "aten.bitwise_not") + x_meta = n.args[0].meta.get("val") + + if x_meta is not None and x_meta.dtype == torch.bool: + # For boolean tensors, bitwise_not is equivalent to logical_not + out = P.make_or_get_slot(n) + P.emit( + LogicalNotNode( + x=P.slot_to_tid(args[0]), + out=P.slot_to_tid(out), + ) + ) + return out + else: + raise NotImplementedError( + f"aten.bitwise_not is only supported for boolean tensors. " + f"Got dtype={x_meta.dtype if x_meta else 'unknown'}" + ) + + +@REGISTRY.register( + target=[torch.ops.aten.logical_and.default, torch.ops.aten.bitwise_and.Tensor] +) +def _logical_and_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.logical_and / aten.bitwise_and on bool tensors.""" + args = P.args(n) + require_args(args, 2, 2, "aten.logical_and/bitwise_and") + require_kwargs(P.kwargs(n), set(), "aten.logical_and/bitwise_and") + + # bitwise_and is only equivalent to logical_and for bool tensors. + if n.target == torch.ops.aten.bitwise_and.Tensor: + dtype = n.args[0].meta.get("val", None) + if dtype is not None and hasattr(dtype, "dtype") and dtype.dtype != torch.bool: + raise ValueError( + f"aten.bitwise_and on non-bool dtype {dtype.dtype} is not supported; " + "only bool tensors can be lowered via LogicalAndNode" + ) + out = P.make_or_get_slot(n) + P.emit( + LogicalAndNode( + a=P.slot_to_tid(args[0]), + b=P.slot_to_tid(args[1]), + out=P.slot_to_tid(out), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.scalar_tensor.default]) +def _scalar_tensor_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """This is equivalent to torch.full([], scalar, dtype=dtype).""" + args = P.args(n) + kwargs = P.kwargs(n) + require_args(args, 1, 1, "aten.scalar_tensor") + require_kwargs( + kwargs, {"dtype", "layout", "device", "pin_memory"}, "aten.scalar_tensor" + ) + require_contiguous_format( + layout=kwargs.get("layout"), + op_name="aten.scalar_tensor", + ) + scalar_value = args[0] + + out = P.make_or_get_slot(n) + + # Get dtype from kwargs, default to float32 + dtype = n.kwargs.get("dtype") + if dtype is None: + # Infer dtype from scalar type + if isinstance(scalar_value, bool): + dtype = torch.bool + elif isinstance(scalar_value, int): + dtype = torch.int64 + else: + dtype = torch.float32 + + P.emit( + FullNode( + out=P.slot_to_tid(out), + shape=[], # 0-D tensor (scalar) + v=P.to_float_or_vid(scalar_value), + scalar_type=torch_dtype_to_scalar_type(dtype), + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.tril.default]) +def _tril_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.tril - extract lower triangular part of matrix. + + tril(input, diagonal=0) -> Tensor + + Returns the lower triangular part of the matrix, with all elements above + the diagonal set to zero. The diagonal parameter controls which diagonal + to consider: 0 = main diagonal, positive = above main, negative = below main. + """ + args = P.args(n) + require_args(args, 1, 2, "aten.tril") + require_kwargs(P.kwargs(n), set(), "aten.tril") + x = args[0] + diagonal = args[1] if len(args) > 1 else 0 + + out = P.make_or_get_slot(n) + P.emit( + TrilNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + k=diagonal, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.triu.default]) +def _triu_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.triu - extract upper triangular part of matrix. + + triu(input, diagonal=0) -> Tensor + + Returns the upper triangular part of the matrix, with all elements below + the diagonal set to zero. The diagonal parameter controls which diagonal + to consider: 0 = main diagonal, positive = above main, negative = below main. + """ + args = P.args(n) + require_args(args, 1, 2, "aten.triu") + require_kwargs(P.kwargs(n), set(), "aten.triu") + x = args[0] + diagonal = args[1] if len(args) > 1 else 0 + + out = P.make_or_get_slot(n) + P.emit( + TriuNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + k=diagonal, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.round.default]) +def _round_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.round - round elements to nearest integer. + + Note: round.decimals variant is not supported as it's not in Core ATen. + """ + args = P.args(n) + require_args(args, 1, 1, "aten.round") + require_kwargs(P.kwargs(n), set(), "aten.round") + x = args[0] + out = P.make_or_get_slot(n) + P.emit(RoundNode(x=P.slot_to_tid(x), out=P.slot_to_tid(out), decimals=0)) + return out + + +@REGISTRY.register(target=[torch.ops.aten.logsumexp.default]) +def _logsumexp_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.logsumexp - log(sum(exp(x))) along axes.""" + args = P.args(n) + require_args(args, 1, 3, "aten.logsumexp") + require_kwargs(P.kwargs(n), set(), "aten.logsumexp") + x = args[0] + dim = args[1] if len(args) > 1 else None + keepdim = args[2] if len(args) > 2 else False + + # Normalize dim to list + if dim is None: + axes = [] + elif isinstance(dim, int): + axes = [dim] + else: + axes = list(dim) + + out = P.make_or_get_slot(n) + P.emit( + LogSumExpNode( + x=P.slot_to_tid(x), out=P.slot_to_tid(out), axes=axes, keepdims=keepdim + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.var.correction, torch.ops.aten.var.dim]) +def _var_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.var - variance of elements along axes.""" + args = P.args(n) + require_args(args, 1, 2, "aten.var") + require_kwargs(P.kwargs(n), {"correction", "keepdim"}, "aten.var") + x = args[0] + axes, _ = normalize_reduction_dim(args) + + # Get correction/ddof and keepdim from kwargs + correction = n.kwargs.get("correction", None) + keepdim = n.kwargs.get("keepdim", False) + ddof = int(correction) if correction is not None else 1 + + out = P.make_or_get_slot(n) + P.emit( + VarNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + axes=axes, + keepdims=keepdim, + ddof=ddof, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.std.correction]) +def _std_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.std - standard deviation of elements along axes.""" + args = P.args(n) + require_args(args, 1, 2, "aten.std") + require_kwargs(P.kwargs(n), {"correction", "keepdim"}, "aten.std") + x = args[0] + axes, _ = normalize_reduction_dim(args) + + correction = n.kwargs.get("correction", None) + keepdim = n.kwargs.get("keepdim", False) + ddof = int(correction) if correction is not None else 1 + + out = P.make_or_get_slot(n) + P.emit( + StdNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + axes=axes, + keepdims=keepdim, + ddof=ddof, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.max.default]) +def _max_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.max.default - global max (reduce all axes).""" + args = P.args(n) + require_args(args, 1, 1, "aten.max") + require_kwargs(P.kwargs(n), set(), "aten.max") + x = args[0] + + out = P.make_or_get_slot(n) + P.emit(MaxNode(x=P.slot_to_tid(x), out=P.slot_to_tid(out), axes=[], keepdims=False)) + return out + + +@REGISTRY.register(target=[torch.ops.aten.min.default]) +def _min_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.min.default - global min (reduce all axes).""" + args = P.args(n) + require_args(args, 1, 1, "aten.min") + require_kwargs(P.kwargs(n), set(), "aten.min") + x = args[0] + + out = P.make_or_get_slot(n) + P.emit(MinNode(x=P.slot_to_tid(x), out=P.slot_to_tid(out), axes=[], keepdims=False)) + return out + + +@REGISTRY.register(target=[torch.ops.aten.argmax.default]) +def _argmax_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.argmax - index of max element along axis.""" + args = P.args(n) + require_args(args, 1, 3, "aten.argmax") + require_kwargs(P.kwargs(n), set(), "aten.argmax") + x = args[0] + dim = args[1] if len(args) > 1 else None + keepdim = args[2] if len(args) > 2 else False + + out = P.make_or_get_slot(n) + + if dim is None: + # argmax without dim: flatten tensor to 1D, then argmax over axis 0 + # Result is a scalar index into the flattened tensor + _, flat_slot = P.make_tmp_slot() + + # Get total number of elements from input shape + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise ValueError("Input tensor metadata not found for argmax") + numel = x_meta.numel() + + P.emit( + ReshapeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(flat_slot), + shape=[P.to_int_or_vid(numel)], + ) + ) + P.emit( + ArgmaxNode( + x=P.slot_to_tid(flat_slot), + out=P.slot_to_tid(out), + axis=0, + keepdims=False, + ) + ) + else: + P.emit( + ArgmaxNode( + x=P.slot_to_tid(x), out=P.slot_to_tid(out), axis=dim, keepdims=keepdim + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.argmin.default]) +def _argmin_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.argmin - index of min element along axis.""" + args = P.args(n) + require_args(args, 1, 3, "aten.argmin") + require_kwargs(P.kwargs(n), set(), "aten.argmin") + x = args[0] + dim = args[1] if len(args) > 1 else None + keepdim = args[2] if len(args) > 2 else False + + out = P.make_or_get_slot(n) + + if dim is None: + # argmin without dim: flatten tensor to 1D, then argmin over axis 0 + # Result is a scalar index into the flattened tensor + _, flat_slot = P.make_tmp_slot() + + # Get total number of elements from input shape + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise ValueError("Input tensor metadata not found for argmin") + numel = x_meta.numel() + + P.emit( + ReshapeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(flat_slot), + shape=[P.to_int_or_vid(numel)], + ) + ) + P.emit( + ArgminNode( + x=P.slot_to_tid(flat_slot), + out=P.slot_to_tid(out), + axis=0, + keepdims=False, + ) + ) + else: + P.emit( + ArgminNode( + x=P.slot_to_tid(x), out=P.slot_to_tid(out), axis=dim, keepdims=keepdim + ) + ) + return out + + +def _parse_pool_args(args, ndim, op_name, is_avg_pool=False): # noqa: C901 + """Parse pooling op arguments, normalizing scalars to lists. + + ATen pooling signatures: + max_pool{N}d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode) + avg_pool{N}d(input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + + Extra args beyond (input, kernel_size, stride, padding) are accepted only + when they match safe defaults: + max_pool: dilation=1, ceil_mode=False + avg_pool: ceil_mode=False, count_include_pad=True, divisor_override=None + + Returns (kernel_size, stride, padding) as lists of length ndim. + """ + if is_avg_pool: + require_args(args, 2, 7, op_name) + # args[4] = ceil_mode (must be False) + if len(args) > 4 and args[4]: + raise ValueError(f"{op_name}: ceil_mode=True is not supported.") + # args[5] = count_include_pad (must be True) + if len(args) > 5 and not args[5]: + raise ValueError(f"{op_name}: count_include_pad=False is not supported.") + # args[6] = divisor_override (must be None) + if len(args) > 6 and args[6] is not None: + raise ValueError(f"{op_name}: divisor_override is not supported.") + else: + require_args(args, 2, 6, op_name) + # args[4] = dilation (must be 1) + if len(args) > 4: + dilation = args[4] + if isinstance(dilation, list): + if any(d != 1 for d in dilation): + raise ValueError( + f"{op_name}: dilation != 1 is not supported, got {dilation}." + ) + elif dilation != 1: + raise ValueError( + f"{op_name}: dilation != 1 is not supported, got {dilation}." + ) + # args[5] = ceil_mode (must be False) + if len(args) > 5 and args[5]: + raise ValueError(f"{op_name}: ceil_mode=True is not supported.") + + kernel_size = args[1] + if isinstance(kernel_size, int): + kernel_size = [kernel_size] * ndim + + stride = args[2] if len(args) > 2 and args[2] else kernel_size + if isinstance(stride, int): + stride = [stride] * ndim + if not stride: # empty list means default to kernel_size + stride = list(kernel_size) + + padding = args[3] if len(args) > 3 else [0] * ndim + if isinstance(padding, int): + padding = [padding] * ndim + + return list(kernel_size), list(stride), list(padding) + + +def _emit_pool_nd( + P: MLXProgramBuilder, + n: Node, + ndim: int, + reduce_node_cls: type, + padding_value: float, + kernel_size: List[int], + stride: List[int], + padding: List[int], +) -> Slot: + """Emit IR nodes for N-dimensional pooling. + + Decomposes pooling into: + Transpose (channels-first -> channels-last) + -> Pad (if needed) + -> Reshape+Transpose (fast path) or AsStrided (general path) + -> Max/Mean reduction over kernel dims + -> Transpose (channels-last -> channels-first) + + Works for 1D, 2D, and 3D pooling uniformly. + + Args: + P: Program builder. + n: FX graph node for the pooling op. + ndim: Spatial dimensionality (1, 2, or 3). + reduce_node_cls: MaxNode or MeanNode. + padding_value: Padding fill value (-inf for max, 0 for avg). + kernel_size: Kernel size per spatial dim, length ndim. + stride: Stride per spatial dim, length ndim. + padding: Padding per spatial dim, length ndim. + + Returns: + Output Slot with shape [N, C, *out_spatial]. + """ + x_node = P.args(n)[0] + (x,) = P.slot_map([x_node]) + x_meta = n.args[0].meta["val"] + shape = list(x_meta.shape) # [N, C, *spatial] + + N = shape[0] + C = shape[1] + spatial = shape[2:] # length == ndim + + # 1. Transpose: channels-first [N, C, *spatial] -> channels-last [N, *spatial, C] + to_cl = [0] + list(range(2, ndim + 2)) + [1] + _, cur = P.make_tmp_slot() + P.emit( + TransposeNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(cur), + perm=to_cl, + ) + ) + + # 2. Pad spatial dims if needed + spatial_padded = [s + 2 * p for s, p in zip(spatial, padding)] + if any(p > 0 for p in padding): + pad_width = [0, 0] # batch dim: no pad + for p in padding: + pad_width += [p, p] + pad_width += [0, 0] # channel dim: no pad + P.emit( + PadNode( + x=P.slot_to_tid(cur), + out=P.slot_to_tid(cur), + pad_width=[P.to_int_or_vid(v) for v in pad_width], + mode="constant", + constant_value=padding_value, + ) + ) + + # 3. Sliding windows -> [N, *out_spatial, *kernel_size, C] + out_spatial = [ + (sp - k) // s + 1 for sp, k, s in zip(spatial_padded, kernel_size, stride) + ] + + can_fast_path = all( + k == s and sp % k == 0 for k, s, sp in zip(kernel_size, stride, spatial_padded) + ) + + if can_fast_path: + # Fast path: reshape + transpose (no AsStridedNode needed). + # [N, *spatial_padded, C] + # -> reshape [N, sp0//k0, k0, sp1//k1, k1, ..., C] + # -> transpose to gather output-spatial dims, then kernel dims, then C + reshape_shape = [N] + for sp, k in zip(spatial_padded, kernel_size): + reshape_shape += [sp // k, k] + reshape_shape += [C] + + P.emit( + ReshapeNode( + x=P.slot_to_tid(cur), + out=P.slot_to_tid(cur), + shape=[IntOrVid.from_literal(d) for d in reshape_shape], + ) + ) + + # Transpose: gather output-spatial (odd indices), then kernel (even indices after batch) + # Reshaped tensor axes: [0=batch, 1=out0, 2=k0, 3=out1, 4=k1, ..., last=C] + last = 2 * ndim + 1 + out_spatial_axes = list(range(1, last, 2)) # [1, 3, 5, ...] + kernel_axes = list(range(2, last, 2)) # [2, 4, 6, ...] + perm = [0] + out_spatial_axes + kernel_axes + [last] + + P.emit( + TransposeNode( + x=P.slot_to_tid(cur), + out=P.slot_to_tid(cur), + perm=perm, + ) + ) + else: + # General path: as_strided to create sliding window view. + # Input layout: [N, *spatial_padded, C] (channels-last, row-major) + dims = [N] + spatial_padded + [C] + elem_strides = [] + acc = 1 + for d in reversed(dims): + elem_strides.append(acc) + acc *= d + elem_strides.reverse() + + # as_strided shape: [N, *out_spatial, *kernel_size, C] + as_shape = [N] + out_spatial + kernel_size + [C] + + # as_strided strides: + # batch: elem_strides[0] + # out_spatial[i]: elem_strides[i+1] * stride[i] (skip by pool stride) + # kernel[i]: elem_strides[i+1] (consecutive rows/cols) + # channel: 1 + as_strides = [elem_strides[0]] + for i in range(ndim): + as_strides.append(elem_strides[i + 1] * stride[i]) + for i in range(ndim): + as_strides.append(elem_strides[i + 1]) + as_strides.append(1) + + P.emit( + AsStridedNode( + x=P.slot_to_tid(cur), + out=P.slot_to_tid(cur), + shape=[IntOrVid.from_literal(d) for d in as_shape], + strides=[IntOrVid.from_literal(d) for d in as_strides], + offset=0, + ) + ) + + # 4. Reduce over kernel dims (axes [ndim+1 .. 2*ndim]) + reduce_axes = list(range(ndim + 1, 2 * ndim + 1)) + _, reduced = P.make_tmp_slot() + P.emit( + reduce_node_cls( + x=P.slot_to_tid(cur), + out=P.slot_to_tid(reduced), + axes=reduce_axes, + keepdims=False, + ) + ) + + # 5. Transpose: channels-last [N, *out_spatial, C] -> channels-first [N, C, *out_spatial] + to_cf = [0, ndim + 1] + list(range(1, ndim + 1)) + output_slots = P.make_or_get_slots(n) + out = output_slots[0] + P.emit( + TransposeNode( + x=P.slot_to_tid(reduced), + out=P.slot_to_tid(out), + perm=to_cf, + ) + ) + return out + + +_POOL_OPS: List[Tuple[Any, int, type, float, str, bool]] = [ + # (target, ndim, reduce_cls, pad_value, op_name, returns_indices) + ( + torch.ops.aten.max_pool1d.default, + 1, + MaxNode, + float("-inf"), + "aten.max_pool1d", + False, + ), + ( + torch.ops.aten.max_pool1d_with_indices.default, + 1, + MaxNode, + float("-inf"), + "aten.max_pool1d_with_indices", + True, + ), + ( + torch.ops.aten.max_pool2d_with_indices.default, + 2, + MaxNode, + float("-inf"), + "aten.max_pool2d_with_indices", + True, + ), + ( + torch.ops.aten.max_pool3d_with_indices.default, + 3, + MaxNode, + float("-inf"), + "aten.max_pool3d_with_indices", + True, + ), + (torch.ops.aten.avg_pool1d.default, 1, MeanNode, 0.0, "aten.avg_pool1d", False), + (torch.ops.aten.avg_pool2d.default, 2, MeanNode, 0.0, "aten.avg_pool2d", False), + (torch.ops.aten.avg_pool3d.default, 3, MeanNode, 0.0, "aten.avg_pool3d", False), +] + + +def _make_pool_handler( + ndim: int, + reduce_node_cls: type, + padding_value: float, + op_name: str, + returns_indices: bool, +): + """Create a handler for an N-dimensional pooling op.""" + + is_avg = reduce_node_cls is MeanNode + + def handler(P: MLXProgramBuilder, n: Node) -> Slot: + args = P.args(n) + kernel_size, stride, padding = _parse_pool_args( + args, ndim, op_name, is_avg_pool=is_avg + ) + result = _emit_pool_nd( + P, n, ndim, reduce_node_cls, padding_value, kernel_size, stride, padding + ) + if not returns_indices: + return result + + handler.__name__ = f"_{op_name.replace('.', '_')}_handler" + handler.__doc__ = f"Handle {op_name} (table-driven pool op)." + return handler + + +for _target, _ndim, _cls, _pad, _name, _indices in _POOL_OPS: + REGISTRY.register(target=[_target])( + _make_pool_handler(_ndim, _cls, _pad, _name, _indices) + ) + + +@REGISTRY.register(target=[torch.ops.torchao.dequantize_affine.default]) +def _dequantize_affine_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle standalone torchao.dequantize_affine (not fused with linear/embedding). + + MLX's dequantize always operates along the last axis. When the quantized + dimension is not last (e.g. Conv2d with block_size=[1,32,1,1]), we permute + the constant weight/scale/zero_point tensors at compile time so the + quantized dim becomes last, emit the DequantizeNode, then emit a + TransposeNode with the inverse permutation to restore the original layout. + """ + parsed = parse_dequant_node(n) + if parsed is None: + raise NotImplementedError( + f"dequantize_affine: unsupported quantization config at {n}" + ) + ( + qdata_node, + scale_node, + zero_point_node, + group_size, + bits, + out_dtype, + quantized_dim, + ) = parsed + + qdata_target, qdata = P.get_placeholder_target_and_tensor(qdata_node) + zero_point_target, zero_point = P.get_placeholder_target_and_tensor(zero_point_node) + scale_target, scale = P.get_placeholder_target_and_tensor(scale_node) + + if out_dtype is None: + out_dtype = scale_node.meta["val"].dtype + out_scalar_type = torch_dtype_to_scalar_type(out_dtype) + + ndim = qdata.ndim + needs_permute = quantized_dim != ndim - 1 + + if needs_permute: + perm = list(range(ndim)) + perm.remove(quantized_dim) + perm.append(quantized_dim) + qdata = qdata.permute(perm).contiguous() + scale = scale.permute(perm).contiguous() + zero_point = zero_point.permute(perm).contiguous() + + # to_mlx_qparams expects 2D tensors; flatten N-D to 2D for packing, + # then restore the (possibly permuted) leading dimensions afterward. + permuted_shape = qdata.shape + qdata_2d = qdata.reshape(-1, qdata.shape[-1]) + scale_2d = scale.reshape(-1, scale.shape[-1]) + zero_point_2d = zero_point.reshape(-1, zero_point.shape[-1]) + + Q, B = to_mlx_qparams(qdata_2d, scale_2d, zero_point_2d, bits) + + leading_dims = permuted_shape[:-1] + Q = Q.reshape(*leading_dims, Q.shape[-1]) + scale_nd = scale_2d.reshape(*leading_dims, scale_2d.shape[-1]) + if B is not None: + B = B.reshape(*leading_dims, B.shape[-1]) + + w = P.make_or_get_constant(f"{qdata_target}_to_packed", Q) + scale_const = P.make_or_get_constant(f"{scale_target}_scale", scale_nd) + biases = emit_quantized_biases( + P, zero_point_target, scale, zero_point, bits, B, scale_const + ) + + if needs_permute: + _, dequant_tmp = P.make_tmp_slot() + else: + dequant_tmp = P.make_or_get_slot(n) + + P.emit( + DequantizeNode( + w=P.slot_to_tid(w), + scales=P.slot_to_tid(scale_const), + out=P.slot_to_tid(dequant_tmp), + biases=P.slot_to_tid(biases), + group_size=group_size, + bits=bits, + mode="affine", + dtype=out_scalar_type, + ) + ) + + if needs_permute: + inv_perm = [0] * ndim + for i, p in enumerate(perm): + inv_perm[p] = i + out = P.make_or_get_slot(n) + P.emit( + TransposeNode( + x=P.slot_to_tid(dequant_tmp), + out=P.slot_to_tid(out), + perm=inv_perm, + ) + ) + else: + out = dequant_tmp + + return out + + +@REGISTRY.register(target=[torch.ops.aten.cumsum.default]) +def _cumsum_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.cumsum - cumulative sum along an axis.""" + args = P.args(n) + require_args(args, 2, 3, "aten.cumsum") + require_kwargs(P.kwargs(n), {"dtype"}, "aten.cumsum") + x = args[0] + dim = args[1] + + out = P.make_or_get_slot(n) + P.emit( + CumsumNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + axis=dim, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.stack.default]) +def _stack_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.stack - stack tensors along a new axis.""" + args = P.args(n) + require_args(args, 1, 2, "aten.stack") + require_kwargs(P.kwargs(n), set(), "aten.stack") + tensors_list = args[0] + dim = args[1] if len(args) > 1 else 0 + + out = P.make_or_get_slot(n) + tensor_tids = [P.slot_to_tid(t) for t in tensors_list] + P.emit( + StackNode( + tensors=tensor_tids, + out=P.slot_to_tid(out), + axis=dim, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.repeat_interleave.self_int]) +def _repeat_interleave_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.repeat_interleave - repeat each element along an axis.""" + args = P.args(n) + require_args(args, 2, 4, "aten.repeat_interleave") + require_kwargs(P.kwargs(n), {"output_size"}, "aten.repeat_interleave") + x = args[0] + repeats = args[1] + dim = args[2] if len(args) > 2 else 0 + + out = P.make_or_get_slot(n) + P.emit( + RepeatNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + repeats=P.to_int_or_vid(repeats), + axis=dim, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.sort.default]) +def _sort_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.sort - sort elements along an axis. + + Returns (values, indices) as a tuple of output slots. + """ + args = P.args(n) + require_args(args, 1, 3, "aten.sort") + require_kwargs(P.kwargs(n), set(), "aten.sort") + x = args[0] + dim = args[1] if len(args) > 1 else -1 + + # torch.sort returns (values, indices) - 2 outputs + output_slots = P.make_or_get_slots(n) + values_slot, indices_slot = output_slots + + used = used_getitem_indices(n) + + if 0 in used: + P.emit( + SortNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(values_slot), + axis=dim, + ) + ) + if 1 in used: + P.emit( + ArgsortNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(indices_slot), + axis=dim, + ) + ) + + return output_slots + + +@REGISTRY.register(target=[torch.ops.aten.argsort.default]) +def _argsort_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.argsort - indices that sort elements along an axis.""" + args = P.args(n) + require_args(args, 1, 3, "aten.argsort") + require_kwargs(P.kwargs(n), set(), "aten.argsort") + x = args[0] + dim = args[1] if len(args) > 1 else -1 + + out = P.make_or_get_slot(n) + P.emit( + ArgsortNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(out), + axis=dim, + ) + ) + return out + + +@REGISTRY.register(target=[torch.ops.aten.topk.default]) +def _topk_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.topk - top-k elements along an axis. + + Decomposes into: partition → slice → sort → reverse (for values) + argpartition → slice → gather → argsort → reverse → reorder (for indices) + + torch.topk returns (values, indices) sorted descending. + """ + args = P.args(n) + require_args(args, 2, 5, "aten.topk") + require_kwargs(P.kwargs(n), set(), "aten.topk") + x = args[0] + k = args[1] + dim = args[2] if len(args) > 2 else -1 + + output_slots = P.make_or_get_slots(n) + values_slot, indices_slot = output_slots + + used = used_getitem_indices(n) + + # Get dim size from input metadata for forward slice stop + x_meta = n.args[0].meta.get("val") + if x_meta is None: + raise ValueError("Input tensor metadata not found for topk") + norm_axis = dim if dim >= 0 else dim + len(x_meta.shape) + dim_size = x_meta.shape[norm_axis] + + # Compute -k for partition index and forward slice start + if isinstance(k, int): + neg_k = P.to_int_or_vid(-k) + # Reverse slice: start=k-1, stop=-(k+1) on the k-sized sliced tensor + rev_start = P.to_int_or_vid(k - 1) + rev_stop = P.to_int_or_vid(-(k + 1)) + else: + # k is dynamic — emit neg_k = k * -1 at runtime + _, neg_k_slot = P.make_tmp_value_slot() + P.emit( + MultiplyIntNode( + a=P.to_int_or_vid(k), + b=IntOrVid.from_literal(-1), + out=P.slot_to_vid(neg_k_slot), + ) + ) + neg_k = P.to_int_or_vid(neg_k_slot) + # rev_start = k - 1 + _, rev_start_slot = P.make_tmp_value_slot() + P.emit( + AddIntNode( + a=P.to_int_or_vid(k), + b=IntOrVid.from_literal(-1), + out=P.slot_to_vid(rev_start_slot), + ) + ) + rev_start = P.to_int_or_vid(rev_start_slot) + # rev_stop = -(k + 1) = neg_k - 1 + _, rev_stop_slot = P.make_tmp_value_slot() + P.emit( + AddIntNode( + a=neg_k, + b=IntOrVid.from_literal(-1), + out=P.slot_to_vid(rev_stop_slot), + ) + ) + rev_stop = P.to_int_or_vid(rev_stop_slot) + + stop_val = P.to_int_or_vid(dim_size) + + def emit_partition_and_slice(node_cls): + """Emit partition/argpartition → slice last k elements.""" + _, part_tmp = P.make_tmp_slot() + P.emit( + node_cls( + x=P.slot_to_tid(x), + out=P.slot_to_tid(part_tmp), + kth=neg_k, + axis=dim, + ) + ) + _, slice_tmp = P.make_tmp_slot() + P.emit( + SliceNode( + x=P.slot_to_tid(part_tmp), + out=P.slot_to_tid(slice_tmp), + axis=P.to_int_or_vid(dim), + start=neg_k, + stop=stop_val, + step=1, + ) + ) + return slice_tmp + + def emit_reverse(in_slot, out_slot): + """Reverse a tensor along dim using slice with step=-1.""" + P.emit( + SliceNode( + x=P.slot_to_tid(in_slot), + out=P.slot_to_tid(out_slot), + axis=P.to_int_or_vid(dim), + start=rev_start, + stop=rev_stop, + step=-1, + ) + ) + + if 0 in used: + # partition → slice last k → sort ascending → reverse to descending + slice_tmp = emit_partition_and_slice(PartitionNode) + _, sort_tmp = P.make_tmp_slot() + P.emit( + SortNode( + x=P.slot_to_tid(slice_tmp), + out=P.slot_to_tid(sort_tmp), + axis=dim, + ) + ) + emit_reverse(sort_tmp, values_slot) + + if 1 in used: + # argpartition → slice last k → gather values → argsort → reverse → reorder + idx_slice_tmp = emit_partition_and_slice(ArgPartitionNode) + # Gather original values at the partitioned indices + _, gathered_tmp = P.make_tmp_slot() + P.emit( + TakeAlongAxisNode( + x=P.slot_to_tid(x), + indices=P.slot_to_tid(idx_slice_tmp), + out=P.slot_to_tid(gathered_tmp), + axis=dim, + ) + ) + # Argsort gathered values ascending → reverse → descending order + _, order_tmp = P.make_tmp_slot() + P.emit( + ArgsortNode( + x=P.slot_to_tid(gathered_tmp), + out=P.slot_to_tid(order_tmp), + axis=dim, + ) + ) + _, rev_order_tmp = P.make_tmp_slot() + emit_reverse(order_tmp, rev_order_tmp) + # Apply descending order to indices + P.emit( + TakeAlongAxisNode( + x=P.slot_to_tid(idx_slice_tmp), + indices=P.slot_to_tid(rev_order_tmp), + out=P.slot_to_tid(indices_slot), + axis=dim, + ) + ) + + return output_slots diff --git a/backends/mlx/patterns.py b/backends/mlx/patterns.py index c8bef1f91ca..29e5e326c69 100644 --- a/backends/mlx/patterns.py +++ b/backends/mlx/patterns.py @@ -12,3 +12,1164 @@ This module contains pattern handlers that match multi-node subgraphs and lower them to optimized MLX operations. """ + +from __future__ import annotations + +from typing import Any, List, Optional, Tuple + +import torch +from executorch.backends.mlx.builder.op_helpers import ( + emit_quantized_biases, + emit_stop_position, + parse_dequant_node, + parse_dequant_nvfp4_node, + to_mlx_qparams, + torch_dtype_to_scalar_type, +) +from executorch.backends.mlx.builder.op_registry import PatternHandler, REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.pattern_utils import ( + has_single_user, + match_target, + OpStep, + walk_back, +) +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + AddIntNode, + AddNode, + AsTypeNode, + DequantizeNode, + IndexCopyNode, + IntOrVid, + IntOrVidOrTid, + ModIntNode, + MultiplyNode, + QuantizedMatmulNode, + SdpaNode, + SliceNode, + SliceUpdateNode, + SubtractIntNode, + SymSizeNode, + TakeNode, +) +from torch.export.exported_program import ExportedProgram +from torch.fx.node import Node + + +@REGISTRY.register_pattern(name="INDEX_COPY") +class IndexCopyHandler(PatternHandler): + """ + Pattern for index-based updates on mutable buffers. + """ + + def __init__( + self, + head: Node, + body: List[Node], + dst: Node, + update: Node, + indices: Node, + axis: int, + ): + super().__init__(head, body) + self.dst = dst + self.update = update + self.indices = indices + self.axis = axis + + @classmethod + def maybe_create( # noqa: C901 + cls, ep: ExportedProgram, head: Node + ) -> Optional["IndexCopyHandler"]: + index_copy_node = head + if not match_target(index_copy_node, torch.ops.aten.index_copy.default): + return None + + # index_copy should write to a mutable input/buffer to be an index update. + if (index_copy_node.name not in ep.graph_signature.buffers_to_mutate) and ( + index_copy_node.name not in ep.graph_signature.user_inputs_to_mutate + ): + return None + + # index_copy(dst, axis, indices, update) + if len(index_copy_node.args) != 4: + return None + dst, axis, indices, update = index_copy_node.args + + # axis must be a literal int + if not isinstance(axis, int): + return None + + return cls( + head=index_copy_node, + body=[], + dst=dst, + update=update, + indices=indices, + axis=axis, + ) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + dst, update, indices = P.slot_map([self.dst, self.update, self.indices]) + + P.emit( + IndexCopyNode( + dst=P.slot_to_tid(dst), + update=P.slot_to_tid(update), + indices=P.slot_to_tid(indices), + out=P.slot_to_tid(dst), + axis=self.axis, + ) + ) + + P.set_slot(n, dst) + return dst + + +@REGISTRY.register_pattern(name="ET_KV_CACHE_UPDATE") +class ETKVCacheUpdateHandler(PatternHandler): + """ + Pattern for KV cache updates using torch.ops.mlx.kv_cache_update. + + Matches: auto_functionalized → getitem[1] + HEAD = getitem[1] (no alias_copy required) + + Graph structure: + auto_func = auto_functionalized_v2(mlx.kv_cache_update, new_values=k_val, ...) + getitem_1 = getitem(auto_func, 1) # HEAD - updated cache + """ + + def __init__( + self, + head: Node, + body: List[Node], + cache: Node, + update: Node, + start_pos: Any, + ring_size: int = 0, + ): + super().__init__(head, body) + self.cache = cache + self.update = update + self.start_pos = start_pos + self.ring_size = ring_size + + @staticmethod + def _is_auto_func_et_kv_cache_update(node: Node) -> bool: + """Check if a node is auto_functionalized_v2 wrapping mlx.kv_cache_update.""" + if node.op != "call_function": + return False + target_str = str(node.target) + if "auto_functionalized" not in target_str: + return False + if len(node.args) < 1: + return False + func_arg = node.args[0] + func_str = str(func_arg) if func_arg else "" + return "kv_cache_update" in func_str and "mlx" in func_str + + @classmethod + def maybe_create( + cls, ep: ExportedProgram, head: Node + ) -> Optional["ETKVCacheUpdateHandler"]: + """ + Match the ET_KV_CACHE_UPDATE pattern. + + Pattern (HEAD = getitem): + auto_func = auto_functionalized_v2(mlx.kv_cache_update, ...) + getitem_1 = getitem(auto_func, 1) # HEAD + """ + + # HEAD must be getitem with idx=1 + if head.op != "call_function" or "getitem" not in str(head.target): + return None + + if len(head.args) < 2 or head.args[1] != 1: + return None + + # getitem's source should be auto_functionalized_v2 wrapping mlx.kv_cache_update + if not isinstance(head.args[0], Node): + return None + + auto_func_node = head.args[0] + if not cls._is_auto_func_et_kv_cache_update(auto_func_node): + return None + + # Extract info from auto_functionalized_v2 kwargs + kwargs = auto_func_node.kwargs + new_values_node = kwargs.get("new_values") + start_pos_node = kwargs.get("start_pos") + all_bases = kwargs.get("_all_bases", []) + + if not new_values_node or not all_bases: + return None + + cache_node = all_bases[0] + + body = [auto_func_node] + + return cls( + head=head, + body=body, + cache=cache_node, + update=new_values_node, + start_pos=start_pos_node, + ring_size=kwargs.get("ring_size", 0), + ) + + def __call__(self, P: "MLXProgramBuilder", n: Node) -> Slot: + assert n == self.head + + cache_slot, update_slot, start_slot = P.slot_map( + [self.cache, self.update, self.start_pos] + ) + + if self.ring_size > 0: + self._emit_ring_buffer(P, cache_slot, update_slot, start_slot) + else: + self._emit_linear(P, cache_slot, update_slot, start_slot) + + P.set_slot(n, cache_slot) + return cache_slot + + def _emit_linear(self, P: "MLXProgramBuilder", cache_slot, update_slot, start_slot): + """Emit a single SliceUpdate for linear (non-ring) cache.""" + update_meta = self.update.meta.get("val") + stop_slot = emit_stop_position( + P, + start=start_slot, + length_tensor=update_slot, + length_dim=2, # S_step is dim 2 in [B, H, S_step, D] + length_meta=update_meta, + ) + + # This updates cache[:, :, start:stop, :] = update + # SliceUpdateNode on axis=2 + # cache is [B, H, S, D], update is [B, H, S_step, D] + P.emit( + SliceUpdateNode( + dst=P.slot_to_tid(cache_slot), + update=P.slot_to_tid(update_slot), + out=P.slot_to_tid(cache_slot), + axis=IntOrVid.from_literal(2), # S dimension in [B, H, S, D] + start=P.to_int_or_vid(start_slot), + stop=P.to_int_or_vid(stop_slot), + ) + ) + + def _emit_ring_buffer( + self, P: "MLXProgramBuilder", cache_slot, update_slot, start_slot + ): + """ + Emit two unconditional SliceUpdates for ring buffer wrapping. + + write_pos = start_pos % ring_size + first_len = ring_size - write_pos + first_chunk = update[:, :, :first_len, :] (Slice clamps to seq_len) + actual_first = first_chunk.shape[2] (min(first_len, seq_len)) + rest_chunk = update[:, :, actual_first:seq_len, :] + overflow = seq_len - actual_first + SliceUpdate(cache, first_chunk, write_pos, write_pos + actual_first) + SliceUpdate(cache, rest_chunk, 0, overflow) + + When no wrap: actual_first == seq_len, rest_chunk is zero-length, + second SliceUpdate is a no-op (guarded in exec_slice_update). + """ + ring_size = self.ring_size + + # write_pos = start_pos % ring_size + _, write_pos_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + ModIntNode( + a=P.to_int_or_vid(start_slot), + b=IntOrVid.from_literal(ring_size), + out=P.slot_to_vid(write_pos_slot), + ) + ) + + # seq_len = update.shape[2] + _, seq_len_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + SymSizeNode( + a=P.slot_to_tid(update_slot), + dim=2, + out=P.slot_to_vid(seq_len_slot), + ) + ) + + # first_len = ring_size - write_pos (may be > seq_len) + _, first_len_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + SubtractIntNode( + a=IntOrVid.from_literal(ring_size), + b=P.to_int_or_vid(write_pos_slot), + out=P.slot_to_vid(first_len_slot), + ) + ) + + # first_chunk = update[:, :, :first_len, :] (Slice clamps to seq_len) + _, first_chunk_slot = P.make_tmp_slot() + P.emit( + SliceNode( + x=P.slot_to_tid(update_slot), + out=P.slot_to_tid(first_chunk_slot), + axis=IntOrVid.from_literal(2), + start=IntOrVid.from_literal(0), + stop=P.to_int_or_vid(first_len_slot), + ) + ) + + # actual_first = first_chunk.shape[2] (= min(first_len, seq_len)) + _, actual_first_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + SymSizeNode( + a=P.slot_to_tid(first_chunk_slot), + dim=2, + out=P.slot_to_vid(actual_first_slot), + ) + ) + + # rest_chunk = update[:, :, actual_first:seq_len, :] + _, rest_chunk_slot = P.make_tmp_slot() + P.emit( + SliceNode( + x=P.slot_to_tid(update_slot), + out=P.slot_to_tid(rest_chunk_slot), + axis=IntOrVid.from_literal(2), + start=P.to_int_or_vid(actual_first_slot), + stop=P.to_int_or_vid(seq_len_slot), + ) + ) + + # stop1 = write_pos + actual_first + _, stop1_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + AddIntNode( + a=P.to_int_or_vid(write_pos_slot), + b=P.to_int_or_vid(actual_first_slot), + out=P.slot_to_vid(stop1_slot), + ) + ) + + # overflow = seq_len - actual_first + _, overflow_slot = P.slot_manager.make_tmp_value_slot() + P.emit( + SubtractIntNode( + a=P.to_int_or_vid(seq_len_slot), + b=P.to_int_or_vid(actual_first_slot), + out=P.slot_to_vid(overflow_slot), + ) + ) + + # SliceUpdate 1: cache[:, :, write_pos:stop1, :] = first_chunk + P.emit( + SliceUpdateNode( + dst=P.slot_to_tid(cache_slot), + update=P.slot_to_tid(first_chunk_slot), + out=P.slot_to_tid(cache_slot), + axis=IntOrVid.from_literal(2), + start=P.to_int_or_vid(write_pos_slot), + stop=P.to_int_or_vid(stop1_slot), + ) + ) + + # SliceUpdate 2: cache[:, :, 0:overflow, :] = rest_chunk + # Zero-length no-op when no wrap (overflow=0) + P.emit( + SliceUpdateNode( + dst=P.slot_to_tid(cache_slot), + update=P.slot_to_tid(rest_chunk_slot), + out=P.slot_to_tid(cache_slot), + axis=IntOrVid.from_literal(2), + start=IntOrVid.from_literal(0), + stop=P.to_int_or_vid(overflow_slot), + ) + ) + + +@REGISTRY.register_pattern(name="SDPA") +class SDPAHandler(PatternHandler): + """ + Pattern for Scaled Dot Product Attention with optional GQA. + + Matches: scaled_dot_product_attention + Optionally with repeat_interleave for grouped query attention. + """ + + def __init__( + self, + head: Node, + body: List[Node], + q_node: Node, + k_node: Node, + v_node: Node, + ): + super().__init__(head, body) + self.q_node = q_node + self.k_node = k_node + self.v_node = v_node + + @classmethod + def _parse_sdpa_args_and_kwargs(cls, sdpa_node: Node): + q, k, v = sdpa_node.args[0:3] + attn_mask = sdpa_node.args[3] if len(sdpa_node.args) > 3 else None + dropout_p = sdpa_node.args[4] if len(sdpa_node.args) > 4 else 0.0 + is_causal = sdpa_node.args[5] if len(sdpa_node.args) > 5 else False + enable_gqa = sdpa_node.args[6] if len(sdpa_node.args) > 6 else False + scale = sdpa_node.kwargs.get("scale", None) + return q, k, v, attn_mask, dropout_p, is_causal, scale, enable_gqa + + @classmethod + def _try_unwrap_repeat_kv(cls, node: Node) -> Optional[Tuple[Node, List[Node]]]: + """Try to unwrap a HuggingFace repeat_kv pattern. + + HuggingFace's repeat_kv expands KV heads for grouped query attention: + hidden_states[:, :, None, :, :].expand(B, n_kv, n_rep, T, D) + .clone().reshape(B, n_heads, T, D) + + In Edge IR this becomes: + unsqueeze_copy(x, 2) → expand_copy → clone → view_copy + + Returns: + (base_node, body_nodes) if pattern matches, else None. + base_node is the original [B, n_kv, T, D] tensor. + body_nodes are the intermediate nodes to absorb. + """ + result = walk_back( + node, + [ + OpStep(op=torch.ops.aten.view.default, nargs=2), + OpStep(op=torch.ops.aten.clone.default, optional=True), + OpStep(op=torch.ops.aten.expand.default, nargs=2), + OpStep(op=torch.ops.aten.unsqueeze.default, nargs=2), + ], + ) + if result is None: + return None + + base, entries = result + _view, _clone, _expand, unsqueeze = entries + + # unsqueeze must be on dim=2 + if unsqueeze.args[1] != 2: + return None + + body = [e for e in entries if e is not None] + return base, body + + @classmethod + def maybe_create(cls, ep: ExportedProgram, head: Node) -> Optional["SDPAHandler"]: + sdpa_node = head + if not match_target( + sdpa_node, torch.ops.aten.scaled_dot_product_attention.default + ): + return None + + q, k, v, _, _, _, _, _ = cls._parse_sdpa_args_and_kwargs(sdpa_node) + + # Detect grouped kv attention pattern with repeat_interleave before SDPA + is_grouped_kv = False + k_base = k + v_base = v + body: List[Node] = [] + if ( + match_target(k, torch.ops.aten.repeat_interleave.self_int) + and has_single_user(k) + and (len(k.args) == 3) + and (len(k.kwargs) == 0) + and match_target(v, torch.ops.aten.repeat_interleave.self_int) + and has_single_user(v) + and (len(v.args) == 3) + and (len(v.kwargs) == 0) + ): + k_unrepeated, k_reps, k_dim = k.args + v_unrepeated, v_reps, v_dim = v.args + + if (k_dim == 1 and v_dim == 1) and (k_reps == v_reps): + is_grouped_kv = True + k_base = k_unrepeated + v_base = v_unrepeated + body = [k, v] + + # Detect HuggingFace repeat_kv pattern: + # unsqueeze(dim=2) → expand → clone → view + if not is_grouped_kv: + k_unwrap = cls._try_unwrap_repeat_kv(k) + v_unwrap = cls._try_unwrap_repeat_kv(v) + if k_unwrap is not None and v_unwrap is not None: + k_base, k_body = k_unwrap + v_base, v_body = v_unwrap + is_grouped_kv = True + body = k_body + v_body + + head = sdpa_node + if not is_grouped_kv: + body = [] + return SDPAHandler(head, body, q_node=q, k_node=k_base, v_node=v_base) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + q, k, v, attn_mask, dropout_p, is_causal, scale, enable_gqa = ( + SDPAHandler._parse_sdpa_args_and_kwargs(n) + ) + head_dim = q.meta["val"].shape[-1] + if scale is None: + scale = head_dim**-0.5 + + q = self.q_node + k = self.k_node + v = self.v_node + + assert dropout_p == 0.0, "SDPA with dropout is not supported" + + q, k, v, attn_mask = P.slot_map([q, k, v, attn_mask]) + out = P.make_or_get_slot(n) + + P.emit( + SdpaNode( + q=P.slot_to_tid(q), + k=P.slot_to_tid(k), + v=P.slot_to_tid(v), + out=P.slot_to_tid(out), + scale=scale, + mask=P.slot_to_tid(attn_mask) if attn_mask else None, + causal=is_causal, + ) + ) + return out + + +@REGISTRY.register_pattern(name="NVFP4_QUANTIZED_EMBEDDING") +class NVFP4QuantizedEmbeddingHandler(PatternHandler): + """Fuse dequantize_nvfp4 + embedding into gather + DequantizeNode(mode="nvfp4"). + + Matches: + embedding(dequantize_nvfp4(qdata, scale, per_tensor_scale, ...), indices) + + Emits: + TakeNode(qdata) → TakeNode(scales) → DequantizeNode(mode="nvfp4") + [→ MultiplyNode(per_tensor_scale)] [→ AsTypeNode] + """ + + def __init__(self, head, body, qdata, scale, per_tensor_scale, output_dtype): + super().__init__(head, body) + self.qdata = qdata + self.scale = scale + self.per_tensor_scale = per_tensor_scale + self.output_dtype = output_dtype + + @classmethod + def maybe_create(cls, ep, head): + if not match_target(head, torch.ops.aten.embedding.default): + return None + + w, x = head.args[0:2] + if not isinstance(w, Node): + return None + if not has_single_user(w): + return None + parsed = parse_dequant_nvfp4_node(w) + if parsed is None: + return None + qdata, scale, per_tensor_scale, output_dtype = parsed + return cls(head, [w], qdata, scale, per_tensor_scale, output_dtype) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + w_node, x_node = n.args[0:2] + + has_per_tensor_scale = True + _, per_tensor_scale_value = P.get_placeholder_target_and_tensor( + self.per_tensor_scale + ) + from torch._subclasses.fake_tensor import FakeTensor + + if not isinstance(per_tensor_scale_value, FakeTensor): + if per_tensor_scale_value.item() == 1.0: + has_per_tensor_scale = False + + x_dtype = x_node.meta["val"].dtype + needs_cast = self.output_dtype != x_dtype + + x, scales_slot, per_tensor_scale, qdata_slot = P.slot_map( + [x_node, self.scale, self.per_tensor_scale, self.qdata] + ) + + ids_index = IntOrVidOrTid.from_tid(P.slot_to_tid(x)) + + # Gather quantized weights by indices + _, wq_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(qdata_slot), + index=ids_index, + out=P.slot_to_tid(wq_sel), + axis=0, + ) + ) + + # Gather scales by indices + _, sc_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(scales_slot), + index=ids_index, + out=P.slot_to_tid(sc_sel), + axis=0, + ) + ) + + # Dequantize the gathered slices + out = P.make_or_get_slot(n) + P.emit( + DequantizeNode( + w=P.slot_to_tid(wq_sel), + scales=P.slot_to_tid(sc_sel), + out=P.slot_to_tid(out), + biases=None, + group_size=16, + bits=4, + mode="nvfp4", + dtype=torch_dtype_to_scalar_type(self.output_dtype), + ) + ) + + if has_per_tensor_scale: + P.emit( + MultiplyNode( + a=P.slot_to_tid(out), + b=P.slot_to_tid(per_tensor_scale), + out=P.slot_to_tid(out), + ) + ) + + if needs_cast: + P.emit( + AsTypeNode( + x=P.slot_to_tid(out), + out=P.slot_to_tid(out), + scalar_type=torch_dtype_to_scalar_type(self.output_dtype), + ) + ) + + return out + + +@REGISTRY.register_pattern(name="MLX_CUSTOM_SDPA") +class MLXCustomSdpaHandler(PatternHandler): + """ + Pattern handler for mlx::custom_sdpa custom op. + + This op follows the optimum-executorch pattern: + - Input: Q, K, V in BHSD format [B, num_heads, seq_len, head_dim] + - start_pos: FIRST position of current query batch (not last!) + - stop_pos: computed as start_pos + query_seq_len + - K/V are FULL cache, sliced internally to [:, :, :stop_pos, :] + + For prefill with 7 tokens at positions [0,1,2,3,4,5,6]: start_pos=0, stop_pos=7 + For decode at position 10: start_pos=10, stop_pos=11 + + Decomposes into: + - SliceNode (K): slice to [:, :, :stop_pos, :] + - SliceNode (V): slice to [:, :, :stop_pos, :] + - SdpaNode: scaled dot-product attention (handles GQA internally) + """ + + def __init__( + self, + head: Node, + body: List[Node], + query: Node, + key: Node, + value: Node, + start_pos: Any, # int or Node (SymInt) + attn_mask: Optional[Node], + scale: Optional[float], + is_causal: bool, + ): + super().__init__(head, body) + self.query = query + self.key = key + self.value = value + self.start_pos = start_pos + self.attn_mask = attn_mask + self.scale = scale + self.is_causal = is_causal + + @classmethod + def maybe_create( + cls, ep: ExportedProgram, head: Node + ) -> Optional["MLXCustomSdpaHandler"]: + """Match the mlx::custom_sdpa custom op.""" + if head.op != "call_function": + return None + + target_str = str(head.target) + if "custom_sdpa" not in target_str or "mlx" not in target_str: + return None + + # Op signature: custom_sdpa(query, key, value, start_pos, attn_mask, dropout_p, is_causal, scale) + # start_pos is a SymInt (int), not a Tensor + args = head.args + kwargs = head.kwargs + + if len(args) < 4: + return None + + query = args[0] + key = args[1] + value = args[2] + start_pos = args[3] # int or SymInt (Node) + + # Get optional args + attn_mask = args[4] if len(args) > 4 else kwargs.get("attn_mask", None) + dropout_p = args[5] if len(args) > 5 else kwargs.get("dropout_p", 0.0) + is_causal = args[6] if len(args) > 6 else kwargs.get("is_causal", False) + scale = args[7] if len(args) > 7 else kwargs.get("scale", None) + + if dropout_p != 0.0: + return None + + return MLXCustomSdpaHandler( + head=head, + body=[], + query=query, + key=key, + value=value, + start_pos=start_pos, + attn_mask=attn_mask, + scale=scale, + is_causal=is_causal, + ) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + SdpaNode, + SliceNode, + ) + + assert n == self.head + + # Get slots for Q, K, V + q_slot, k_slot, v_slot = P.slot_map([self.query, self.key, self.value]) + + # Get scale from metadata if not provided + q_meta = self.query.meta.get("val") + head_dim = q_meta.shape[-1] + scale = self.scale if self.scale is not None else head_dim**-0.5 + + # Resolve start_pos to int or Slot (same pattern as KVCacheUpdateHandler) + if isinstance(self.start_pos, Node): + start_slot = P.slot_map([self.start_pos])[0] + else: + start_slot = self.start_pos + + # Compute stop = start_pos + seq_len using emit_stop_position, + # which handles static/dynamic seq_len (SymInt) and start_pos correctly. + # BHSD layout: q is [B, num_heads, seq_len, head_dim], seq_len is dim 2. + stop = emit_stop_position( + P, + start=start_slot, + length_tensor=q_slot, + length_dim=2, + length_meta=q_meta, + ) + slice_stop = P.to_int_or_vid(stop) + + # Step 1: Slice K to [:, :, :stop_pos, :] where stop_pos = start_pos + query_seq_len + _, k_sliced_slot = P.make_tmp_slot() + P.emit( + SliceNode( + x=P.slot_to_tid(k_slot), + out=P.slot_to_tid(k_sliced_slot), + axis=IntOrVid.from_literal(2), + start=IntOrVid.from_literal(0), + stop=slice_stop, + ) + ) + + # Step 2: Slice V to [:, :, :stop_pos, :] where stop_pos = start_pos + query_seq_len + _, v_sliced_slot = P.make_tmp_slot() + P.emit( + SliceNode( + x=P.slot_to_tid(v_slot), + out=P.slot_to_tid(v_sliced_slot), + axis=IntOrVid.from_literal(2), + start=IntOrVid.from_literal(0), + stop=slice_stop, + ) + ) + + # Step 3: SDPA (handles GQA internally) - outputs BHSD + out_slot = P.make_or_get_slot(n) + P.emit( + SdpaNode( + q=P.slot_to_tid(q_slot), + k=P.slot_to_tid(k_sliced_slot), + v=P.slot_to_tid(v_sliced_slot), + out=P.slot_to_tid(out_slot), + scale=scale, + mask=( + P.slot_to_tid(P.slot_map([self.attn_mask])[0]) + if self.attn_mask is not None + else None + ), + causal=self.is_causal, + ) + ) + + return out_slot + + +@REGISTRY.register_pattern(name="QUANTIZED_LINEAR") +class QuantizedLinearHandler(PatternHandler): + """ + Pattern for quantized linear: dequantize_affine + linear. + """ + + def __init__( + self, + head: Node, + body: List[Node], + qdata: Node, + scale: Node, + zero_point: Node, + group_size: int, + bits: int, + out_dtype: torch.dtype, + ): + super().__init__(head, body) + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.group_size = group_size + self.bits = bits + self.out_dtype = out_dtype + + @classmethod + def maybe_create( + cls, ep: ExportedProgram, head: Node + ) -> Optional["QuantizedLinearHandler"]: + linear_node = head + if not match_target(linear_node, torch.ops.aten.linear.default): + return None + + x, w = linear_node.args[0:2] + dequant_node = w + if not match_target(dequant_node, torch.ops.torchao.dequantize_affine.default): + return None + if not has_single_user(dequant_node): + return None + + parsed = parse_dequant_node(dequant_node) + if parsed is None: + return None + qdata, scale, zero_point, group_size, bits, out_dtype, _quantized_dim = parsed + out_dtype = x.meta["val"].dtype if out_dtype is None else out_dtype + + head = linear_node + body = [dequant_node] + return QuantizedLinearHandler( + head, + body, + qdata=qdata, + scale=scale, + zero_point=zero_point, + group_size=group_size, + bits=bits, + out_dtype=out_dtype, + ) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + + x_node, w_node = n.args[0:2] + b_node = n.args[2] if len(n.args) > 2 else None + + qdata_target, qdata = P.get_placeholder_target_and_tensor(self.qdata) + zero_point_target, zero_point = P.get_placeholder_target_and_tensor( + self.zero_point + ) + _, scale = P.get_placeholder_target_and_tensor(self.scale) + + x_slot, scale_slot, b_slot = P.slot_map([x_node, self.scale, b_node]) + + Q, B = to_mlx_qparams(qdata, scale, zero_point, self.bits) + w = P.make_or_get_constant(f"{qdata_target}_to_packed", Q) + biases = emit_quantized_biases( + P, zero_point_target, scale, zero_point, self.bits, B, scale_slot + ) + + out = P.make_or_get_slot(n) + has_bias = b_node is not None + x_dtype = x_node.meta["val"].dtype + needs_cast = self.out_dtype != x_dtype + + P.emit( + QuantizedMatmulNode( + x=P.slot_to_tid(x_slot), + w=P.slot_to_tid(w), + scales=P.slot_to_tid(scale_slot), + out=P.slot_to_tid(out), + biases=P.slot_to_tid(biases), + group_size=self.group_size, + bits=self.bits, + mode="affine", + transpose=True, + ) + ) + + if has_bias: + P.emit( + AddNode( + a=P.slot_to_tid(out), + b=P.slot_to_tid(b_slot), + out=P.slot_to_tid(out), + ) + ) + + if needs_cast: + P.emit( + AsTypeNode( + x=P.slot_to_tid(out), + out=P.slot_to_tid(out), + scalar_type=torch_dtype_to_scalar_type(self.out_dtype), + ) + ) + + return out + + +@REGISTRY.register_pattern(name="QUANTIZED_EMBEDDING") +class QuantizedEmbeddingHandler(PatternHandler): + """ + Pattern for quantized embedding: dequantize_affine + embedding. + """ + + def __init__( + self, + head: Node, + body: List[Node], + qdata: Node, + scale: Node, + zero_point: Node, + group_size: int, + bits: int, + out_dtype: torch.dtype, + ): + super().__init__(head, body) + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.group_size = group_size + self.bits = bits + self.out_dtype = out_dtype + + @classmethod + def maybe_create( + cls, ep: ExportedProgram, head: Node + ) -> Optional["QuantizedEmbeddingHandler"]: + embedding_node = head + if not match_target(embedding_node, torch.ops.aten.embedding.default): + return None + + w, x = embedding_node.args[0:2] + + dequant_node = w + if not match_target(dequant_node, torch.ops.torchao.dequantize_affine.default): + return None + if not has_single_user(dequant_node): + return None + + parsed = parse_dequant_node(dequant_node) + if parsed is None: + return None + qdata, scale, zero_point, group_size, bits, out_dtype, _quantized_dim = parsed + out_dtype = scale.meta["val"].dtype if out_dtype is None else out_dtype + + head = embedding_node + body = [dequant_node] + return QuantizedEmbeddingHandler( + head, + body, + qdata=qdata, + scale=scale, + zero_point=zero_point, + group_size=group_size, + bits=bits, + out_dtype=out_dtype, + ) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + w, x = n.args[0:2] + + qdata_target, qdata = P.get_placeholder_target_and_tensor(self.qdata) + zero_point_target, zero_point = P.get_placeholder_target_and_tensor( + self.zero_point + ) + _, scale = P.get_placeholder_target_and_tensor(self.scale) + + Q, B = to_mlx_qparams(qdata, scale, zero_point, self.bits) + out_scalar_type = torch_dtype_to_scalar_type(self.out_dtype) + + w = P.make_or_get_constant(f"{qdata_target}_to_packed", Q) + + x, scale_slot = P.slot_map([x, self.scale]) + biases = emit_quantized_biases( + P, zero_point_target, scale, zero_point, self.bits, B, scale_slot + ) + ids_index = IntOrVidOrTid.from_tid(P.slot_to_tid(x)) + + # Gather quantized weights by ids + _, wq_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(w), + index=ids_index, + out=P.slot_to_tid(wq_sel), + axis=0, + ) + ) + + # Gather scales by ids + _, sc_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(scale_slot), + index=ids_index, + out=P.slot_to_tid(sc_sel), + axis=0, + ) + ) + + # Gather biases by ids + _, b_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(biases), + index=ids_index, + out=P.slot_to_tid(b_sel), + axis=0, + ) + ) + + # Dequantize the gathered slices + out = P.make_or_get_slot(n) + P.emit( + DequantizeNode( + w=P.slot_to_tid(wq_sel), + scales=P.slot_to_tid(sc_sel), + out=P.slot_to_tid(out), + biases=P.slot_to_tid(b_sel), + group_size=self.group_size, + bits=self.bits, + mode="affine", + dtype=out_scalar_type, + ) + ) + return out + + +@REGISTRY.register_pattern(name="NVFP4_QUANTIZED_LINEAR") +class NVFP4QuantizedLinearHandler(PatternHandler): + """Fuse dequantize_nvfp4 + linear into QuantizedMatmulNode(mode="nvfp4"). + + Matches: + linear(x, dequantize_nvfp4(qdata, scale, block_size, [per_tensor_scale]), bias) + + Emits: + QuantizedMatmulNode [→ MultiplyNode(per_tensor_scale)] [→ AddNode(bias)] + """ + + def __init__(self, head, body, qdata, scale, per_tensor_scale, output_dtype): + super().__init__(head, body) + self.qdata = qdata + self.scale = scale + self.per_tensor_scale = per_tensor_scale + self.output_dtype = output_dtype + + @classmethod + def maybe_create(cls, ep, head): + if not match_target(head, torch.ops.aten.linear.default): + return None + x, dequant = head.args[0:2] + if not isinstance(dequant, Node): + return None + if not has_single_user(dequant): + return None + parsed = parse_dequant_nvfp4_node(dequant) + if parsed is None: + return None + qdata, scale, per_tensor_scale, output_dtype = parsed + return cls(head, [dequant], qdata, scale, per_tensor_scale, output_dtype) + + def __call__(self, P, n): + assert n == self.head + + x_node, w_node = n.args[0:2] + b_node = n.args[2] if len(n.args) > 2 else None + + needs_cast = x_node.meta["val"].dtype != self.output_dtype + has_bias = b_node is not None + has_per_tensor_scale = True + + _, per_tensor_scale_value = P.get_placeholder_target_and_tensor( + self.per_tensor_scale + ) + from torch._subclasses.fake_tensor import FakeTensor + + if not isinstance(per_tensor_scale_value, FakeTensor): + if per_tensor_scale_value.item() == 1.0: + has_per_tensor_scale = False + + x, w, scales, bias, per_tensor_scale = P.slot_map( + [x_node, self.qdata, self.scale, b_node, self.per_tensor_scale] + ) + + out = P.make_or_get_slot(n) + P.emit( + QuantizedMatmulNode( + x=P.slot_to_tid(x), + w=P.slot_to_tid(w), + scales=P.slot_to_tid(scales), + out=P.slot_to_tid(out), + biases=None, + group_size=16, + bits=4, + mode="nvfp4", + transpose=True, + ) + ) + + if has_per_tensor_scale: + P.emit( + MultiplyNode( + a=P.slot_to_tid(out), + b=P.slot_to_tid(per_tensor_scale), + out=P.slot_to_tid(out), + ) + ) + + if has_bias: + P.emit( + AddNode( + a=P.slot_to_tid(out), + b=P.slot_to_tid(bias), + out=P.slot_to_tid(out), + ) + ) + + if needs_cast: + P.emit( + AsTypeNode( + x=P.slot_to_tid(out), + out=P.slot_to_tid(out), + scalar_type=torch_dtype_to_scalar_type(self.output_dtype), + ) + ) + + return out diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index bfd593c162b..a351fcfb619 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -96,28 +96,1449 @@ inline std::vector infer_shape_with_minus_one( return resolved_shape; } +// Formula: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))) +inline array gelu_tanh_impl(const array& x, StreamOrDevice s = {}) { + constexpr float sqrt_2_over_pi = 0.7978845608f; + auto dtype = x.dtype(); + + auto x3 = multiply(x, multiply(x, x, s), s); + auto term = multiply(array(0.044715f, dtype), x3, s); + auto inner = add(x, term, s); + inner = multiply(array(sqrt_2_over_pi, dtype), inner, s); + auto tanh_val = tanh(inner, s); + auto one_plus_tanh = add(array(1.0f, dtype), tanh_val, s); + auto out = multiply(x, one_plus_tanh, s); + out = multiply(array(0.5f, dtype), out, s); + return out; +} + +// Formula: 0.5 * x * (1 + erf(x / sqrt(2))) +inline array gelu_none_impl(const array& x, StreamOrDevice s = {}) { + constexpr float inv_sqrt_2 = 0.7071067812f; + auto dtype = x.dtype(); + + auto scaled = multiply(array(inv_sqrt_2, dtype), x, s); + auto erf_val = erf(scaled, s); + auto one_plus_erf = add(array(1.0f, dtype), erf_val, s); + auto out = multiply(x, one_plus_erf, s); + out = multiply(array(0.5f, dtype), out, s); + return out; +} + inline void exec_noop(const NoopNode&, ExecutionState&, StreamOrDevice) {} inline void -exec_id_copy(const IdCopyNode& n, ExecutionState& st, StreamOrDevice) { - st.set_tensor(n.out, st.const_tensor_ref(n.x)); +exec_id_copy(const IdCopyNode& n, ExecutionState& st, StreamOrDevice) { + st.set_tensor(n.out, st.const_tensor_ref(n.x)); +} + +inline void +exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& mat1 = st.const_tensor_ref(n.mat1); + const auto& mat2 = st.const_tensor_ref(n.mat2); + + array Y = n.bias ? addmm( + st.const_tensor_ref(*n.bias), + mat1, + mat2, + /*alpha=*/n.alpha, + /*beta=*/n.beta, + s) + : matmul(mat1, mat2, s); + + st.set_tensor(n.out, std::move(Y)); +} + +inline void +exec_item_int(const ItemIntNode& n, ExecutionState& st, StreamOrDevice) { + // Intentional sync: item() requires a concrete scalar value for SymInt + // shape computation, so we must force GPU evaluation here. + auto x = st.const_tensor_ref(n.x); + eval(x); + int item = x.item(); + st.set_value(n.out, item); +} + +inline void exec_expand_dims( + const ExpandDimsNode& n, + ExecutionState& st, + StreamOrDevice s) { + st.set_tensor(n.out, expand_dims(st.const_tensor_ref(n.x), n.axis, s)); +} + +inline void exec_tile(const TileNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + auto reps = resolve_ints(n.reps, st); + st.set_tensor(n.out, tile(x, reps, s)); +} + +inline void exec_take_along_axis( + const TakeAlongAxisNode& n, + ExecutionState& st, + StreamOrDevice s) { + st.set_tensor( + n.out, + take_along_axis( + st.const_tensor_ref(n.x), st.const_tensor_ref(n.indices), n.axis, s)); +} + +inline void exec_take(const TakeNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + int axis = normalize_axis(n.axis, static_cast(x.ndim()), "Take"); + switch (n.index.kind) { + case 0: { // literal int + int index = normalize_axis( + clamp_to_int32(n.index.literal), x.shape(axis), "Take"); + st.set_tensor(n.out, take(x, index, axis, s)); + break; + } + case 1: { // Vid (dynamic int) + int index = normalize_axis( + st.const_value_ref(n.index.vid), x.shape(axis), "Take"); + st.set_tensor(n.out, take(x, index, axis, s)); + break; + } + case 2: { // Tid (tensor of indices) + const auto& indices = st.const_tensor_ref(n.index.tid); + st.set_tensor(n.out, take(x, indices, axis, s)); + break; + } + default: + throw std::runtime_error( + "TakeNode: invalid index kind: " + std::to_string(n.index.kind)); + } +} + +inline void +exec_rms_norm(const RMSNormNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::optional w = std::nullopt; + if (n.weight) { + w = st.const_tensor_ref(*n.weight); + } + st.set_tensor(n.out, fast::rms_norm(x, w, n.eps, s)); +} + +inline void +exec_layer_norm(const LayerNormNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + + std::optional w = std::nullopt; + if (n.weight) { + w = st.const_tensor_ref(*n.weight); + } + std::optional bias = std::nullopt; + if (n.bias) { + bias = st.const_tensor_ref(*n.bias); + } + st.set_tensor(n.out, fast::layer_norm(x, w, bias, n.eps, s)); +} + +inline void exec_rope(const RopeNode& n, ExecutionState& st, StreamOrDevice s) { + const array& x = st.const_tensor_ref(n.x); + + std::optional freqs_arr = std::nullopt; + if (n.freqs) { + freqs_arr = st.const_tensor_ref(*n.freqs); + } + + // MLX has two overloads: rope(..., int offset, ...) and rope(..., const + // array& offset, ...) Call the appropriate one based on is_vid + if (n.offset.is_vid) { + // Scalar offset from Vid + int offset = st.const_value_ref(n.offset.vid); + st.set_tensor( + n.out, + fast::rope( + x, n.dims, n.traditional, n.base, n.scale, offset, freqs_arr, s)); + } else { + // Tensor offset from Tid + const array& offset = st.const_tensor_ref(n.offset.tid); + st.set_tensor( + n.out, + fast::rope( + x, n.dims, n.traditional, n.base, n.scale, offset, freqs_arr, s)); + } +} + +inline void exec_sdpa(const SdpaNode& n, ExecutionState& st, StreamOrDevice s) { + array Q = st.const_tensor_ref(n.q); + array K = st.const_tensor_ref(n.k); + array V = st.const_tensor_ref(n.v); + + std::string mask_mode = ""; + std::optional mask_arr = std::nullopt; + std::optional sinks = std::nullopt; + + if (n.mask) { + array M = st.const_tensor_ref(*n.mask); + // MLX's SDPA handles bool masks natively (True=attend, False=masked) + // For non-bool masks, ensure dtype matches Q + if (M.dtype() != bool_ && M.dtype() != Q.dtype()) { + M = astype(M, Q.dtype(), s); + } + mask_arr = std::move(M); + } + if (n.causal) { + mask_mode = "causal"; + } + + array out = fast::scaled_dot_product_attention( + Q, K, V, static_cast(n.scale), mask_mode, mask_arr, sinks, s); + st.set_tensor(n.out, std::move(out)); +} + +inline void exec_add(const AddNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, add(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_add_int(const AddIntNode& n, ExecutionState& st, StreamOrDevice) { + int64_t a = resolve_int(n.a, st); + int64_t b = resolve_int(n.b, st); + int64_t result = a + b; + if (result > std::numeric_limits::max() || + result < std::numeric_limits::min()) { + throw std::runtime_error("add_int: overflow"); + } + st.set_value(n.out, static_cast(result)); +} + +inline void exec_subtract_int( + const SubtractIntNode& n, + ExecutionState& st, + StreamOrDevice) { + int64_t a = resolve_int(n.a, st); + int64_t b = resolve_int(n.b, st); + int64_t result = a - b; + if (result > std::numeric_limits::max() || + result < std::numeric_limits::min()) { + throw std::runtime_error("subtract_int: overflow"); + } + st.set_value(n.out, static_cast(result)); +} + +inline void exec_multiply_int( + const MultiplyIntNode& n, + ExecutionState& st, + StreamOrDevice) { + int64_t a = resolve_int(n.a, st); + int64_t b = resolve_int(n.b, st); + int64_t result = a * b; + if (result > std::numeric_limits::max() || + result < std::numeric_limits::min()) { + throw std::runtime_error("multiply_int: overflow"); + } + st.set_value(n.out, static_cast(result)); +} + +inline void exec_floor_divide_int( + const FloorDivideIntNode& n, + ExecutionState& st, + StreamOrDevice) { + int32_t a = resolve_int(n.a, st); + int32_t b = resolve_int(n.b, st); + if (b == 0) { + throw std::runtime_error("floor_divide_int: division by zero"); + } + if (a == std::numeric_limits::min() && b == -1) { + throw std::runtime_error("floor_divide_int: overflow (INT32_MIN / -1)"); + } + // Floor division for integers (Python semantics: rounds towards negative + // infinity) + int32_t result = a / b; + // Adjust for floor division when signs differ and there's a remainder + if ((a % b != 0) && ((a < 0) != (b < 0))) { + result -= 1; + } + st.set_value(n.out, result); +} + +inline void +exec_mod_int(const ModIntNode& n, ExecutionState& st, StreamOrDevice) { + int32_t a = resolve_int(n.a, st); + int32_t b = resolve_int(n.b, st); + if (b == 0) { + throw std::runtime_error("mod_int: division by zero"); + } + // Python modulo semantics: result has same sign as divisor + int32_t result = a % b; + if ((result != 0) && ((result < 0) != (b < 0))) { + result += b; + } + st.set_value(n.out, result); +} + +inline void +exec_sym_size(const SymSizeNode& n, ExecutionState& st, StreamOrDevice) { + const array& a = st.const_tensor_ref(n.a); + int rank = static_cast(a.ndim()); + int dim = n.dim; + if (dim < 0) { + dim += rank; + } + if (dim < 0 || dim >= rank) { + throw std::out_of_range("SYM_SIZE: dim out of range"); + } + int32_t size = static_cast(a.shape()[static_cast(dim)]); + st.set_value(n.out, size); +} + +inline void +exec_multiply(const MultiplyNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, multiply(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_divide(const DivideNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, divide(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_subtract(const SubtractNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, subtract(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_conv1d(const Conv1DNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const auto& w = st.const_tensor_ref(n.w); + auto out = conv1d(x, w, n.stride, n.padding, n.dilation, n.groups, s); + st.set_tensor(n.out, std::move(out)); +} + +inline void +exec_conv2d(const Conv2DNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const auto& w = st.const_tensor_ref(n.w); + + std::pair stride = {n.stride_h, n.stride_w}; + std::pair padding = {n.padding_h, n.padding_w}; + std::pair dilation = {n.dilation_h, n.dilation_w}; + + auto out = conv2d(x, w, stride, padding, dilation, n.groups, s); + st.set_tensor(n.out, std::move(out)); +} + +inline void +exec_conv3d(const Conv3DNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const auto& w = st.const_tensor_ref(n.w); + + std::tuple stride = {n.stride_d, n.stride_h, n.stride_w}; + std::tuple padding = {n.padding_d, n.padding_h, n.padding_w}; + std::tuple dilation = { + n.dilation_d, n.dilation_h, n.dilation_w}; + + auto out = conv3d(x, w, stride, padding, dilation, n.groups, s); + st.set_tensor(n.out, std::move(out)); +} + +inline void exec_conv_transpose1d( + const ConvTranspose1DNode& n, + ExecutionState& st, + StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const auto& w = st.const_tensor_ref(n.w); + auto out = conv_transpose1d( + x, w, n.stride, n.padding, n.dilation, n.output_padding, n.groups, s); + st.set_tensor(n.out, std::move(out)); +} + +inline void exec_conv_transpose2d( + const ConvTranspose2DNode& n, + ExecutionState& st, + StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const auto& w = st.const_tensor_ref(n.w); + + std::pair stride = {n.stride_h, n.stride_w}; + std::pair padding = {n.padding_h, n.padding_w}; + std::pair dilation = {n.dilation_h, n.dilation_w}; + std::pair output_padding = {n.output_padding_h, n.output_padding_w}; + + auto out = conv_transpose2d( + x, w, stride, padding, dilation, output_padding, n.groups, s); + st.set_tensor(n.out, std::move(out)); +} + +inline void exec_conv_transpose3d( + const ConvTranspose3DNode& n, + ExecutionState& st, + StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const auto& w = st.const_tensor_ref(n.w); + + std::tuple stride = {n.stride_d, n.stride_h, n.stride_w}; + std::tuple padding = {n.padding_d, n.padding_h, n.padding_w}; + std::tuple dilation = { + n.dilation_d, n.dilation_h, n.dilation_w}; + std::tuple output_padding = { + n.output_padding_d, n.output_padding_h, n.output_padding_w}; + + auto out = conv_transpose3d( + x, w, stride, padding, dilation, output_padding, n.groups, s); + st.set_tensor(n.out, std::move(out)); +} + +inline void exec_gelu(const GeluNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + if (n.approximate == "tanh") { + st.set_tensor(n.out, gelu_tanh_impl(x, s)); + } else { + // "none" or any other value uses exact GELU + st.set_tensor(n.out, gelu_none_impl(x, s)); + } +} + +inline void +exec_arange(const ARangeNode& n, ExecutionState& st, StreamOrDevice s) { + // Get start, stop, step - may be literal int64 or dynamic Vid + int start_val = resolve_int(n.start, st); + int stop_val = resolve_int(n.stop, st); + int step_val = resolve_int(n.step, st); + + if (step_val == 0) { + throw std::runtime_error("arange: step must not be zero"); + } + + // Bound the output size: numel = ceil((stop - start) / step) + int64_t range = static_cast(stop_val) - start_val; + int64_t numel = 0; + if ((range > 0 && step_val > 0) || (range < 0 && step_val < 0)) { + numel = (range / step_val) + (range % step_val != 0 ? 1 : 0); + } + auto dtype = n.scalar_type.has_value() ? resolve_dtype(n.scalar_type.value()) + : ::mlx::core::int32; + check_allocation_bounded( + {static_cast(std::min( + numel, static_cast(std::numeric_limits::max())))}, + dtype, + "arange"); + + if (n.scalar_type.has_value()) { + st.set_tensor(n.out, arange(start_val, stop_val, step_val, dtype, s)); + } else { + // No dtype specified - use MLX's default (infers from inputs). + // The bounds check above conservatively assumes int32. + st.set_tensor(n.out, arange(start_val, stop_val, step_val, s)); + } +} + +inline void exec_silu(const SiluNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, multiply(x, sigmoid(x, s), s)); +} + +inline void +exec_sigmoid(const SigmoidNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, sigmoid(x, s)); +} + +inline void exec_tanh(const TanhNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, tanh(x, s)); +} + +inline void +exec_squeeze(const SqueezeNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const auto& dims_fb = n.dims; + + if (dims_fb.size() > 0) { + // Squeeze specific dimensions, filtering out non-size-1 dims to match + // PyTorch semantics where squeeze on a non-size-1 dim is a no-op. + std::vector dims; + for (auto d : dims_fb) { + int axis = d < 0 ? d + static_cast(x.ndim()) : d; + if (axis >= 0 && axis < static_cast(x.ndim()) && + x.shape(axis) == 1) { + dims.push_back(d); + } + } + if (dims.size() > 0) { + st.set_tensor(n.out, squeeze(x, dims, s)); + } else { + st.set_tensor(n.out, x); + } + } else { + // Squeeze all dimensions of size 1 + st.set_tensor(n.out, squeeze(x, s)); + } +} + +inline void +exec_split(const SplitNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + + // Resolve dynamic sizes to std::vector + std::vector sizes_vec = resolve_ints(n.sizes, st); + + // Get results based on split mode + auto outs_fb = n.outs; + + if (sizes_vec.size() == 1) { + // Single value means split_size (chunk size) + // Compute actual sizes: e.g., dim_size=10, split_size=3 -> [3, 3, 3, 1] + int split_size = sizes_vec[0]; + if (split_size <= 0) { + throw std::runtime_error( + "split: split_size must be positive, got " + + std::to_string(split_size)); + } + int axis = n.axis < 0 ? n.axis + static_cast(x.ndim()) : n.axis; + int dim_size = x.shape(axis); + + std::vector indices; + for (int pos = split_size; pos < dim_size; pos += split_size) { + indices.push_back(pos); + } + + auto results = split(x, to_shape(indices), n.axis, s); + if (results.size() != outs_fb.size()) { + throw std::runtime_error("Split: output count mismatch"); + } + for (size_t i = 0; i < results.size(); ++i) { + st.set_tensor(outs_fb[i], std::move(results[i])); + } + } else { + // Multiple sizes: convert to cumulative indices for MLX + // sizes=[10, 20, 30] -> indices=[10, 30] (split at positions 10 and 30) + std::vector indices; + indices.reserve(sizes_vec.size() - 1); + int64_t cumsum = 0; + for (size_t i = 0; i < sizes_vec.size() - 1; ++i) { + cumsum += static_cast(sizes_vec[i]); + if (cumsum > std::numeric_limits::max() || + cumsum < std::numeric_limits::min()) { + throw std::runtime_error("split: cumulative size overflow"); + } + indices.push_back(static_cast(cumsum)); + } + auto results = split(x, to_shape(indices), n.axis, s); + if (results.size() != outs_fb.size()) { + throw std::runtime_error("Split: output count mismatch"); + } + for (size_t i = 0; i < results.size(); ++i) { + st.set_tensor(outs_fb[i], std::move(results[i])); + } + } +} + +inline void +exec_rsqrt(const RsqrtNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, rsqrt(x, s)); +} + +inline void +exec_maximum(const MaximumNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, maximum(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_minimum(const MinimumNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, minimum(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void exec_log(const LogNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, log(x, s)); +} + +inline void +exec_softmax(const SoftmaxNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, softmax(x, n.axis, n.precise, s)); +} + +inline void exec_broadcast_to( + const BroadcastToNode& n, + ExecutionState& st, + StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + auto shape_vec = resolve_ints(n.shape, st); + + // Replace -1 with actual input dimensions (PyTorch expand semantics: + // -1 means "keep this dimension unchanged from input"). + // Dimensions are aligned from the RIGHT (broadcast semantics). + const auto& x_shape = x.shape(); + int offset = + static_cast(shape_vec.size()) - static_cast(x_shape.size()); + for (size_t i = 0; i < shape_vec.size(); i++) { + if (shape_vec[i] == -1) { + int input_dim = static_cast(i) - offset; + if (input_dim >= 0 && input_dim < static_cast(x_shape.size())) { + shape_vec[i] = + static_cast(x_shape[static_cast(input_dim)]); + } + } + } + + st.set_tensor( + n.out, + broadcast_to( + x, ::mlx::core::Shape(shape_vec.begin(), shape_vec.end()), s)); +} + +inline void exec_pad(const PadNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + + // Convert flat pad_width to vector of pairs + std::vector> pad_width_pairs; + auto pad_width_resolved = resolve_ints(n.pad_width, st); + if (pad_width_resolved.size() % 2 != 0) { + throw std::runtime_error( + "pad: pad_width must have even length, got " + + std::to_string(pad_width_resolved.size())); + } + for (size_t i = 0; i < pad_width_resolved.size(); i += 2) { + pad_width_pairs.push_back( + {pad_width_resolved[i], pad_width_resolved[i + 1]}); + } + + // MLX pad signature: pad(array, pad_width, pad_value, mode, stream) + if (n.mode == "constant") { + array pad_value(n.constant_value); + st.set_tensor(n.out, pad(x, pad_width_pairs, pad_value, "constant", s)); + } else if (n.mode == "edge") { + array pad_value(0.0f); + st.set_tensor(n.out, pad(x, pad_width_pairs, pad_value, "edge", s)); + } else { + throw std::runtime_error("Unsupported pad mode: " + n.mode); + } +} + +inline void +exec_where(const WhereNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& condition = st.const_tensor_ref(n.condition); + const auto& x = st.const_tensor_ref(n.x); + const auto& y = st.const_tensor_ref(n.y); + st.set_tensor(n.out, where(condition, x, y, s)); +} + +inline void +exec_reshape(const ReshapeNode& n, ExecutionState& st, StreamOrDevice s) { + auto new_shape = to_shape(n.shape, st); + st.set_tensor(n.out, reshape(st.const_tensor_ref(n.x), new_shape, s)); +} + +inline void +exec_transpose(const TransposeNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, transpose(st.const_tensor_ref(n.x), n.perm, s)); +} + +inline void +exec_as_strided(const AsStridedNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + auto shape = to_shape(n.shape, st); + auto resolved_strides = resolve_ints(n.strides, st); + Strides strides(resolved_strides.begin(), resolved_strides.end()); + st.set_tensor(n.out, as_strided(x, shape, strides, n.offset, s)); +} + +inline void +exec_contiguous(const ContiguousNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, contiguous(st.const_tensor_ref(n.x), false, s)); +} + +inline void +exec_gather(const GatherNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + const int rank = static_cast(x.ndim()); + + if (n.indices.size() != n.axes.size()) { + throw std::runtime_error( + "GatherNode: indices count (" + std::to_string(n.indices.size()) + + ") must match axes count (" + std::to_string(n.axes.size()) + ")"); + } + + if (static_cast(n.slice_sizes.size()) != rank) { + throw std::runtime_error( + "GatherNode: slice_sizes length (" + + std::to_string(n.slice_sizes.size()) + ") must match input ndim (" + + std::to_string(rank) + ")"); + } + + for (auto axis : n.axes) { + if (axis < 0 || axis >= rank) { + throw std::runtime_error( + "GatherNode: axis " + std::to_string(axis) + + " out of range for input with ndim " + std::to_string(rank)); + } + } + + Shape slice_sizes(n.slice_sizes.begin(), n.slice_sizes.end()); + check_allocation_bounded(slice_sizes, x.dtype(), "gather"); + + std::vector indices; + indices.reserve(n.indices.size()); + for (auto tid : n.indices) { + indices.push_back(st.const_tensor_ref(tid)); + } + + st.set_tensor(n.out, gather(x, indices, n.axes, slice_sizes, s)); +} + +inline void +exec_slice(const SliceNode& n, ExecutionState& st, StreamOrDevice s) { + const array& x = st.const_tensor_ref(n.x); + const int rank = static_cast(x.ndim()); + + int axis = normalize_axis(resolve_int(n.axis, st), rank, "Slice"); + int start = resolve_int(n.start, st); + int stop = resolve_int(n.stop, st); + + std::vector vstart(static_cast(rank), 0); + std::vector vstop; + vstop.reserve(static_cast(rank)); + auto sh = x.shape(); + for (size_t i = 0; i < static_cast(rank); ++i) { + vstop.push_back(static_cast(sh[i])); + } + + if (n.step == 0) { + throw std::invalid_argument("Slice: step must not be 0"); + } + + vstart[static_cast(axis)] = start; + vstop[static_cast(axis)] = stop; + + std::vector vstrides(static_cast(rank), 1); + vstrides[static_cast(axis)] = n.step; + st.set_tensor( + n.out, + slice(x, to_shape(vstart), to_shape(vstop), to_shape(vstrides), s)); +} + +inline void +exec_astype(const AsTypeNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, astype(st.const_tensor_ref(n.x), resolve_dtype(n.scalar_type), s)); +} + +inline void exec_quantized_matmul( + const QuantizedMatmulNode& n, + ExecutionState& st, + StreamOrDevice s) { + array X = st.const_tensor_ref(n.x); + array Wq = st.const_tensor_ref(n.w); + array Sc = st.const_tensor_ref(n.scales); + + std::optional Qb = std::nullopt; + if (n.biases.has_value()) { + Qb = st.const_tensor_ref(*n.biases); + } + + array Y = quantized_matmul( + X, Wq, Sc, Qb, n.transpose, n.group_size, n.bits, n.mode, s); + + st.set_tensor(n.out, std::move(Y)); +} + +inline void exec_concatenate( + const ConcatenateNode& n, + ExecutionState& st, + StreamOrDevice s) { + auto tensors_fb = n.tensors; + std::vector tensors; + for (auto tid : tensors_fb) { + tensors.push_back(st.const_tensor_ref(tid)); + } + st.set_tensor(n.out, concatenate(tensors, n.axis, s)); +} + +inline void exec_full(const FullNode& n, ExecutionState& st, StreamOrDevice s) { + auto shape = to_shape(n.shape, st); + auto dtype = resolve_dtype(n.scalar_type); + check_allocation_bounded(shape, dtype, "full"); + st.set_tensor(n.out, full(shape, resolve_float(n.v, st), dtype, s)); +} + +inline void +exec_full_like(const FullLikeNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + // Use input dtype if not specified + auto dtype = n.scalar_type.has_value() ? resolve_dtype(n.scalar_type.value()) + : x.dtype(); + st.set_tensor(n.out, full_like(x, resolve_float(n.v, st), dtype, s)); +} + +inline void exec_slice_update( + const SliceUpdateNode& n, + ExecutionState& st, + StreamOrDevice s) { + // When out == dst, use direct assignment to preserve MLX buffer donation. + // TODO: I'm not sure if this is needed as a special case since the standard + // st.set_tensor does a std::move. Keeping for now, but should investigate and + // possibly remove in future. + const bool in_place = (n.out.idx == n.dst.idx); + array& dst = st.tensor_ref(n.dst); + const array& upd = st.const_tensor_ref(n.update); + + const int rank = static_cast(dst.ndim()); + + int axis = normalize_axis(resolve_int(n.axis, st), rank, "SliceUpdate"); + int start = resolve_int(n.start, st); + int stop = resolve_int(n.stop, st); + + std::vector vstart(static_cast(rank), 0); + std::vector vstop; + vstop.reserve(static_cast(rank)); + auto sh = dst.shape(); + for (size_t i = 0; i < static_cast(rank); ++i) { + vstop.push_back(static_cast(sh[i])); + } + + const int dst_dim = vstop[static_cast(axis)]; + + if (start < 0) + start += dst_dim; + start = std::max(0, std::min(start, dst_dim)); + if (stop < 0) + stop += dst_dim; + stop = std::max(0, std::min(stop, dst_dim)); + + vstart[static_cast(axis)] = start; + vstop[static_cast(axis)] = stop; + + std::vector vstrides(static_cast(rank), 1); + if (n.step < 1) { + throw std::invalid_argument( + "SliceUpdate: step must be >= 1, got " + std::to_string(n.step) + ""); + } + vstrides[static_cast(axis)] = n.step; + + if (in_place) { + if (start == stop) { + return; + } + if (n.step == 1) { + dst = slice_update(dst, upd, to_shape(vstart), to_shape(vstop), s); + } else { + dst = slice_update( + dst, upd, to_shape(vstart), to_shape(vstop), to_shape(vstrides), s); + } + + } else { + if (start == stop) { + st.set_tensor(n.out, dst); + return; + } + if (n.step == 1) { + st.set_tensor( + n.out, slice_update(dst, upd, to_shape(vstart), to_shape(vstop), s)); + } else { + st.set_tensor( + n.out, + slice_update( + dst, + upd, + to_shape(vstart), + to_shape(vstop), + to_shape(vstrides), + s)); + } + } +} + +// Helper: finds next contiguous run in indices starting at offset +// Returns (dst_start, dst_stop, upd_start, upd_stop) for the run +// Returns (0, 0, 0, 0) when no more runs +inline std::tuple next_contiguous_run( + const std::vector& indices, + size_t offset) { + if (offset >= indices.size()) + return {0, 0, 0, 0}; + + int dst_start = indices[offset]; + int upd_start = static_cast(offset); + size_t len = 1; + while (offset + len < indices.size() && + len < static_cast(std::numeric_limits::max()) && + indices[offset + len] == dst_start + static_cast(len)) { + ++len; + } + int dst_stop = dst_start + static_cast(len); + int upd_stop = upd_start + static_cast(len); + return {dst_start, dst_stop, upd_start, upd_stop}; +} + +// Copies update tensor into dst at positions specified by 1D indices along axis +// Optimizes into slice_update calls for contiguous runs +inline void +exec_index_copy(const IndexCopyNode& n, ExecutionState& st, StreamOrDevice s) { + array& dst = st.tensor_ref(n.dst); + const array& upd = st.const_tensor_ref(n.update); + const array& indices = st.const_tensor_ref(n.indices); + if (indices.ndim() != 1) { + throw std::invalid_argument("IndexCopyNode: indices must be 1D"); + } + const int rank = static_cast(dst.ndim()); + int axis = normalize_axis(n.axis, rank, "IndexCopyNode"); + const size_t uaxis = static_cast(axis); + const int dst_dim = static_cast(dst.shape()[uaxis]); + + // Get indices as a vector of ints, handling negative indices + // Note: PyTorch uses int64 for indices, so we read as int64_t + eval(indices); // Ensure indices are materialized before accessing data + if (indices.dtype() != ::mlx::core::int64) { + throw std::invalid_argument( + std::string("IndexCopyNode: expected int64 indices, got ") + + ExecutionState::dtype_str(indices.dtype())); + } + std::vector idx_vec(indices.size()); + auto idx_data = indices.data(); + for (size_t i = 0; i < indices.size(); ++i) { + int64_t idx = idx_data[i]; + if (idx < 0) { + idx += dst_dim; + } + if (idx < 0 || idx >= dst_dim) { + throw std::out_of_range( + "IndexCopyNode: index " + std::to_string(idx_data[i]) + + " out of range for axis " + std::to_string(axis) + " with size " + + std::to_string(dst_dim)); + } + if (idx > std::numeric_limits::max()) { + throw std::out_of_range( + "IndexCopyNode: index " + std::to_string(idx) + + " exceeds int32 range"); + } + idx_vec[i] = static_cast(idx); + } + + // When out == dst, use direct assignment to preserve MLX buffer donation. + // TODO: I'm not sure if this is needed as a special case since the standard + // st.set_tensor does a std::move. Keeping for now, but should investigate and + // possibly remove in future. + const bool in_place = (n.out.idx == n.dst.idx); + + if (idx_vec.empty()) { + if (!in_place) { + st.set_tensor(n.out, dst); + } + return; + } + + // Build base start/stop vectors for slice_update + const size_t urank = static_cast(rank); + std::vector dst_vstart(urank, 0); + std::vector dst_vstop; + dst_vstop.reserve(urank); + auto sh = dst.shape(); + for (size_t i = 0; i < urank; ++i) { + dst_vstop.push_back(static_cast(sh[i])); + } + + std::vector upd_vstart(urank, 0); + std::vector upd_vstop; + upd_vstop.reserve(urank); + auto upd_sh = upd.shape(); + for (size_t i = 0; i < urank; ++i) { + upd_vstop.push_back(static_cast(upd_sh[i])); + } + + array result = dst; // copy of dst to accumulate into + + // Process contiguous runs + size_t offset = 0; + while (offset < idx_vec.size()) { + auto [dst_start, dst_stop, upd_start, upd_stop] = + next_contiguous_run(idx_vec, offset); + + // Set axis range for dst + dst_vstart[uaxis] = dst_start; + dst_vstop[uaxis] = dst_stop; + + // Set axis range for upd slice + upd_vstart[uaxis] = upd_start; + upd_vstop[uaxis] = upd_stop; + + // Slice update - skip slicing if using entire update tensor + array upd_slice = + (upd_start == 0 && upd_stop == static_cast(upd_sh[uaxis])) + ? upd + : slice(upd, to_shape(upd_vstart), to_shape(upd_vstop), s); + + if (in_place) { + dst = slice_update( + dst, upd_slice, to_shape(dst_vstart), to_shape(dst_vstop), s); + } else { + result = slice_update( + result, upd_slice, to_shape(dst_vstart), to_shape(dst_vstop), s); + } + + offset = static_cast(upd_stop); + } + + if (!in_place) { + st.set_tensor(n.out, result); + } +} + +inline void +exec_dequantize(const DequantizeNode& n, ExecutionState& st, StreamOrDevice s) { + array Wq = st.const_tensor_ref(n.w); + array Sc = st.const_tensor_ref(n.scales); + + std::optional Qb = std::nullopt; + if (n.biases) { + Qb = st.const_tensor_ref(*n.biases); + } + + std::optional global_scale = std::nullopt; + if (n.global_scale) { + global_scale = st.const_tensor_ref(*n.global_scale); + } + + std::optional dtype = std::nullopt; + if (n.dtype) { + dtype = resolve_dtype(*n.dtype); + } + + array Y = dequantize( + Wq, Sc, Qb, n.group_size, n.bits, n.mode, global_scale, dtype, s); + + st.set_tensor(n.out, std::move(Y)); +} + +inline void exec_less(const LessNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, less(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_less_equal(const LessEqualNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, less_equal(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_greater(const GreaterNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, greater(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void exec_greater_equal( + const GreaterEqualNode& n, + ExecutionState& st, + StreamOrDevice s) { + st.set_tensor( + n.out, + greater_equal(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_equal(const EqualNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, equal(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_not_equal(const NotEqualNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, not_equal(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void exec_logical_not( + const LogicalNotNode& n, + ExecutionState& st, + StreamOrDevice s) { + st.set_tensor(n.out, logical_not(st.const_tensor_ref(n.x), s)); +} + +inline void exec_logical_and( + const LogicalAndNode& n, + ExecutionState& st, + StreamOrDevice s) { + st.set_tensor( + n.out, + logical_and(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_logical_or(const LogicalOrNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, logical_or(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void exec_tri(const TriNode& n, ExecutionState& st, StreamOrDevice s) { + int rows = resolve_int(n.n, st); + int cols = resolve_int(n.m, st); + auto dtype = resolve_dtype(n.scalar_type); + check_allocation_bounded({rows, cols}, dtype, "tri"); + st.set_tensor(n.out, tri(rows, cols, n.k, dtype, s)); +} + +inline void exec_tril(const TrilNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, tril(x, n.k, s)); +} + +inline void exec_triu(const TriuNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, triu(x, n.k, s)); +} + +inline void +exec_floor(const FloorNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, floor(st.const_tensor_ref(n.x), s)); +} + +inline void exec_ceil(const CeilNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, ceil(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_square(const SquareNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, square(st.const_tensor_ref(n.x), s)); +} + +inline void exec_exp(const ExpNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, exp(st.const_tensor_ref(n.x), s)); +} + +inline void exec_sin(const SinNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, sin(st.const_tensor_ref(n.x), s)); +} + +inline void exec_cos(const CosNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, cos(st.const_tensor_ref(n.x), s)); +} + +inline void exec_tan(const TanNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, tan(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_arcsin(const ArcsinNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, arcsin(st.const_tensor_ref(n.x), s)); } inline void -exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) { - const auto& mat1 = st.const_tensor_ref(n.mat1); - const auto& mat2 = st.const_tensor_ref(n.mat2); +exec_arccos(const ArccosNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, arccos(st.const_tensor_ref(n.x), s)); +} - array Y = n.bias ? addmm( - st.const_tensor_ref(*n.bias), - mat1, - mat2, - /*alpha=*/n.alpha, - /*beta=*/n.beta, - s) - : matmul(mat1, mat2, s); +inline void +exec_arctan(const ArctanNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, arctan(st.const_tensor_ref(n.x), s)); +} - st.set_tensor(n.out, std::move(Y)); +inline void exec_sinh(const SinhNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, sinh(st.const_tensor_ref(n.x), s)); +} + +inline void exec_cosh(const CoshNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, cosh(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_arcsinh(const ArcsinhNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, arcsinh(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_arccosh(const ArccoshNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, arccosh(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_arctanh(const ArctanhNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, arctanh(st.const_tensor_ref(n.x), s)); +} + +inline void exec_log2(const Log2Node& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, log2(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_log10(const Log10Node& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, log10(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_log1p(const Log1pNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, log1p(st.const_tensor_ref(n.x), s)); +} + +inline void exec_erf(const ErfNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, erf(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_expm1(const Expm1Node& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, expm1(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_round(const RoundNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, round(st.const_tensor_ref(n.x), n.decimals, s)); +} + +inline void +exec_reciprocal(const ReciprocalNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, reciprocal(st.const_tensor_ref(n.x), s)); +} + +inline void exec_sqrt(const SqrtNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, sqrt(st.const_tensor_ref(n.x), s)); +} + +inline void exec_abs(const AbsNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, abs(st.const_tensor_ref(n.x), s)); +} + +inline void exec_neg(const NegNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, negative(st.const_tensor_ref(n.x), s)); +} + +inline void +exec_atan2(const Atan2Node& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, arctan2(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_logaddexp(const LogAddExpNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, logaddexp(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void exec_floor_divide( + const FloorDivideNode& n, + ExecutionState& st, + StreamOrDevice s) { + const array& a = st.const_tensor_ref(n.a); + const array& b = st.const_tensor_ref(n.b); + + if (!issubdtype(a.dtype(), inexact)) { + // mlx::floor_divide for integer types uses C++ truncation toward zero, + // but PyTorch floor_divide floors toward negative infinity. + // Adjust: floor_div(a, b) = trunc_div(a, b) - ((a % b != 0) & (sign(a) != + // sign(b))) + auto quot = divide(a, b, s); + auto rem = remainder(a, b, s); + auto zero = array(0, a.dtype()); + auto has_rem = not_equal(rem, zero, s); + auto a_neg = less(a, zero, s); + auto b_neg = less(b, zero, s); + auto signs_differ = not_equal(a_neg, b_neg, s); + auto adjust = logical_and(has_rem, signs_differ, s); + st.set_tensor(n.out, subtract(quot, astype(adjust, a.dtype(), s), s)); + } else { + st.set_tensor(n.out, floor_divide(a, b, s)); + } +} + +inline void +exec_remainder(const RemainderNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, remainder(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_power(const PowerNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor( + n.out, power(st.const_tensor_ref(n.a), st.const_tensor_ref(n.b), s)); +} + +inline void +exec_logsumexp(const LogSumExpNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + st.set_tensor(n.out, logsumexp(x, axes, n.keepdims, s)); +} + +inline void exec_sum(const SumNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, sum(x, n.keepdims, s)); + } else { + st.set_tensor(n.out, sum(x, axes, n.keepdims, s)); + } +} + +inline void exec_mean(const MeanNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, mean(x, n.keepdims, s)); + } else { + st.set_tensor(n.out, mean(x, axes, n.keepdims, s)); + } +} + +inline void exec_var(const VarNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, var(x, n.keepdims, n.ddof, s)); + } else { + st.set_tensor(n.out, var(x, axes, n.keepdims, n.ddof, s)); + } +} + +inline void exec_std(const StdNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, ::mlx::core::std(x, n.keepdims, n.ddof, s)); + } else { + st.set_tensor(n.out, ::mlx::core::std(x, axes, n.keepdims, n.ddof, s)); + } +} + +inline void exec_prod(const ProdNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, prod(x, n.keepdims, s)); + } else { + st.set_tensor(n.out, prod(x, axes, n.keepdims, s)); + } +} + +inline void exec_max(const MaxNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, max(x, n.keepdims, s)); + } else { + st.set_tensor(n.out, max(x, axes, n.keepdims, s)); + } +} + +inline void exec_min(const MinNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, min(x, n.keepdims, s)); + } else { + st.set_tensor(n.out, min(x, axes, n.keepdims, s)); + } +} + +inline void +exec_argmax(const ArgmaxNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, argmax(x, n.axis, n.keepdims, s)); +} + +inline void +exec_argmin(const ArgminNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, argmin(x, n.axis, n.keepdims, s)); +} + +inline void +exec_median(const MedianNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, median(x, n.keepdims, s)); + } else { + st.set_tensor(n.out, median(x, axes, n.keepdims, s)); + } +} + +inline void exec_clip(const ClipNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::optional a_min = n.a_min + ? std::optional(st.const_tensor_ref(*n.a_min)) + : std::nullopt; + std::optional a_max = n.a_max + ? std::optional(st.const_tensor_ref(*n.a_max)) + : std::nullopt; + st.set_tensor(n.out, clip(x, a_min, a_max, s)); +} + +inline void +exec_cumsum(const CumsumNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + st.set_tensor(n.out, cumsum(x, n.axis, n.reverse, n.inclusive, s)); +} + +inline void +exec_stack(const StackNode& n, ExecutionState& st, StreamOrDevice s) { + std::vector tensors; + for (auto tid : n.tensors) { + tensors.push_back(st.const_tensor_ref(tid)); + } + st.set_tensor(n.out, stack(tensors, n.axis, s)); +} + +inline void exec_sign(const SignNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, sign(st.const_tensor_ref(n.x), s)); +} + +inline void exec_any(const AnyNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, any(x, n.keepdims, s)); + } else { + st.set_tensor(n.out, any(x, axes, n.keepdims, s)); + } +} + +inline void exec_all(const AllNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + std::vector axes(n.axes.begin(), n.axes.end()); + if (axes.empty()) { + st.set_tensor(n.out, all(x, n.keepdims, s)); + } else { + st.set_tensor(n.out, all(x, axes, n.keepdims, s)); + } +} + +inline void +exec_repeat(const RepeatNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + int repeats = static_cast(resolve_int(n.repeats, st)); + if (repeats < 0) { + throw std::invalid_argument( + "repeat: repeats must be non-negative, got " + std::to_string(repeats)); + } + auto out_shape = x.shape(); + int axis = n.axis < 0 ? n.axis + static_cast(x.ndim()) : n.axis; + out_shape[static_cast(axis)] *= repeats; + check_allocation_bounded(out_shape, x.dtype(), "repeat"); + st.set_tensor(n.out, repeat(x, repeats, n.axis, s)); +} + +inline void exec_sort(const SortNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, sort(st.const_tensor_ref(n.x), n.axis, s)); +} + +inline void +exec_argsort(const ArgsortNode& n, ExecutionState& st, StreamOrDevice s) { + st.set_tensor(n.out, argsort(st.const_tensor_ref(n.x), n.axis, s)); +} + +inline void +exec_partition(const PartitionNode& n, ExecutionState& st, StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + int kth = static_cast(resolve_int(n.kth, st)); + st.set_tensor(n.out, partition(x, kth, n.axis, s)); +} + +inline void exec_argpartition( + const ArgPartitionNode& n, + ExecutionState& st, + StreamOrDevice s) { + const auto& x = st.const_tensor_ref(n.x); + int kth = static_cast(resolve_int(n.kth, st)); + st.set_tensor(n.out, argpartition(x, kth, n.axis, s)); } } // namespace ops @@ -165,6 +1586,369 @@ class Interpreter { case OpCode::ADDMM: ops::exec_addmm(std::get(instr.node), st, s); break; + case OpCode::ITEM_INT: + ops::exec_item_int(std::get(instr.node), st, s); + break; + case OpCode::EXPAND_DIMS: + ops::exec_expand_dims(std::get(instr.node), st, s); + break; + case OpCode::TILE: + ops::exec_tile(std::get(instr.node), st, s); + break; + case OpCode::TAKE_ALONG_AXIS: + ops::exec_take_along_axis( + std::get(instr.node), st, s); + break; + case OpCode::TAKE: + ops::exec_take(std::get(instr.node), st, s); + break; + case OpCode::RMS_NORM: + ops::exec_rms_norm(std::get(instr.node), st, s); + break; + case OpCode::LAYER_NORM: + ops::exec_layer_norm(std::get(instr.node), st, s); + break; + case OpCode::ROPE: + ops::exec_rope(std::get(instr.node), st, s); + break; + case OpCode::SDPA: + ops::exec_sdpa(std::get(instr.node), st, s); + break; + case OpCode::ADD: + ops::exec_add(std::get(instr.node), st, s); + break; + case OpCode::ADD_INT: + ops::exec_add_int(std::get(instr.node), st, s); + break; + case OpCode::SUBTRACT_INT: + ops::exec_subtract_int(std::get(instr.node), st, s); + break; + case OpCode::MULTIPLY_INT: + ops::exec_multiply_int(std::get(instr.node), st, s); + break; + case OpCode::FLOOR_DIVIDE_INT: + ops::exec_floor_divide_int( + std::get(instr.node), st, s); + break; + case OpCode::MOD_INT: + ops::exec_mod_int(std::get(instr.node), st, s); + break; + case OpCode::SYM_SIZE: + ops::exec_sym_size(std::get(instr.node), st, s); + break; + case OpCode::MULTIPLY: + ops::exec_multiply(std::get(instr.node), st, s); + break; + case OpCode::DIVIDE: + ops::exec_divide(std::get(instr.node), st, s); + break; + case OpCode::SUBTRACT: + ops::exec_subtract(std::get(instr.node), st, s); + break; + case OpCode::CONV1D: + ops::exec_conv1d(std::get(instr.node), st, s); + break; + case OpCode::CONV2D: + ops::exec_conv2d(std::get(instr.node), st, s); + break; + case OpCode::CONV3D: + ops::exec_conv3d(std::get(instr.node), st, s); + break; + case OpCode::GELU: + ops::exec_gelu(std::get(instr.node), st, s); + break; + case OpCode::ARANGE: + ops::exec_arange(std::get(instr.node), st, s); + break; + case OpCode::SILU: + ops::exec_silu(std::get(instr.node), st, s); + break; + case OpCode::SIGMOID: + ops::exec_sigmoid(std::get(instr.node), st, s); + break; + case OpCode::TANH: + ops::exec_tanh(std::get(instr.node), st, s); + break; + case OpCode::SQUEEZE: + ops::exec_squeeze(std::get(instr.node), st, s); + break; + case OpCode::SPLIT: + ops::exec_split(std::get(instr.node), st, s); + break; + case OpCode::RSQRT: + ops::exec_rsqrt(std::get(instr.node), st, s); + break; + case OpCode::MAXIMUM: + ops::exec_maximum(std::get(instr.node), st, s); + break; + case OpCode::MINIMUM: + ops::exec_minimum(std::get(instr.node), st, s); + break; + case OpCode::LOG: + ops::exec_log(std::get(instr.node), st, s); + break; + case OpCode::SOFTMAX: + ops::exec_softmax(std::get(instr.node), st, s); + break; + case OpCode::BROADCAST_TO: + ops::exec_broadcast_to(std::get(instr.node), st, s); + break; + case OpCode::PAD: + ops::exec_pad(std::get(instr.node), st, s); + break; + case OpCode::WHERE: + ops::exec_where(std::get(instr.node), st, s); + break; + case OpCode::RESHAPE: + ops::exec_reshape(std::get(instr.node), st, s); + break; + case OpCode::TRANSPOSE: + ops::exec_transpose(std::get(instr.node), st, s); + break; + case OpCode::AS_STRIDED: + ops::exec_as_strided(std::get(instr.node), st, s); + break; + case OpCode::CONTIGUOUS: + ops::exec_contiguous(std::get(instr.node), st, s); + break; + case OpCode::GATHER: + ops::exec_gather(std::get(instr.node), st, s); + break; + case OpCode::SLICE: + ops::exec_slice(std::get(instr.node), st, s); + break; + case OpCode::ASTYPE: + ops::exec_astype(std::get(instr.node), st, s); + break; + case OpCode::CONCATENATE: + ops::exec_concatenate(std::get(instr.node), st, s); + break; + case OpCode::FULL: + ops::exec_full(std::get(instr.node), st, s); + break; + case OpCode::FULL_LIKE: + ops::exec_full_like(std::get(instr.node), st, s); + break; + case OpCode::ARGMAX: + ops::exec_argmax(std::get(instr.node), st, s); + break; + case OpCode::SLICE_UPDATE: + ops::exec_slice_update(std::get(instr.node), st, s); + break; + case OpCode::INDEX_COPY: + ops::exec_index_copy(std::get(instr.node), st, s); + break; + case OpCode::DEQUANTIZE: + ops::exec_dequantize(std::get(instr.node), st, s); + break; + case OpCode::LESS: + ops::exec_less(std::get(instr.node), st, s); + break; + case OpCode::LESS_EQUAL: + ops::exec_less_equal(std::get(instr.node), st, s); + break; + case OpCode::GREATER: + ops::exec_greater(std::get(instr.node), st, s); + break; + case OpCode::GREATER_EQUAL: + ops::exec_greater_equal(std::get(instr.node), st, s); + break; + case OpCode::EQUAL: + ops::exec_equal(std::get(instr.node), st, s); + break; + case OpCode::NOT_EQUAL: + ops::exec_not_equal(std::get(instr.node), st, s); + break; + case OpCode::LOGICAL_NOT: + ops::exec_logical_not(std::get(instr.node), st, s); + break; + case OpCode::LOGICAL_AND: + ops::exec_logical_and(std::get(instr.node), st, s); + break; + case OpCode::LOGICAL_OR: + ops::exec_logical_or(std::get(instr.node), st, s); + break; + case OpCode::TRI: + ops::exec_tri(std::get(instr.node), st, s); + break; + case OpCode::TRIL: + ops::exec_tril(std::get(instr.node), st, s); + break; + case OpCode::TRIU: + ops::exec_triu(std::get(instr.node), st, s); + break; + // Math ops - Unary + case OpCode::FLOOR: + ops::exec_floor(std::get(instr.node), st, s); + break; + case OpCode::CEIL: + ops::exec_ceil(std::get(instr.node), st, s); + break; + case OpCode::SQUARE: + ops::exec_square(std::get(instr.node), st, s); + break; + case OpCode::EXP: + ops::exec_exp(std::get(instr.node), st, s); + break; + case OpCode::SIN: + ops::exec_sin(std::get(instr.node), st, s); + break; + case OpCode::COS: + ops::exec_cos(std::get(instr.node), st, s); + break; + case OpCode::TAN: + ops::exec_tan(std::get(instr.node), st, s); + break; + case OpCode::ARCSIN: + ops::exec_arcsin(std::get(instr.node), st, s); + break; + case OpCode::ARCCOS: + ops::exec_arccos(std::get(instr.node), st, s); + break; + case OpCode::ARCTAN: + ops::exec_arctan(std::get(instr.node), st, s); + break; + case OpCode::SINH: + ops::exec_sinh(std::get(instr.node), st, s); + break; + case OpCode::COSH: + ops::exec_cosh(std::get(instr.node), st, s); + break; + case OpCode::ARCSINH: + ops::exec_arcsinh(std::get(instr.node), st, s); + break; + case OpCode::ARCCOSH: + ops::exec_arccosh(std::get(instr.node), st, s); + break; + case OpCode::ARCTANH: + ops::exec_arctanh(std::get(instr.node), st, s); + break; + case OpCode::LOG2: + ops::exec_log2(std::get(instr.node), st, s); + break; + case OpCode::LOG10: + ops::exec_log10(std::get(instr.node), st, s); + break; + case OpCode::LOG1P: + ops::exec_log1p(std::get(instr.node), st, s); + break; + case OpCode::ERF: + ops::exec_erf(std::get(instr.node), st, s); + break; + case OpCode::EXPM1: + ops::exec_expm1(std::get(instr.node), st, s); + break; + case OpCode::ROUND: + ops::exec_round(std::get(instr.node), st, s); + break; + case OpCode::RECIPROCAL: + ops::exec_reciprocal(std::get(instr.node), st, s); + break; + case OpCode::SQRT: + ops::exec_sqrt(std::get(instr.node), st, s); + break; + case OpCode::ABS: + ops::exec_abs(std::get(instr.node), st, s); + break; + case OpCode::NEG: + ops::exec_neg(std::get(instr.node), st, s); + break; + // Math ops - Binary + case OpCode::ATAN2: + ops::exec_atan2(std::get(instr.node), st, s); + break; + case OpCode::LOG_ADD_EXP: + ops::exec_logaddexp(std::get(instr.node), st, s); + break; + case OpCode::FLOOR_DIVIDE: + ops::exec_floor_divide(std::get(instr.node), st, s); + break; + case OpCode::REMAINDER: + ops::exec_remainder(std::get(instr.node), st, s); + break; + case OpCode::POWER: + ops::exec_power(std::get(instr.node), st, s); + break; + // Math ops - Reduction + case OpCode::LOG_SUM_EXP: + ops::exec_logsumexp(std::get(instr.node), st, s); + break; + case OpCode::SUM: + ops::exec_sum(std::get(instr.node), st, s); + break; + case OpCode::MEAN: + ops::exec_mean(std::get(instr.node), st, s); + break; + case OpCode::VAR: + ops::exec_var(std::get(instr.node), st, s); + break; + case OpCode::STD: + ops::exec_std(std::get(instr.node), st, s); + break; + case OpCode::PROD: + ops::exec_prod(std::get(instr.node), st, s); + break; + case OpCode::MAX: + ops::exec_max(std::get(instr.node), st, s); + break; + case OpCode::MIN: + ops::exec_min(std::get(instr.node), st, s); + break; + case OpCode::ARGMIN: + ops::exec_argmin(std::get(instr.node), st, s); + break; + case OpCode::MEDIAN: + ops::exec_median(std::get(instr.node), st, s); + break; + case OpCode::CONV_TRANSPOSE1D: + ops::exec_conv_transpose1d( + std::get(instr.node), st, s); + break; + case OpCode::CONV_TRANSPOSE2D: + ops::exec_conv_transpose2d( + std::get(instr.node), st, s); + break; + case OpCode::CONV_TRANSPOSE3D: + ops::exec_conv_transpose3d( + std::get(instr.node), st, s); + break; + case OpCode::CLIP: + ops::exec_clip(std::get(instr.node), st, s); + break; + case OpCode::CUMSUM: + ops::exec_cumsum(std::get(instr.node), st, s); + break; + case OpCode::STACK: + ops::exec_stack(std::get(instr.node), st, s); + break; + case OpCode::SIGN: + ops::exec_sign(std::get(instr.node), st, s); + break; + case OpCode::ANY: + ops::exec_any(std::get(instr.node), st, s); + break; + case OpCode::ALL: + ops::exec_all(std::get(instr.node), st, s); + break; + case OpCode::REPEAT: + ops::exec_repeat(std::get(instr.node), st, s); + break; + case OpCode::SORT: + ops::exec_sort(std::get(instr.node), st, s); + break; + case OpCode::ARGSORT: + ops::exec_argsort(std::get(instr.node), st, s); + break; + case OpCode::PARTITION: + ops::exec_partition(std::get(instr.node), st, s); + break; + case OpCode::ARG_PARTITION: + ops::exec_argpartition(std::get(instr.node), st, s); + break; + case OpCode::QUANTIZED_MATMUL: + ops::exec_quantized_matmul( + std::get(instr.node), st, s); + break; default: throw std::runtime_error( "Unknown opcode: " + std::to_string(static_cast(instr.op))); diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index 8b159314760..b101b5756f7 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -86,6 +86,819 @@ table AddmmNode { beta: float = 1.0; // Scalar multiplier for bias } +table ItemIntNode { + x: Tid (required); + out: Vid (required); +} + +table ExpandDimsNode { + x: Tid (required); + out: Tid (required); + axis: int32; +} + +table TileNode { + x: Tid (required); + out: Tid (required); + reps: [IntOrVid] (required); +} + +table TakeAlongAxisNode { + x: Tid (required); + indices: Tid (required); + out: Tid (required); + axis: int32; +} + +table TakeNode { + x: Tid (required); + out: Tid (required); + index: IntOrVidOrTid (required); // Scalar int, dynamic Vid, or tensor of indices + axis: int32; // Axis along which to select +} + +table RMSNormNode { + x: Tid (required); + weight: Tid; // optional (None = no per-element scaling, same as ones) + out: Tid (required); + eps: float; +} + +table LayerNormNode { + x: Tid (required); + out: Tid (required); + weight: Tid; // optional + bias: Tid; // optional + eps: float; +} + +table RopeNode { + x: Tid (required); + out: Tid (required); + dims: int32; + offset: VidOrTid (required); // Position offset: scalar (Vid) or tensor of positions (Tid) + freqs: Tid; // optional + traditional: bool = false; + base: float = 500000.0; // Llama 3 default + scale: float = 1.0; +} + +table SdpaNode { + q: Tid (required); + k: Tid (required); + v: Tid (required); + out: Tid (required); + scale: float; + mask: Tid; // optional + causal: bool = false; +} + +table AddNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table AddIntNode { + a: IntOrVid (required); + b: IntOrVid (required); + out: Vid (required); +} + +table SubtractIntNode { + a: IntOrVid (required); + b: IntOrVid (required); + out: Vid (required); +} + +table MultiplyIntNode { + a: IntOrVid (required); + b: IntOrVid (required); + out: Vid (required); +} + +table FloorDivideIntNode { + a: IntOrVid (required); + b: IntOrVid (required); + out: Vid (required); +} + +table ModIntNode { + a: IntOrVid (required); + b: IntOrVid (required); + out: Vid (required); +} + +table SymSizeNode { + a: Tid (required); + dim: int32; + out: Vid (required); +} + +table MultiplyNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table DivideNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table SubtractNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table Conv1DNode { + x: Tid (required); + w: Tid (required); + out: Tid (required); + stride: int32 = 1; + padding: int32 = 0; + dilation: int32 = 1; + groups: int32 = 1; +} + +table Conv2DNode { + x: Tid (required); + w: Tid (required); + out: Tid (required); + stride_h: int32 = 1; + stride_w: int32 = 1; + padding_h: int32 = 0; + padding_w: int32 = 0; + dilation_h: int32 = 1; + dilation_w: int32 = 1; + groups: int32 = 1; +} + +table Conv3DNode { + x: Tid (required); + w: Tid (required); + out: Tid (required); + stride_d: int32 = 1; + stride_h: int32 = 1; + stride_w: int32 = 1; + padding_d: int32 = 0; + padding_h: int32 = 0; + padding_w: int32 = 0; + dilation_d: int32 = 1; + dilation_h: int32 = 1; + dilation_w: int32 = 1; + groups: int32 = 1; +} + +table ConvTranspose1DNode { + x: Tid (required); + w: Tid (required); + out: Tid (required); + stride: int32 = 1; + padding: int32 = 0; + dilation: int32 = 1; + output_padding: int32 = 0; + groups: int32 = 1; +} + +table ConvTranspose2DNode { + x: Tid (required); + w: Tid (required); + out: Tid (required); + stride_h: int32 = 1; + stride_w: int32 = 1; + padding_h: int32 = 0; + padding_w: int32 = 0; + dilation_h: int32 = 1; + dilation_w: int32 = 1; + output_padding_h: int32 = 0; + output_padding_w: int32 = 0; + groups: int32 = 1; +} + +table ConvTranspose3DNode { + x: Tid (required); + w: Tid (required); + out: Tid (required); + stride_d: int32 = 1; + stride_h: int32 = 1; + stride_w: int32 = 1; + padding_d: int32 = 0; + padding_h: int32 = 0; + padding_w: int32 = 0; + dilation_d: int32 = 1; + dilation_h: int32 = 1; + dilation_w: int32 = 1; + output_padding_d: int32 = 0; + output_padding_h: int32 = 0; + output_padding_w: int32 = 0; + groups: int32 = 1; +} + +table GeluNode { + x: Tid (required); + out: Tid (required); + approximate: string (required); // "none" or "tanh" +} + +table ARangeNode { + out: Tid (required); + start: IntOrVid (required); // Can be literal or dynamic (from item()) + stop: IntOrVid (required); // Can be literal or dynamic (from item()) + step: IntOrVid (required); // Can be literal or dynamic + scalar_type: int8 = null; // ET ScalarType (optional - None means infer from context) +} + +table SiluNode { + x: Tid (required); + out: Tid (required); +} + +table SigmoidNode { + x: Tid (required); + out: Tid (required); +} + +table TanhNode { + x: Tid (required); + out: Tid (required); +} + +table SqueezeNode { + x: Tid (required); + out: Tid (required); + dims: [int32]; // Optional list of dimensions to squeeze. If empty, squeeze all dims of size 1 +} + +table SplitNode { + x: Tid (required); + outs: [Tid] (required); // Multiple output tensor IDs (one for each split chunk) + sizes: [IntOrVid] (required); // Split sizes (can be dynamic) + axis: int32; +} + +table RsqrtNode { + x: Tid (required); + out: Tid (required); +} + +table MaximumNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table MinimumNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table LogNode { + x: Tid (required); + out: Tid (required); +} + +table SoftmaxNode { + x: Tid (required); + out: Tid (required); + axis: int32; // Dimension to compute softmax over + precise: bool = false; // Use precise (slow) implementation +} + +table BroadcastToNode { + x: Tid (required); + out: Tid (required); + shape: [IntOrVid] (required); // Target shape to broadcast to +} + +table PadNode { + x: Tid (required); + out: Tid (required); + pad_width: [IntOrVid] (required); // Padding pairs: [(before_0, after_0), (before_1, after_1), ...] + mode: string (required); // "constant" or "edge" + constant_value: float = 0.0; // Value to pad with (for constant mode) +} + +table WhereNode { + condition: Tid (required); + x: Tid (required); + y: Tid (required); + out: Tid (required); +} + +table ReshapeNode { + x: Tid (required); + out: Tid (required); + shape: [IntOrVid] (required); +} + +table TransposeNode { + x: Tid (required); + out: Tid (required); + perm: [int32] (required); +} + +table AsStridedNode { + x: Tid (required); + out: Tid (required); + shape: [IntOrVid] (required); // Output view shape (can be dynamic) + strides: [IntOrVid] (required); // Element strides per dimension (can be dynamic) + offset: uint64 = 0; // Element offset into source +} + +table ContiguousNode { + x: Tid (required); + out: Tid (required); +} + +table GatherNode { + x: Tid (required); + indices: [Tid] (required); // Index tensors (one per indexed axis) + out: Tid (required); + axes: [int32] (required); // Which axes to gather along + slice_sizes: [int32] (required); // Size of slice per dimension of x +} + +table SliceNode { + x: Tid (required); + out: Tid (required); + axis: IntOrVid (required); + start: IntOrVid (required); + stop: IntOrVid (required); + step: int32 = 1; +} + +table AsTypeNode { + x: Tid (required); + out: Tid (required); + scalar_type: int8; // ET ScalarType +} + +table QuantizedMatmulNode { + x: Tid (required); + w: Tid (required); + scales: Tid (required); + out: Tid (required); + biases: Tid; // optional - required for affine mode, null for nvfp4 + group_size: int32; + bits: int32; + mode: string (required); + transpose: bool = true; +} + +table ConcatenateNode { + tensors: [Tid] (required); // List of tensors to concatenate + out: Tid (required); + axis: int32; +} + +table FullNode { + out: Tid (required); + shape: [IntOrVid] (required); + v: FloatOrVid (required); // Fill value (can be dynamic from item()) + scalar_type: int8; // ET ScalarType +} + +table FullLikeNode { + x: Tid (required); // Input tensor to copy shape from + out: Tid (required); + v: FloatOrVid (required); // Fill value (can be dynamic from item()) + scalar_type: int8 = null; // ET ScalarType (optional - if null, use x's dtype) +} + +table ArgmaxNode { + x: Tid (required); + out: Tid (required); + axis: int32; + keepdims: bool = false; +} + +table SliceUpdateNode { + dst: Tid (required); + update: Tid (required); + out: Tid (required); // Can be same as dst + axis: IntOrVid (required); + start: IntOrVid (required); + stop: IntOrVid (required); + step: int32 = 1; +} + +// Index-based update: copies update tensor into dst at positions specified by 1D indices +// Runtime optimizes these into slice_update calls for contiguous runs +table IndexCopyNode { + dst: Tid (required); // destination tensor to update + update: Tid (required); // source tensor to copy from + indices: Tid (required); // 1D tensor of indices along axis + out: Tid (required); // output tensor (can be same as dst) + axis: int32; // dimension to index along +} + + +table DequantizeNode { + w: Tid (required); // Quantized matrix to dequantize + scales: Tid (required); // Scales per group_size elements + out: Tid (required); + biases: Tid; // optional - biases per group_size elements + group_size: int32; + bits: int32; + mode: string (required); // Quantization mode (e.g. "affine") + global_scale: Tid; // optional - global scale for nvfp4 + dtype: int8 = null; // ET ScalarType for output dtype +} + +// Comparison ops (return bool arrays) +table LessNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table LessEqualNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table GreaterNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table GreaterEqualNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table EqualNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table NotEqualNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +// Logical ops +table LogicalNotNode { + x: Tid (required); + out: Tid (required); +} + +table LogicalAndNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table LogicalOrNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +// Triangular matrix ops +table TriNode { + out: Tid (required); + n: IntOrVid (required); // Number of rows + m: IntOrVid (required); // Number of columns + k: int32 = 0; // Diagonal offset: 0=main, +above, -below + scalar_type: int8; // ET ScalarType +} + +table TrilNode { + x: Tid (required); + out: Tid (required); + k: int32 = 0; // Diagonal offset: 0=main, +above, -below +} + +table TriuNode { + x: Tid (required); + out: Tid (required); + k: int32 = 0; // Diagonal offset: 0=main, +above, -below +} + +table ClipNode { + x: Tid (required); + out: Tid (required); + a_min: Tid; // optional lower bound + a_max: Tid; // optional upper bound +} + +table CumsumNode { + x: Tid (required); + out: Tid (required); + axis: int32; + reverse: bool = false; + inclusive: bool = true; +} + +table StackNode { + tensors: [Tid] (required); + out: Tid (required); + axis: int32 = 0; +} + +table SignNode { + x: Tid (required); + out: Tid (required); +} + +table AnyNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; +} + +table AllNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; +} + +table RepeatNode { + x: Tid (required); + out: Tid (required); + repeats: IntOrVid (required); // Number of times to repeat each element (can be dynamic) + axis: int32; // Axis along which to repeat +} + +table SortNode { + x: Tid (required); + out: Tid (required); + axis: int32; +} + +table ArgsortNode { + x: Tid (required); + out: Tid (required); + axis: int32; +} + +table PartitionNode { + x: Tid (required); + out: Tid (required); + kth: IntOrVid (required); // Partition index + axis: int32; +} + +table ArgPartitionNode { + x: Tid (required); + out: Tid (required); + kth: IntOrVid (required); // Partition index + axis: int32; +} + + +// ============================================================================= +// Math ops - Unary element-wise +// ============================================================================= + +table FloorNode { + x: Tid (required); + out: Tid (required); +} + +table CeilNode { + x: Tid (required); + out: Tid (required); +} + +table SquareNode { + x: Tid (required); + out: Tid (required); +} + +table ExpNode { + x: Tid (required); + out: Tid (required); +} + +table SinNode { + x: Tid (required); + out: Tid (required); +} + +table CosNode { + x: Tid (required); + out: Tid (required); +} + +table TanNode { + x: Tid (required); + out: Tid (required); +} + +table ArcsinNode { + x: Tid (required); + out: Tid (required); +} + +table ArccosNode { + x: Tid (required); + out: Tid (required); +} + +table ArctanNode { + x: Tid (required); + out: Tid (required); +} + +table SinhNode { + x: Tid (required); + out: Tid (required); +} + +table CoshNode { + x: Tid (required); + out: Tid (required); +} + +table ArcsinhNode { + x: Tid (required); + out: Tid (required); +} + +table ArccoshNode { + x: Tid (required); + out: Tid (required); +} + +table ArctanhNode { + x: Tid (required); + out: Tid (required); +} + +table Log2Node { + x: Tid (required); + out: Tid (required); +} + +table Log10Node { + x: Tid (required); + out: Tid (required); +} + +table Log1pNode { + x: Tid (required); + out: Tid (required); +} + +table ErfNode { + x: Tid (required); + out: Tid (required); +} + +table Expm1Node { + x: Tid (required); + out: Tid (required); +} + +table RoundNode { + x: Tid (required); + out: Tid (required); + decimals: int32 = 0; +} + +table ReciprocalNode { + x: Tid (required); + out: Tid (required); +} + +table SqrtNode { + x: Tid (required); + out: Tid (required); +} + +table AbsNode { + x: Tid (required); + out: Tid (required); +} + +table NegNode { + x: Tid (required); + out: Tid (required); +} + +// ============================================================================= +// Math ops - Binary element-wise +// ============================================================================= + +table Atan2Node { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table LogAddExpNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table FloorDivideNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table RemainderNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +table PowerNode { + a: Tid (required); + b: Tid (required); + out: Tid (required); +} + +// ============================================================================= +// Math ops - Reduction +// ============================================================================= + +table LogSumExpNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; + keepdims: bool = false; +} + +table SumNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; +} + +table MeanNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; +} + +table VarNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; + ddof: int32 = 0; // Delta degrees of freedom (0=population var, 1=sample var) +} + +table StdNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; + ddof: int32 = 0; // Delta degrees of freedom +} + +table ProdNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; +} + +table MaxNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; +} + +table MinNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; +} + +table ArgminNode { + x: Tid (required); + out: Tid (required); + axis: int32; + keepdims: bool = false; +} + +table MedianNode { + x: Tid (required); + out: Tid (required); + axes: [int32]; // Empty = reduce all + keepdims: bool = false; +} + // ============================================================================= // Union of all op types // ============================================================================= @@ -95,8 +908,126 @@ table AddmmNode { union OpNode { NoopNode, IdCopyNode, - AddmmNode - // BC: Add new op nodes here (append only) + AddmmNode, + ItemIntNode, + ExpandDimsNode, + TileNode, + TakeAlongAxisNode, + TakeNode, + RMSNormNode, + LayerNormNode, + RopeNode, + SdpaNode, + AddNode, + AddIntNode, + SubtractIntNode, + MultiplyIntNode, + FloorDivideIntNode, + SymSizeNode, + MultiplyNode, + DivideNode, + SubtractNode, + Conv1DNode, + Conv2DNode, + Conv3DNode, + GeluNode, + ARangeNode, + SiluNode, + SigmoidNode, + TanhNode, + SqueezeNode, + SplitNode, + RsqrtNode, + MaximumNode, + MinimumNode, + LogNode, + SoftmaxNode, + BroadcastToNode, + PadNode, + WhereNode, + ReshapeNode, + TransposeNode, + AsStridedNode, + ContiguousNode, + GatherNode, + SliceNode, + AsTypeNode, + ConcatenateNode, + FullNode, + FullLikeNode, + ArgmaxNode, + SliceUpdateNode, + IndexCopyNode, + DequantizeNode, + LessNode, + LessEqualNode, + GreaterNode, + GreaterEqualNode, + EqualNode, + NotEqualNode, + LogicalNotNode, + LogicalAndNode, + LogicalOrNode, + TriNode, + TrilNode, + TriuNode, + FloorNode, + CeilNode, + SquareNode, + ExpNode, + SinNode, + CosNode, + TanNode, + ArcsinNode, + ArccosNode, + ArctanNode, + SinhNode, + CoshNode, + ArcsinhNode, + ArccoshNode, + ArctanhNode, + Log2Node, + Log10Node, + Log1pNode, + ErfNode, + Expm1Node, + RoundNode, + ReciprocalNode, + SqrtNode, + AbsNode, + NegNode, + Atan2Node, + LogAddExpNode, + FloorDivideNode, + PowerNode, + LogSumExpNode, + SumNode, + MeanNode, + VarNode, + StdNode, + ProdNode, + MaxNode, + MinNode, + ArgminNode, + MedianNode, + ModIntNode, + RemainderNode, + ConvTranspose1DNode, + ConvTranspose2DNode, + ConvTranspose3DNode, + ClipNode, + CumsumNode, + StackNode, + SignNode, + AnyNode, + AllNode, + RepeatNode, + SortNode, + ArgsortNode, + PartitionNode, + ArgPartitionNode, + QuantizedMatmulNode + // BC: Add new op nodes here (append only) } // ============================================================================= diff --git a/backends/mlx/test/CMakeLists.txt b/backends/mlx/test/CMakeLists.txt index 2a709a63412..39024639d1d 100644 --- a/backends/mlx/test/CMakeLists.txt +++ b/backends/mlx/test/CMakeLists.txt @@ -49,3 +49,23 @@ target_link_libraries( strict_compile_test PRIVATE mlx_schema executorch_core mlx ) add_dependencies(op_test_runner strict_compile_test) + +# Multi-threaded inference test +include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) + +et_cxx_test( + multi_thread_test_runner + SOURCES + ${CMAKE_CURRENT_LIST_DIR}/multi_thread_test_runner.cpp + EXTRA_LIBS + extension_module + extension_tensor + mlxdelegate +) + +# Add sanitizer link flags to multi_thread_test_runner if enabled +if(EXECUTORCH_MLX_ENABLE_SANITIZERS) + target_link_options( + multi_thread_test_runner PRIVATE ${_mlx_sanitizer_link_options} + ) +endif() diff --git a/backends/mlx/test/export_multi_thread_test_model.py b/backends/mlx/test/export_multi_thread_test_model.py new file mode 100644 index 00000000000..3c6500cad78 --- /dev/null +++ b/backends/mlx/test/export_multi_thread_test_model.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Export a test model for the multi-threaded inference test. + +The model exercises multiple ops and a mutable buffer (KV cache), +producing deterministic outputs that the C++ test can verify. + +Model behavior (accumulation via KV cache): + forward(x, input_pos): + x: [1, 1, 1, dim] (input tensor) + input_pos: [1] (cache write position, always 0) + + z = relu(x * 2 + 1) # always 3.0 with ones input + old_k = cache.k_cache[:, :, :1, :] # read old cache at pos 0 + new_val = z + old_k # accumulate: 3 + old + k_cache, v_cache = cache.update(input_pos, new_val, new_val) + return k_cache[:, :, :1, :] + v_cache[:, :, :1, :] + +With all-ones input and input_pos=[0], calling forward N times: + Call 1: old=0, new_val=3, cache=3. Output = 3 + 3 = 6.0 + Call 2: old=3, new_val=6, cache=6. Output = 6 + 6 = 12.0 + Call N: Output = 6.0 * N + +The C++ test can verify: output == 6.0 * call_number (all elements). + +Usage: + python export_multi_thread_test_model.py /tmp/multi_thread_test_model.pte +""" + +import argparse + +import torch +import torch.nn as nn + +from executorch.backends.mlx.llm.cache import KVCache +from executorch.backends.mlx.partitioner import MLXPartitioner +from executorch.exir import to_edge_transform_and_lower +from executorch.exir.capture._config import ExecutorchBackendConfig + + +class MultiOpCacheModel(nn.Module): + """ + A model with multiple ops and a mutable KV cache buffer that accumulates. + + Each forward() call: + 1. Computes z = relu(x * 2 + 1) — mul, add, relu (= 3.0 with ones) + 2. Reads old cache value at pos 0 — old_k + 3. Accumulates: new_val = z + old_k — add + 4. Writes new_val to both k and v caches — mutable buffer via kv_cache_update + 5. Returns k_cache + v_cache at pos 0 — sum of both cache slices + + With ones input, output = 6.0 * call_number (all elements). + """ + + def __init__(self, dim=4, max_len=8): + super().__init__() + self.cache = KVCache( + max_batch_size=1, + max_context_length=max_len, + n_heads=1, + head_dim=dim, + enable_dynamic_shape=True, + ) + + def forward(self, x: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor: + z = torch.relu(x * 2.0 + 1.0) + old_k = self.cache.k_cache[:, :, :1, :] + new_val = z + old_k + k_cache, v_cache = self.cache.update(input_pos, new_val, new_val) + return k_cache[:, :, :1, :] + v_cache[:, :, :1, :] + + +def export_model(output_path: str, dim=4, max_len=8): + model = MultiOpCacheModel(dim=dim, max_len=max_len) + example_inputs = ( + torch.randn(1, 1, 1, dim), # x: [B, H, S, D] + torch.tensor([0], dtype=torch.int64), # input_pos + ) + + with torch.no_grad(): + exported = torch.export.export(model, example_inputs) + exported = exported.run_decompositions({}) + + et_program = to_edge_transform_and_lower(exported, partitioner=[MLXPartitioner()]) + et_program = et_program.to_executorch( + config=ExecutorchBackendConfig(extract_delegate_segments=True) + ) + with open(output_path, "wb") as f: + f.write(et_program.buffer) + print(f"Exported model to {output_path}") + + # Verify accumulation pattern + model_ref = MultiOpCacheModel(dim=dim, max_len=max_len) + x = torch.ones(1, 1, 1, dim) + input_pos = torch.tensor([0], dtype=torch.int64) + print(f"Reference (ones input, dim={dim}, max_len={max_len}):") + for i in range(1, 4): + result = model_ref(x, input_pos) + expected = 6.0 * i + actual = result[0, 0, 0, 0].item() + status = "OK" if abs(actual - expected) < 1e-6 else "FAIL" + print(f" Call {i}: output={actual:.1f}, expected={expected:.1f} [{status}]") + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "output", + nargs="?", + default="/tmp/multi_thread_test_model.pte", + help="Output .pte path (default: /tmp/multi_thread_test_model.pte)", + ) + args = parser.parse_args() + export_model(args.output) + + +if __name__ == "__main__": + main() diff --git a/backends/mlx/test/multi_thread_test_runner.cpp b/backends/mlx/test/multi_thread_test_runner.cpp new file mode 100644 index 00000000000..72c0917d81e --- /dev/null +++ b/backends/mlx/test/multi_thread_test_runner.cpp @@ -0,0 +1,204 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * Multi-threaded inference stress test for the MLX delegate. + * + * Loads a .pte model on multiple threads (each with its own Module instance) + * and runs forward passes in parallel, verifying that all succeed and + * produce correct outputs. + * + * The model accumulates via KV cache: with all-ones input and input_pos=[0], + * call N produces output = 6.0 * N (all elements). Each thread has its own + * Module (and cache state), so correctness is verified independently. + * + * The test expects a model exported by export_multi_thread_test_model.py. + * + * Build: + * cmake --preset mlx + * cmake --build cmake-out --target multi_thread_test_runner + * + * Usage: + * ET_TESTING_MODEL_PATH=/tmp/multi_thread_test_model.pte \ + * ./cmake-out/backends/mlx/test/multi_thread_test_runner + * + * Environment variables: + * ET_TESTING_MODEL_PATH Path to .pte model file (required) + * ET_TESTING_NUM_THREADS Number of parallel threads (default: 4) + * ET_PREDICTIONS_PER_THREAD Inferences per thread (default: 10) + */ + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +using namespace ::executorch::runtime; +using namespace ::executorch::extension; + +const std::string kTestPTEPath = [] { + if (const char* env_p = std::getenv("ET_TESTING_MODEL_PATH")) { + return std::string(env_p); + } + return std::string("model.pte"); +}(); + +const int kNumThreads = [] { + if (const char* env_p = std::getenv("ET_TESTING_NUM_THREADS")) { + try { + return std::stoi(env_p); + } catch (...) { + } + } + return 4; +}(); + +const int kPredictionsPerThread = [] { + if (const char* env_p = std::getenv("ET_PREDICTIONS_PER_THREAD")) { + try { + return std::stoi(env_p); + } catch (...) { + } + } + return 10; +}(); + +std::vector get_ones_inputs(Module& module) { + const auto method_meta = module.method_meta("forward"); + const auto num_inputs = method_meta->num_inputs(); + + std::vector tensors; + tensors.reserve(num_inputs); + + for (auto index = 0; index < num_inputs; ++index) { + const auto input_tag = method_meta->input_tag(index); + + switch (*input_tag) { + case Tag::Tensor: { + const auto tensor_meta = method_meta->input_tensor_meta(index); + const auto sizes = tensor_meta->sizes(); + if (tensor_meta->scalar_type() == exec_aten::ScalarType::Long) { + tensors.emplace_back( + zeros({sizes.begin(), sizes.end()}, tensor_meta->scalar_type())); + } else { + tensors.emplace_back( + ones({sizes.begin(), sizes.end()}, tensor_meta->scalar_type())); + } + } break; + default: + throw std::runtime_error( + "Unsupported input tag at index " + std::to_string(index)); + } + } + return tensors; +} + +struct ThreadResult { + size_t success_count{0}; + size_t correctness_failures{0}; + std::string error_message; +}; + +void run_predict( + int thread_id, + const std::string& model_path, + ThreadResult& result) { + Module module(model_path); + + for (int pred = 0; pred < kPredictionsPerThread; pred++) { + auto inputs = get_ones_inputs(module); + for (int i = 0; i < inputs.size(); i++) { + if (module.set_input(inputs[i], i) != Error::Ok) { + std::cerr << "Thread " << thread_id << ", prediction " << pred + << ": set_input(" << i << ") failed" << std::endl; + break; + } + } + + const auto forward_result = module.forward(); + + if (!forward_result.ok()) { + std::cerr << "Thread " << thread_id << ", prediction " << pred + << ": forward() failed with error " + << static_cast(forward_result.error()) << std::endl; + continue; + } + + const auto outputs = forward_result.get(); + if (outputs.empty() || !outputs[0].isTensor()) { + std::cerr << "Thread " << thread_id << ", prediction " << pred + << ": no tensor output" << std::endl; + continue; + } + + const auto& output_tensor = outputs[0].toTensor(); + const float* data = output_tensor.const_data_ptr(); + const float expected = 6.0f * (pred + 1); + bool correct = true; + for (ssize_t j = 0; j < output_tensor.numel(); j++) { + if (std::fabs(data[j] - expected) > 1e-4f) { + std::cerr << "Thread " << thread_id << ", prediction " << pred + << ": output[" << j << "] = " << data[j] << ", expected " + << expected << std::endl; + correct = false; + break; + } + } + if (!correct) { + result.correctness_failures++; + } + + result.success_count++; + } +} + +TEST(MLXMultiThreadTest, LoadAndRunParallel) { + ASSERT_FALSE(kTestPTEPath.empty()) << "ET_TESTING_MODEL_PATH must be set"; + ASSERT_GT(kNumThreads, 0) << "ET_TESTING_NUM_THREADS must be > 0"; + ASSERT_GT(kPredictionsPerThread, 0) + << "ET_PREDICTIONS_PER_THREAD must be > 0"; + + std::cout << "Running " << kNumThreads << " threads x " + << kPredictionsPerThread + << " predictions with model: " << kTestPTEPath << std::endl; + + std::vector threads(kNumThreads); + std::vector results(kNumThreads); + + for (int i = 0; i < kNumThreads; i++) { + threads[i] = + std::thread([&, i]() { run_predict(i, kTestPTEPath, results[i]); }); + } + for (int i = 0; i < kNumThreads; i++) { + threads[i].join(); + } + + size_t total_success = 0; + size_t total_correctness_failures = 0; + for (int i = 0; i < kNumThreads; i++) { + total_success += results[i].success_count; + total_correctness_failures += results[i].correctness_failures; + } + + const size_t total = kNumThreads * kPredictionsPerThread; + std::cout << "Success: " << total_success << "/" << total << std::endl; + std::cout << "Correctness failures: " << total_correctness_failures + << std::endl; + + ASSERT_EQ(total_success, total) << "Some forward() calls failed"; + ASSERT_EQ(total_correctness_failures, 0) << "Some outputs were incorrect"; +} diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 0ba98b532ad..35514f4df04 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -24,7 +24,7 @@ See README.md in this directory for full documentation. """ -from typing import List, Tuple +from typing import Callable, Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -34,52 +34,6107 @@ custom_ops, ops, ) +from torch.export import Dim from .test_utils import OpTestCase, register_test +class AddTensorModel(nn.Module): + """Add two tensors, optionally with alpha.""" + + def __init__(self, alpha: Optional[float] = None): + super().__init__() + self.alpha = alpha + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + if self.alpha is not None: + return torch.add(x, y, alpha=self.alpha) + return x + y + + +class AddScalarModel(nn.Module): + """Add tensor and scalar.""" + + def __init__(self, scalar: float = 1.0): + super().__init__() + self.scalar = scalar + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.scalar + + +@register_test +class AddTest(OpTestCase): + """Test case for add op.""" + + name = "add" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 16, 64), + scalar: Optional[float] = None, + alpha: Optional[float] = None, + ): + self.shape = shape + self.scalar = scalar + self.alpha = alpha + + if alpha is not None: + self.name = "add_alpha" + elif scalar is not None: + self.name = "add_scalar" + else: + self.name = "add" + + @classmethod + def get_test_configs(cls) -> List["AddTest"]: + return [ + cls(), # tensor + tensor + cls(scalar=2.5), # tensor + scalar + cls(alpha=2.0), # tensor + alpha * tensor + ] + + def create_model(self) -> nn.Module: + if self.scalar is not None: + return AddScalarModel(self.scalar) + else: + return AddTensorModel(self.alpha) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + if self.scalar is not None: + return (x,) + else: + y = torch.randn(self.shape) + return (x, y) + + +class SubModel(nn.Module): + """Model that performs element-wise subtraction, optionally with alpha.""" + + def __init__(self, alpha: Optional[float] = None): + super().__init__() + self.alpha = alpha + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + if self.alpha is not None: + return torch.sub(x, y, alpha=self.alpha) + return torch.sub(x, y) + + +@register_test +class SubTest(OpTestCase): + name = "sub" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + scalar_sub: bool = False, + alpha: Optional[float] = None, + ): + self.shape = shape + self.scalar_sub = scalar_sub + self.alpha = alpha + shape_str = "x".join(str(s) for s in shape) + if alpha is not None: + self.name = f"sub_{shape_str}_alpha" + elif scalar_sub: + self.name = f"sub_{shape_str}_scalar" + else: + self.name = f"sub_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["SubTest"]: + return [ + cls(shape=(2, 3, 4)), + cls(shape=(10,)), + cls(shape=(4, 8)), + cls(shape=(2, 8, 16)), + cls(shape=(1, 128, 128)), + cls(shape=(2, 3, 4), scalar_sub=True), + cls(shape=(2, 3, 4), alpha=2.0), + ] + + def create_model(self) -> nn.Module: + return SubModel(self.alpha) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + if self.scalar_sub: + y = torch.randn(()) + else: + y = torch.randn(self.shape) + return (x, y) + + +class MulTensorModel(nn.Module): + """Multiply two tensors.""" + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x * y + + +class MulScalarModel(nn.Module): + """Multiply tensor and scalar.""" + + def __init__(self, scalar: float = 1.0): + super().__init__() + self.scalar = scalar + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * self.scalar + + +@register_test +class MulTest(OpTestCase): + """Test case for mul op.""" + + name = "mul" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 16, 64), + scalar: Optional[float] = None, + ): + self.shape = shape + self.scalar = scalar + + if scalar is not None: + self.name = "mul_scalar" + else: + self.name = "mul" + + @classmethod + def get_test_configs(cls) -> List["MulTest"]: + return [ + cls(), + cls(scalar=2.5), + ] + + def create_model(self) -> nn.Module: + if self.scalar is not None: + return MulScalarModel(self.scalar) + else: + return MulTensorModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + if self.scalar is not None: + return (x,) + else: + y = torch.randn(self.shape) + return (x, y) + + +class DivModel(nn.Module): + """Model that performs element-wise division.""" + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.div(x, y) + + +@register_test +class DivTest(OpTestCase): + name = "div" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + scalar_divisor: bool = False, + ): + self.shape = shape + self.scalar_divisor = scalar_divisor + shape_str = "x".join(str(s) for s in shape) + if scalar_divisor: + self.name = f"div_{shape_str}_scalar" + else: + self.name = f"div_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["DivTest"]: + return [ + cls(shape=(2, 3, 4)), + cls(shape=(10,)), + cls(shape=(4, 8)), + cls(shape=(2, 8, 16)), + cls(shape=(1, 128, 64)), + cls(shape=(2, 3, 4), scalar_divisor=True), + ] + + def create_model(self) -> nn.Module: + return DivModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + 2.0 + if self.scalar_divisor: + y = torch.randn(()) + 2.0 + else: + y = torch.randn(self.shape) + 2.0 + return (x, y) + + +class ClampModel(nn.Module): + """Model that applies clamp with min and max.""" + + def __init__(self, min_val: Optional[float], max_val: Optional[float]): + super().__init__() + self.min_val = min_val + self.max_val = max_val + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.clamp(x, min=self.min_val, max=self.max_val) + + +@register_test +class ClampTest(OpTestCase): + """Test case for clamp op with various min/max combinations.""" + + name = "clamp" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + min_val: Optional[float] = None, + max_val: Optional[float] = None, + ): + self.shape = shape + self.min_val = min_val + self.max_val = max_val + + # Build descriptive name + parts = ["clamp"] + if min_val is not None: + parts.append(f"min{min_val}") + if max_val is not None: + parts.append(f"max{max_val}") + if min_val is None and max_val is None: + parts.append("none") + shape_str = "x".join(str(s) for s in shape) + parts.append(shape_str) + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["ClampTest"]: + return [ + # Only min specified + cls(shape=(2, 3, 4), min_val=-0.5, max_val=None), + # Only max specified + cls(shape=(2, 3, 4), min_val=None, max_val=0.5), + # Both min and max specified + cls(shape=(2, 3, 4), min_val=-0.5, max_val=0.5), + # Different shapes + cls(shape=(10,), min_val=-1.0, max_val=1.0), + cls(shape=(4, 8), min_val=0.0, max_val=None), # ReLU-like + cls(shape=(2, 8, 16), min_val=-0.25, max_val=0.75), + ] + + def create_model(self) -> nn.Module: + return ClampModel(self.min_val, self.max_val) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Create inputs with values that span beyond typical clamp range + x = torch.randn(self.shape) * 2 # values roughly in [-4, 4] + return (x,) + + +class GELUModel(nn.Module): + """Simple model using GELU activation.""" + + def __init__(self, approximate: str = "none"): + super().__init__() + self.gelu = nn.GELU(approximate=approximate) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.gelu(x) + + +@register_test +class GELUTest(OpTestCase): + """Test case for GELU activation.""" + + name = "gelu" + + def __init__(self, shape: Tuple[int, ...] = (2, 16, 64), approximate: str = "none"): + self.shape = shape + self.approximate = approximate + self.name = f"gelu_{approximate}" if approximate != "none" else "gelu" + + @classmethod + def get_test_configs(cls) -> List["GELUTest"]: + return [ + cls(), + cls(shape=(4, 32, 128)), + cls(approximate="tanh"), + cls(shape=(4, 32, 128), approximate="tanh"), + ] + + def create_model(self) -> nn.Module: + return GELUModel(approximate=self.approximate) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class SoftmaxModel(nn.Module): + """Model that performs softmax along a specified dimension.""" + + def __init__(self, dim: int = -1): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.softmax(x, dim=self.dim) + + +@register_test +class SoftmaxTest(OpTestCase): + """Test case for softmax op.""" + + name = "softmax" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + dim: int = -1, + ): + self.shape = shape + self.dim = dim + shape_str = "x".join(str(s) for s in shape) + self.name = f"softmax_{shape_str}_dim{dim}" + + @classmethod + def get_test_configs(cls) -> List["SoftmaxTest"]: + return [ + cls(shape=(2, 3, 4), dim=-1), + cls(shape=(2, 3, 4), dim=1), + cls(shape=(4, 8), dim=-1), + cls(shape=(2, 4, 8, 16), dim=-1), + ] + + def create_model(self) -> nn.Module: + return SoftmaxModel(dim=self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class LogSoftmaxModel(nn.Module): + """Model that applies log_softmax.""" + + def __init__(self, dim: int = -1): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.log_softmax(x, dim=self.dim) + + +@register_test +class LogSoftmaxTest(OpTestCase): + name = "log_softmax" + rtol = 1e-5 + atol = 1e-5 + + def __init__(self, shape: Tuple[int, ...] = (2, 3, 4), dim: int = -1): + self.shape = shape + self.dim = dim + shape_str = "x".join(str(s) for s in shape) + self.name = f"log_softmax_{shape_str}_dim{dim}" + + @classmethod + def get_test_configs(cls) -> List["LogSoftmaxTest"]: + return [ + cls(shape=(2, 3, 4), dim=-1), + cls(shape=(10,), dim=0), + cls(shape=(4, 8), dim=1), + cls(shape=(2, 8, 16), dim=1), + cls(shape=(1, 128, 512), dim=-1), + ] + + def create_model(self) -> nn.Module: + return LogSoftmaxModel(dim=self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class SqueezeModel(nn.Module): + """Model that squeezes a tensor at specified dimensions.""" + + def __init__(self, dims: Optional[Tuple[int, ...]] = None): + super().__init__() + self.dims = dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.dims is None: + return torch.squeeze(x) + else: + return torch.squeeze(x, dim=self.dims) + + +@register_test +class SqueezeTest(OpTestCase): + """Test case for squeeze op.""" + + name = "squeeze" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (1, 3, 1, 4), + dims: Optional[Tuple[int, ...]] = (0, 2), + ): + self.shape = shape + self.dims = dims + shape_str = "x".join(str(s) for s in shape) + if dims is None: + dims_str = "all" + elif len(dims) == 0: + dims_str = "empty" + else: + dims_str = "_".join(str(d) for d in dims) + self.name = f"squeeze_{shape_str}_dims{dims_str}" + + @classmethod + def get_test_configs(cls) -> List["SqueezeTest"]: + return [ + cls(shape=(1, 3, 1, 4), dims=(0, 2)), + cls(shape=(1, 5, 1, 1), dims=(0,)), + cls(shape=(3, 1, 4), dims=(1,)), + cls(shape=(1, 1, 8), dims=(0, 1)), + cls(shape=(2, 1, 3, 1), dims=(1, 3)), + # Squeeze all singleton dims (no dims specified) + cls(shape=(1, 3, 1, 4), dims=None), + # Dims include non-size-1 axes (should be no-op for those axes) + cls(shape=(1, 1, 1, 8198), dims=(0, 1, 2, 3)), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + def create_model(self) -> nn.Module: + return SqueezeModel(self.dims) + + +class UnsqueezeModel(nn.Module): + """Model that unsqueezes a tensor at a given dimension.""" + + def __init__(self, dim: int = 0): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.unsqueeze(self.dim) + + +@register_test +class UnsqueezeTest(OpTestCase): + """Test case for unsqueeze op.""" + + name = "unsqueeze" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 16, 64), + dim: int = 0, + ): + self.shape = shape + self.dim = dim + self.name = f"unsqueeze_dim{dim}" + + @classmethod + def get_test_configs(cls) -> List["UnsqueezeTest"]: + return [ + cls(dim=0), + cls(dim=1), + cls(dim=-1), + ] + + def create_model(self) -> nn.Module: + return UnsqueezeModel(self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class PermuteModel(nn.Module): + """Model that permutes tensor dimensions.""" + + def __init__(self, dims: Tuple[int, ...] = (0, 2, 1, 3)): + super().__init__() + self.dims = dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(self.dims) + + +class TransposeModel(nn.Module): + """Model that transposes two dimensions.""" + + def __init__(self, dim0: int = 1, dim1: int = 2): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.transpose(self.dim0, self.dim1) + + +@register_test +class PermuteTest(OpTestCase): + """Test case for permute and transpose ops.""" + + name = "permute" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 8, 16, 64), + variant: str = "permute", + permute_dims: Tuple[int, ...] = (0, 2, 1, 3), + transpose_dims: Tuple[int, int] = (1, 2), + ): + self.shape = shape + self.variant = variant + self.permute_dims = permute_dims + self.transpose_dims = transpose_dims + + if variant == "transpose": + self.name = "transpose" + else: + self.name = "permute" + + @classmethod + def get_test_configs(cls) -> List["PermuteTest"]: + return [ + cls(variant="permute", permute_dims=(0, 2, 1, 3)), + cls(variant="transpose", transpose_dims=(1, 2)), + ] + + def create_model(self) -> nn.Module: + if self.variant == "transpose": + return TransposeModel(self.transpose_dims[0], self.transpose_dims[1]) + else: + return PermuteModel(self.permute_dims) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class NarrowModel(nn.Module): + """Model that narrows a tensor along a dimension.""" + + def __init__(self, dim: int, start: int, length: int): + super().__init__() + self.dim = dim + self.start = start + self.length = length + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.narrow(self.dim, self.start, self.length) + + +@register_test +class NarrowTest(OpTestCase): + """Test case for tensor.narrow().""" + + name = "narrow" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + shape: Tuple[int, ...] = (4, 16, 8), + dim: int = 1, + start: int = 2, + length: int = 8, + ): + self.shape = shape + self.dim = dim + self.start = start + self.length = length + self.name = f"narrow_dim{dim}_start{start}_len{length}" + + @classmethod + def get_test_configs(cls) -> List["NarrowTest"]: + return [ + cls(shape=(4, 16, 8), dim=1, start=2, length=8), + cls(shape=(8, 8), dim=0, start=1, length=4), + cls(shape=(2, 32, 4), dim=1, start=0, length=16), + ] + + def create_model(self) -> nn.Module: + return NarrowModel(self.dim, self.start, self.length) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class SelectModel(nn.Module): + """Model that selects a single index along a dimension. + + torch.select(input, dim, index) returns input[..., index, ...] where + the indexing happens at dimension `dim`. The selected dimension is removed. + Maps to aten.select_copy.int -> MLX take(array, index, axis). + """ + + def __init__(self, dim: int, index: int): + super().__init__() + self.dim = dim + self.index = index + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.select(x, self.dim, self.index) + + +@register_test +class SelectTest(OpTestCase): + """Test case for torch.select (aten.select_copy.int -> TakeNode).""" + + name = "select" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (4, 8, 16), + dim: int = 1, + index: int = 3, + ): + self.shape = shape + self.dim = dim + self.index = index + self.name = f"select_dim{dim}_idx{index}" + + @classmethod + def get_test_configs(cls) -> List["SelectTest"]: + return [ + cls(shape=(4, 8, 16), dim=0, index=2), + cls(shape=(4, 8, 16), dim=1, index=3), + cls(shape=(4, 8, 16), dim=2, index=0), + cls(shape=(4, 8, 16), dim=-1, index=5), + cls(shape=(2, 3), dim=0, index=1), + ] + + def create_model(self) -> nn.Module: + return SelectModel(self.dim, self.index) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class SliceModel(nn.Module): + """Model that slices a tensor along dimension 1.""" + + def __init__(self, start: int, stop: int): + super().__init__() + self.start = start + self.stop = stop + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x[:, self.start : self.stop] + + +class SliceDim0Model(nn.Module): + """Model that slices a tensor along dimension 0.""" + + def __init__(self, start: int, stop: int): + super().__init__() + self.start = start + self.stop = stop + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x[self.start : self.stop] + + +@register_test +class SliceTest(OpTestCase): + """Test case for tensor slicing.""" + + name = "slice" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + shape: Tuple[int, ...] = (4, 16, 8), + dim: int = 1, + start: int = 2, + stop: int = 10, + ): + self.shape = shape + self.dim = dim + self.start = start + self.stop = stop + self.name = f"slice_dim{dim}_{start}to{stop}" + + @classmethod + def get_test_configs(cls) -> List["SliceTest"]: + return [ + cls(shape=(4, 16, 8), dim=1, start=2, stop=10), + cls(shape=(8, 8), dim=0, start=1, stop=5), + cls(shape=(2, 32, 4), dim=1, start=0, stop=16), + ] + + def create_model(self) -> nn.Module: + if self.dim == 0: + return SliceDim0Model(self.start, self.stop) + return SliceModel(self.start, self.stop) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class RepeatModel(nn.Module): + """Model that repeats a tensor along specified dimensions.""" + + def __init__(self, repeats: Tuple[int, ...]): + super().__init__() + self.repeats = repeats + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.repeat(*self.repeats) + + +@register_test +class RepeatTest(OpTestCase): + """Test case for tensor.repeat().""" + + name = "repeat" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + input_shape: Tuple[int, ...] = (2, 3, 4), + repeats: Tuple[int, ...] = (2, 1, 3), + ): + self.input_shape = input_shape + self.repeats = repeats + repeat_str = "x".join(str(r) for r in repeats) + self.name = f"repeat_{repeat_str}" + + @classmethod + def get_test_configs(cls) -> List["RepeatTest"]: + return [ + cls(input_shape=(2, 3), repeats=(2, 3)), + cls(input_shape=(2, 3, 4), repeats=(1, 2, 1)), + cls(input_shape=(4, 4), repeats=(3, 3)), + ] + + def create_model(self) -> nn.Module: + return RepeatModel(self.repeats) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.input_shape) + return (x,) + + +class CatNModel(nn.Module): + """Model that concatenates N tensors along a dimension.""" + + def __init__(self, dim: int = 0, n: int = 3): + super().__init__() + self.dim = dim + self.n = n + + def forward(self, *tensors: torch.Tensor) -> torch.Tensor: + return torch.cat(tensors[: self.n], dim=self.dim) + + +@register_test +class CatTest(OpTestCase): + name = "cat" + rtol = 1e-5 + atol = 1e-5 + + def __init__(self, shapes: List[Tuple[int, ...]], dim: int = 0, tag: str = ""): + self.shapes = shapes + self.dim = dim + self.name = f"cat_{tag}" if tag else "cat" + + @classmethod + def get_test_configs(cls) -> List["CatTest"]: + return [ + cls(shapes=[(2, 3), (4, 3), (1, 3)], dim=0, tag="2d_dim0"), + cls(shapes=[(3, 2), (3, 4), (3, 1)], dim=1, tag="2d_dim1"), + cls(shapes=[(2, 3, 4), (5, 3, 4), (3, 3, 4)], dim=0, tag="3d_dim0"), + cls(shapes=[(3, 4), (2, 4)], dim=0, tag="two_tensors"), + cls(shapes=[(3, 2, 4), (3, 5, 4), (3, 1, 4)], dim=-2, tag="neg_dim"), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return tuple(torch.randn(s) for s in self.shapes) + + def create_model(self) -> nn.Module: + return CatNModel(dim=self.dim, n=len(self.shapes)) + + +class WhereModel(nn.Module): + """Model that conditionally selects from x or y based on condition.""" + + def forward( + self, condition: torch.Tensor, x: torch.Tensor, y: torch.Tensor + ) -> torch.Tensor: + return torch.where(condition, x, y) + + +@register_test +class WhereTest(OpTestCase): + """Test case for where op.""" + + name = "where" + rtol = 1e-5 + atol = 1e-5 + + def __init__(self, shape: Tuple[int, ...] = (2, 3, 4)): + self.shape = shape + shape_str = "x".join(str(s) for s in shape) + self.name = f"where_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["WhereTest"]: + return [ + cls(shape=(2, 3, 4)), + cls(shape=(4, 8)), + cls(shape=(2, 8, 16, 16)), + cls(shape=(1, 1, 128, 128)), + ] + + def create_model(self) -> nn.Module: + return WhereModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + condition = torch.rand(self.shape) > 0.5 + x = torch.randn(self.shape) + y = torch.randn(self.shape) + return (condition, x, y) + + +class PadModel(nn.Module): + """Model that pads a tensor with a constant value.""" + + def __init__(self, pad: Tuple[int, ...], value: float = 0.0): + super().__init__() + self.pad = pad + self.value = value + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.pad(x, self.pad, mode="constant", value=self.value) + + +@register_test +class PadTest(OpTestCase): + name = "pad" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + pad: Tuple[int, ...] = (1, 1, 1, 1), + value: float = 0.0, + ): + self.shape = shape + self.pad = pad + self.value = value + shape_str = "x".join(str(s) for s in shape) + pad_str = "_".join(str(p) for p in pad) + self.name = f"pad_{shape_str}_p{pad_str}_v{int(value)}" + + @classmethod + def get_test_configs(cls) -> List["PadTest"]: + return [ + cls(shape=(2, 3, 4), pad=(1, 1, 1, 1), value=0.0), + cls(shape=(10,), pad=(2, 3), value=0.0), + cls(shape=(4, 8), pad=(1, 2), value=0.0), + cls(shape=(2, 8, 16), pad=(1, 1, 2, 2), value=0.0), + cls(shape=(1, 3, 32, 32), pad=(1, 1, 1, 1), value=0.0), + cls(shape=(2, 3, 4), pad=(1, 1, 1, 1), value=1.0), + ] + + def create_model(self) -> nn.Module: + return PadModel(self.pad, self.value) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.shape) + return (x,) + + +class LinearModel(nn.Module): + """Simple linear layer for testing.""" + + def __init__( + self, in_features: int = 64, out_features: int = 128, bias: bool = True + ): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +@register_test +class LinearTest(OpTestCase): + """Test case for nn.Linear.""" + + name = "linear" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + in_features: int = 64, + out_features: int = 128, + batch_size: int = 2, + seq_len: int = 16, + bias: bool = True, + ): + self.in_features = in_features + self.out_features = out_features + self.batch_size = batch_size + self.seq_len = seq_len + self.bias = bias + + if not bias: + self.name = "linear_no_bias" + else: + self.name = "linear" + + @classmethod + def get_test_configs(cls) -> List["LinearTest"]: + return [ + cls(), + cls(bias=False), + ] + + def create_model(self) -> nn.Module: + return LinearModel(self.in_features, self.out_features, bias=self.bias) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.seq_len, self.in_features) + return (x,) + + +class EmbeddingModel(nn.Module): + """Simple embedding layer for testing.""" + + def __init__(self, num_embeddings: int = 1000, embedding_dim: int = 64): + super().__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.embedding(x) + + +@register_test +class EmbeddingTest(OpTestCase): + """Test case for nn.Embedding.""" + + name = "embedding" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + num_embeddings: int = 1000, + embedding_dim: int = 64, + batch_size: int = 2, + seq_len: int = 16, + ): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.batch_size = batch_size + self.seq_len = seq_len + self.name = "embedding" + + @classmethod + def get_test_configs(cls) -> List["EmbeddingTest"]: + return [ + cls(), + cls(num_embeddings=512, embedding_dim=128), + ] + + def create_model(self) -> nn.Module: + return EmbeddingModel(self.num_embeddings, self.embedding_dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randint(0, self.num_embeddings, (self.batch_size, self.seq_len)) + return (x,) + + +class MaxPool1dModel(nn.Module): + def __init__(self, kernel_size, stride=None, padding=0): + super().__init__() + self.pool = nn.MaxPool1d(kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.pool(x) + + +@register_test +class MaxPool1dTest(OpTestCase): + name = "max_pool1d" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + kernel_size=2, + stride=2, + padding=0, + in_channels: int = 8, + seq_len: int = 32, + batch_size: int = 1, + tag: str = "", + ): + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.in_channels = in_channels + self.seq_len = seq_len + self.batch_size = batch_size + + if tag: + self.name = f"max_pool1d_{tag}" + else: + parts = ["max_pool1d", f"k{kernel_size}", f"s{stride}"] + if padding != 0: + parts.append(f"p{padding}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["MaxPool1dTest"]: + return [ + # Fast path: kernel == stride + cls(kernel_size=2, stride=2), + # General path: overlapping windows with padding + cls(kernel_size=3, stride=2, padding=1), + # Fast path: larger kernel + cls(kernel_size=4, stride=4, seq_len=64), + # stride=None (defaults to kernel_size) + cls(kernel_size=4, stride=None, seq_len=64, tag="stride_none"), + # Single channel + cls(kernel_size=2, stride=2, in_channels=1, tag="c1"), + # Global pooling: kernel == spatial size + cls(kernel_size=32, stride=32, tag="global"), + # Batch > 1 + cls(kernel_size=2, stride=2, batch_size=4, tag="batch4"), + # Stride > kernel (gaps between windows) + cls(kernel_size=2, stride=3, seq_len=32, tag="stride_gt_kernel"), + ] + + def create_model(self) -> nn.Module: + return MaxPool1dModel(self.kernel_size, self.stride, self.padding) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.batch_size, self.in_channels, self.seq_len),) + + +class MaxPool2dModel(nn.Module): + def __init__(self, kernel_size, stride=None, padding=0): + super().__init__() + self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.pool(x) + + +@register_test +class MaxPool2dTest(OpTestCase): + name = "max_pool2d" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + kernel_size=2, + stride=2, + padding=0, + in_channels: int = 16, + input_size: Tuple[int, int] = (32, 32), + batch_size: int = 1, + tag: str = "", + ): + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.in_channels = in_channels + self.input_size = input_size + self.batch_size = batch_size + + if tag: + self.name = f"max_pool2d_{tag}" + else: + parts = ["max_pool2d", f"k{kernel_size}", f"s{stride}"] + if padding != 0: + parts.append(f"p{padding}") + parts.append(f"{input_size[0]}x{input_size[1]}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["MaxPool2dTest"]: + return [ + # Fast path: kernel == stride, evenly divisible + cls(kernel_size=2, stride=2, input_size=(32, 32)), + # General path: overlapping windows + cls(kernel_size=3, stride=2, padding=1, input_size=(32, 32)), + # Fast path: 4x4 pooling + cls(kernel_size=4, stride=4, input_size=(64, 64)), + # General path: stride != kernel, no padding + cls(kernel_size=3, stride=1, input_size=(16, 16)), + # Batch > 1 + cls(kernel_size=2, stride=2, input_size=(32, 32), batch_size=4), + # stride=None (defaults to kernel_size) + cls(kernel_size=2, stride=None, input_size=(32, 32), tag="stride_none"), + # Single channel + cls(kernel_size=2, stride=2, in_channels=1, input_size=(16, 16), tag="c1"), + # Global pooling + cls(kernel_size=8, stride=8, input_size=(8, 8), tag="global"), + # Non-square kernel/stride + cls( + kernel_size=(2, 3), + stride=(2, 3), + input_size=(16, 18), + tag="nonsquare_fast", + ), + cls( + kernel_size=(3, 2), + stride=(1, 2), + padding=(1, 0), + input_size=(16, 16), + tag="nonsquare_general", + ), + # Stride > kernel (gaps between windows) + cls(kernel_size=2, stride=3, input_size=(16, 16), tag="stride_gt_kernel"), + # Non-square input with square kernel + cls(kernel_size=2, stride=2, input_size=(16, 32), tag="nonsquare_input"), + ] + + def create_model(self) -> nn.Module: + return MaxPool2dModel(self.kernel_size, self.stride, self.padding) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return ( + torch.randn( + self.batch_size, + self.in_channels, + self.input_size[0], + self.input_size[1], + ), + ) + + +class MaxPool3dModel(nn.Module): + def __init__(self, kernel_size, stride=None, padding=0): + super().__init__() + self.pool = nn.MaxPool3d(kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.pool(x) + + +@register_test +class MaxPool3dTest(OpTestCase): + name = "max_pool3d" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + kernel_size=2, + stride=2, + padding=0, + in_channels: int = 8, + input_size: Tuple[int, int, int] = (8, 16, 16), + batch_size: int = 1, + tag: str = "", + ): + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.in_channels = in_channels + self.input_size = input_size + self.batch_size = batch_size + + if tag: + self.name = f"max_pool3d_{tag}" + else: + parts = ["max_pool3d", f"k{kernel_size}", f"s{stride}"] + if padding != 0: + parts.append(f"p{padding}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["MaxPool3dTest"]: + return [ + # Fast path: kernel == stride + cls(kernel_size=2, stride=2), + # General path: overlapping windows with padding + cls(kernel_size=3, stride=2, padding=1), + # Batch > 1 + cls(kernel_size=2, stride=2, batch_size=2, tag="batch2"), + # Single channel + cls(kernel_size=2, stride=2, in_channels=1, tag="c1"), + # Non-cubic kernel/stride + cls( + kernel_size=(2, 2, 4), + stride=(2, 2, 4), + input_size=(8, 16, 16), + tag="noncubic_fast", + ), + # Stride > kernel + cls( + kernel_size=2, stride=3, input_size=(8, 16, 16), tag="stride_gt_kernel" + ), + # stride=None (defaults to kernel_size) + cls(kernel_size=2, stride=None, tag="stride_none"), + # Global pooling: kernel == spatial + cls( + kernel_size=(8, 16, 16), + stride=(8, 16, 16), + input_size=(8, 16, 16), + tag="global", + ), + # Non-cubic general path (stride != kernel) + cls( + kernel_size=(3, 2, 2), + stride=(1, 2, 2), + padding=(1, 0, 0), + input_size=(8, 16, 16), + tag="noncubic_general", + ), + # Non-cubic input with cubic kernel + cls(kernel_size=2, stride=2, input_size=(4, 8, 16), tag="nonsquare_input"), + ] + + def create_model(self) -> nn.Module: + return MaxPool3dModel(self.kernel_size, self.stride, self.padding) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return ( + torch.randn( + self.batch_size, + self.in_channels, + *self.input_size, + ), + ) + + +class AvgPool1dModel(nn.Module): + def __init__(self, kernel_size, stride=None, padding=0): + super().__init__() + self.pool = nn.AvgPool1d(kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.pool(x) + + +@register_test +class AvgPool1dTest(OpTestCase): + name = "avg_pool1d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + kernel_size=2, + stride=2, + padding=0, + in_channels: int = 8, + seq_len: int = 32, + batch_size: int = 1, + tag: str = "", + ): + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.in_channels = in_channels + self.seq_len = seq_len + self.batch_size = batch_size + + if tag: + self.name = f"avg_pool1d_{tag}" + else: + parts = ["avg_pool1d", f"k{kernel_size}", f"s{stride}"] + if padding != 0: + parts.append(f"p{padding}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["AvgPool1dTest"]: + return [ + # Fast path: kernel == stride + cls(kernel_size=2, stride=2), + # General path: overlapping windows with padding + cls(kernel_size=3, stride=2, padding=1), + # Fast path: larger kernel + cls(kernel_size=4, stride=4, seq_len=64), + # stride=None (defaults to kernel_size) + cls(kernel_size=4, stride=None, seq_len=64, tag="stride_none"), + # Single channel + cls(kernel_size=2, stride=2, in_channels=1, tag="c1"), + # Global pooling + cls(kernel_size=32, stride=32, tag="global"), + # Batch > 1 + cls(kernel_size=2, stride=2, batch_size=4, tag="batch4"), + # Stride > kernel (gaps between windows) + cls(kernel_size=2, stride=3, seq_len=32, tag="stride_gt_kernel"), + ] + + def create_model(self) -> nn.Module: + return AvgPool1dModel(self.kernel_size, self.stride, self.padding) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.batch_size, self.in_channels, self.seq_len),) + + +class AvgPool2dModel(nn.Module): + def __init__(self, kernel_size, stride=None, padding=0): + super().__init__() + self.pool = nn.AvgPool2d(kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.pool(x) + + +@register_test +class AvgPool2dTest(OpTestCase): + name = "avg_pool2d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + kernel_size=2, + stride=2, + padding=0, + in_channels: int = 16, + input_size: Tuple[int, int] = (32, 32), + batch_size: int = 1, + tag: str = "", + ): + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.in_channels = in_channels + self.input_size = input_size + self.batch_size = batch_size + + if tag: + self.name = f"avg_pool2d_{tag}" + else: + parts = ["avg_pool2d", f"k{kernel_size}", f"s{stride}"] + if padding != 0: + parts.append(f"p{padding}") + parts.append(f"{input_size[0]}x{input_size[1]}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["AvgPool2dTest"]: + return [ + # Fast path: kernel == stride + cls(kernel_size=2, stride=2, input_size=(32, 32)), + # General path: overlapping windows + cls(kernel_size=3, stride=2, padding=1, input_size=(32, 32)), + # Fast path: 4x4 pooling + cls(kernel_size=4, stride=4, input_size=(64, 64)), + # General path: stride != kernel + cls(kernel_size=3, stride=1, input_size=(16, 16)), + # Batch > 1 + cls(kernel_size=2, stride=2, input_size=(32, 32), batch_size=4), + # stride=None + cls(kernel_size=2, stride=None, input_size=(32, 32), tag="stride_none"), + # Single channel + cls(kernel_size=2, stride=2, in_channels=1, input_size=(16, 16), tag="c1"), + # Global pooling + cls(kernel_size=8, stride=8, input_size=(8, 8), tag="global"), + # Non-square kernel/stride + cls( + kernel_size=(2, 3), + stride=(2, 3), + input_size=(16, 18), + tag="nonsquare_fast", + ), + cls( + kernel_size=(3, 2), + stride=(1, 2), + padding=(1, 0), + input_size=(16, 16), + tag="nonsquare_general", + ), + # Stride > kernel (gaps between windows) + cls(kernel_size=2, stride=3, input_size=(16, 16), tag="stride_gt_kernel"), + # Non-square input with square kernel + cls(kernel_size=2, stride=2, input_size=(16, 32), tag="nonsquare_input"), + ] + + def create_model(self) -> nn.Module: + return AvgPool2dModel(self.kernel_size, self.stride, self.padding) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return ( + torch.randn( + self.batch_size, + self.in_channels, + self.input_size[0], + self.input_size[1], + ), + ) + + +class AvgPool3dModel(nn.Module): + def __init__(self, kernel_size, stride=None, padding=0): + super().__init__() + self.pool = nn.AvgPool3d(kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.pool(x) + + +@register_test +class AvgPool3dTest(OpTestCase): + name = "avg_pool3d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + kernel_size=2, + stride=2, + padding=0, + in_channels: int = 8, + input_size: Tuple[int, int, int] = (8, 16, 16), + batch_size: int = 1, + tag: str = "", + ): + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.in_channels = in_channels + self.input_size = input_size + self.batch_size = batch_size + + if tag: + self.name = f"avg_pool3d_{tag}" + else: + parts = ["avg_pool3d", f"k{kernel_size}", f"s{stride}"] + if padding != 0: + parts.append(f"p{padding}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["AvgPool3dTest"]: + return [ + # Fast path: kernel == stride + cls(kernel_size=2, stride=2), + # General path: overlapping windows with padding + cls(kernel_size=3, stride=2, padding=1), + # Batch > 1 + cls(kernel_size=2, stride=2, batch_size=2, tag="batch2"), + # Single channel + cls(kernel_size=2, stride=2, in_channels=1, tag="c1"), + # Non-cubic kernel/stride + cls( + kernel_size=(2, 2, 4), + stride=(2, 2, 4), + input_size=(8, 16, 16), + tag="noncubic_fast", + ), + # Stride > kernel + cls( + kernel_size=2, stride=3, input_size=(8, 16, 16), tag="stride_gt_kernel" + ), + # stride=None (defaults to kernel_size) + cls(kernel_size=2, stride=None, tag="stride_none"), + # Global pooling: kernel == spatial + cls( + kernel_size=(8, 16, 16), + stride=(8, 16, 16), + input_size=(8, 16, 16), + tag="global", + ), + # Non-cubic general path (stride != kernel) + cls( + kernel_size=(3, 2, 2), + stride=(1, 2, 2), + padding=(1, 0, 0), + input_size=(8, 16, 16), + tag="noncubic_general", + ), + # Non-cubic input with cubic kernel + cls(kernel_size=2, stride=2, input_size=(4, 8, 16), tag="nonsquare_input"), + ] + + def create_model(self) -> nn.Module: + return AvgPool3dModel(self.kernel_size, self.stride, self.padding) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return ( + torch.randn( + self.batch_size, + self.in_channels, + *self.input_size, + ), + ) + + +class RMSNormModel(nn.Module): + """Model using torch.nn.functional.rms_norm.""" + + def __init__(self, hidden_dim: int = 64, eps: float = 1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_dim)) + self.hidden_dim = hidden_dim + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.rms_norm( + x, (self.hidden_dim,), self.weight, self.eps + ) + + +@register_test +class RMSNormTest(OpTestCase): + """Test case for torch.nn.functional.rms_norm (aten.rms_norm).""" + + name = "aten_rms_norm" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + hidden_dim: int = 64, + batch_size: int = 2, + seq_len: int = 16, + eps: float = 1e-5, + ): + self.hidden_dim = hidden_dim + self.batch_size = batch_size + self.seq_len = seq_len + self.eps = eps + self.name = "aten_rms_norm" + + @classmethod + def get_test_configs(cls) -> List["RMSNormTest"]: + return [ + cls(), + cls(hidden_dim=128, eps=1e-6), + ] + + def create_model(self) -> nn.Module: + return RMSNormModel(self.hidden_dim, self.eps) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.seq_len, self.hidden_dim) + return (x,) + + +class RopeModel(nn.Module): + """Model that applies RoPE with dynamic position.""" + + def __init__( + self, + dims: int = 64, + traditional: bool = False, + base: float = 500000.0, + scale: float = 1.0, + ): + super().__init__() + self.dims = dims + self.traditional = traditional + self.base = base + self.scale = scale + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + pos_tensor: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + pos = pos_tensor.item() + q_rot = torch.ops.mlx.rope( + q, self.dims, pos, self.traditional, self.base, self.scale, None + ) + k_rot = torch.ops.mlx.rope( + k, self.dims, pos, self.traditional, self.base, self.scale, None + ) + return q_rot, k_rot + + +@register_test +class RopeTest(OpTestCase): + """Test case for RoPE.""" + + name = "rope" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 1, + num_heads: int = 8, + seq_len: int = 16, + head_dim: int = 64, + dims: Optional[int] = None, + pos: int = 0, + traditional: bool = False, + base: float = 500000.0, + scale: float = 1.0, + ): + self.batch_size = batch_size + self.num_heads = num_heads + self.seq_len = seq_len + self.head_dim = head_dim + self.dims = dims if dims is not None else head_dim + self.pos = pos + self.traditional = traditional + self.base = base + self.scale = scale + self.name = "rope" + + @classmethod + def get_test_configs(cls) -> List["RopeTest"]: + configs = [ + cls(), + cls(traditional=True), + cls(head_dim=64, dims=32), + cls(head_dim=64, dims=32, traditional=True), + ] + for cfg in configs: + parts = ["rope"] + if cfg.traditional: + parts.append("traditional") + if cfg.dims != cfg.head_dim: + parts.append(f"dims{cfg.dims}") + cfg.name = "_".join(parts) + return configs + + def create_model(self) -> nn.Module: + return RopeModel( + dims=self.dims, + traditional=self.traditional, + base=self.base, + scale=self.scale, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + q = torch.randn(self.batch_size, self.num_heads, self.seq_len, self.head_dim) + k = torch.randn(self.batch_size, self.num_heads, self.seq_len, self.head_dim) + pos_tensor = torch.tensor(self.pos, dtype=torch.int64) + return (q, k, pos_tensor) + + +from executorch.backends.mlx.llm.cache import KVCache + + +class KVCacheModel(nn.Module): + """ + Test model wrapping KVCache from cache.py. + + This tests the ExecutorTorch llama KVCache-compatible interface that uses + the mlx::kv_cache_update op internally. + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool = True, + ): + super().__init__() + self.cache = KVCache( + max_batch_size=max_batch_size, + max_context_length=max_context_length, + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=enable_dynamic_shape, + ) + + def forward( + self, + input_pos: torch.Tensor, # [S] position tensor + k_val: torch.Tensor, # [B, H, S, D] + v_val: torch.Tensor, # [B, H, S, D] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Update cache and return full cache tensors.""" + k_cache, v_cache = self.cache.update(input_pos, k_val, v_val) + return k_cache, v_cache + + +@register_test +class KVCacheTest(OpTestCase): + """ + Test case for MLX KVCache with ExecutorTorch llama KVCache interface. + + This verifies that KVCache: + 1. Accepts the ET llama KVCache update interface + 2. Correctly delegates to mlx::kv_cache_update custom op + 3. Produces correct outputs for both export and test inputs + """ + + name = "kv_cache" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + max_batch_size: int = 1, + n_heads: int = 4, + head_dim: int = 64, + max_context_length: int = 128, + seq_step: int = 8, + enable_dynamic_shape: bool = True, + ): + self.max_batch_size = max_batch_size + self.n_heads = n_heads + self.head_dim = head_dim + self.max_context_length = max_context_length + self.seq_step = seq_step + self.enable_dynamic_shape = enable_dynamic_shape + + @classmethod + def get_test_configs(cls) -> List["KVCacheTest"]: + return [ + cls(), # default config + cls(n_heads=8, head_dim=32), # different head config + cls(enable_dynamic_shape=False), # static shape mode + ] + + def create_model(self) -> nn.Module: + return KVCacheModel( + max_batch_size=self.max_batch_size, + max_context_length=self.max_context_length, + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=self.enable_dynamic_shape, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + input_pos = torch.tensor([0], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, self.n_heads, self.seq_step, self.head_dim + ) + v_val = torch.randn( + self.max_batch_size, self.n_heads, self.seq_step, self.head_dim + ) + return (input_pos, k_val, v_val) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + # Note: KVCache.update() takes (input_pos, k_val, v_val) - position first + # Test with different position and different seq_step + test_seq_step = self.seq_step + 4 + input_pos = torch.tensor([16], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, self.n_heads, test_seq_step, self.head_dim + ) + v_val = torch.randn( + self.max_batch_size, self.n_heads, test_seq_step, self.head_dim + ) + return (input_pos, k_val, v_val) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + # seq_step (dim 2) is dynamic for k_val and v_val + seq_dim = Dim("seq_step", min=1, max=self.max_context_length) + return { + "input_pos": None, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + } + + +class KVCacheIntModel(nn.Module): + """ + Test model that passes int/SymInt (not tensor) to KVCache.update(). + + This tests the "int route" where the caller extracts the start position + from the tensor before calling update, which is the preferred pattern + in multi-layer models to avoid redundant SymInt extraction. + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool = True, + ): + super().__init__() + self.cache = KVCache( + max_batch_size=max_batch_size, + max_context_length=max_context_length, + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=enable_dynamic_shape, + ) + + def forward( + self, + input_pos: torch.Tensor, # [S] position tensor + k_val: torch.Tensor, # [B, H, S, D] + v_val: torch.Tensor, # [B, H, S, D] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Extract int from tensor, then pass to update (the int route).""" + start_pos = input_pos[0].item() + return self.cache.update(start_pos, k_val, v_val) + + +@register_test +class KVCacheIntTest(OpTestCase): + """ + Test case for MLX KVCache with int/SymInt input_pos. + + This verifies the "int route" where the caller extracts the start position + before calling update, matching the recommended pattern for multi-layer models. + """ + + name = "kv_cache_int" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + max_batch_size: int = 1, + n_heads: int = 4, + head_dim: int = 64, + max_context_length: int = 128, + seq_step: int = 8, + enable_dynamic_shape: bool = True, + ): + self.max_batch_size = max_batch_size + self.n_heads = n_heads + self.head_dim = head_dim + self.max_context_length = max_context_length + self.seq_step = seq_step + self.enable_dynamic_shape = enable_dynamic_shape + + @classmethod + def get_test_configs(cls) -> List["KVCacheIntTest"]: + return [ + cls(), # default config + cls(n_heads=8, head_dim=32), # different head config + ] + + def create_model(self) -> nn.Module: + return KVCacheIntModel( + max_batch_size=self.max_batch_size, + max_context_length=self.max_context_length, + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=self.enable_dynamic_shape, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + input_pos = torch.tensor([0], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, self.n_heads, self.seq_step, self.head_dim + ) + v_val = torch.randn( + self.max_batch_size, self.n_heads, self.seq_step, self.head_dim + ) + return (input_pos, k_val, v_val) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + test_seq_step = self.seq_step + 4 + input_pos = torch.tensor([16], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, self.n_heads, test_seq_step, self.head_dim + ) + v_val = torch.randn( + self.max_batch_size, self.n_heads, test_seq_step, self.head_dim + ) + return (input_pos, k_val, v_val) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + seq_dim = Dim("seq_step", min=1, max=self.max_context_length) + return { + "input_pos": None, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + } + + +class KVCacheSliceModel(nn.Module): + """ + Test model that updates KVCache then slices the result. + + This tests that operations on the returned cache work correctly, + matching the pattern used in attention implementations. + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + enable_dynamic_shape: bool = True, + ): + super().__init__() + self.max_context_length = max_context_length + self.cache = KVCache( + max_batch_size=max_batch_size, + max_context_length=max_context_length, + n_heads=n_heads, + head_dim=head_dim, + enable_dynamic_shape=enable_dynamic_shape, + ) + + def forward( + self, + input_pos: torch.Tensor, # [S] position tensor + k_val: torch.Tensor, # [B, H, S, D] + v_val: torch.Tensor, # [B, H, S, D] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Update cache and return sliced result.""" + start_pos = input_pos[0].item() + seq_len = k_val.size(2) + end_pos = start_pos + seq_len + + torch._check(start_pos >= 0) + torch._check(end_pos <= self.max_context_length) + torch._check(end_pos >= 0) + + k_cache, v_cache = self.cache.update(input_pos, k_val, v_val) + + k_valid = k_cache[:, :, :end_pos, :] + v_valid = v_cache[:, :, :end_pos, :] + return k_valid, v_valid + + +@register_test +class KVCacheSliceTest(OpTestCase): + """ + Test case for MLX KVCache update followed by slicing. + + This verifies that: + 1. The ET llama KVCache-compatible interface works correctly + 2. Subsequent slice operations on the returned cache work correctly + """ + + name = "kv_cache_slice" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + max_batch_size: int = 1, + n_heads: int = 4, + head_dim: int = 64, + max_context_length: int = 128, + seq_step: int = 8, + enable_dynamic_shape: bool = True, + ): + self.max_batch_size = max_batch_size + self.n_heads = n_heads + self.head_dim = head_dim + self.max_context_length = max_context_length + self.seq_step = seq_step + self.enable_dynamic_shape = enable_dynamic_shape + + @classmethod + def get_test_configs(cls) -> List["KVCacheSliceTest"]: + return [ + cls(), + cls(n_heads=8, head_dim=32), + ] + + def create_model(self) -> nn.Module: + return KVCacheSliceModel( + max_batch_size=self.max_batch_size, + max_context_length=self.max_context_length, + n_heads=self.n_heads, + head_dim=self.head_dim, + enable_dynamic_shape=self.enable_dynamic_shape, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + input_pos = torch.tensor([0], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, self.n_heads, self.seq_step, self.head_dim + ) + v_val = torch.randn( + self.max_batch_size, self.n_heads, self.seq_step, self.head_dim + ) + return (input_pos, k_val, v_val) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + test_seq_step = self.seq_step + 4 + input_pos = torch.tensor([16], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, self.n_heads, test_seq_step, self.head_dim + ) + v_val = torch.randn( + self.max_batch_size, self.n_heads, test_seq_step, self.head_dim + ) + return (input_pos, k_val, v_val) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + seq_dim = Dim("seq_step", min=1, max=self.max_context_length) + return { + "input_pos": None, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + } + + +class RingBufferKVCacheModel(nn.Module): + """ + Test model wrapping RingBufferKVCache from cache.py. + + Updates the ring buffer cache and returns the full cache contents. + Uses kv_cache_update with ring_size > 0, which should emit + ModIntNode + SubtractIntNode + two SliceUpdateNodes. + """ + + def __init__( + self, + max_batch_size: int, + max_context_length: int, + n_heads: int, + head_dim: int, + ): + super().__init__() + from executorch.backends.mlx.llm.cache import RingBufferKVCache + + self.cache = RingBufferKVCache( + max_batch_size=max_batch_size, + max_context_length=max_context_length, + n_heads=n_heads, + head_dim=head_dim, + ) + + def forward( + self, + input_pos: torch.Tensor, + k_val: torch.Tensor, + v_val: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + start_pos = input_pos[0].item() + torch._check(start_pos >= 0) + + k_cache, v_cache = self.cache.update(start_pos, k_val, v_val) + return k_cache, v_cache + + +@register_test +class RingBufferKVCacheTest(OpTestCase): + """ + Test case for RingBufferKVCache with ring_size > 0. + + Verifies that kv_cache_update with ring_size emits the ring buffer + SliceUpdate pattern (ModInt + SubtractInt + 2x Slice + 2x SliceUpdate) + and produces correct results. + """ + + name = "ring_buffer_kv_cache" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + max_batch_size: int = 1, + n_heads: int = 4, + head_dim: int = 64, + max_context_length: int = 64, + seq_step: int = 4, + ): + self.max_batch_size = max_batch_size + self.n_heads = n_heads + self.head_dim = head_dim + self.max_context_length = max_context_length + self.seq_step = seq_step + + @classmethod + def get_test_configs(cls) -> List["RingBufferKVCacheTest"]: + return [ + cls(), + cls(n_heads=8, head_dim=32, max_context_length=32, seq_step=2), + ] + + def create_model(self) -> nn.Module: + return RingBufferKVCacheModel( + max_batch_size=self.max_batch_size, + max_context_length=self.max_context_length, + n_heads=self.n_heads, + head_dim=self.head_dim, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + input_pos = torch.tensor([0], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, self.n_heads, self.seq_step, self.head_dim + ) + v_val = torch.randn( + self.max_batch_size, self.n_heads, self.seq_step, self.head_dim + ) + return (input_pos, k_val, v_val) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + test_seq_step = self.seq_step + 2 + input_pos = torch.tensor([8], dtype=torch.int64) + k_val = torch.randn( + self.max_batch_size, self.n_heads, test_seq_step, self.head_dim + ) + v_val = torch.randn( + self.max_batch_size, self.n_heads, test_seq_step, self.head_dim + ) + return (input_pos, k_val, v_val) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + seq_dim = Dim("seq_step", min=1, max=self.max_context_length) + return { + "input_pos": None, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + } + + def get_expected_node_counts(self) -> Optional[Dict[str, int]]: + return { + "ItemIntNode": 1, + "ModIntNode": 2, + "SymSizeNode": 4, + "SubtractIntNode": 4, + "AddIntNode": 2, + "SliceNode": 4, + "SliceUpdateNode": 4, + "IdCopyNode": 2, + } + + +class MockModelConfig: + """ + Mock HuggingFace model config for testing HFStaticCache. + + This simulates the config structure expected by HFStaticCache. + """ + + def __init__( + self, + num_hidden_layers: int = 2, + num_attention_heads: int = 4, + num_key_value_heads: int | None = None, + hidden_size: int = 256, + head_dim: int | None = None, + max_position_embeddings: int = 128, + ): + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads or num_attention_heads + self.hidden_size = hidden_size + self.head_dim = head_dim or (hidden_size // num_attention_heads) + self.max_position_embeddings = max_position_embeddings + + def get_text_config(self, **kwargs): + """Return self for HF StaticCache compatibility.""" + return self + + +class HFStaticCacheModel(nn.Module): + """ + Test model wrapping HFStaticCache from cache.py. + + This tests the HuggingFace-compatible StaticCache interface. + """ + + def __init__( + self, + config: MockModelConfig, + layer_idx: int = 0, + ): + super().__init__() + from executorch.backends.mlx.llm.cache import HFStaticCache + + self.cache = HFStaticCache(config) + self.layer_idx = layer_idx + + # Register buffers explicitly so torch.export treats them as mutable + # buffers rather than constants. This mirrors what replace_hf_cache_with_mlx() does. + for i, layer_cache in enumerate(self.cache.kv_cache): + self.register_buffer( + f"key_cache_{i}", layer_cache.k_cache, persistent=False + ) + self.register_buffer( + f"value_cache_{i}", layer_cache.v_cache, persistent=False + ) + + def forward( + self, + k_val: torch.Tensor, # [B, H, S, D] + v_val: torch.Tensor, # [B, H, S, D] + cache_position: torch.Tensor, # 1D tensor with start position + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Update cache using HuggingFace-style interface.""" + return self.cache.update( + k_val, + v_val, + self.layer_idx, + cache_kwargs={"cache_position": cache_position}, + ) + + +@register_test +class HFStaticCacheTest(OpTestCase): + """Test case for HFStaticCache with HuggingFace-compatible interface.""" + + name = "hf_static_cache" + rtol = 1e-5 + atol = 1e-5 + expected_node_counts = { + "ItemIntNode": 1, + "SymSizeNode": 2, + "AddIntNode": 2, + "SliceUpdateNode": 2, + "IdCopyNode": 2, + } + + def __init__( + self, + num_heads: int = 4, + head_dim: int = 64, + num_layers: int = 2, + max_seq_len: int = 128, + seq_step: int = 8, + layer_idx: int = 0, + ): + self.num_heads = num_heads + self.head_dim = head_dim + self.num_layers = num_layers + self.max_seq_len = max_seq_len + self.seq_step = seq_step + self.layer_idx = layer_idx + + @classmethod + def get_test_configs(cls) -> List["HFStaticCacheTest"]: + return [ + cls(), # default config, layer 0 + cls(num_heads=8, head_dim=32, layer_idx=1), # different config, layer 1 + ] + + def create_model(self) -> nn.Module: + config = MockModelConfig( + num_hidden_layers=self.num_layers, + num_attention_heads=self.num_heads, + hidden_size=self.num_heads * self.head_dim, + head_dim=self.head_dim, + max_position_embeddings=self.max_seq_len, + ) + return HFStaticCacheModel(config, layer_idx=self.layer_idx) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # BHSD layout [B, H, S, D] + k_val = torch.randn(1, self.num_heads, self.seq_step, self.head_dim) + v_val = torch.randn(1, self.num_heads, self.seq_step, self.head_dim) + cache_position = torch.tensor([0], dtype=torch.int64) # 1D tensor + return (k_val, v_val, cache_position) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + # Test with different position and different seq_step + test_seq_step = self.seq_step + 4 # Different from export seq_step + k_val = torch.randn(1, self.num_heads, test_seq_step, self.head_dim) + v_val = torch.randn(1, self.num_heads, test_seq_step, self.head_dim) + cache_position = torch.tensor([16], dtype=torch.int64) # 1D tensor + return (k_val, v_val, cache_position) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + # seq_step (dim 2) is dynamic for k_val and v_val + seq_dim = Dim("seq_step", min=1, max=self.max_seq_len) + return { + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + "cache_position": None, + } + + +class HFStaticCacheSliceModel(nn.Module): + """ + Test model that updates HFStaticCache then slices the result. + + This tests that operations on the returned cache work correctly + with the HuggingFace-compatible interface. + """ + + def __init__( + self, + config: MockModelConfig, + layer_idx: int = 0, + ): + super().__init__() + from executorch.backends.mlx.llm.cache import HFStaticCache + + self.max_seq_len = config.max_position_embeddings + self.cache = HFStaticCache(config) + self.layer_idx = layer_idx + + # Register buffers explicitly so torch.export treats them as mutable + # buffers rather than constants. This mirrors what replace_hf_cache_with_mlx() does. + for i, layer_cache in enumerate(self.cache.kv_cache): + self.register_buffer( + f"key_cache_{i}", layer_cache.k_cache, persistent=False + ) + self.register_buffer( + f"value_cache_{i}", layer_cache.v_cache, persistent=False + ) + + def forward( + self, + k_val: torch.Tensor, # [B, H, S, D] + v_val: torch.Tensor, # [B, H, S, D] + cache_position: torch.Tensor, # 1D tensor with start position + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Update cache and return sliced cache (only the valid portion).""" + pos = cache_position[0].item() + seq_len = k_val.size(2) + end_pos = pos + seq_len + + # Add constraints for dynamic shapes + torch._check(pos >= 0) + torch._check(end_pos <= self.max_seq_len) + torch._check(end_pos >= 0) + + # Update cache using HuggingFace-style interface + k_cache, v_cache = self.cache.update( + k_val, + v_val, + self.layer_idx, + cache_kwargs={"cache_position": cache_position}, + ) + + # Slice to get only the valid portion [0:end_pos] + k_valid = k_cache[:, :, :end_pos, :] + v_valid = v_cache[:, :, :end_pos, :] + + return k_valid, v_valid + + +@register_test +class HFStaticCacheSliceTest(OpTestCase): + """ + Test case for HFStaticCache update followed by slicing. + + This verifies that: + 1. The HuggingFace-compatible interface works correctly + 2. Subsequent slice operations on the returned cache work correctly + """ + + name = "hf_static_cache_slice" + rtol = 1e-5 + atol = 1e-5 + expected_node_counts = { + "ItemIntNode": 2, + "SymSizeNode": 3, + "AddIntNode": 3, + "SliceUpdateNode": 2, + "IdCopyNode": 2, + "SliceNode": 2, + } + + def __init__( + self, + num_heads: int = 4, + head_dim: int = 64, + num_layers: int = 2, + max_seq_len: int = 128, + seq_step: int = 8, + layer_idx: int = 0, + ): + self.num_heads = num_heads + self.head_dim = head_dim + self.num_layers = num_layers + self.max_seq_len = max_seq_len + self.seq_step = seq_step + self.layer_idx = layer_idx + + @classmethod + def get_test_configs(cls) -> List["HFStaticCacheSliceTest"]: + return [ + cls(), # default config, layer 0 + cls(num_heads=8, head_dim=32, layer_idx=1), # different config, layer 1 + ] + + def create_model(self) -> nn.Module: + config = MockModelConfig( + num_hidden_layers=self.num_layers, + num_attention_heads=self.num_heads, + hidden_size=self.num_heads * self.head_dim, + head_dim=self.head_dim, + max_position_embeddings=self.max_seq_len, + ) + return HFStaticCacheSliceModel(config, layer_idx=self.layer_idx) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # BHSD layout [B, H, S, D] + k_val = torch.randn(1, self.num_heads, self.seq_step, self.head_dim) + v_val = torch.randn(1, self.num_heads, self.seq_step, self.head_dim) + cache_position = torch.tensor([0], dtype=torch.int64) # 1D tensor + return (k_val, v_val, cache_position) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + # Test with different position and different seq_step + test_seq_step = self.seq_step + 4 # Different from export seq_step + k_val = torch.randn(1, self.num_heads, test_seq_step, self.head_dim) + v_val = torch.randn(1, self.num_heads, test_seq_step, self.head_dim) + cache_position = torch.tensor([16], dtype=torch.int64) # 1D tensor + return (k_val, v_val, cache_position) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + # seq_step (dim 2) is dynamic for k_val and v_val + seq_dim = Dim("seq_step", min=1, max=self.max_seq_len) + return { + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + "cache_position": None, + } + + +class DynamicArangeModel(nn.Module): + """Model that uses arange with dynamic start/stop from tensor.item().""" + + def __init__(self, length: int, vocab_size: int = 32): + super().__init__() + self.length = length + self.embed = nn.Embedding(vocab_size, 16) + + def forward(self, pos: torch.Tensor) -> torch.Tensor: + torch._check(pos.numel() == 1) + pos_int = pos.item() + torch._check(pos_int >= 0) + positions = torch.arange( + pos_int, pos_int + self.length, device=pos.device, dtype=torch.long + ) + return self.embed(positions) + + +@register_test +class DynamicArangeTest(OpTestCase): + """Test case for torch.arange() with dynamic start/stop.""" + + name = "arange_dynamic" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + position: int = 4, + length: int = 4, + vocab_size: int = 32, + ): + self.position = position + self.length = length + self.vocab_size = vocab_size + self.name = f"arange_dynamic_pos{position}_len{length}" + + @classmethod + def get_test_configs(cls) -> List["DynamicArangeTest"]: + return [ + cls(position=0, length=4), + cls(position=4, length=4), + cls(position=10, length=8), + ] + + def create_model(self) -> nn.Module: + return DynamicArangeModel(length=self.length, vocab_size=self.vocab_size) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + pos = torch.tensor([self.position], dtype=torch.long) + return (pos,) + + +class LayerNormModel(nn.Module): + """Simple model using LayerNorm.""" + + def __init__(self, normalized_shape: int = 64, eps: float = 1e-5): + super().__init__() + self.layer_norm = nn.LayerNorm(normalized_shape, eps=eps) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layer_norm(x) + + +@register_test +class LayerNormTest(OpTestCase): + """Test case for nn.LayerNorm.""" + + name = "layer_norm" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + normalized_shape: int = 64, + batch_size: int = 2, + seq_len: int = 16, + eps: float = 1e-5, + ): + self.normalized_shape = normalized_shape + self.batch_size = batch_size + self.seq_len = seq_len + self.eps = eps + self.name = "layer_norm" + + @classmethod + def get_test_configs(cls) -> List["LayerNormTest"]: + return [ + cls(), + cls(normalized_shape=128, eps=1e-6), + ] + + def create_model(self) -> nn.Module: + return LayerNormModel(self.normalized_shape, self.eps) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.seq_len, self.normalized_shape) + return (x,) + + +class Conv1dModel(nn.Module): + """Simple model using Conv1d.""" + + def __init__( + self, + in_channels: int = 16, + out_channels: int = 32, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + bias: bool = True, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +@register_test +class Conv1dTest(OpTestCase): + """Test case for nn.Conv1d.""" + + name = "conv1d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + in_channels: int = 16, + out_channels: int = 32, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + bias: bool = True, + batch_size: int = 2, + seq_len: int = 64, + ): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.bias = bias + self.batch_size = batch_size + self.seq_len = seq_len + + parts = ["conv1d"] + if not bias: + parts.append("no_bias") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["Conv1dTest"]: + return [ + cls(), + cls(bias=False), + ] + + def create_model(self) -> nn.Module: + return Conv1dModel( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.bias, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.in_channels, self.seq_len) + return (x,) + + +class Conv2DModel(nn.Module): + """Model that performs 2D convolution.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = True, + ): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +@register_test +class Conv2DTest(OpTestCase): + """Test case for conv2d op.""" + + name = "conv2d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 16, + kernel_size: int = 3, + stride: int = 1, + padding: int = 0, + input_size: Tuple[int, int] = (32, 32), + batch_size: int = 1, + bias: bool = True, + ): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.input_size = input_size + self.batch_size = batch_size + self.bias = bias + + parts = [ + "conv2d", + f"in{in_channels}", + f"out{out_channels}", + f"k{kernel_size}", + ] + if stride != 1: + parts.append(f"s{stride}") + if padding != 0: + parts.append(f"p{padding}") + parts.append(f"{input_size[0]}x{input_size[1]}") + if batch_size != 1: + parts.append(f"b{batch_size}") + if not bias: + parts.append("nobias") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["Conv2DTest"]: + return [ + cls( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + input_size=(32, 32), + ), + cls( + in_channels=16, + out_channels=32, + kernel_size=3, + stride=2, + padding=1, + input_size=(64, 64), + ), + cls(in_channels=64, out_channels=128, kernel_size=1, input_size=(16, 16)), + # 5x5 conv + cls( + in_channels=3, + out_channels=8, + kernel_size=5, + padding=2, + input_size=(28, 28), + ), + # Batch size > 1 + cls( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + input_size=(32, 32), + batch_size=4, + ), + cls( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + input_size=(32, 32), + bias=False, + ), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, self.in_channels, self.input_size[0], self.input_size[1] + ) + return (x,) + + def create_model(self) -> nn.Module: + return Conv2DModel( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.bias, + ) + + +class Conv3DModel(nn.Module): + """Model that performs 3D convolution.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = True, + ): + super().__init__() + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +@register_test +class Conv3DTest(OpTestCase): + """Test case for conv3d op.""" + + name = "conv3d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 16, + kernel_size: int = 3, + stride: int = 1, + padding: int = 0, + input_size: Tuple[int, int, int] = (8, 16, 16), + batch_size: int = 1, + bias: bool = True, + ): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.input_size = input_size + self.batch_size = batch_size + self.bias = bias + + parts = [ + "conv3d", + f"in{in_channels}", + f"out{out_channels}", + f"k{kernel_size}", + ] + if stride != 1: + parts.append(f"s{stride}") + if padding != 0: + parts.append(f"p{padding}") + parts.append(f"{input_size[0]}x{input_size[1]}x{input_size[2]}") + if batch_size != 1: + parts.append(f"b{batch_size}") + if not bias: + parts.append("nobias") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["Conv3DTest"]: + return [ + cls( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + input_size=(8, 16, 16), + ), + cls( + in_channels=16, + out_channels=32, + kernel_size=3, + stride=2, + padding=1, + input_size=(8, 16, 16), + ), + cls(in_channels=64, out_channels=128, kernel_size=1, input_size=(4, 8, 8)), + # Batch size > 1 + cls( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + input_size=(8, 16, 16), + batch_size=2, + ), + cls( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + input_size=(8, 16, 16), + bias=False, + ), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, + self.in_channels, + self.input_size[0], + self.input_size[1], + self.input_size[2], + ) + return (x,) + + def create_model(self) -> nn.Module: + return Conv3DModel( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.bias, + ) + + +class ConvTranspose1dModel(nn.Module): + """Simple model using ConvTranspose1d.""" + + def __init__( + self, + in_channels: int = 16, + out_channels: int = 32, + kernel_size: int = 3, + stride: int = 1, + padding: int = 0, + output_padding: int = 0, + bias: bool = True, + groups: int = 1, + ): + super().__init__() + self.conv_transpose = nn.ConvTranspose1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + bias=bias, + groups=groups, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv_transpose(x) + + +@register_test +class ConvTranspose1dTest(OpTestCase): + """Test case for nn.ConvTranspose1d.""" + + name = "conv_transpose1d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + in_channels: int = 16, + out_channels: int = 32, + kernel_size: int = 3, + stride: int = 1, + padding: int = 0, + output_padding: int = 0, + bias: bool = True, + groups: int = 1, + batch_size: int = 2, + seq_len: int = 64, + ): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.output_padding = output_padding + self.bias = bias + self.groups = groups + self.batch_size = batch_size + self.seq_len = seq_len + + parts = ["conv_transpose1d"] + if stride != 1: + parts.append(f"s{stride}") + if padding != 0: + parts.append(f"p{padding}") + if output_padding != 0: + parts.append(f"op{output_padding}") + if not bias: + parts.append("no_bias") + if groups != 1: + parts.append(f"g{groups}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["ConvTranspose1dTest"]: + return [ + cls(), + cls(bias=False), + cls(stride=2), + cls(stride=2, output_padding=1), + cls(padding=1), + cls(in_channels=8, out_channels=8, groups=8), # depthwise + cls(in_channels=6, out_channels=6, groups=3), # grouped + ] + + def create_model(self) -> nn.Module: + return ConvTranspose1dModel( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.output_padding, + self.bias, + self.groups, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.in_channels, self.seq_len) + return (x,) + + +class ConvTranspose2DModel(nn.Module): + """Model that performs 2D transposed convolution.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + output_padding: int = 0, + bias: bool = True, + groups: int = 1, + ): + super().__init__() + self.conv_transpose = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + bias=bias, + groups=groups, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv_transpose(x) + + +@register_test +class ConvTranspose2DTest(OpTestCase): + """Test case for nn.ConvTranspose2d.""" + + name = "conv_transpose2d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 16, + kernel_size: int = 3, + stride: int = 1, + padding: int = 0, + output_padding: int = 0, + input_size: Tuple[int, int] = (32, 32), + batch_size: int = 1, + bias: bool = True, + groups: int = 1, + ): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.output_padding = output_padding + self.input_size = input_size + self.batch_size = batch_size + self.bias = bias + self.groups = groups + + parts = [ + "conv_transpose2d", + f"in{in_channels}", + f"out{out_channels}", + f"k{kernel_size}", + ] + if stride != 1: + parts.append(f"s{stride}") + if padding != 0: + parts.append(f"p{padding}") + if output_padding != 0: + parts.append(f"op{output_padding}") + parts.append(f"{input_size[0]}x{input_size[1]}") + if not bias: + parts.append("nobias") + if groups != 1: + parts.append(f"g{groups}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["ConvTranspose2DTest"]: + return [ + cls(in_channels=3, out_channels=16, kernel_size=3, padding=1), + cls(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), + cls( + in_channels=16, + out_channels=32, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + ), + cls(in_channels=64, out_channels=128, kernel_size=1), + cls(in_channels=3, out_channels=16, kernel_size=3, padding=1, bias=False), + cls( + in_channels=8, out_channels=8, kernel_size=3, padding=1, groups=8 + ), # depthwise + cls( + in_channels=6, out_channels=6, kernel_size=3, padding=1, groups=3 + ), # grouped + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, + self.in_channels, + self.input_size[0], + self.input_size[1], + ) + return (x,) + + def create_model(self) -> nn.Module: + return ConvTranspose2DModel( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.output_padding, + self.bias, + self.groups, + ) + + +class ConvTranspose3DModel(nn.Module): + """Model that performs 3D transposed convolution.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + output_padding: int = 0, + bias: bool = True, + ): + super().__init__() + self.conv_transpose = nn.ConvTranspose3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv_transpose(x) + + +@register_test +class ConvTranspose3DTest(OpTestCase): + """Test case for nn.ConvTranspose3d.""" + + name = "conv_transpose3d" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 16, + kernel_size: int = 3, + stride: int = 1, + padding: int = 0, + output_padding: int = 0, + input_size: Tuple[int, int, int] = (8, 16, 16), + batch_size: int = 1, + bias: bool = True, + ): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.output_padding = output_padding + self.input_size = input_size + self.batch_size = batch_size + self.bias = bias + + parts = [ + "conv_transpose3d", + f"in{in_channels}", + f"out{out_channels}", + f"k{kernel_size}", + ] + if stride != 1: + parts.append(f"s{stride}") + if padding != 0: + parts.append(f"p{padding}") + if output_padding != 0: + parts.append(f"op{output_padding}") + parts.append(f"{input_size[0]}x{input_size[1]}x{input_size[2]}") + if not bias: + parts.append("nobias") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["ConvTranspose3DTest"]: + return [ + cls(in_channels=3, out_channels=16, kernel_size=3, padding=1), + cls(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), + cls(in_channels=64, out_channels=128, kernel_size=1), + cls(in_channels=3, out_channels=16, kernel_size=3, padding=1, bias=False), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, + self.in_channels, + self.input_size[0], + self.input_size[1], + self.input_size[2], + ) + return (x,) + + def create_model(self) -> nn.Module: + return ConvTranspose3DModel( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.padding, + self.output_padding, + self.bias, + ) + + +class SliceScatterModel(nn.Module): + """Model that performs slice_scatter.""" + + def __init__(self, dim: int = 0, start: int = 0, end: int = 2, step: int = 1): + super().__init__() + self.dim = dim + self.start = start + self.end = end + self.step = step + + def forward(self, x: torch.Tensor, src: torch.Tensor) -> torch.Tensor: + return x.slice_scatter( + src, dim=self.dim, start=self.start, end=self.end, step=self.step + ) + + +@register_test +class SliceScatterTest(OpTestCase): + """Test case for aten.slice_scatter.""" + + name = "slice_scatter" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + input_shape: Tuple[int, ...] = (4, 8), + dim: int = 0, + start: int = 0, + end: int = 2, + step: int = 1, + ): + self.input_shape = input_shape + self.dim = dim + self.start = start + self.end = end + self.step = step + + parts = ["slice_scatter", f"d{dim}", f"s{start}", f"e{end}"] + if step != 1: + parts.append(f"step{step}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["SliceScatterTest"]: + return [ + # Basic: replace first 2 rows + cls(input_shape=(4, 8), dim=0, start=0, end=2), + # Replace middle rows + cls(input_shape=(4, 8), dim=0, start=1, end=3), + # Along dim=1 + cls(input_shape=(4, 8), dim=1, start=2, end=6), + # With step=2 + cls(input_shape=(4, 8), dim=0, start=0, end=4, step=2), + # 3D tensor + cls(input_shape=(2, 4, 8), dim=1, start=0, end=2), + ] + + def create_model(self) -> nn.Module: + return SliceScatterModel( + dim=self.dim, + start=self.start, + end=self.end, + step=self.step, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.input_shape) + # Compute the src shape: same as x but with the slice size along dim + src_shape = list(self.input_shape) + slice_len = len(range(self.start, self.end, self.step)) + src_shape[self.dim] = slice_len + src = torch.randn(src_shape) + return (x, src) + + class BmmModel(nn.Module): """Model that performs batch matrix multiplication.""" - def __init__(self, batch_size: int, n: int, m: int, p: int): + def __init__(self, batch_size: int, n: int, m: int, p: int): + super().__init__() + self.weight = nn.Parameter(torch.randn(batch_size, m, p)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.bmm(x, self.weight) + + +@register_test +class BmmTest(OpTestCase): + """Test case for bmm (batch matrix multiplication).""" + + name = "bmm" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 4, + n: int = 8, + m: int = 16, + p: int = 32, + ): + self.batch_size = batch_size + self.n = n + self.m = m + self.p = p + self.name = f"bmm_{batch_size}x{n}x{m}x{p}" + + @classmethod + def get_test_configs(cls) -> List["BmmTest"]: + return [ + cls(batch_size=4, n=8, m=16, p=32), + cls(batch_size=2, n=64, m=64, p=32), + ] + + def create_model(self) -> nn.Module: + return BmmModel(self.batch_size, self.n, self.m, self.p) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.n, self.m) + return (x,) + + +class AddmmModel(nn.Module): + """Model that performs addmm: bias + (mat1 @ mat2).""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + alpha: float = 1.0, + beta: float = 1.0, + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.randn(out_features)) + else: + self.bias = None + self.alpha = alpha + self.beta = beta + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.bias is not None: + return torch.addmm( + self.bias, x, self.weight.t(), beta=self.beta, alpha=self.alpha + ) + else: + return torch.mm(x, self.weight.t()) + + +@register_test +class AddmmTest(OpTestCase): + """Test case for addmm.""" + + name = "addmm" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 2, + in_features: int = 64, + out_features: int = 32, + bias: bool = True, + alpha: float = 1.0, + beta: float = 1.0, + ): + self.batch_size = batch_size + self.in_features = in_features + self.out_features = out_features + self.bias = bias + self.alpha = alpha + self.beta = beta + + # Build unique test name + if not bias: + name = f"addmm_{in_features}x{out_features}_no_bias" + elif alpha != 1.0 or beta != 1.0: + name = f"addmm_{in_features}x{out_features}_a{alpha}_b{beta}" + else: + name = f"addmm_{in_features}x{out_features}" + self.name = name + + @classmethod + def get_test_configs(cls) -> List["AddmmTest"]: + return [ + cls( + batch_size=2, in_features=64, out_features=32 + ), # with bias, default alpha/beta + cls( + batch_size=2, in_features=64, out_features=32, bias=False + ), # without bias + cls(batch_size=4, in_features=128, out_features=64), # larger size + cls( + batch_size=2, in_features=64, out_features=32, alpha=2.0, beta=0.5 + ), # custom alpha/beta + ] + + def create_model(self) -> nn.Module: + return AddmmModel( + self.in_features, + self.out_features, + bias=self.bias, + alpha=self.alpha, + beta=self.beta, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.in_features) + return (x,) + + +class ExpandModel(nn.Module): + """Model that expands a tensor to a larger shape.""" + + def __init__(self, target_shape: Tuple[int, ...]): + super().__init__() + self.target_shape = target_shape + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.expand(self.target_shape) + + +@register_test +class ExpandTest(OpTestCase): + """Test case for expand (expand_copy) op.""" + + name = "expand" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + input_shape: Tuple[int, ...] = (1, 3, 1), + target_shape: Tuple[int, ...] = (2, 3, 4), + ): + self.input_shape = input_shape + self.target_shape = target_shape + + input_str = "x".join(str(s) for s in input_shape) + target_str = "x".join(str(s) for s in target_shape) + self.name = f"expand_{input_str}_to_{target_str}" + + @classmethod + def get_test_configs(cls) -> List["ExpandTest"]: + return [ + cls(input_shape=(2, 3, 1), target_shape=(2, 3, 4)), + cls(input_shape=(1, 3, 4), target_shape=(2, 3, 4)), + cls(input_shape=(1, 1, 4), target_shape=(2, 3, 4)), + cls(input_shape=(1, 1, 1), target_shape=(2, 3, 4)), + cls(input_shape=(1, 8), target_shape=(4, 8)), + cls(input_shape=(1, 1, 1, 64), target_shape=(2, 8, 16, 64)), + # Expand with -1 (keep dimension unchanged from input) + cls(input_shape=(93,), target_shape=(1, -1)), + # Multiple -1 dimensions (keep all but first) + cls(input_shape=(1, 1, 5, 8), target_shape=(1, -1, -1, -1)), + # Multiple -1 with actual expansion on first dim + cls(input_shape=(1, 3, 5, 8), target_shape=(2, -1, -1, -1)), + # Two -1 dimensions at start + cls(input_shape=(2, 3, 4), target_shape=(-1, -1, 4)), + ] + + def create_model(self) -> nn.Module: + return ExpandModel(self.target_shape) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.input_shape) + return (x,) + + +class IndexModel(nn.Module): + """Model that indexes a tensor using another tensor.""" + + def forward(self, x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + return x[indices] + + +@register_test +class IndexTest(OpTestCase): + """Test case for tensor indexing.""" + + name = "index" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + table_size: int = 100, + num_indices: int = 10, + ): + self.table_size = table_size + self.num_indices = num_indices + self.name = f"index_{table_size}_idx{num_indices}" + + @classmethod + def get_test_configs(cls) -> List["IndexTest"]: + return [ + cls(table_size=100, num_indices=10), + cls(table_size=50, num_indices=5), + ] + + def create_model(self) -> nn.Module: + return IndexModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.table_size) + indices = torch.randint(0, self.table_size, (self.num_indices,)) + return (x, indices) + + +class AdvancedIndexModel(nn.Module): + """Model that performs advanced (multi-index) tensor indexing. + + Implements x[i0, i1, ...] with multiple index tensors, which maps to + aten.index.Tensor with multiple non-None indices. + """ + + def __init__(self, num_indexed_dims: int): + super().__init__() + self.num_indexed_dims = num_indexed_dims + + def forward(self, x: torch.Tensor, *indices: torch.Tensor) -> torch.Tensor: + idx_list = list(indices) + return x[tuple(idx_list)] + + +@register_test +class AdvancedIndexTest(OpTestCase): + """Test case for multi-index tensor indexing (advanced/fancy indexing).""" + + name = "advanced_index" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + input_shape: Tuple[int, ...] = (4, 5, 6), + num_indexed_dims: int = 2, + num_indices: int = 3, + ): + self.input_shape = input_shape + self.num_indexed_dims = num_indexed_dims + self.num_indices = num_indices + self.name = ( + f"advanced_index_{'x'.join(str(s) for s in input_shape)}" + f"_dims{num_indexed_dims}_idx{num_indices}" + ) + + @classmethod + def get_test_configs(cls) -> List["AdvancedIndexTest"]: + return [ + # 2D input, index both dims + cls(input_shape=(8, 6), num_indexed_dims=2, num_indices=4), + # 3D input, index all 3 dims + cls(input_shape=(4, 5, 6), num_indexed_dims=3, num_indices=3), + # 4D input, index all 4 dims (the original failing case) + cls(input_shape=(2, 3, 4, 5), num_indexed_dims=4, num_indices=2), + # 3D input, index 2 of 3 dims + cls(input_shape=(4, 5, 6), num_indexed_dims=2, num_indices=5), + ] + + def create_model(self) -> nn.Module: + return AdvancedIndexModel(self.num_indexed_dims) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.input_shape) + indices = [] + for dim in range(self.num_indexed_dims): + idx = torch.randint(0, self.input_shape[dim], (self.num_indices,)) + indices.append(idx) + return (x, *indices) + + +class IndexUpdateModel(nn.Module): + """Model that performs index_copy on a mutable buffer. + + This triggers the INDEX_UPDATE pattern which matches aten.index_copy.default + on a mutable buffer and lowers it to IndexUpdateNode. + """ + + def __init__( + self, + buffer_size: int = 128, + feature_dim: int = 64, + axis: int = 0, + ): + super().__init__() + self.axis = axis + if axis == 0: + self.register_buffer("data", torch.zeros(buffer_size, feature_dim)) + else: + # axis == 1 + self.register_buffer("data", torch.zeros(feature_dim, buffer_size)) + + def forward(self, indices: torch.Tensor, update: torch.Tensor) -> torch.Tensor: + """Update buffer at indices along axis using index_copy.""" + self.data.index_copy_(self.axis, indices, update) + return self.data.clone() + + +@register_test +class IndexUpdateTest(OpTestCase): + """Test case for index_update pattern (index_copy on mutable buffer). + + This tests the INDEX_UPDATE pattern handler which recognizes + aten.index_copy.default on a mutable buffer and lowers it to IndexUpdateNode. + The buffer is managed internally by the MLX backend. + """ + + name = "index_update" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + buffer_size: int = 128, + feature_dim: int = 64, + num_indices: int = 8, + axis: int = 0, + ): + self.buffer_size = buffer_size + self.feature_dim = feature_dim + self.num_indices = num_indices + self.axis = axis + self.name = ( + f"index_update_axis{axis}_{buffer_size}x{feature_dim}_idx{num_indices}" + ) + + @classmethod + def get_test_configs(cls) -> List["IndexUpdateTest"]: + return [ + # Basic case: update along axis 0 + cls(buffer_size=128, feature_dim=64, num_indices=8, axis=0), + # Smaller buffer + cls(buffer_size=32, feature_dim=16, num_indices=4, axis=0), + # Update along axis 1 + cls(buffer_size=64, feature_dim=32, num_indices=8, axis=1), + ] + + def create_model(self) -> nn.Module: + return IndexUpdateModel( + buffer_size=self.buffer_size, + feature_dim=self.feature_dim, + axis=self.axis, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Create unique indices (no duplicates) for index_copy + # PyTorch requires int64 (long) for indices + indices = torch.randperm(self.buffer_size)[: self.num_indices].to(torch.int64) + + # Create update tensor with shape matching the indexed dimension + if self.axis == 0: + update = torch.randn(self.num_indices, self.feature_dim) + else: + update = torch.randn(self.feature_dim, self.num_indices) + + return (indices, update) + + +class SplitWithSizesModel(nn.Module): + """Model that splits a tensor into chunks with specified sizes.""" + + def __init__(self, sizes, dim=0): + super().__init__() + self.sizes = sizes + self.dim = dim + + def forward(self, x): + chunks = torch.ops.aten.split_with_sizes_copy.default(x, self.sizes, self.dim) + return chunks[0] + + +class SplitWithSizesMultiOutputModel(nn.Module): + """Model that splits with specified sizes and uses multiple outputs.""" + + def __init__(self, sizes, dim=0): + super().__init__() + self.sizes = sizes + self.dim = dim + + def forward(self, x): + chunks = torch.ops.aten.split_with_sizes_copy.default(x, self.sizes, self.dim) + return chunks[0] + chunks[-1] + + +class SplitUniformModel(nn.Module): + """Model that splits a tensor into chunks of uniform size using torch.split.""" + + def __init__(self, split_size, dim=0): + super().__init__() + self.split_size = split_size + self.dim = dim + + def forward(self, x): + chunks = torch.split(x, self.split_size, dim=self.dim) + return chunks[0] + + +class SplitUniformMultiOutputModel(nn.Module): + """Model that splits uniformly and uses multiple outputs.""" + + def __init__(self, split_size, dim=0): + super().__init__() + self.split_size = split_size + self.dim = dim + + def forward(self, x): + chunks = torch.split(x, self.split_size, dim=self.dim) + return torch.cat([chunks[0], chunks[-1]], dim=self.dim) + + +@register_test +class SplitTest(OpTestCase): + name = "split" + rtol = 1e-5 + atol = 1e-5 + + def __init__(self, shape, model_cls, model_kwargs, tag=""): + self.shape = shape + self.model_cls = model_cls + self.model_kwargs = model_kwargs + self.name = f"split_{tag}" if tag else "split" + + @classmethod + def get_test_configs(cls) -> List["SplitTest"]: + return [ + # split_with_sizes_copy tests + cls( + shape=(9, 4), + model_cls=SplitWithSizesModel, + model_kwargs={"sizes": [2, 3, 4], "dim": 0}, + tag="sizes_dim0", + ), + cls( + shape=(3, 10), + model_cls=SplitWithSizesModel, + model_kwargs={"sizes": [2, 3, 5], "dim": 1}, + tag="sizes_dim1", + ), + cls( + shape=(2, 12, 4), + model_cls=SplitWithSizesModel, + model_kwargs={"sizes": [3, 4, 5], "dim": 1}, + tag="sizes_3d", + ), + cls( + shape=(8, 4), + model_cls=SplitWithSizesModel, + model_kwargs={"sizes": [3, 5], "dim": 0}, + tag="sizes_two", + ), + cls( + shape=(10, 3), + model_cls=SplitWithSizesMultiOutputModel, + model_kwargs={"sizes": [5, 5], "dim": 0}, + tag="sizes_multi", + ), + # torch.split (uniform) tests + cls( + shape=(10, 4), + model_cls=SplitUniformModel, + model_kwargs={"split_size": 3, "dim": 0}, + tag="uniform_dim0", + ), + cls( + shape=(3, 7), + model_cls=SplitUniformModel, + model_kwargs={"split_size": 4, "dim": 1}, + tag="uniform_dim1", + ), + cls( + shape=(11, 5), + model_cls=SplitUniformMultiOutputModel, + model_kwargs={"split_size": 3, "dim": 0}, + tag="uniform_multi", + ), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape),) + + def create_model(self) -> nn.Module: + return self.model_cls(**self.model_kwargs) + + +class ArangeModel(nn.Module): + """Model that creates a tensor using arange and multiplies with input.""" + + def __init__(self, stop: int, use_dtype: bool = True): + super().__init__() + self.stop = stop + self.use_dtype = use_dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_dtype: + indices = torch.arange(self.stop, dtype=x.dtype, device=x.device) + else: + # No dtype - let MLX infer (defaults to int64 for integer inputs) + indices = torch.arange(self.stop, device=x.device) + indices = indices.to(x.dtype) # Cast for multiplication + return x * indices + + +@register_test +class ArangeTest(OpTestCase): + """Test case for torch.arange().""" + + name = "arange" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + stop: int = 10, + dtype: torch.dtype = torch.float32, + use_dtype: bool = True, + ): + self.stop = stop + self.dtype = dtype + self.use_dtype = use_dtype + dtype_name = str(dtype).split(".")[-1] + if use_dtype: + self.name = f"arange_{stop}_{dtype_name}" + else: + self.name = f"arange_{stop}_no_dtype" + + @classmethod + def get_test_configs(cls) -> List["ArangeTest"]: + return [ + # With explicit dtype + cls(stop=10, dtype=torch.float32, use_dtype=True), + cls(stop=32, dtype=torch.float32, use_dtype=True), + cls(stop=100, dtype=torch.float32, use_dtype=True), + cls(stop=16, dtype=torch.int32, use_dtype=True), + cls(stop=16, dtype=torch.int64, use_dtype=True), + # Without dtype (let MLX infer) + cls(stop=10, dtype=torch.float32, use_dtype=False), + cls(stop=32, dtype=torch.float32, use_dtype=False), + ] + + def create_model(self) -> nn.Module: + return ArangeModel(self.stop, use_dtype=self.use_dtype) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + if self.dtype in (torch.int32, torch.int64): + x = torch.randint(1, 10, (self.stop,), dtype=self.dtype) + else: + x = torch.randn(self.stop, dtype=self.dtype) + return (x,) + + +class UnaryOpModel(nn.Module): + """Generic model that applies a single unary torch op.""" + + def __init__(self, op_fn: Callable): + super().__init__() + self.op_fn = op_fn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.op_fn(x) + + +def _input_fn( + uniform: bool = False, scale: float = 1.0, offset: float = 0.0, abs: bool = False +): + """Return a callable(shape, dtype) that generates a single-element input tuple. + + Args: + uniform: Use torch.rand (uniform [0,1]) instead of torch.randn (normal). + scale: Multiply the base tensor by this value. + offset: Add this value after scaling. + abs: Apply .abs() to the base tensor before scale/offset. + """ + + def fn(shape, dtype): + base = ( + torch.rand(shape, dtype=dtype) + if uniform + else torch.randn(shape, dtype=dtype) + ) + if abs: + base = base.abs() + return (base * scale + offset,) + + return fn + + +def _bool_input_fn(): + """Return a callable(shape, dtype) that generates a single-element bool tensor tuple.""" + + def fn(shape, _dtype): + return (torch.randint(0, 2, shape, dtype=torch.bool),) + + return fn + + +def _int_input_fn(low: int = -100, high: int = 100): + """Return a callable(shape, dtype) that generates a single-element integer tensor tuple.""" + + def fn(shape, dtype): + return (torch.randint(low, high, shape, dtype=dtype),) + + return fn + + +# Standard shape and dtype configs used by unary tests. +_SHAPES_3 = [(16,), (4, 4), (2, 3, 4)] +_SHAPES_2 = [(16,), (4, 4)] +_UNARY_DTYPES = [torch.float32, torch.float16, torch.bfloat16] + + +def _make_unary_op_test( + op_name: str, + op_fn: Callable, + shapes: List[Tuple[int, ...]] = None, + dtypes: List[torch.dtype] = None, + input_fn: Callable = None, +) -> type: + """Generate a registered OpTestCase subclass for a unary math op. + + Args: + op_name: Name used for test registration and output directories. + op_fn: The torch function to test (e.g. torch.floor). + shapes: List of input shapes. Defaults to _SHAPES_2. + dtypes: List of dtypes to test. Defaults to _UNARY_DTYPES. + input_fn: Callable(shape, dtype) -> Tuple[Tensor, ...] that creates inputs. + Defaults to _input_fn() (standard randn). + """ + if shapes is None: + shapes = _SHAPES_2 + if dtypes is None: + dtypes = _UNARY_DTYPES + if input_fn is None: + input_fn = _input_fn() + + class _Test(OpTestCase): + name = op_name + + def __init__( + self, + shape: Tuple[int, ...], + dtype: torch.dtype, + ): + self.shape = shape + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"{op_name}_{shape_str}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["_Test"]: + return [cls(shape=s, dtype=d) for s in shapes for d in dtypes] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return input_fn(self.shape, self.dtype) + + def create_model(self) -> nn.Module: + return UnaryOpModel(op_fn) + + _Test.__name__ = f"{op_name.title().replace('_', '')}Test" + _Test.__qualname__ = _Test.__name__ + return _Test + + +# fmt: off +# Each entry is a dict with required keys "op_name" and "op_fn". +# Optional keys: "shapes" (default _SHAPES_2), "dtypes" (default _UNARY_DTYPES), +# "input_fn" (default _input_fn()). +# _input_fn(uniform, scale, offset) — uniform=True uses rand, False uses randn. +_UNARY_OP_TESTS = [ + {"op_name": "floor", "op_fn": torch.floor, "shapes": _SHAPES_3, "input_fn": _input_fn(scale=10)}, + {"op_name": "ceil", "op_fn": torch.ceil, "shapes": _SHAPES_3, "input_fn": _input_fn(scale=10)}, + {"op_name": "square", "op_fn": torch.square, "shapes": _SHAPES_3}, + {"op_name": "exp", "op_fn": torch.exp, "shapes": _SHAPES_3}, + {"op_name": "sin", "op_fn": torch.sin, "shapes": _SHAPES_3, "input_fn": _input_fn(scale=3.14159)}, + {"op_name": "cos", "op_fn": torch.cos, "shapes": _SHAPES_3, "input_fn": _input_fn(scale=3.14159)}, + {"op_name": "tan", "op_fn": torch.tan, "input_fn": _input_fn(scale=0.5)}, + {"op_name": "asin", "op_fn": torch.asin, "input_fn": _input_fn(uniform=True, scale=2, offset=-1)}, + {"op_name": "acos", "op_fn": torch.acos, "input_fn": _input_fn(uniform=True, scale=2, offset=-1)}, + {"op_name": "atan", "op_fn": torch.atan}, + {"op_name": "sinh", "op_fn": torch.sinh}, + {"op_name": "cosh", "op_fn": torch.cosh}, + {"op_name": "asinh", "op_fn": torch.asinh}, + {"op_name": "acosh", "op_fn": torch.acosh, "input_fn": _input_fn(uniform=True, offset=1.0)}, + {"op_name": "atanh", "op_fn": torch.atanh, "input_fn": _input_fn(uniform=True, scale=1.8, offset=-0.9)}, + {"op_name": "log2", "op_fn": torch.log2, "input_fn": _input_fn(uniform=True, offset=0.1)}, + {"op_name": "log10", "op_fn": torch.log10, "input_fn": _input_fn(uniform=True, offset=0.1)}, + {"op_name": "log1p", "op_fn": torch.log1p, "input_fn": _input_fn(uniform=True)}, + {"op_name": "erf", "op_fn": torch.erf}, + {"op_name": "expm1", "op_fn": torch.expm1}, + {"op_name": "round", "op_fn": torch.round, "input_fn": _input_fn(scale=10)}, + {"op_name": "reciprocal", "op_fn": torch.reciprocal, "input_fn": _input_fn(offset=1.0)}, + {"op_name": "sqrt", "op_fn": torch.sqrt, "input_fn": _input_fn(uniform=True, offset=0.1)}, + {"op_name": "abs", "op_fn": torch.abs}, + {"op_name": "neg", "op_fn": torch.neg}, + {"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()}, + # activations + {"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)}, + {"op_name": "sigmoid", "op_fn": torch.sigmoid, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2)}, + {"op_name": "tanh", "op_fn": torch.tanh, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=3)}, + {"op_name": "silu", "op_fn": nn.SiLU(), "shapes": [(2, 16, 64), (4, 32, 128)], "dtypes": [torch.float32]}, + # math + {"op_name": "rsqrt", "op_fn": torch.rsqrt, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(uniform=True, offset=0.1)}, + {"op_name": "clone", "op_fn": torch.clone, "shapes": [(2, 3, 4), (8, 8), (16,)], "dtypes": [torch.float32]}, +] +# fmt: on + +# Generate and register all unary math op test classes. +for _entry in _UNARY_OP_TESTS: + _cls = _make_unary_op_test(**_entry) + register_test(_cls) + globals()[_cls.__name__] = _cls + + +class BinaryOpModel(nn.Module): + def __init__(self, op_fn: Callable): + super().__init__() + self.op_fn = op_fn + + def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return self.op_fn(a, b) + + +class PowerScalarModel(nn.Module): + def __init__(self, exponent: float): + super().__init__() + self.exponent = exponent + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return torch.pow(a, self.exponent) + + +_BINARY_DTYPES = [torch.float32] + + +def _make_binary_op_test( + op_name: str, + op_fn: Callable, + shapes: List[Tuple[int, ...]] = None, + dtypes: List[torch.dtype] = None, + input_fn_a: Callable = None, + input_fn_b: Callable = None, +) -> type: + """Generate a registered OpTestCase subclass for a binary math op.""" + if shapes is None: + shapes = _SHAPES_3 + if dtypes is None: + dtypes = _BINARY_DTYPES + if input_fn_a is None: + input_fn_a = _input_fn() + if input_fn_b is None: + input_fn_b = _input_fn() + + class _Test(OpTestCase): + name = op_name + + def __init__( + self, + shape: Tuple[int, ...], + dtype: torch.dtype, + ): + self.shape = shape + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"{op_name}_{shape_str}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["_Test"]: + return [cls(shape=s, dtype=d) for s in shapes for d in dtypes] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return input_fn_a(self.shape, self.dtype) + input_fn_b( + self.shape, self.dtype + ) + + def create_model(self) -> nn.Module: + return BinaryOpModel(op_fn) + + _Test.__name__ = f"{op_name.title().replace('_', '')}Test" + _Test.__qualname__ = _Test.__name__ + return _Test + + +# fmt: off +_BINARY_OP_TESTS = [ + # math + {"op_name": "maximum", "op_fn": torch.maximum}, + {"op_name": "minimum", "op_fn": torch.minimum}, + {"op_name": "atan2", "op_fn": torch.atan2}, + {"op_name": "logaddexp", "op_fn": torch.logaddexp}, + {"op_name": "floor_divide", "op_fn": torch.floor_divide, "input_fn_a": _input_fn(scale=10), "input_fn_b": _input_fn(abs=True, offset=1)}, + {"op_name": "floor_divide_int", "op_fn": torch.floor_divide, "dtypes": [torch.int32], "input_fn_a": _int_input_fn(-100, 100), "input_fn_b": _int_input_fn(1, 10)}, + {"op_name": "remainder", "op_fn": torch.remainder, "input_fn_a": _input_fn(scale=10), "input_fn_b": _input_fn(abs=True, offset=1)}, + {"op_name": "remainder_int", "op_fn": torch.remainder, "dtypes": [torch.int32], "input_fn_a": _int_input_fn(-100, 100), "input_fn_b": _int_input_fn(1, 10)}, + {"op_name": "power", "op_fn": torch.pow, "input_fn_a": _input_fn(uniform=True, offset=0.5), "input_fn_b": _input_fn(uniform=True, scale=2)}, + # comparison + {"op_name": "less", "op_fn": torch.lt, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.float32, torch.bfloat16]}, + {"op_name": "less_equal", "op_fn": torch.le, "shapes": [(2, 3, 4), (10,)], "dtypes": [torch.float32]}, + {"op_name": "greater", "op_fn": torch.gt, "shapes": [(2, 3, 4), (10,)], "dtypes": [torch.float32]}, + {"op_name": "greater_equal", "op_fn": torch.ge, "shapes": [(2, 3, 4), (10,)], "dtypes": [torch.float32]}, + {"op_name": "equal", "op_fn": torch.eq, "shapes": [(2, 3, 4), (10,)], "dtypes": [torch.float32]}, + {"op_name": "not_equal", "op_fn": torch.ne, "shapes": [(2, 3, 4), (10,)], "dtypes": [torch.float32]}, + # logical + {"op_name": "logical_and", "op_fn": torch.logical_and, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()}, + {"op_name": "logical_or", "op_fn": torch.logical_or, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn_a": _bool_input_fn(), "input_fn_b": _bool_input_fn()}, +] +# fmt: on + + +for _entry in _BINARY_OP_TESTS: + _cls = _make_binary_op_test(**_entry) + register_test(_cls) + globals()[_cls.__name__] = _cls + + +@register_test +class PowerScalarTest(OpTestCase): + """Test case for aten.pow op (Tensor_Scalar variant).""" + + name = "power_scalar" + + def __init__( + self, + shape: Tuple[int, ...] = (4, 4), + exponent: float = 2.0, + dtype: torch.dtype = torch.float32, + ): + self.shape = shape + self.exponent = exponent + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"power_scalar_{shape_str}_exp{exponent}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["PowerScalarTest"]: + return [ + cls(shape=(16,), exponent=2.0, dtype=torch.float32), + cls(shape=(4, 4), exponent=0.5, dtype=torch.float32), + cls(shape=(4, 4), exponent=3.0, dtype=torch.float32), + cls(shape=(2, 3, 4), exponent=-1.0, dtype=torch.float32), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.rand(self.shape, dtype=self.dtype) + 0.5,) + + def create_model(self) -> nn.Module: + return PowerScalarModel(self.exponent) + + +class CompareScalarModel(nn.Module): + def __init__(self, op_fn: Callable, scalar: float): + super().__init__() + self.op_fn = op_fn + self.scalar = scalar + + def forward(self, a: torch.Tensor) -> torch.Tensor: + return self.op_fn(a, self.scalar) + + +def _make_compare_scalar_test( + op_name: str, + op_fn: Callable, +) -> type: + """Generate a registered OpTestCase subclass for a comparison Scalar op.""" + + class _Test(OpTestCase): + name = op_name + + def __init__( + self, + shape: Tuple[int, ...], + scalar: float, + dtype: torch.dtype, + ): + self.shape = shape + self.scalar = scalar + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"{op_name}_{shape_str}_s{scalar}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["_Test"]: + return [ + cls(shape=(16,), scalar=0.0, dtype=torch.float32), + cls(shape=(4, 4), scalar=0.5, dtype=torch.float32), + cls(shape=(2, 3, 4), scalar=-1.0, dtype=torch.float32), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape, dtype=self.dtype),) + + def create_model(self) -> nn.Module: + return CompareScalarModel(op_fn, self.scalar) + + _Test.__name__ = f"{op_name.title().replace('_', '')}Test" + _Test.__qualname__ = _Test.__name__ + return _Test + + +_COMPARE_SCALAR_TESTS = [ + {"op_name": "less_scalar", "op_fn": torch.lt}, + {"op_name": "less_equal_scalar", "op_fn": torch.le}, + {"op_name": "greater_scalar", "op_fn": torch.gt}, + {"op_name": "greater_equal_scalar", "op_fn": torch.ge}, + {"op_name": "equal_scalar", "op_fn": torch.eq}, + {"op_name": "not_equal_scalar", "op_fn": torch.ne}, +] + +for _entry in _COMPARE_SCALAR_TESTS: + _cls = _make_compare_scalar_test(**_entry) + register_test(_cls) + globals()[_cls.__name__] = _cls + + +class ReductionOpModel(nn.Module): + def __init__(self, op_fn: Callable, dim=None, keepdim: bool = False): + super().__init__() + self.op_fn = op_fn + self.dim = dim + self.keepdim = keepdim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.dim is None: + return self.op_fn(x) + return self.op_fn(x, dim=self.dim, keepdim=self.keepdim) + + +class CorrectionReductionOpModel(nn.Module): + def __init__( + self, op_fn: Callable, dim=None, keepdim: bool = False, correction: int = 1 + ): + super().__init__() + self.op_fn = op_fn + self.dim = dim + self.keepdim = keepdim + self.correction = correction + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.dim is None: + return self.op_fn(x, correction=self.correction) + return self.op_fn( + x, dim=self.dim, keepdim=self.keepdim, correction=self.correction + ) + + +def _make_reduction_op_test( + op_name: str, + op_fn: Callable, + configs: List[dict], + input_fn: Callable = None, + has_correction: bool = False, +) -> type: + """Generate a registered OpTestCase subclass for a reduction op. + + Args: + op_name: Name used for test registration. + op_fn: The torch function (e.g. torch.sum). + configs: List of dicts with keys: shape, dim, keepdim, dtype, and + optionally correction (for var/std). + input_fn: Callable(shape, dtype) -> Tuple[Tensor, ...]. + has_correction: If True, use CorrectionReductionOpModel. + """ + if input_fn is None: + input_fn = _input_fn() + + class _Test(OpTestCase): + name = op_name + + def __init__(self, shape, dim, keepdim, dtype, correction=1): + self.shape = shape + self.dim = dim + self.keepdim = keepdim + self.dtype = dtype + self.correction = correction + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + dim_str = f"_dim{dim}" if dim is not None else "_all" + kd_str = "_kd" if keepdim else "" + corr_str = f"_corr{correction}" if has_correction else "" + self.name = f"{op_name}_{shape_str}{dim_str}{kd_str}{corr_str}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["_Test"]: + return [cls(**c) for c in configs] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return input_fn(self.shape, self.dtype) + + def create_model(self) -> nn.Module: + if has_correction: + return CorrectionReductionOpModel( + op_fn, + dim=self.dim, + keepdim=self.keepdim, + correction=self.correction, + ) + return ReductionOpModel(op_fn, dim=self.dim, keepdim=self.keepdim) + + _Test.__name__ = f"{op_name.title().replace('_', '')}Test" + _Test.__qualname__ = _Test.__name__ + return _Test + + +_REDUCTION_CONFIGS_6 = [ + {"shape": (16,), "dim": 0, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": -1, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": 0, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": -1, "keepdim": True, "dtype": torch.float32}, + {"shape": (2, 3, 4), "dim": 1, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": None, "keepdim": False, "dtype": torch.float32}, +] + +_REDUCTION_CONFIGS_5 = [ + {"shape": (16,), "dim": 0, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": -1, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": -1, "keepdim": True, "dtype": torch.float32}, + { + "shape": (4, 4), + "dim": -1, + "keepdim": False, + "dtype": torch.float32, + "correction": 0, + }, + {"shape": (2, 3, 4), "dim": 1, "keepdim": False, "dtype": torch.float32}, +] + +_PROD_CONFIGS = [ + {"shape": (8,), "dim": 0, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": -1, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": 0, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": -1, "keepdim": True, "dtype": torch.float32}, + {"shape": (2, 3, 4), "dim": 1, "keepdim": False, "dtype": torch.float32}, +] + +_LOGSUMEXP_CONFIGS = [ + {"shape": (16,), "dim": 0, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": -1, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": 0, "keepdim": False, "dtype": torch.float32}, + {"shape": (4, 4), "dim": -1, "keepdim": True, "dtype": torch.float32}, + {"shape": (2, 3, 4), "dim": 1, "keepdim": False, "dtype": torch.float32}, +] + +# fmt: off +_REDUCTION_OP_TESTS = [ + {"op_name": "sum", "op_fn": torch.sum, "configs": _REDUCTION_CONFIGS_6}, + {"op_name": "mean", "op_fn": torch.mean, "configs": _REDUCTION_CONFIGS_6}, + {"op_name": "amax", "op_fn": torch.amax, "configs": _REDUCTION_CONFIGS_6}, + {"op_name": "amin", "op_fn": torch.amin, "configs": _REDUCTION_CONFIGS_6}, + {"op_name": "argmax", "op_fn": torch.argmax, "configs": _REDUCTION_CONFIGS_6}, + {"op_name": "argmin", "op_fn": torch.argmin, "configs": _REDUCTION_CONFIGS_6}, + {"op_name": "prod", "op_fn": torch.prod, "configs": _PROD_CONFIGS, "input_fn": _input_fn(scale=0.5, offset=1.0)}, + {"op_name": "var", "op_fn": torch.var, "configs": _REDUCTION_CONFIGS_5, "has_correction": True}, + {"op_name": "std", "op_fn": torch.std, "configs": _REDUCTION_CONFIGS_5, "has_correction": True}, + {"op_name": "logsumexp", "op_fn": torch.logsumexp, "configs": _LOGSUMEXP_CONFIGS}, +] +# fmt: on + +for _entry in _REDUCTION_OP_TESTS: + _cls = _make_reduction_op_test(**_entry) + register_test(_cls) + globals()[_cls.__name__] = _cls + + +# --- Global max (aten.max.default) - no dim argument --- + + +class MaxGlobalModel(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.max(x) + + +@register_test +class MaxGlobalTest(OpTestCase): + name = "max_global" + + def __init__(self, shape=(3, 4), dtype=torch.float32): + self.shape = shape + self.dtype = dtype + + @classmethod + def get_test_configs(cls): + return [ + cls(shape=(16,)), + cls(shape=(3, 4)), + cls(shape=(2, 3, 4)), + cls(shape=(3, 4), dtype=torch.bfloat16), + ] + + def create_model(self): + return MaxGlobalModel() + + def create_inputs(self): + return (torch.randn(self.shape, dtype=self.dtype),) + + +class MinGlobalModel(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.min(x) + + +@register_test +class MinGlobalTest(OpTestCase): + name = "min_global" + + def __init__(self, shape=(3, 4), dtype=torch.float32): + self.shape = shape + self.dtype = dtype + + @classmethod + def get_test_configs(cls): + return [ + cls(shape=(16,)), + cls(shape=(3, 4)), + cls(shape=(2, 3, 4)), + cls(shape=(3, 4), dtype=torch.bfloat16), + ] + + def create_model(self): + return MinGlobalModel() + + def create_inputs(self): + return (torch.randn(self.shape, dtype=self.dtype),) + + +class TriangularModel(nn.Module): + def __init__(self, mode: str = "tril", diagonal: int = 0): + super().__init__() + self.op_fn = torch.tril if mode == "tril" else torch.triu + self.diagonal = diagonal + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.op_fn(x, diagonal=self.diagonal) + + +_TRIANGULAR_CONFIGS = [ + {"shape": (4, 4), "diagonal": 0, "dtype": torch.float32}, + {"shape": (8, 8), "diagonal": 0, "dtype": torch.float32}, + {"shape": (4, 6), "diagonal": 0, "dtype": torch.float32}, + {"shape": (6, 4), "diagonal": 0, "dtype": torch.float32}, + {"shape": (4, 4), "diagonal": 1, "dtype": torch.float32}, + {"shape": (4, 4), "diagonal": -1, "dtype": torch.float32}, + {"shape": (4, 4), "diagonal": 2, "dtype": torch.float32}, + {"shape": (4, 4), "diagonal": 0, "dtype": torch.bfloat16}, + {"shape": (2, 4, 4), "diagonal": 0, "dtype": torch.float32}, + {"shape": (2, 3, 4, 4), "diagonal": 0, "dtype": torch.float32}, +] + + +def _make_triangular_test(mode: str) -> type: + """Generate a registered OpTestCase subclass for tril or triu.""" + + class _Test(OpTestCase): + name = mode + + def __init__( + self, + shape: Tuple[int, ...] = (4, 4), + diagonal: int = 0, + dtype: torch.dtype = torch.float32, + ): + self.shape = shape + self.diagonal = diagonal + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + diag_str = f"d{diagonal}" if diagonal != 0 else "" + self.name = f"{mode}_{shape_str}_{dtype_str}{diag_str}" + + @classmethod + def get_test_configs(cls) -> List["_Test"]: + return [cls(**c) for c in _TRIANGULAR_CONFIGS] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape, dtype=self.dtype),) + + def create_model(self) -> nn.Module: + return TriangularModel(mode=mode, diagonal=self.diagonal) + + _Test.__name__ = f"{mode.title()}Test" + _Test.__qualname__ = _Test.__name__ + return _Test + + +TrilTest = _make_triangular_test("tril") +TriuTest = _make_triangular_test("triu") +register_test(TrilTest) +register_test(TriuTest) + + +class ZerosLikeModel(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.zeros_like(x) + + +class OnesLikeModel(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ones_like(x) + + +class FullLikeModel(nn.Module): + def __init__(self, fill_value: float, dtype: Optional[torch.dtype] = None): + super().__init__() + self.fill_value = fill_value + self.dtype = dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + t = torch.full_like(x, self.fill_value, dtype=self.dtype) + if self.dtype is not None and self.dtype != x.dtype: + return x * t.to(x.dtype) + return t + + +@register_test +class ZerosLikeTest(OpTestCase): + """Test case for aten.zeros_like op.""" + + name = "zeros_like" + + def __init__( + self, + shape: Tuple[int, ...] = (4, 4), + dtype: torch.dtype = torch.float32, + ): + self.shape = shape + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"zeros_like_{shape_str}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["ZerosLikeTest"]: + return [ + cls(shape=(16,), dtype=torch.float32), + cls(shape=(4, 4), dtype=torch.float32), + cls(shape=(2, 3, 4), dtype=torch.float32), + cls(shape=(4, 4), dtype=torch.bfloat16), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape, dtype=self.dtype),) + + def create_model(self) -> nn.Module: + return ZerosLikeModel() + + +@register_test +class OnesLikeTest(OpTestCase): + """Test case for aten.ones_like op.""" + + name = "ones_like" + + def __init__( + self, + shape: Tuple[int, ...] = (4, 4), + dtype: torch.dtype = torch.float32, + ): + self.shape = shape + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"ones_like_{shape_str}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["OnesLikeTest"]: + return [ + cls(shape=(16,), dtype=torch.float32), + cls(shape=(4, 4), dtype=torch.float32), + cls(shape=(2, 3, 4), dtype=torch.float32), + cls(shape=(4, 4), dtype=torch.bfloat16), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape, dtype=self.dtype),) + + def create_model(self) -> nn.Module: + return OnesLikeModel() + + +@register_test +class FullLikeTest(OpTestCase): + """Test case for aten.full_like op.""" + + name = "full_like" + + def __init__( + self, + shape: Tuple[int, ...] = (4, 4), + fill_value: float = 3.14, + dtype: torch.dtype = torch.float32, + fill_dtype: Optional[torch.dtype] = None, + rtol: Optional[float] = None, + atol: Optional[float] = None, + ): + self.shape = shape + self.fill_value = fill_value + self.dtype = dtype + self.fill_dtype = fill_dtype + if rtol is not None: + self.rtol = rtol + if atol is not None: + self.atol = atol + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + fill_dtype_str = ( + f"_as_{str(fill_dtype).replace('torch.', '')}" if fill_dtype else "" + ) + self.name = f"full_like_{shape_str}_v{fill_value}_{dtype_str}{fill_dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["FullLikeTest"]: + return [ + cls(shape=(16,), fill_value=3.14, dtype=torch.float32), + cls(shape=(4, 4), fill_value=2.71, dtype=torch.float32), + cls(shape=(2, 3, 4), fill_value=-1.0, dtype=torch.float32), + cls(shape=(4, 4), fill_value=0.5, dtype=torch.bfloat16), + # Explicit fill_dtype exercises scalar_type serialization (optional_int). + # 1.005859375 rounds differently in bf16 vs f32, so the model multiplies + # the bf16 mask back into the f32 input to make the precision loss observable. + cls( + shape=(4, 4), + fill_value=1.005859375, + fill_dtype=torch.bfloat16, + rtol=0.0, + atol=0.0, + ), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + if self.fill_dtype is not None: + torch.manual_seed(42) + return (torch.randn(self.shape, dtype=self.dtype) * 100,) + return (torch.randn(self.shape, dtype=self.dtype),) + + def create_model(self) -> nn.Module: + return FullLikeModel(fill_value=self.fill_value, dtype=self.fill_dtype) + + +class FullModel(nn.Module): + def __init__(self, shape: Tuple[int, ...], fill_value: float, dtype: torch.dtype): + super().__init__() + self.shape = shape + self.fill_value = fill_value + self.dtype = dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.full(self.shape, self.fill_value, dtype=self.dtype) + + +class ZerosModel(nn.Module): + def __init__(self, shape: Tuple[int, ...], dtype: torch.dtype): + super().__init__() + self.shape = shape + self.dtype = dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.zeros(self.shape, dtype=self.dtype) + + +class OnesModel(nn.Module): + def __init__(self, shape: Tuple[int, ...], dtype: torch.dtype): + super().__init__() + self.shape = shape + self.dtype = dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ones(self.shape, dtype=self.dtype) + + +@register_test +class FullTest(OpTestCase): + """Test case for aten.full op.""" + + name = "full" + rtol = 1e-3 + atol = 1e-3 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + fill_value: float = 1.5, + dtype: torch.dtype = torch.float32, + ): + self.shape = shape + self.fill_value = fill_value + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"full_{shape_str}_{fill_value}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["FullTest"]: + return [ + cls(shape=(2, 3, 4), fill_value=1.5, dtype=torch.float32), + cls(shape=(10,), fill_value=0.0, dtype=torch.float32), + cls(shape=(1, 128), fill_value=-2.5, dtype=torch.float32), + cls(shape=(4, 8, 16), fill_value=3.14159, dtype=torch.float32), + cls(shape=(2, 3, 4), fill_value=1.0, dtype=torch.bfloat16), + cls(shape=(8, 16), fill_value=-1.0, dtype=torch.bfloat16), + cls(shape=(2, 3, 4), fill_value=2.0, dtype=torch.float16), + # Integer fill values (matching individual test file) + cls(shape=(2, 3, 4), fill_value=0.0, dtype=torch.float32), + cls(shape=(2, 3, 4), fill_value=1.0, dtype=torch.float32), + cls(shape=(2, 3, 4), fill_value=-1.0, dtype=torch.float32), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(1, dtype=torch.float32) + return (x,) + + def create_model(self) -> nn.Module: + return FullModel(self.shape, self.fill_value, self.dtype) + + +@register_test +class ZerosTest(OpTestCase): + """Test case for aten.zeros op.""" + + name = "zeros" + rtol = 1e-3 + atol = 1e-3 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + dtype: torch.dtype = torch.float32, + ): + self.shape = shape + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"zeros_{shape_str}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["ZerosTest"]: + return [ + cls(shape=(2, 3, 4), dtype=torch.float32), + cls(shape=(10,), dtype=torch.float32), + cls(shape=(1, 128), dtype=torch.float32), + cls(shape=(4, 8, 16), dtype=torch.float32), + cls(shape=(2, 3, 4), dtype=torch.bfloat16), + cls(shape=(8, 16), dtype=torch.bfloat16), + cls(shape=(2, 3, 4), dtype=torch.float16), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(1, dtype=torch.float32) + return (x,) + + def create_model(self) -> nn.Module: + return ZerosModel(self.shape, self.dtype) + + +@register_test +class OnesTest(OpTestCase): + """Test case for aten.ones op.""" + + name = "ones" + rtol = 1e-3 + atol = 1e-3 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + dtype: torch.dtype = torch.float32, + ): + self.shape = shape + self.dtype = dtype + shape_str = "x".join(str(s) for s in shape) + dtype_str = str(dtype).replace("torch.", "") + self.name = f"ones_{shape_str}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["OnesTest"]: + return [ + cls(shape=(2, 3, 4), dtype=torch.float32), + cls(shape=(10,), dtype=torch.float32), + cls(shape=(1, 128), dtype=torch.float32), + cls(shape=(4, 8, 16), dtype=torch.float32), + cls(shape=(2, 3, 4), dtype=torch.bfloat16), + cls(shape=(8, 16), dtype=torch.bfloat16), + cls(shape=(2, 3, 4), dtype=torch.float16), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(1, dtype=torch.float32) + return (x,) + + def create_model(self) -> nn.Module: + return OnesModel(self.shape, self.dtype) + + +class ToDtypeModel(nn.Module): + def __init__(self, target_dtype: torch.dtype): + super().__init__() + self.target_dtype = target_dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.to(self.target_dtype) + + +@register_test +class ToDtypeTest(OpTestCase): + """Test case for to.dtype op.""" + + name = "to_dtype" + rtol = 1e-3 + atol = 1e-3 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + source_dtype: torch.dtype = torch.float32, + target_dtype: torch.dtype = torch.bfloat16, + ): + self.shape = shape + self.source_dtype = source_dtype + self.target_dtype = target_dtype + shape_str = "x".join(str(s) for s in shape) + src_str = str(source_dtype).replace("torch.", "") + tgt_str = str(target_dtype).replace("torch.", "") + self.name = f"to_dtype_{shape_str}_{src_str}_to_{tgt_str}" + + @classmethod + def get_test_configs(cls) -> List["ToDtypeTest"]: + return [ + cls( + shape=(2, 3, 4), source_dtype=torch.float32, target_dtype=torch.bfloat16 + ), + cls(shape=(10,), source_dtype=torch.float32, target_dtype=torch.bfloat16), + cls( + shape=(1, 128), source_dtype=torch.float32, target_dtype=torch.bfloat16 + ), + cls( + shape=(2, 3, 4), source_dtype=torch.bfloat16, target_dtype=torch.float32 + ), + cls( + shape=(4, 8, 16), + source_dtype=torch.bfloat16, + target_dtype=torch.float32, + ), + cls( + shape=(2, 3, 4), source_dtype=torch.float32, target_dtype=torch.float16 + ), + cls( + shape=(2, 3, 4), source_dtype=torch.float16, target_dtype=torch.float32 + ), + cls(shape=(2, 3, 4), source_dtype=torch.float32, target_dtype=torch.int32), + cls(shape=(2, 3, 4), source_dtype=torch.int32, target_dtype=torch.float32), + cls(shape=(2, 3, 4), source_dtype=torch.float32, target_dtype=torch.int64), + cls(shape=(2, 3, 4), source_dtype=torch.int64, target_dtype=torch.float32), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + if self.source_dtype in (torch.int32, torch.int64): + x = torch.randint(-100, 100, self.shape, dtype=self.source_dtype) + else: + x = torch.randn(self.shape, dtype=self.source_dtype) + return (x,) + + def create_model(self) -> nn.Module: + return ToDtypeModel(self.target_dtype) + + +class BatchNormModel(nn.Module): + def __init__(self, num_features: int, dtype: torch.dtype, affine: bool = True): + super().__init__() + self.bn = nn.BatchNorm2d(num_features, affine=affine, dtype=dtype) + self.bn.eval() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.bn(x) + + +class BatchNorm1dModel(nn.Module): + def __init__(self, num_features: int, dtype: torch.dtype, affine: bool = True): + super().__init__() + self.bn = nn.BatchNorm1d(num_features, affine=affine, dtype=dtype) + self.bn.eval() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.bn(x) + + +@register_test +class BatchNorm2dTest(OpTestCase): + """Test case for aten._native_batch_norm_legit_no_training op with 2D input.""" + + name = "batch_norm_2d" + rtol = 1e-3 + atol = 1e-3 + + def __init__( + self, + batch_size: int = 2, + num_features: int = 16, + height: int = 8, + width: int = 8, + dtype: torch.dtype = torch.float32, + affine: bool = True, + ): + self.batch_size = batch_size + self.num_features = num_features + self.height = height + self.width = width + self.dtype = dtype + self.affine = affine + dtype_str = str(dtype).replace("torch.", "") + prefix = "batch_norm_2d_no_affine" if not affine else "batch_norm_2d" + self.name = f"{prefix}_{batch_size}x{num_features}x{height}x{width}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["BatchNorm2dTest"]: + return [ + cls(batch_size=1, num_features=16, height=8, width=8, dtype=torch.float32), + cls( + batch_size=2, num_features=32, height=16, width=16, dtype=torch.float32 + ), + cls(batch_size=4, num_features=64, height=4, width=4, dtype=torch.float32), + cls(batch_size=2, num_features=16, height=8, width=8, dtype=torch.bfloat16), + cls(batch_size=1, num_features=32, height=4, width=4, dtype=torch.bfloat16), + cls(batch_size=2, num_features=16, height=8, width=8, dtype=torch.float16), + # No-affine variants (no weight/bias) + cls( + batch_size=1, + num_features=16, + height=8, + width=8, + dtype=torch.float32, + affine=False, + ), + cls( + batch_size=2, + num_features=32, + height=4, + width=4, + dtype=torch.bfloat16, + affine=False, + ), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, + self.num_features, + self.height, + self.width, + dtype=self.dtype, + ) + return (x,) + + def create_model(self) -> nn.Module: + return BatchNormModel(self.num_features, self.dtype, affine=self.affine) + + +@register_test +class BatchNorm1dTest(OpTestCase): + """Test case for aten._native_batch_norm_legit_no_training op with 1D input.""" + + name = "batch_norm_1d" + rtol = 1e-3 + atol = 1e-3 + + def __init__( + self, + batch_size: int = 2, + num_features: int = 16, + seq_len: int = 32, + dtype: torch.dtype = torch.float32, + affine: bool = True, + ): + self.batch_size = batch_size + self.num_features = num_features + self.seq_len = seq_len + self.dtype = dtype + self.affine = affine + dtype_str = str(dtype).replace("torch.", "") + prefix = "batch_norm_1d_no_affine" if not affine else "batch_norm_1d" + self.name = f"{prefix}_{batch_size}x{num_features}x{seq_len}_{dtype_str}" + + @classmethod + def get_test_configs(cls) -> List["BatchNorm1dTest"]: + return [ + cls(batch_size=1, num_features=16, seq_len=32, dtype=torch.float32), + cls(batch_size=2, num_features=32, seq_len=64, dtype=torch.float32), + cls(batch_size=2, num_features=16, seq_len=32, dtype=torch.bfloat16), + cls(batch_size=2, num_features=16, seq_len=32, dtype=torch.float16), + # No-affine variants (no weight/bias) + cls( + batch_size=1, + num_features=16, + seq_len=32, + dtype=torch.float32, + affine=False, + ), + cls( + batch_size=2, + num_features=32, + seq_len=64, + dtype=torch.bfloat16, + affine=False, + ), + ] + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, self.num_features, self.seq_len, dtype=self.dtype + ) + return (x,) + + def create_model(self) -> nn.Module: + return BatchNorm1dModel(self.num_features, self.dtype, affine=self.affine) + + +class SDPAModel(nn.Module): + """Basic scaled dot product attention.""" + + def __init__(self, is_causal: bool = False): + super().__init__() + self.is_causal = is_causal + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=self.is_causal + ) + + +class SDPAWithMaskModel(nn.Module): + """SDPA with explicit attention mask (additive float format).""" + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) + + +class SDPAWithBoolMaskModel(nn.Module): + """SDPA with boolean attention mask. + + This tests the case where a boolean mask is passed to SDPA. + PyTorch expects: True = attend, False = masked out. + """ + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) + + +class GQAModel(nn.Module): + """Grouped Query Attention - fewer KV heads than Q heads.""" + + def __init__(self, num_heads: int, num_kv_heads: int, is_causal: bool = False): + super().__init__() + self.num_groups = num_heads // num_kv_heads + self.is_causal = is_causal + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> torch.Tensor: + k = k.repeat_interleave(self.num_groups, dim=1) + v = v.repeat_interleave(self.num_groups, dim=1) + return torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=self.is_causal + ) + + +@register_test +class SDPATest(OpTestCase): + """Test case for SDPA.""" + + name = "sdpa" + rtol = 1e-3 + atol = 1e-3 + expected_node_counts = {"SdpaNode": 1, "ExpandDimsNode": 0} + + def __init__( + self, + batch_size: int = 2, + num_heads: int = 8, + seq_len: int = 32, + head_dim: int = 64, + num_kv_heads: Optional[int] = None, + is_causal: bool = False, + use_mask: bool = False, + use_bool_mask: bool = False, + ): + self.batch_size = batch_size + self.num_heads = num_heads + self.seq_len = seq_len + self.head_dim = head_dim + self.num_kv_heads = num_kv_heads + self.is_causal = is_causal + self.use_mask = use_mask + self.use_bool_mask = use_bool_mask + + parts = ["sdpa"] + if num_kv_heads is not None: + parts.append(f"gqa{num_kv_heads}") + if is_causal: + parts.append("causal") + if use_mask: + parts.append("mask") + if use_bool_mask: + parts.append("bool_mask") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["SDPATest"]: + return [ + cls(), + cls(is_causal=True), + cls(num_kv_heads=4), + cls(use_mask=True), + cls(use_bool_mask=True), # Test boolean mask conversion + ] + + def create_model(self) -> nn.Module: + if self.use_mask: + return SDPAWithMaskModel() + elif self.use_bool_mask: + return SDPAWithBoolMaskModel() + elif self.num_kv_heads is not None: + return GQAModel(self.num_heads, self.num_kv_heads, self.is_causal) + else: + return SDPAModel(self.is_causal) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + q = torch.randn(self.batch_size, self.num_heads, self.seq_len, self.head_dim) + kv_heads = self.num_kv_heads if self.num_kv_heads else self.num_heads + k = torch.randn(self.batch_size, kv_heads, self.seq_len, self.head_dim) + v = torch.randn(self.batch_size, kv_heads, self.seq_len, self.head_dim) + + if self.use_mask: + # Additive float mask: 0 = attend, -inf = masked + mask = torch.zeros(self.batch_size, 1, self.seq_len, self.seq_len) + mask[:, :, :, : self.seq_len // 4] = float("-inf") + return (q, k, v, mask) + elif self.use_bool_mask: + # Boolean mask: True = attend, False = masked + # This tests that the backend correctly converts bool -> additive format + mask = torch.ones( + self.batch_size, 1, self.seq_len, self.seq_len, dtype=torch.bool + ) + mask[:, :, :, : self.seq_len // 4] = False # Mask out first quarter + return (q, k, v, mask) + return (q, k, v) + + +class CustomSDPAModel(nn.Module): + """ + Test model for mlx::custom_sdpa with KVCache. + + Simulates a single attention layer: updates the KV cache, then calls + mlx::custom_sdpa which slices K/V to [0:start_pos+seq_len] and runs SDPA. + """ + + def __init__( + self, + max_context_length: int, + n_heads: int, + n_kv_heads: int, + head_dim: int, + ): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + self.cache = KVCache( + max_batch_size=1, + max_context_length=max_context_length, + n_heads=n_kv_heads, + head_dim=head_dim, + enable_dynamic_shape=True, + ) + + def forward( + self, + input_pos: torch.Tensor, # [S] position indices + q: torch.Tensor, # [B, n_heads, S, D] + k_val: torch.Tensor, # [B, n_kv_heads, S, D] + v_val: torch.Tensor, # [B, n_kv_heads, S, D] + ) -> torch.Tensor: + # Update KV cache and get full cache tensors + k_cache, v_cache = self.cache.update(input_pos, k_val, v_val) + + start_pos = input_pos[0].item() + + output = torch.ops.mlx.custom_sdpa( + q, + k_cache, + v_cache, + start_pos=start_pos, + is_causal=True, + scale=self.head_dim**-0.5, + ) + return output + + +@register_test +class CustomSDPATest(OpTestCase): + """ + Test case for mlx::custom_sdpa with KV cache slicing. + + Verifies that custom_sdpa: + 1. Correctly slices K/V cache to [0:start_pos+seq_len] + 2. Produces numerically correct attention output + 3. Handles GQA (fewer KV heads than Q heads) + 4. Works with dynamic shapes (varying seq_len and start_pos) + """ + + name = "custom_sdpa" + rtol = 1e-3 + atol = 1e-3 + expected_node_counts = { + "SdpaNode": 1, + "SliceUpdateNode": 2, + "SliceNode": 2, + "IdCopyNode": 2, + "ExpandDimsNode": 0, + } + + def __init__( + self, + n_heads: int = 8, + n_kv_heads: int = 8, + head_dim: int = 64, + max_context_length: int = 128, + seq_len: int = 8, + ): + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + self.max_context_length = max_context_length + self.seq_len = seq_len + + parts = ["custom_sdpa"] + if n_kv_heads != n_heads: + parts.append(f"gqa{n_kv_heads}") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["CustomSDPATest"]: + return [ + cls(), # MHA + cls(n_kv_heads=4), # GQA (8 Q heads, 4 KV heads) + cls(n_kv_heads=1), # MQA (8 Q heads, 1 KV head) + ] + + def create_model(self) -> nn.Module: + return CustomSDPAModel( + max_context_length=self.max_context_length, + n_heads=self.n_heads, + n_kv_heads=self.n_kv_heads, + head_dim=self.head_dim, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + input_pos = torch.tensor([0], dtype=torch.int64) + q = torch.randn(1, self.n_heads, self.seq_len, self.head_dim) + k = torch.randn(1, self.n_kv_heads, self.seq_len, self.head_dim) + v = torch.randn(1, self.n_kv_heads, self.seq_len, self.head_dim) + return (input_pos, q, k, v) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + test_seq_len = self.seq_len + 4 + input_pos = torch.tensor([16], dtype=torch.int64) + q = torch.randn(1, self.n_heads, test_seq_len, self.head_dim) + k = torch.randn(1, self.n_kv_heads, test_seq_len, self.head_dim) + v = torch.randn(1, self.n_kv_heads, test_seq_len, self.head_dim) + return (input_pos, q, k, v) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + seq_dim = Dim("seq_len", min=1, max=self.max_context_length) + return { + "input_pos": None, + "q": {2: seq_dim}, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + } + + +@register_test +class QuantizedLinearTest(OpTestCase): + """Test case for TorchAO int4 quantized nn.Linear.""" + + name = "quantized_linear" + rtol = 0.1 + atol = 0.1 + + def __init__( + self, + in_features: int = 128, + out_features: int = 128, + batch_size: int = 2, + seq_len: int = 16, + bias: bool = True, + group_size: int = 32, + dtype: torch.dtype = torch.bfloat16, + qdtype: torch.dtype = torch.int4, + ): + self.in_features = in_features + self.out_features = out_features + self.batch_size = batch_size + self.seq_len = seq_len + self.bias = bias + self.group_size = group_size + self.dtype = dtype + self.qdtype = qdtype + + parts = ["quantized_linear", f"{qdtype}", f"g{group_size}"] + if not bias: + parts.append("no_bias") + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["QuantizedLinearTest"]: + return [ + cls(), + cls(bias=False), + cls(group_size=64), + cls(group_size=128), + cls(qdtype=torch.int2), + cls(qdtype=torch.int8), + ] + + def create_model(self) -> nn.Module: + model = LinearModel(self.in_features, self.out_features, bias=self.bias) + model = model.to(self.dtype) + + from torchao.quantization.granularity import PerGroup + from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ + + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=self.qdtype, granularity=PerGroup(self.group_size) + ), + ) + + return model + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, self.seq_len, self.in_features, dtype=self.dtype + ) + return (x,) + + +@register_test +class QuantizedEmbeddingTest(OpTestCase): + """Test case for TorchAO int4 quantized nn.Embedding.""" + + name = "quantized_embedding" + rtol = 0.1 + atol = 0.1 + + def __init__( + self, + num_embeddings: int = 1000, + embedding_dim: int = 128, + batch_size: int = 2, + seq_len: int = 16, + group_size: int = 32, + dtype: torch.dtype = torch.bfloat16, + qdtype: torch.dtype = torch.int4, + ): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.batch_size = batch_size + self.seq_len = seq_len + self.group_size = group_size + self.dtype = dtype + self.qdtype = qdtype + + parts = ["quantized_embedding", f"{qdtype}", f"g{group_size}"] + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["QuantizedEmbeddingTest"]: + return [ + cls(), + cls(group_size=64), + cls(group_size=128), + cls(qdtype=torch.int2), + cls(qdtype=torch.int8), + ] + + def create_model(self) -> nn.Module: + model = EmbeddingModel(self.num_embeddings, self.embedding_dim) + model = model.to(self.dtype) + + from torchao.quantization.granularity import PerGroup + from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ + + def embedding_filter(module: nn.Module, fqn: str) -> bool: + return isinstance(module, nn.Embedding) + + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, granularity=PerGroup(self.group_size) + ), + embedding_filter, + ) + + return model + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randint(0, self.num_embeddings, (self.batch_size, self.seq_len)) + return (x,) + + +class DequantizeConv2dModel(nn.Module): + """Conv2d layer whose weight will be quantized. + + The pattern matcher only fuses dequantize_affine with linear and embedding. + A quantized Conv2d produces a standalone dequantize_affine node in the graph, + exercising the DequantizeNode path. + """ + + def __init__( + self, + in_channels: int = 32, + out_channels: int = 64, + kernel_size: int = 3, + ): super().__init__() - self.weight = nn.Parameter(torch.randn(batch_size, m, p)) + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size, padding=1, bias=False + ) def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.bmm(x, self.weight) + return self.conv(x) @register_test -class BmmTest(OpTestCase): - """Test case for bmm (batch matrix multiplication).""" +class DequantizeTest(OpTestCase): + """Test case for standalone TorchAO dequantize_affine (DequantizeNode). - name = "bmm" - rtol = 1e-4 - atol = 1e-4 + Uses a quantized Conv2d to produce a standalone dequantize_affine node, + since the pattern matcher only fuses dequantize with linear/embedding. + """ + + name = "dequantize" + rtol = 0.1 + atol = 0.1 def __init__( self, - batch_size: int = 4, - n: int = 8, - m: int = 16, - p: int = 32, + in_channels: int = 32, + out_channels: int = 64, + kernel_size: int = 3, + height: int = 8, + width: int = 8, + batch_size: int = 1, + group_size: int = 32, + dtype: torch.dtype = torch.bfloat16, ): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.height = height + self.width = width self.batch_size = batch_size + self.group_size = group_size + self.dtype = dtype + + parts = ["dequantize", f"g{group_size}"] + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["DequantizeTest"]: + return [ + cls(), + ] + + def create_model(self) -> nn.Module: + model = DequantizeConv2dModel( + self.in_channels, self.out_channels, self.kernel_size + ) + model = model.to(self.dtype) + + from torchao.quantization.granularity import PerGroup + from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ + + def conv2d_filter(module: nn.Module, fqn: str) -> bool: + return isinstance(module, nn.Conv2d) + + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, granularity=PerGroup(self.group_size) + ), + conv2d_filter, + ) + + return model + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, + self.in_channels, + self.height, + self.width, + dtype=self.dtype, + ) + return (x,) + + +class CumsumModel(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.cumsum(x, dim=self.dim) + + +@register_test +class CumsumTest(OpTestCase): + name = "cumsum" + rtol = 1e-5 + atol = 1e-5 + + def __init__(self, shape: Tuple[int, ...] = (3, 4), dim: int = 0): + self.shape = shape + self.dim = dim + shape_str = "x".join(str(s) for s in shape) + self.name = f"cumsum_dim{dim}_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["CumsumTest"]: + return [ + cls(shape=(8,), dim=0), + cls(shape=(3, 4), dim=0), + cls(shape=(3, 4), dim=1), + cls(shape=(2, 3, 4), dim=-1), + ] + + def create_model(self) -> nn.Module: + return CumsumModel(self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape),) + + +class StackModel(nn.Module): + def __init__(self, dim: int = 0, n: int = 3): + super().__init__() + self.dim = dim self.n = n - self.m = m - self.p = p - self.name = f"bmm_{batch_size}x{n}x{m}x{p}" + + def forward(self, *tensors: torch.Tensor) -> torch.Tensor: + return torch.stack(tensors[: self.n], dim=self.dim) + + +@register_test +class StackTest(OpTestCase): + name = "stack" + rtol = 1e-5 + atol = 1e-5 + + def __init__(self, shape: Tuple[int, ...] = (3, 4), dim: int = 0, n: int = 3): + self.shape = shape + self.dim = dim + self.n = n + shape_str = "x".join(str(s) for s in shape) + self.name = f"stack_dim{dim}_n{n}_{shape_str}" @classmethod - def get_test_configs(cls) -> List["BmmTest"]: + def get_test_configs(cls) -> List["StackTest"]: return [ - cls(batch_size=4, n=8, m=16, p=32), - cls(batch_size=2, n=64, m=64, p=32), + cls(shape=(3, 4), dim=0, n=3), + cls(shape=(3, 4), dim=1, n=2), + cls(shape=(2, 3), dim=-1, n=4), ] def create_model(self) -> nn.Module: - return BmmModel(self.batch_size, self.n, self.m, self.p) + return StackModel(dim=self.dim, n=self.n) def create_inputs(self) -> Tuple[torch.Tensor, ...]: - x = torch.randn(self.batch_size, self.n, self.m) + return tuple(torch.randn(self.shape) for _ in range(self.n)) + + +class SignModel(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) + + +@register_test +class SignTest(OpTestCase): + name = "sign" + rtol = 0.0 + atol = 0.0 + + def __init__(self, shape: Tuple[int, ...] = (3, 4)): + self.shape = shape + shape_str = "x".join(str(s) for s in shape) + self.name = f"sign_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["SignTest"]: + return [ + cls(shape=(8,)), + cls(shape=(3, 4)), + cls(shape=(2, 3, 4)), + ] + + def create_model(self) -> nn.Module: + return SignModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape),) + + +class AnyModel(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.any(x, dim=self.dim) + + +@register_test +class AnyTest(OpTestCase): + name = "any" + rtol = 0.0 + atol = 0.0 + + def __init__(self, shape: Tuple[int, ...] = (3, 4), dim: int = 0): + self.shape = shape + self.dim = dim + shape_str = "x".join(str(s) for s in shape) + self.name = f"any_dim{dim}_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["AnyTest"]: + return [ + cls(shape=(4, 6), dim=0), + cls(shape=(4, 6), dim=1), + cls(shape=(2, 3, 4), dim=-1), + ] + + def create_model(self) -> nn.Module: + return AnyModel(self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Mix of True/False values + return (torch.randint(0, 2, self.shape).bool(),) + + +class AllModel(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.all(x, dim=self.dim) + + +@register_test +class AllTest(OpTestCase): + name = "all" + rtol = 0.0 + atol = 0.0 + + def __init__(self, shape: Tuple[int, ...] = (3, 4), dim: int = 0): + self.shape = shape + self.dim = dim + shape_str = "x".join(str(s) for s in shape) + self.name = f"all_dim{dim}_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["AllTest"]: + return [ + cls(shape=(4, 6), dim=0), + cls(shape=(4, 6), dim=1), + cls(shape=(2, 3, 4), dim=-1), + ] + + def create_model(self) -> nn.Module: + return AllModel(self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + # Mostly True with some False + x = torch.ones(self.shape, dtype=torch.bool) + x[0] = False + return (x,) + + +class RepeatInterleaveModel(nn.Module): + def __init__(self, repeats: int, dim: int): + super().__init__() + self.repeats = repeats + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.repeat_interleave(self.repeats, dim=self.dim) + + +@register_test +class RepeatInterleaveTest(OpTestCase): + name = "repeat_interleave" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (2, 3, 4), + repeats: int = 2, + dim: int = 0, + ): + self.shape = shape + self.repeats = repeats + self.dim = dim + shape_str = "x".join(str(s) for s in shape) + self.name = f"repeat_interleave_r{repeats}_dim{dim}_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["RepeatInterleaveTest"]: + return [ + cls(shape=(2, 4), repeats=3, dim=0), + cls(shape=(2, 4), repeats=2, dim=1), + cls(shape=(1, 8, 4, 16), repeats=4, dim=1), # GQA-like pattern + ] + + def create_model(self) -> nn.Module: + return RepeatInterleaveModel(self.repeats, self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape),) + + +class SortModel(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Only return sorted values + return torch.sort(x, dim=self.dim)[0] + + +class SortIndicesModel(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Only return sort indices + return torch.sort(x, dim=self.dim)[1] + + +class SortBothModel(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + values, indices = torch.sort(x, dim=self.dim) + return values, indices + + +@register_test +class SortTest(OpTestCase): + name = "sort" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (3, 4), + dim: int = -1, + output: str = "values", + ): + self.shape = shape + self.dim = dim + self.output = output + shape_str = "x".join(str(s) for s in shape) + self.name = f"sort_{output}_dim{dim}_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["SortTest"]: + return [ + cls(shape=(8,), dim=0, output="values"), + cls(shape=(3, 4), dim=-1, output="values"), + cls(shape=(3, 4), dim=0, output="indices"), + cls(shape=(2, 3, 4), dim=1, output="both"), + ] + + def create_model(self) -> nn.Module: + if self.output == "values": + return SortModel(self.dim) + elif self.output == "indices": + return SortIndicesModel(self.dim) + else: + return SortBothModel(self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape),) + + def get_expected_node_counts(self) -> Optional[Dict[str, int]]: + if self.output == "values": + return {"SortNode": 1, "ArgsortNode": 0} + elif self.output == "indices": + return {"SortNode": 0, "ArgsortNode": 1} + else: + return {"SortNode": 1, "ArgsortNode": 1} + + +class ArgsortModel(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.argsort(x, dim=self.dim) + + +@register_test +class ArgsortTest(OpTestCase): + name = "argsort" + rtol = 0.0 + atol = 0.0 + + def __init__(self, shape: Tuple[int, ...] = (3, 4), dim: int = -1): + self.shape = shape + self.dim = dim + shape_str = "x".join(str(s) for s in shape) + self.name = f"argsort_dim{dim}_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["ArgsortTest"]: + return [ + cls(shape=(8,), dim=0), + cls(shape=(3, 4), dim=-1), + cls(shape=(3, 4), dim=0), + ] + + def create_model(self) -> nn.Module: + return ArgsortModel(self.dim) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(self.shape),) + + +class TopKValuesModel(nn.Module): + def __init__(self, k: int, dim: int): + super().__init__() + self.k = k + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.topk(x, self.k, dim=self.dim)[0] + + +class TopKIndicesModel(nn.Module): + def __init__(self, k: int, dim: int): + super().__init__() + self.k = k + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.topk(x, self.k, dim=self.dim)[1] + + +class TopKBothModel(nn.Module): + def __init__(self, k: int, dim: int): + super().__init__() + self.k = k + self.dim = dim + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + values, indices = torch.topk(x, self.k, dim=self.dim) + return values, indices + + +class TopKDynamicKModel(nn.Module): + """TopK with k derived from a dynamic tensor shape (exercises dynamic k path).""" + + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor, k_source: torch.Tensor) -> torch.Tensor: + k = k_source.shape[0] + return torch.topk(x, k, dim=self.dim)[0] + + +@register_test +class TopKTest(OpTestCase): + name = "topk" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + shape: Tuple[int, ...] = (3, 8), + k: int = 3, + dim: int = -1, + output: str = "values", + ): + self.shape = shape + self.k = k + self.dim = dim + self.output = output + shape_str = "x".join(str(s) for s in shape) + self.name = f"topk_k{k}_dim{dim}_{output}_{shape_str}" + + @classmethod + def get_test_configs(cls) -> List["TopKTest"]: + return [ + # Values only + cls(shape=(16,), k=5, dim=0, output="values"), + cls(shape=(4, 8), k=3, dim=-1, output="values"), + cls(shape=(2, 4, 16), k=4, dim=-1, output="values"), + # Indices only + cls(shape=(4, 8), k=3, dim=-1, output="indices"), + # Both values and indices + cls(shape=(4, 8), k=3, dim=-1, output="both"), + # Dynamic k + cls(shape=(4, 8), k=3, dim=-1, output="dynamic_k"), + ] + + def create_model(self) -> nn.Module: + if self.output == "values": + return TopKValuesModel(self.k, self.dim) + elif self.output == "indices": + return TopKIndicesModel(self.k, self.dim) + elif self.output == "dynamic_k": + return TopKDynamicKModel(self.dim) + else: + return TopKBothModel(self.k, self.dim) + + def get_dynamic_shapes(self) -> Optional[Dict[str, any]]: + if self.output == "dynamic_k": + k_dim = Dim("k", min=1, max=self.shape[self.dim]) + return {"x": None, "k_source": {0: k_dim}} + return None + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + if self.output == "dynamic_k": + return (torch.randn(self.shape), torch.randn(self.k)) + return (torch.randn(self.shape),) + + +class NVFP4QuantizedLinearModel(nn.Module): + """Simple linear layer that will be quantized with NVFP4.""" + + def __init__( + self, in_features: int = 64, out_features: int = 128, bias: bool = True + ): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +@register_test +class NVFP4QuantizedLinearTest(OpTestCase): + """Test case for NVFP4 quantized nn.Linear.""" + + name = "nvfp4_quantized_linear" + rtol = 0.1 + atol = 0.1 + + def __init__( + self, + in_features: int = 64, + out_features: int = 128, + batch_size: int = 2, + seq_len: int = 16, + bias: bool = True, + use_per_tensor_scale: bool = True, + dtype: torch.dtype = torch.float32, + ): + self.in_features = in_features + self.out_features = out_features + self.batch_size = batch_size + self.seq_len = seq_len + self.bias = bias + self.use_per_tensor_scale = use_per_tensor_scale + self.dtype = dtype + + parts = ["nvfp4_quantized_linear"] + if not bias: + parts.append("no_bias") + if not use_per_tensor_scale: + parts.append("no_pts") + if dtype != torch.float32: + parts.append(str(dtype).split(".")[-1]) + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["NVFP4QuantizedLinearTest"]: + return [ + cls(), + cls(bias=False), + cls(use_per_tensor_scale=False), + cls(bias=False, use_per_tensor_scale=False), + cls(dtype=torch.bfloat16), + cls(bias=False, dtype=torch.bfloat16), + cls(use_per_tensor_scale=False, dtype=torch.bfloat16), + cls(bias=False, use_per_tensor_scale=False, dtype=torch.bfloat16), + ] + + def get_edge_compile_config(self): + from executorch.exir import EdgeCompileConfig + + return EdgeCompileConfig(_check_ir_validity=False) + + def create_model(self) -> nn.Module: + model = NVFP4QuantizedLinearModel( + self.in_features, self.out_features, bias=self.bias + ) + model = model.to(self.dtype) + + from executorch.extension.llm.export.nvfp4 import ExportableNVFP4Config + from torchao.quantization import quantize_ + + quantize_( + model, + ExportableNVFP4Config(use_per_tensor_scale=self.use_per_tensor_scale), + ) + + return model + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, self.seq_len, self.in_features, dtype=self.dtype + ) return (x,) diff --git a/backends/mlx/third-party/mlx b/backends/mlx/third-party/mlx index 365d6f29b47..ce45c52505c 160000 --- a/backends/mlx/third-party/mlx +++ b/backends/mlx/third-party/mlx @@ -1 +1 @@ -Subproject commit 365d6f29b47686a9f5401f6a9ec5825fee162d69 +Subproject commit ce45c52505c8158ea48d2a54e8caae05efd86bfe diff --git a/extension/llm/export/nvfp4.py b/extension/llm/export/nvfp4.py new file mode 100644 index 00000000000..feeb95f50a6 --- /dev/null +++ b/extension/llm/export/nvfp4.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +NVFP4 export-compatible quantization. + +Upstream NVFP4Tensor's dequantize() uses raw Python ops that don't survive +run_decompositions. This module registers a torch.library custom op +(torchao::dequantize_nvfp4) so the dequant node persists through export, +similar to how dequantize_affine works for int4. + +Usage: + from executorch.extension.llm.export.nvfp4 import ExportableNVFP4Config + from torchao.quantization import quantize_ + + quantize_(model, ExportableNVFP4Config()) +""" + +import types +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch import Tensor +from torchao.core.config import AOBaseConfig +from torchao.prototype.mx_formats.kernels import f4_unpacked_to_f32, unpack_uint4 +from torchao.prototype.mx_formats.nvfp4_tensor import ( + nvfp4_quantize, + per_tensor_amax_to_scale, +) +from torchao.quantization.quant_api import _quantization_type +from torchao.quantization.transform_module import register_quantize_module_handler +from torchao.utils import TorchAOBaseTensor + +aten = torch.ops.aten + + +@torch.library.custom_op("torchao::dequantize_nvfp4", mutates_args=()) +def nvfp4_dequantize( + qdata: Tensor, + scale: Tensor, + per_tensor_scale: Tensor, + block_size: int, + output_dtype: torch.dtype = torch.float32, +) -> Tensor: + """Dequantize NVFP4 packed data.""" + data_unpacked = unpack_uint4(qdata.view(torch.uint8).contiguous()) + data_f32 = f4_unpacked_to_f32(data_unpacked) + + M = data_f32.shape[0] + K = data_f32.shape[1] + + data_f32 = data_f32.view(M, K // block_size, block_size) + scale_fp8 = scale.view(torch.float8_e4m3fn) + scale_f32 = scale_fp8.to(torch.float32).view(M, K // block_size, 1) + scale_f32 = per_tensor_scale * scale_f32 + result = (data_f32 * scale_f32).view(M, K) + return result.to(output_dtype) + + +@nvfp4_dequantize.register_fake +def _(qdata, scale, per_tensor_scale, block_size, output_dtype=torch.float32): + M = qdata.shape[0] + K = qdata.shape[1] * 8 # 8 FP4 values per uint32 + return torch.empty(M, K, dtype=output_dtype, device=qdata.device) + + +class ExportableNVFP4Tensor(TorchAOBaseTensor): + """NVFP4 tensor subclass that dequantizes via a registered custom op.""" + + tensor_data_names = ["qdata", "scale", "per_tensor_scale"] + tensor_attribute_names = ["block_size", "orig_dtype"] + + def __new__(cls, qdata, scale, per_tensor_scale, block_size, orig_dtype): + K = qdata.shape[-1] * 8 # 8 FP4 values per uint32 + shape = (qdata.shape[0], K) + self = torch.Tensor._make_wrapper_subclass( + cls, shape, dtype=orig_dtype, device=qdata.device, requires_grad=False + ) + self.qdata = qdata + self.scale = scale + self.per_tensor_scale = per_tensor_scale + self.block_size = block_size + self.orig_dtype = orig_dtype + return self + + def dequantize(self, output_dtype=None): + dtype = output_dtype or self.orig_dtype + return torch.ops.torchao.dequantize_nvfp4( + self.qdata, + self.scale, + self.per_tensor_scale, + self.block_size, + output_dtype=dtype, + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + +implements = ExportableNVFP4Tensor.implements + + +@implements([aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor = args[0] + weight_tensor = args[1] + bias = args[2] if len(args) > 2 else None + weight_dequant = weight_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_dequant, bias) + + +@implements([aten.embedding.default]) +def _(func, types, args, kwargs): + weight_tensor = args[0] + indices = args[1] + weight_dequant = weight_tensor.dequantize() + return torch.nn.functional.embedding(indices, weight_dequant) + + +@implements([aten.t.default]) +def _(func, types, args, kwargs): + return args[0].dequantize().t() + + +@implements([aten.detach.default]) +def _(func, types, args, kwargs): + return args[0] + + +@implements([aten._to_copy.default]) +def _(func, types, args, kwargs): + dtype = kwargs.get("dtype", args[0].orig_dtype) + return args[0].dequantize(output_dtype=dtype) + + +@dataclass +class ExportableNVFP4Config(AOBaseConfig): + """NVFP4 weight-only quantization config for torch.export.""" + + use_per_tensor_scale: bool = True + + +def _linear_extra_repr(self): + return ( + f"in_features={self.weight.shape[1]}, " + f"out_features={self.weight.shape[0]}, " + f"weight={_quantization_type(self.weight)}" + ) + + +@register_quantize_module_handler(ExportableNVFP4Config) +def _exportable_nvfp4_transform(module: nn.Module, config: ExportableNVFP4Config): + weight = module.weight + + if weight.shape[-2] % 16 != 0 or weight.shape[-1] % 16 != 0: + raise RuntimeError( + f"NVFP4 requires weight dims divisible by 16, got {weight.shape}" + ) + + per_tensor_scale = 1.0 + if config.use_per_tensor_scale: + tensor_amax = torch.max(torch.abs(weight)) + per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) + + scales_fp8, qdata_packed = nvfp4_quantize( + weight, block_size=16, per_tensor_scale=per_tensor_scale + ) + + qdata_u32 = qdata_packed.view(torch.uint32) + scales_u8 = scales_fp8.view(torch.uint8) + + pts = torch.tensor(per_tensor_scale, dtype=torch.float32) + quantized_weight = ExportableNVFP4Tensor( + qdata_u32, + scales_u8, + pts, + block_size=16, + orig_dtype=weight.dtype, + ) + module.weight = nn.Parameter(quantized_weight, requires_grad=False) + if isinstance(module, nn.Linear): + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module