Skip to content

Commit eb9cc01

Browse files
Ishan GodawattaIshanG97
authored andcommitted
feat(mlx): add handler for aten.roll
Maps torch.roll to mlx::core::roll via a new RollNode. Adds the schema table, the custom handler for the (shifts, dims) args, the exec_roll runtime, and test cases covering 1D, 2D, multi-axis, negative shifts, and negative dims. Flat roll (dims=[]) is explicitly NotImplementedError for now; all known use cases (Swin Transformer shift-window attention) pass dims. Fixes #18919
1 parent 3be4546 commit eb9cc01

4 files changed

Lines changed: 115 additions & 1 deletion

File tree

backends/mlx/ops.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
RepeatNode,
118118
ReshapeNode,
119119
RMSNormNode,
120+
RollNode,
120121
RopeNode,
121122
RoundNode,
122123
RsqrtNode,
@@ -1678,6 +1679,45 @@ def _repeat_handler(P: MLXProgramBuilder, n: Node) -> Slot:
16781679
return out
16791680

16801681

1682+
@REGISTRY.register(target=[torch.ops.aten.roll.default])
1683+
def _roll_handler(P: MLXProgramBuilder, n: Node) -> Slot:
1684+
args = P.args(n)
1685+
require_args(args, 2, 3, "aten.roll")
1686+
require_kwargs(P.kwargs(n), set(), "aten.roll")
1687+
x = args[0]
1688+
shifts_arg = args[1]
1689+
dims_arg = args[2] if len(args) > 2 else []
1690+
1691+
shifts = [shifts_arg] if isinstance(shifts_arg, int) else list(shifts_arg)
1692+
dims: List[int] = [dims_arg] if isinstance(dims_arg, int) else list(dims_arg)
1693+
1694+
# Flat roll (torch.roll with dims=[]) would require reshape + roll +
1695+
# reshape at the graph level. Not yet supported; Swin-style usage always
1696+
# passes explicit dims.
1697+
if not dims:
1698+
raise NotImplementedError(
1699+
"aten.roll without dims (flat roll) is not supported by the MLX "
1700+
"delegate yet."
1701+
)
1702+
if len(shifts) != len(dims):
1703+
raise ValueError(
1704+
f"aten.roll: shifts and dims must have the same length, got "
1705+
f"shifts={shifts} (len={len(shifts)}) dims={dims} (len={len(dims)})"
1706+
)
1707+
require_static_ints(dims, "dims", "aten.roll")
1708+
1709+
out = P.make_or_get_slot(n)
1710+
P.emit(
1711+
RollNode(
1712+
x=P.slot_to_tid(x),
1713+
out=P.slot_to_tid(out),
1714+
shift=[P.to_int_or_vid(s) for s in shifts],
1715+
axes=dims,
1716+
)
1717+
)
1718+
return out
1719+
1720+
16811721
@REGISTRY.register(target=[torch.ops.aten.index.Tensor])
16821722
def _index_handler(P: MLXProgramBuilder, n: Node) -> Slot:
16831723
args = P.args(n)

backends/mlx/runtime/MLXInterpreter.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1733,6 +1733,13 @@ inline void exec_all(const AllNode& n, ExecutionState& st, StreamOrDevice s) {
17331733
}
17341734
}
17351735

1736+
inline void exec_roll(const RollNode& n, ExecutionState& st, StreamOrDevice s) {
1737+
const auto& x = st.const_tensor_ref(n.x);
1738+
auto shifts = to_shape(n.shift, st);
1739+
std::vector<int> axes(n.axes.begin(), n.axes.end());
1740+
st.set_tensor(n.out, roll(x, shifts, axes, s));
1741+
}
1742+
17361743
inline void
17371744
exec_repeat(const RepeatNode& n, ExecutionState& st, StreamOrDevice s) {
17381745
const auto& x = st.const_tensor_ref(n.x);
@@ -2210,6 +2217,9 @@ class Interpreter {
22102217
case OpCode::REPEAT:
22112218
ops::exec_repeat(std::get<RepeatNode>(instr.node), st, s);
22122219
break;
2220+
case OpCode::ROLL:
2221+
ops::exec_roll(std::get<RollNode>(instr.node), st, s);
2222+
break;
22132223
case OpCode::SORT:
22142224
ops::exec_sort(std::get<SortNode>(instr.node), st, s);
22152225
break;

backends/mlx/serialization/schema.fbs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,16 @@ table ArgPartitionNode {
673673
axis: int32;
674674
}
675675

676+
// Shift tensor elements along specified axes with wrap-around.
677+
// Maps to mlx::core::roll(a, shifts, axes).
678+
// Flat roll (torch.roll with dims=None) is not yet supported.
679+
table RollNode {
680+
x: Tid (required);
681+
out: Tid (required);
682+
shift: [IntOrVid] (required); // Shift amount per axis (can be dynamic)
683+
axes: [int32] (required); // Axes to roll along; len(shift) == len(axes)
684+
}
685+
676686

677687
// =============================================================================
678688
// Math ops - Unary element-wise
@@ -1119,7 +1129,8 @@ union OpNode {
11191129
GatherQmmNode,
11201130
ScanNode,
11211131
MetalKernelNode,
1122-
BitwiseInvertNode
1132+
BitwiseInvertNode,
1133+
RollNode
11231134
// BC: Add new op nodes here (append only)
11241135
}
11251136

backends/mlx/test/test_ops.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,59 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]:
855855
return (x,)
856856

857857

858+
class RollModel(nn.Module):
859+
"""Model that rolls a tensor along specified dimensions."""
860+
861+
def __init__(self, shifts: Tuple[int, ...], dims: Tuple[int, ...]):
862+
super().__init__()
863+
self.shifts = shifts
864+
self.dims = dims
865+
866+
def forward(self, x: torch.Tensor) -> torch.Tensor:
867+
return torch.roll(x, shifts=self.shifts, dims=self.dims)
868+
869+
870+
@register_test
871+
class RollTest(OpTestCase):
872+
"""Test case for torch.roll()."""
873+
874+
name = "roll"
875+
rtol = 1e-5
876+
atol = 1e-5
877+
878+
def __init__(
879+
self,
880+
input_shape: Tuple[int, ...] = (4, 5),
881+
shifts: Tuple[int, ...] = (1,),
882+
dims: Tuple[int, ...] = (0,),
883+
):
884+
self.input_shape = input_shape
885+
self.shifts = shifts
886+
self.dims = dims
887+
shift_str = ",".join(str(s) for s in shifts)
888+
dim_str = ",".join(str(d) for d in dims)
889+
self.name = f"roll_shift({shift_str})_dim({dim_str})"
890+
891+
@classmethod
892+
def get_test_configs(cls) -> List["RollTest"]:
893+
return [
894+
cls(input_shape=(8,), shifts=(2,), dims=(0,)),
895+
cls(input_shape=(4, 5), shifts=(1,), dims=(0,)),
896+
cls(input_shape=(4, 5), shifts=(-2,), dims=(1,)),
897+
cls(input_shape=(3, 4, 5), shifts=(3,), dims=(2,)),
898+
cls(input_shape=(3, 4, 5), shifts=(1, 2), dims=(0, 2)),
899+
cls(input_shape=(3, 4, 5), shifts=(-1, -2, -3), dims=(0, 1, 2)),
900+
cls(input_shape=(3, 4, 5), shifts=(2,), dims=(-1,)),
901+
]
902+
903+
def create_model(self) -> nn.Module:
904+
return RollModel(self.shifts, self.dims)
905+
906+
def create_inputs(self) -> Tuple[torch.Tensor, ...]:
907+
x = torch.randn(self.input_shape)
908+
return (x,)
909+
910+
858911
class CatNModel(nn.Module):
859912
"""Model that concatenates N tensors along a dimension."""
860913

0 commit comments

Comments
 (0)