Skip to content

Commit 168794b

Browse files
committed
Add CUDA sort shim for AOTI export (thrust-based sort_stable fallback)
Inductor emits aten::sort.stable for ops like argsort, but lacks a native c-shim for it. This adds a thrust-based implementation (aoti_torch_cuda_sort_stable) that handles int64, int32, and float32 dtypes on contiguous innermost-dim tensors. Registered as a supported fallback kernel in CudaBackend so AOTI-compiled models can use sort. This PR was authored with the assistance of Claude.
1 parent 87e65ac commit 168794b

6 files changed

Lines changed: 624 additions & 1 deletion

File tree

backends/cuda/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
109109

110110
# Only build int4mm shim when CUDA language/toolchain is available.
111111
if(CMAKE_CUDA_COMPILER)
112-
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu)
112+
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu
113+
runtime/shims/sort.cu
114+
)
113115
endif()
114116

115117
add_library(aoti_cuda_shims SHARED ${_aoti_cuda_shim_sources})

backends/cuda/cuda_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def save_data_externally(cls) -> bool:
145145
def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
146146
return {
147147
"at::_ops::_weight_int4pack_mm::call": None,
148+
"at::_ops::sort_stable::call": None,
148149
}
149150

150151
@classmethod

backends/cuda/runtime/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@ runtime.cxx_library(
3333
"shims/cuda_guard.cpp",
3434
"shims/int4mm.cu",
3535
"shims/memory.cpp",
36+
"shims/sort.cu",
3637
"shims/tensor_attribute.cpp",
3738
],
3839
headers = [
3940
"shims/cuda_guard.h",
4041
"shims/int4mm.cuh",
4142
"shims/int4mm.h",
4243
"shims/memory.h",
44+
"shims/sort.h",
4345
"shims/tensor_attribute.h",
4446
"utils.h",
4547
],

0 commit comments

Comments
 (0)