Skip to content

Commit f36eb60

Browse files
committed
init
1 parent 3a62fac commit f36eb60

3 files changed

Lines changed: 79 additions & 4 deletions

File tree

backends/cuda/CMakeLists.txt

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
110110
# Only build CUDA shims when CUDA language/toolchain is available.
111111
if(CMAKE_CUDA_COMPILER)
112112
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu
113-
runtime/shims/sort.cu runtime/shims/rand.cu
113+
runtime/shims/sort.cu
114114
)
115115
endif()
116116

@@ -152,7 +152,7 @@ endif()
152152
# retention.
153153
if(_cuda_is_msvc_toolchain)
154154
target_link_libraries(
155-
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart CUDA::curand
155+
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart
156156
${CMAKE_DL_LIBS}
157157
)
158158
# Link object library directly so symbols are pulled exactly once while
@@ -163,7 +163,7 @@ else()
163163
aoti_cuda_shims
164164
PRIVATE cuda_platform
165165
PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive
166-
CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS}
166+
CUDA::cudart ${CMAKE_DL_LIBS}
167167
)
168168
endif()
169169

@@ -177,6 +177,74 @@ install(
177177
DESTINATION lib
178178
)
179179

180+
# CUDA-specific AOTI sampler shim symbols (rand/randint via curand). Split out
181+
# of aoti_cuda_shims so the curand fatbin (~3.5MB precalc tables + Philox
182+
# kernels per arch) and the CUDA::curand dependency are only paid by the
183+
# small set of consumers that actually use them (e.g. qwen3_5_moe). Other
184+
# CUDA examples (voxtral, parakeet, whisper, dinov2, ...) link only
185+
# aoti_cuda_shims and stay small.
186+
if(CMAKE_CUDA_COMPILER)
187+
add_library(aoti_cuda_sampler_shims SHARED runtime/shims/rand.cu)
188+
189+
# Match aoti_cuda_shims preprocessor defines for symbol export.
190+
target_compile_definitions(aoti_cuda_sampler_shims PRIVATE CUDA_AVAILABLE=1)
191+
if(WIN32)
192+
target_compile_definitions(
193+
aoti_cuda_sampler_shims PRIVATE EXPORT_AOTI_FUNCTIONS
194+
)
195+
if(_cuda_is_windows_msvc)
196+
set_target_properties(
197+
aoti_cuda_sampler_shims PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS OFF
198+
)
199+
endif()
200+
endif()
201+
202+
target_include_directories(
203+
aoti_cuda_sampler_shims
204+
PUBLIC ${CUDAToolkit_INCLUDE_DIRS} $<BUILD_INTERFACE:${EXECUTORCH_ROOT}>
205+
$<INSTALL_INTERFACE:include>
206+
)
207+
208+
target_compile_options(
209+
aoti_cuda_sampler_shims
210+
PUBLIC "$<$<COMPILE_LANGUAGE:CXX>:${_cuda_cxx_compile_options}>"
211+
)
212+
213+
if(_cuda_export_dynamic_option)
214+
target_link_options(
215+
aoti_cuda_sampler_shims PUBLIC ${_cuda_export_dynamic_option}
216+
)
217+
endif()
218+
219+
# rand.cu calls into slim helpers (empty_strided, getCurrentCUDAStream,
220+
# SlimTensor) which are linked into aoti_cuda_shims. Depend on that target
221+
# so we resolve those symbols from the already-loaded shims library
222+
# instead of duplicating slim's static archive into both DLLs.
223+
if(_cuda_is_msvc_toolchain)
224+
target_link_libraries(
225+
aoti_cuda_sampler_shims
226+
PRIVATE cuda_platform CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS}
227+
aoti_cuda_shims
228+
)
229+
else()
230+
target_link_libraries(
231+
aoti_cuda_sampler_shims
232+
PRIVATE cuda_platform
233+
PUBLIC CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS} aoti_cuda_shims
234+
)
235+
endif()
236+
237+
if(NOT _cuda_is_msvc_toolchain)
238+
executorch_target_link_options_shared_lib(aoti_cuda_sampler_shims)
239+
endif()
240+
241+
install(
242+
TARGETS aoti_cuda_sampler_shims
243+
EXPORT ExecuTorchTargets
244+
DESTINATION lib
245+
)
246+
endif()
247+
180248
# CUDA backend implementation
181249
set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp)
182250

backends/cuda/runtime/shims/tests/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,9 @@ foreach(test_name ${CUDA_SHIM_TESTS})
6767

6868
add_test(NAME ${test_name} COMMAND ${test_name})
6969
endforeach()
70+
71+
# rand symbols live in the separate aoti_cuda_sampler_shims DLL to keep the
72+
# curand-induced binary-size cost out of aoti_cuda_shims.
73+
target_link_libraries(
74+
test_aoti_torch_cuda_rand PRIVATE aoti_cuda_sampler_shims
75+
)

examples/models/qwen3_5_moe/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ if(EXECUTORCH_BUILD_METAL)
4848
executorch_target_link_options_shared_lib(metal_backend)
4949
elseif(EXECUTORCH_BUILD_CUDA)
5050
find_package(CUDAToolkit REQUIRED)
51-
list(APPEND link_libraries aoti_cuda_backend)
51+
list(APPEND link_libraries aoti_cuda_backend aoti_cuda_sampler_shims)
5252
executorch_target_link_options_shared_lib(aoti_cuda_backend)
53+
executorch_target_link_options_shared_lib(aoti_cuda_sampler_shims)
5354
add_compile_definitions(EXECUTORCH_BUILD_CUDA)
5455
else()
5556
message(

0 commit comments

Comments
 (0)