Skip to content

Commit 53b77ef

Browse files
author
Ishan Godawatta
committed
feat(mlx): add handler for aten.roll
Maps torch.roll to mlx::core::roll via a new RollNode. Adds the schema table, the custom handler for the (shifts, dims) args, the exec_roll runtime, and test cases covering 1D, 2D, multi-axis, negative shifts, and negative dims. Flat roll (dims=[]) is explicitly NotImplementedError for now; all known use cases (Swin Transformer shift-window attention) pass dims. Fixes #18919 Authored-with: Claude <noreply@anthropic.com>
1 parent 54b0148 commit 53b77ef

4 files changed

Lines changed: 115 additions & 1 deletion

File tree

backends/mlx/ops.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
RepeatNode,
117117
ReshapeNode,
118118
RMSNormNode,
119+
RollNode,
119120
RopeNode,
120121
RoundNode,
121122
RsqrtNode,
@@ -1677,6 +1678,45 @@ def _repeat_handler(P: MLXProgramBuilder, n: Node) -> Slot:
16771678
return out
16781679

16791680

1681+
@REGISTRY.register(target=[torch.ops.aten.roll.default])
1682+
def _roll_handler(P: MLXProgramBuilder, n: Node) -> Slot:
1683+
args = P.args(n)
1684+
require_args(args, 2, 3, "aten.roll")
1685+
require_kwargs(P.kwargs(n), set(), "aten.roll")
1686+
x = args[0]
1687+
shifts_arg = args[1]
1688+
dims_arg = args[2] if len(args) > 2 else []
1689+
1690+
shifts = [shifts_arg] if isinstance(shifts_arg, int) else list(shifts_arg)
1691+
dims: List[int] = [dims_arg] if isinstance(dims_arg, int) else list(dims_arg)
1692+
1693+
# Flat roll (torch.roll with dims=[]) would require reshape + roll +
1694+
# reshape at the graph level. Not yet supported; Swin-style usage always
1695+
# passes explicit dims.
1696+
if not dims:
1697+
raise NotImplementedError(
1698+
"aten.roll without dims (flat roll) is not supported by the MLX "
1699+
"delegate yet."
1700+
)
1701+
if len(shifts) != len(dims):
1702+
raise ValueError(
1703+
f"aten.roll: shifts and dims must have the same length, got "
1704+
f"shifts={shifts} (len={len(shifts)}) dims={dims} (len={len(dims)})"
1705+
)
1706+
require_static_ints(dims, "dims", "aten.roll")
1707+
1708+
out = P.make_or_get_slot(n)
1709+
P.emit(
1710+
RollNode(
1711+
x=P.slot_to_tid(x),
1712+
out=P.slot_to_tid(out),
1713+
shift=[P.to_int_or_vid(s) for s in shifts],
1714+
axes=dims,
1715+
)
1716+
)
1717+
return out
1718+
1719+
16801720
@REGISTRY.register(target=[torch.ops.aten.index.Tensor])
16811721
def _index_handler(P: MLXProgramBuilder, n: Node) -> Slot:
16821722
args = P.args(n)

backends/mlx/runtime/MLXInterpreter.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1726,6 +1726,13 @@ inline void exec_all(const AllNode& n, ExecutionState& st, StreamOrDevice s) {
17261726
}
17271727
}
17281728

1729+
inline void exec_roll(const RollNode& n, ExecutionState& st, StreamOrDevice s) {
1730+
const auto& x = st.const_tensor_ref(n.x);
1731+
auto shifts = to_shape(n.shift, st);
1732+
std::vector<int> axes(n.axes.begin(), n.axes.end());
1733+
st.set_tensor(n.out, roll(x, shifts, axes, s));
1734+
}
1735+
17291736
inline void
17301737
exec_repeat(const RepeatNode& n, ExecutionState& st, StreamOrDevice s) {
17311738
const auto& x = st.const_tensor_ref(n.x);
@@ -2199,6 +2206,9 @@ class Interpreter {
21992206
case OpCode::REPEAT:
22002207
ops::exec_repeat(std::get<RepeatNode>(instr.node), st, s);
22012208
break;
2209+
case OpCode::ROLL:
2210+
ops::exec_roll(std::get<RollNode>(instr.node), st, s);
2211+
break;
22022212
case OpCode::SORT:
22032213
ops::exec_sort(std::get<SortNode>(instr.node), st, s);
22042214
break;

backends/mlx/serialization/schema.fbs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,16 @@ table ArgPartitionNode {
668668
axis: int32;
669669
}
670670

671+
// Shift tensor elements along specified axes with wrap-around.
672+
// Maps to mlx::core::roll(a, shifts, axes).
673+
// Flat roll (torch.roll with dims=None) is not yet supported.
674+
table RollNode {
675+
x: Tid (required);
676+
out: Tid (required);
677+
shift: [IntOrVid] (required); // Shift amount per axis (can be dynamic)
678+
axes: [int32] (required); // Axes to roll along; len(shift) == len(axes)
679+
}
680+
671681

672682
// =============================================================================
673683
// Math ops - Unary element-wise
@@ -1113,7 +1123,8 @@ union OpNode {
11131123
GatherMmNode,
11141124
GatherQmmNode,
11151125
ScanNode,
1116-
MetalKernelNode
1126+
MetalKernelNode,
1127+
RollNode
11171128
// BC: Add new op nodes here (append only)
11181129
}
11191130

backends/mlx/test/test_ops.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,59 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]:
855855
return (x,)
856856

857857

858+
class RollModel(nn.Module):
859+
"""Model that rolls a tensor along specified dimensions."""
860+
861+
def __init__(self, shifts: Tuple[int, ...], dims: Tuple[int, ...]):
862+
super().__init__()
863+
self.shifts = shifts
864+
self.dims = dims
865+
866+
def forward(self, x: torch.Tensor) -> torch.Tensor:
867+
return torch.roll(x, shifts=self.shifts, dims=self.dims)
868+
869+
870+
@register_test
871+
class RollTest(OpTestCase):
872+
"""Test case for torch.roll()."""
873+
874+
name = "roll"
875+
rtol = 1e-5
876+
atol = 1e-5
877+
878+
def __init__(
879+
self,
880+
input_shape: Tuple[int, ...] = (4, 5),
881+
shifts: Tuple[int, ...] = (1,),
882+
dims: Tuple[int, ...] = (0,),
883+
):
884+
self.input_shape = input_shape
885+
self.shifts = shifts
886+
self.dims = dims
887+
shift_str = ",".join(str(s) for s in shifts)
888+
dim_str = ",".join(str(d) for d in dims)
889+
self.name = f"roll_shift({shift_str})_dim({dim_str})"
890+
891+
@classmethod
892+
def get_test_configs(cls) -> List["RollTest"]:
893+
return [
894+
cls(input_shape=(8,), shifts=(2,), dims=(0,)),
895+
cls(input_shape=(4, 5), shifts=(1,), dims=(0,)),
896+
cls(input_shape=(4, 5), shifts=(-2,), dims=(1,)),
897+
cls(input_shape=(3, 4, 5), shifts=(3,), dims=(2,)),
898+
cls(input_shape=(3, 4, 5), shifts=(1, 2), dims=(0, 2)),
899+
cls(input_shape=(3, 4, 5), shifts=(-1, -2, -3), dims=(0, 1, 2)),
900+
cls(input_shape=(3, 4, 5), shifts=(2,), dims=(-1,)),
901+
]
902+
903+
def create_model(self) -> nn.Module:
904+
return RollModel(self.shifts, self.dims)
905+
906+
def create_inputs(self) -> Tuple[torch.Tensor, ...]:
907+
x = torch.randn(self.input_shape)
908+
return (x,)
909+
910+
858911
class CatNModel(nn.Module):
859912
"""Model that concatenates N tensors along a dimension."""
860913

0 commit comments

Comments
 (0)