Skip to content

Commit 7968023

Browse files
leofangclaude
authored andcommitted
Fix tensor bridge DLL import failure on Windows (NVIDIA#1988)
* Fix tensor bridge DLL import failure on Windows aoti_torch_get_current_cuda_stream lives in torch_cuda.dll, not torch_cpu.dll. The stub import library pointed at the wrong DLL, causing "The specified procedure could not be found" on Windows. - Move aoti_torch_get_current_cuda_stream from aoti_shim.def (torch_cpu.dll) to new aoti_shim_cuda.def (torch_cuda.dll) - Update build_hooks.py to generate stub libs for both DLLs via a loop - Add torch_cuda.dll to delvewheel exclude list Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Add SPDX headers to aoti_shim_cuda.def Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> * Resolve aoti_torch_get_current_cuda_stream lazily at runtime The symbol lives in torch_cuda (not torch_cpu), so linking against it at build time breaks CPU-only PyTorch installs and requires a second stub import library on Windows. Instead, resolve it lazily on first use via dlsym (Linux) / LoadLibrary+GetProcAddress (Windows). The cached function pointer keeps subsequent calls fully in C with zero Python overhead. This reverts the two-def-file approach from the previous commit and replaces it with a self-contained inline C helper that handles both platforms. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 66eb984 commit 7968023

4 files changed

Lines changed: 54 additions & 11 deletions

File tree

cuda_core/build_hooks.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ def get_sources(mod_name):
186186
# On Windows, _tensor_bridge.pyx needs a stub import library so the MSVC
187187
# linker can resolve the AOTI symbols (they live in torch_cpu.dll at
188188
# runtime). We generate the .lib from a .def file at build time.
189+
# Note: aoti_torch_get_current_cuda_stream lives in torch_cuda.dll and
190+
# is resolved lazily at runtime (not via the stub lib) — see
191+
# _tensor_bridge.pyx.
189192
_aoti_extra_link_args = []
190193
if sys.platform == "win32":
191194
_def_file = os.path.join("cuda", "core", "_include", "aoti_shim.def")

cuda_core/cuda/core/_include/aoti_shim.def

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,3 @@ EXPORTS
3434
aoti_torch_get_device_index
3535
aoti_torch_device_type_cpu
3636
aoti_torch_device_type_cuda
37-
aoti_torch_get_current_cuda_stream

cuda_core/cuda/core/_include/aoti_shim.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,13 @@ typedef struct AtenTensorOpaque* AtenTensorHandle;
5252

5353
/*
5454
* IMPORTANT: Keep the AOTI_SHIM_API declaration list below in sync with
55-
* aoti_shim.def. On Windows, build_hooks.py turns that .def file into the
55+
* aoti_shim.def. On Windows, build_hooks.py turns that .def file into the
5656
* stub import library that MSVC needs to link _tensor_bridge without making
57-
* PyTorch a build-time dependency. If you add, remove, or rename an imported
58-
* AOTI symbol here, update aoti_shim.def in the same change.
57+
* PyTorch a build-time dependency. If you add, remove, or rename an
58+
* imported AOTI symbol here, update aoti_shim.def in the same change.
59+
*
60+
* Exception: aoti_torch_get_current_cuda_stream lives in torch_cuda (not
61+
* torch_cpu) and is resolved lazily at runtime — see _tensor_bridge.pyx.
5962
*/
6063

6164
/* ---- tensor metadata --------------------------------------------------- */
@@ -105,10 +108,11 @@ AOTI_SHIM_API AOTITorchError aoti_torch_get_device_index(
105108
AOTI_SHIM_API int32_t aoti_torch_device_type_cpu(void);
106109
AOTI_SHIM_API int32_t aoti_torch_device_type_cuda(void);
107110

108-
/* ---- stream -------------------------------------------------------------- */
109-
110-
AOTI_SHIM_API AOTITorchError aoti_torch_get_current_cuda_stream(
111-
int32_t device_index, void** ret_stream);
111+
/* ---- stream --------------------------------------------------------------
112+
* aoti_torch_get_current_cuda_stream is NOT declared here — it lives in
113+
* torch_cuda (not torch_cpu) and is resolved at runtime. See the inline
114+
* C helper _resolve_cuda_stream_fn() in _tensor_bridge.pyx.
115+
* ---------------------------------------------------------------------- */
112116

113117
#ifdef __cplusplus
114118
} /* extern "C" */

cuda_core/cuda/core/_tensor_bridge.pyx

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,38 @@ cdef extern from "_include/aoti_shim.h":
103103
int32_t aoti_torch_device_type_cpu()
104104
int32_t aoti_torch_device_type_cuda()
105105

106-
# stream
107-
AOTITorchError aoti_torch_get_current_cuda_stream(int32_t, void**)
106+
# Note: aoti_torch_get_current_cuda_stream is NOT declared here because
107+
# it lives in torch_cuda.dll (not torch_cpu.dll). It is resolved lazily
108+
# at runtime via dlsym / GetProcAddress — see _resolve_cuda_stream_fn().
109+
110+
# Runtime resolution for aoti_torch_get_current_cuda_stream.
111+
# This symbol lives in torch_cuda.dll (Windows) / libtorch_cuda.so (Linux),
112+
# NOT in torch_cpu. We resolve it lazily on first use so that the module
113+
# can be imported even with CPU-only PyTorch.
114+
ctypedef AOTITorchError (*_get_cuda_stream_fn_t)(int32_t, void**) nogil
115+
116+
cdef extern from *:
117+
"""
118+
#ifdef _WIN32
119+
#include <windows.h>
120+
static void* _resolve_cuda_stream_fn(void) {
121+
HMODULE h = LoadLibraryA("torch_cuda.dll");
122+
if (!h) return NULL;
123+
return (void*)GetProcAddress(h, "aoti_torch_get_current_cuda_stream");
124+
}
125+
#else
126+
#include <dlfcn.h>
127+
#ifndef RTLD_DEFAULT
128+
#define RTLD_DEFAULT ((void*)0)
129+
#endif
130+
static void* _resolve_cuda_stream_fn(void) {
131+
return dlsym(RTLD_DEFAULT, "aoti_torch_get_current_cuda_stream");
132+
}
133+
#endif
134+
"""
135+
void* _resolve_cuda_stream_fn() nogil
136+
137+
cdef _get_cuda_stream_fn_t _cached_get_cuda_stream = NULL
108138

109139
import numpy
110140
import sys
@@ -274,10 +304,17 @@ cpdef int sync_torch_stream(int32_t device_index,
274304
the consumer stream wait on it. This is a no-op if both streams are
275305
the same.
276306
"""
307+
global _cached_get_cuda_stream
277308
cdef void* producer_s
278309
cdef EventHandle h_event
279310

280-
check_aoti(aoti_torch_get_current_cuda_stream(device_index, &producer_s),
311+
if _cached_get_cuda_stream == NULL:
312+
_cached_get_cuda_stream = <_get_cuda_stream_fn_t>_resolve_cuda_stream_fn()
313+
if _cached_get_cuda_stream == NULL:
314+
raise RuntimeError(
315+
"Cannot resolve aoti_torch_get_current_cuda_stream from "
316+
"torch_cuda — is CUDA-enabled PyTorch installed?")
317+
check_aoti(_cached_get_cuda_stream(device_index, &producer_s),
281318
b"aoti_torch_get_current_cuda_stream")
282319
if <intptr_t>producer_s != consumer_s:
283320
with nogil:

0 commit comments

Comments
 (0)