@@ -110,7 +110,7 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
110110# Only build CUDA shims when CUDA language/toolchain is available.
111111if (CMAKE_CUDA_COMPILER)
112112 list (APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu
113- runtime/shims/sort .cu runtime/shims/rand.cu
113+ runtime/shims/sort .cu
114114 )
115115endif ()
116116
@@ -152,7 +152,7 @@ endif()
152152# retention.
153153if (_cuda_is_msvc_toolchain)
154154 target_link_libraries (
155- aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart CUDA::curand
155+ aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart
156156 ${CMAKE_DL_LIBS}
157157 )
158158 # Link object library directly so symbols are pulled exactly once while
@@ -163,7 +163,7 @@ else()
163163 aoti_cuda_shims
164164 PRIVATE cuda_platform
165165 PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive
166- CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS}
166+ CUDA::cudart ${CMAKE_DL_LIBS}
167167 )
168168endif ()
169169
@@ -177,6 +177,74 @@ install(
177177 DESTINATION lib
178178)
179179
180+ # CUDA-specific AOTI sampler shim symbols (rand/randint via curand). Split out
181+ # of aoti_cuda_shims so the curand fatbin (~3.5MB precalc tables + Philox
182+ # kernels per arch) and the CUDA::curand dependency are only paid by the
183+ # small set of consumers that actually use them (e.g. qwen3_5_moe). Other
184+ # CUDA examples (voxtral, parakeet, whisper, dinov2, ...) link only
185+ # aoti_cuda_shims and stay small.
186+ if (CMAKE_CUDA_COMPILER)
187+ add_library (aoti_cuda_sampler_shims SHARED runtime/shims/rand.cu )
188+
189+ # Match aoti_cuda_shims preprocessor defines for symbol export.
190+ target_compile_definitions (aoti_cuda_sampler_shims PRIVATE CUDA_AVAILABLE=1 )
191+ if (WIN32 )
192+ target_compile_definitions (
193+ aoti_cuda_sampler_shims PRIVATE EXPORT_AOTI_FUNCTIONS
194+ )
195+ if (_cuda_is_windows_msvc)
196+ set_target_properties (
197+ aoti_cuda_sampler_shims PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS OFF
198+ )
199+ endif ()
200+ endif ()
201+
202+ target_include_directories (
203+ aoti_cuda_sampler_shims
204+ PUBLIC ${CUDAToolkit_INCLUDE_DIRS} $<BUILD_INTERFACE :${EXECUTORCH_ROOT} >
205+ $<INSTALL_INTERFACE :include >
206+ )
207+
208+ target_compile_options (
209+ aoti_cuda_sampler_shims
210+ PUBLIC "$<$<COMPILE_LANGUAGE :CXX >:${_cuda_cxx_compile_options} >"
211+ )
212+
213+ if (_cuda_export_dynamic_option)
214+ target_link_options (
215+ aoti_cuda_sampler_shims PUBLIC ${_cuda_export_dynamic_option}
216+ )
217+ endif ()
218+
219+ # rand.cu calls into slim helpers (empty_strided, getCurrentCUDAStream,
220+ # SlimTensor) which are linked into aoti_cuda_shims. Depend on that target
221+ # so we resolve those symbols from the already-loaded shims library
222+ # instead of duplicating slim's static archive into both DLLs.
223+ if (_cuda_is_msvc_toolchain)
224+ target_link_libraries (
225+ aoti_cuda_sampler_shims
226+ PRIVATE cuda_platform CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS}
227+ aoti_cuda_shims
228+ )
229+ else ()
230+ target_link_libraries (
231+ aoti_cuda_sampler_shims
232+ PRIVATE cuda_platform
233+ PUBLIC CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS} aoti_cuda_shims
234+ )
235+ endif ()
236+
237+ if (NOT _cuda_is_msvc_toolchain)
238+ executorch_target_link_options_shared_lib (aoti_cuda_sampler_shims )
239+ endif ()
240+
241+ install (
242+ TARGETS aoti_cuda_sampler_shims
243+ EXPORT ExecuTorchTargets
244+ DESTINATION lib
245+ )
246+ endif ()
247+
180248# CUDA backend implementation
181249set (_aoti_cuda_backend_sources runtime/cuda_backend.cpp)
182250
0 commit comments