Commit 53b77ef
Ishan Godawatta
feat(mlx): add handler for aten.roll
Maps torch.roll to mlx::core::roll via a new RollNode. Adds the schema
table, the custom handler for the (shifts, dims) args, the exec_roll
runtime, and test cases covering 1D, 2D, multi-axis, negative shifts,
and negative dims.
Flat roll (dims=[]) is explicitly NotImplementedError for now; all
known use cases (Swin Transformer shift-window attention) pass dims.
Fixes #18919
Authored-with: Claude <noreply@anthropic.com>1 parent 54b0148 commit 53b77ef
4 files changed
Lines changed: 115 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
116 | 116 | | |
117 | 117 | | |
118 | 118 | | |
| 119 | + | |
119 | 120 | | |
120 | 121 | | |
121 | 122 | | |
| |||
1677 | 1678 | | |
1678 | 1679 | | |
1679 | 1680 | | |
| 1681 | + | |
| 1682 | + | |
| 1683 | + | |
| 1684 | + | |
| 1685 | + | |
| 1686 | + | |
| 1687 | + | |
| 1688 | + | |
| 1689 | + | |
| 1690 | + | |
| 1691 | + | |
| 1692 | + | |
| 1693 | + | |
| 1694 | + | |
| 1695 | + | |
| 1696 | + | |
| 1697 | + | |
| 1698 | + | |
| 1699 | + | |
| 1700 | + | |
| 1701 | + | |
| 1702 | + | |
| 1703 | + | |
| 1704 | + | |
| 1705 | + | |
| 1706 | + | |
| 1707 | + | |
| 1708 | + | |
| 1709 | + | |
| 1710 | + | |
| 1711 | + | |
| 1712 | + | |
| 1713 | + | |
| 1714 | + | |
| 1715 | + | |
| 1716 | + | |
| 1717 | + | |
| 1718 | + | |
| 1719 | + | |
1680 | 1720 | | |
1681 | 1721 | | |
1682 | 1722 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1726 | 1726 | | |
1727 | 1727 | | |
1728 | 1728 | | |
| 1729 | + | |
| 1730 | + | |
| 1731 | + | |
| 1732 | + | |
| 1733 | + | |
| 1734 | + | |
| 1735 | + | |
1729 | 1736 | | |
1730 | 1737 | | |
1731 | 1738 | | |
| |||
2199 | 2206 | | |
2200 | 2207 | | |
2201 | 2208 | | |
| 2209 | + | |
| 2210 | + | |
| 2211 | + | |
2202 | 2212 | | |
2203 | 2213 | | |
2204 | 2214 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
668 | 668 | | |
669 | 669 | | |
670 | 670 | | |
| 671 | + | |
| 672 | + | |
| 673 | + | |
| 674 | + | |
| 675 | + | |
| 676 | + | |
| 677 | + | |
| 678 | + | |
| 679 | + | |
| 680 | + | |
671 | 681 | | |
672 | 682 | | |
673 | 683 | | |
| |||
1113 | 1123 | | |
1114 | 1124 | | |
1115 | 1125 | | |
1116 | | - | |
| 1126 | + | |
| 1127 | + | |
1117 | 1128 | | |
1118 | 1129 | | |
1119 | 1130 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
855 | 855 | | |
856 | 856 | | |
857 | 857 | | |
| 858 | + | |
| 859 | + | |
| 860 | + | |
| 861 | + | |
| 862 | + | |
| 863 | + | |
| 864 | + | |
| 865 | + | |
| 866 | + | |
| 867 | + | |
| 868 | + | |
| 869 | + | |
| 870 | + | |
| 871 | + | |
| 872 | + | |
| 873 | + | |
| 874 | + | |
| 875 | + | |
| 876 | + | |
| 877 | + | |
| 878 | + | |
| 879 | + | |
| 880 | + | |
| 881 | + | |
| 882 | + | |
| 883 | + | |
| 884 | + | |
| 885 | + | |
| 886 | + | |
| 887 | + | |
| 888 | + | |
| 889 | + | |
| 890 | + | |
| 891 | + | |
| 892 | + | |
| 893 | + | |
| 894 | + | |
| 895 | + | |
| 896 | + | |
| 897 | + | |
| 898 | + | |
| 899 | + | |
| 900 | + | |
| 901 | + | |
| 902 | + | |
| 903 | + | |
| 904 | + | |
| 905 | + | |
| 906 | + | |
| 907 | + | |
| 908 | + | |
| 909 | + | |
| 910 | + | |
858 | 911 | | |
859 | 912 | | |
860 | 913 | | |
| |||
0 commit comments