Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 80 additions & 5 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ set(_aoti_cuda_shim_sources runtime/cuda_allocator.cpp runtime/shims/memory.cpp
# Only build CUDA shims when CUDA language/toolchain is available.
if(CMAKE_CUDA_COMPILER)
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu
runtime/shims/int4_plain_mm.cu runtime/shims/sort.cu
runtime/shims/rand.cu
runtime/shims/sort.cu runtime/shims/int4_plain_mm.cu
)
endif()

Expand Down Expand Up @@ -153,8 +152,7 @@ endif()
# retention.
if(_cuda_is_msvc_toolchain)
target_link_libraries(
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart CUDA::curand
${CMAKE_DL_LIBS}
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart ${CMAKE_DL_LIBS}
)
# Link object library directly so symbols are pulled exactly once while
# avoiding duplicate static/object inclusion and interface leakage.
Expand All @@ -164,7 +162,7 @@ else()
aoti_cuda_shims
PRIVATE cuda_platform
PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive
CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS}
CUDA::cudart ${CMAKE_DL_LIBS}
)
endif()

Expand All @@ -178,6 +176,83 @@ install(
DESTINATION lib
)

# CUDA-specific AOTI sampler shim symbols (rand/randint via curand). Split out
# of aoti_cuda_shims so the curand fatbin (~3.5MB precalc tables + Philox
# kernels per arch) and the CUDA::curand dependency are only paid by the small
# set of consumers that actually use them (e.g. qwen3_5_moe). Other CUDA
# examples (voxtral, parakeet, whisper, dinov2, ...) link only aoti_cuda_shims
# and stay small.
if(CMAKE_CUDA_COMPILER)
add_library(aoti_cuda_sampler_shims SHARED runtime/shims/rand.cu)

# Match aoti_cuda_shims preprocessor defines for symbol export.
target_compile_definitions(aoti_cuda_sampler_shims PRIVATE CUDA_AVAILABLE=1)
if(WIN32)
target_compile_definitions(
aoti_cuda_sampler_shims PRIVATE EXPORT_AOTI_FUNCTIONS
)
if(_cuda_is_windows_msvc)
set_target_properties(
aoti_cuda_sampler_shims PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS OFF
)
endif()
endif()

target_include_directories(
aoti_cuda_sampler_shims
PUBLIC ${CUDAToolkit_INCLUDE_DIRS} $<BUILD_INTERFACE:${EXECUTORCH_ROOT}>
$<INSTALL_INTERFACE:include>
)

target_compile_options(
aoti_cuda_sampler_shims
PUBLIC "$<$<COMPILE_LANGUAGE:CXX>:${_cuda_cxx_compile_options}>"
)

if(_cuda_export_dynamic_option)
target_link_options(
aoti_cuda_sampler_shims PUBLIC ${_cuda_export_dynamic_option}
)
endif()

# rand.cu calls into slim helpers (empty_strided, getCurrentCUDAStream,
# SlimTensor) which are linked into aoti_cuda_shims. Depend on that target so
# we resolve those symbols from the already-loaded shims library instead of
# duplicating slim's static archive into both DLLs.
#
# Also link `slimtensor` (INTERFACE / header-only) directly so the c10 include
# root (runtime/core/portable_type/c10) is on this target's compile command.
# aoti_cuda_shims links aoti_common_shims_slim PUBLIC on non-MSVC (so includes
# propagate transitively on Linux) but only PRIVATELY via the *_obj OBJECT lib
# on MSVC, which does NOT forward the slimtensor INTERFACE include dirs.
# Linking slimtensor here makes the include path explicit on both toolchains
# and keeps Windows MSVC happy without changing aoti_cuda_shims' propagation
# semantics.
if(_cuda_is_msvc_toolchain)
target_link_libraries(
aoti_cuda_sampler_shims
PRIVATE cuda_platform CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS}
aoti_cuda_shims slimtensor
)
else()
target_link_libraries(
aoti_cuda_sampler_shims
PRIVATE cuda_platform slimtensor
PUBLIC CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS} aoti_cuda_shims
)
endif()

if(NOT _cuda_is_msvc_toolchain)
executorch_target_link_options_shared_lib(aoti_cuda_sampler_shims)
endif()

install(
TARGETS aoti_cuda_sampler_shims
EXPORT ExecuTorchTargets
DESTINATION lib
)
endif()

# CUDA backend implementation
set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp)
if(_cuda_is_msvc_toolchain)
Expand Down
4 changes: 4 additions & 0 deletions backends/cuda/runtime/shims/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,7 @@ foreach(test_name ${CUDA_KERNEL_TESTS})

add_test(NAME ${test_name} COMMAND ${test_name})
endforeach()

# rand symbols live in the separate aoti_cuda_sampler_shims DLL to keep the
# curand-induced binary-size cost out of aoti_cuda_shims.
target_link_libraries(test_aoti_torch_cuda_rand PRIVATE aoti_cuda_sampler_shims)
3 changes: 2 additions & 1 deletion examples/models/qwen3_5_moe/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ if(EXECUTORCH_BUILD_METAL)
executorch_target_link_options_shared_lib(metal_backend)
elseif(EXECUTORCH_BUILD_CUDA)
find_package(CUDAToolkit REQUIRED)
list(APPEND link_libraries aoti_cuda_backend)
list(APPEND link_libraries aoti_cuda_backend aoti_cuda_sampler_shims)
executorch_target_link_options_shared_lib(aoti_cuda_backend)
executorch_target_link_options_shared_lib(aoti_cuda_sampler_shims)
add_compile_definitions(EXECUTORCH_BUILD_CUDA)
else()
message(
Expand Down
Loading