Commit 168794b
committed
Add CUDA sort shim for AOTI export (thrust-based sort_stable fallback)
Inductor emits aten::sort.stable for ops like argsort, but lacks a
native c-shim for it. This adds a thrust-based implementation
(aoti_torch_cuda_sort_stable) that handles int64, int32, and float32
dtypes on contiguous innermost-dim tensors. Registered as a supported
fallback kernel in CudaBackend so AOTI-compiled models can use sort.
This PR was authored with the assistance of Claude.1 parent 87e65ac commit 168794b
6 files changed
Lines changed: 624 additions & 1 deletion
File tree
- backends/cuda
- runtime
- shims
- tests
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
109 | 109 | | |
110 | 110 | | |
111 | 111 | | |
112 | | - | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
113 | 115 | | |
114 | 116 | | |
115 | 117 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
145 | 145 | | |
146 | 146 | | |
147 | 147 | | |
| 148 | + | |
148 | 149 | | |
149 | 150 | | |
150 | 151 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
33 | 33 | | |
34 | 34 | | |
35 | 35 | | |
| 36 | + | |
36 | 37 | | |
37 | 38 | | |
38 | 39 | | |
39 | 40 | | |
40 | 41 | | |
41 | 42 | | |
42 | 43 | | |
| 44 | + | |
43 | 45 | | |
44 | 46 | | |
45 | 47 | | |
| |||
0 commit comments