feat(mlx): add handler for aten.roll#19038
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19038
Note: Links to docs will display an error until the docs builds have been completed.
|
This PR needs a
|
53b77ef to
726c721
Compare
|
Looks fantastic @IshanG97! Let's see what CI says, but is there a reason you didn't run the test locally? Ideally, I'd like to see the output of the test in the PR summary :) |
|
Ignore all of the test-mlx-llm result failures in CI, they have to do with an HF token issue on external contributions. |
Hi @metascroy! Sorry for the late response, I didn't initially because I couldn't get the latest version of Xcode on my machine, I had to use a proxy (to be honest, I was relying on the CI to catch anything haha!). I've now downloaded Xcode and run the tests locally, see my trimmed output: |
Maps torch.roll to mlx::core::roll via a new RollNode. Adds the schema table, the custom handler for the (shifts, dims) args, the exec_roll runtime, and test cases covering 1D, 2D, multi-axis, negative shifts, and negative dims. Flat roll (dims=[]) is explicitly NotImplementedError for now; all known use cases (Swin Transformer shift-window attention) pass dims. Fixes pytorch#18919
b2f3bf0 to
eb9cc01
Compare
Summary
Adds an MLX delegate handler for
aten.roll, mappingtorch.rollontomlx::core::rollvia a newRollNodein the schema. Replaces the default decomposition (index_select + arange + cat) with a single native kernel — needed by Swin Transformer's shift-window attention.Flat roll (
dims=[]) raisesNotImplementedErrorfor now; no known consumer needs it yet.Generated files (
MLXLoader.*,schema_generated.h,mlx_graph_schema.py,_generated_serializers.py,_generated_inspector.py,_generated/) are regenerated fromschema.fbsbybackends/mlx/CMakeLists.txtat build time and are deliberately not committed.Fixes #18919.
Test plan
python backends/mlx/serialization/generate.py— regenerates cleanly withRollNodein all expected outputs.lintrunner --skip MYPY --paths-cmd 'git diff --name-only upstream/main'— no issues.run_all_tests -k rollnot run locally (no executorch build on this machine); relying on CI. Happy to push fixes if it finds anything.