Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cuda_core/build_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ def get_sources(mod_name):
# On Windows, _tensor_bridge.pyx needs a stub import library so the MSVC
# linker can resolve the AOTI symbols (they live in torch_cpu.dll at
# runtime). We generate the .lib from a .def file at build time.
# Note: aoti_torch_get_current_cuda_stream lives in torch_cuda.dll and
# is resolved lazily at runtime (not via the stub lib) — see
# _tensor_bridge.pyx.
_aoti_extra_link_args = []
if sys.platform == "win32":
_def_file = os.path.join("cuda", "core", "_include", "aoti_shim.def")
Expand Down
1 change: 0 additions & 1 deletion cuda_core/cuda/core/_include/aoti_shim.def
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,3 @@ EXPORTS
aoti_torch_get_device_index
aoti_torch_device_type_cpu
aoti_torch_device_type_cuda
aoti_torch_get_current_cuda_stream
18 changes: 11 additions & 7 deletions cuda_core/cuda/core/_include/aoti_shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ typedef struct AtenTensorOpaque* AtenTensorHandle;

/*
* IMPORTANT: Keep the AOTI_SHIM_API declaration list below in sync with
* aoti_shim.def. On Windows, build_hooks.py turns that .def file into the
* aoti_shim.def. On Windows, build_hooks.py turns that .def file into the
* stub import library that MSVC needs to link _tensor_bridge without making
* PyTorch a build-time dependency. If you add, remove, or rename an imported
* AOTI symbol here, update aoti_shim.def in the same change.
* PyTorch a build-time dependency. If you add, remove, or rename an
* imported AOTI symbol here, update aoti_shim.def in the same change.
*
* Exception: aoti_torch_get_current_cuda_stream lives in torch_cuda (not
* torch_cpu) and is resolved lazily at runtime — see _tensor_bridge.pyx.
*/

/* ---- tensor metadata --------------------------------------------------- */
Expand Down Expand Up @@ -105,10 +108,11 @@ AOTI_SHIM_API AOTITorchError aoti_torch_get_device_index(
AOTI_SHIM_API int32_t aoti_torch_device_type_cpu(void);
AOTI_SHIM_API int32_t aoti_torch_device_type_cuda(void);

/* ---- stream -------------------------------------------------------------- */

AOTI_SHIM_API AOTITorchError aoti_torch_get_current_cuda_stream(
int32_t device_index, void** ret_stream);
/* ---- stream --------------------------------------------------------------
* aoti_torch_get_current_cuda_stream is NOT declared here — it lives in
* torch_cuda (not torch_cpu) and is resolved at runtime. See the inline
* C helper _resolve_cuda_stream_fn() in _tensor_bridge.pyx.
* ---------------------------------------------------------------------- */

#ifdef __cplusplus
} /* extern "C" */
Expand Down
43 changes: 40 additions & 3 deletions cuda_core/cuda/core/_tensor_bridge.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,38 @@ cdef extern from "_include/aoti_shim.h":
int32_t aoti_torch_device_type_cpu()
int32_t aoti_torch_device_type_cuda()

# stream
AOTITorchError aoti_torch_get_current_cuda_stream(int32_t, void**)
# Note: aoti_torch_get_current_cuda_stream is NOT declared here because
# it lives in torch_cuda.dll (not torch_cpu.dll). It is resolved lazily
# at runtime via dlsym / GetProcAddress — see _resolve_cuda_stream_fn().

# Runtime resolution for aoti_torch_get_current_cuda_stream.
# This symbol lives in torch_cuda.dll (Windows) / libtorch_cuda.so (Linux),
# NOT in torch_cpu. We resolve it lazily on first use so that the module
# can be imported even with CPU-only PyTorch.
ctypedef AOTITorchError (*_get_cuda_stream_fn_t)(int32_t, void**) nogil

cdef extern from *:
"""
#ifdef _WIN32
#include <windows.h>
static void* _resolve_cuda_stream_fn(void) {
HMODULE h = LoadLibraryA("torch_cuda.dll");
if (!h) return NULL;
return (void*)GetProcAddress(h, "aoti_torch_get_current_cuda_stream");
}
#else
#include <dlfcn.h>
#ifndef RTLD_DEFAULT
#define RTLD_DEFAULT ((void*)0)
#endif
static void* _resolve_cuda_stream_fn(void) {
return dlsym(RTLD_DEFAULT, "aoti_torch_get_current_cuda_stream");
}
#endif
"""
void* _resolve_cuda_stream_fn() nogil

cdef _get_cuda_stream_fn_t _cached_get_cuda_stream = NULL

import numpy
import sys
Expand Down Expand Up @@ -274,10 +304,17 @@ cpdef int sync_torch_stream(int32_t device_index,
the consumer stream wait on it. This is a no-op if both streams are
the same.
"""
global _cached_get_cuda_stream
cdef void* producer_s
cdef EventHandle h_event

check_aoti(aoti_torch_get_current_cuda_stream(device_index, &producer_s),
if _cached_get_cuda_stream == NULL:
_cached_get_cuda_stream = <_get_cuda_stream_fn_t>_resolve_cuda_stream_fn()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have "freethreading_compatible": True in cuda_core/build_hooks.py, therefore this can be a C-level data race.

Could we leave it as-is?

This is what Cursor GPT-5.4 Extra High Fast said when I asked about "worst-case consequences":


If both threads ultimately resolve the same correct function address, there are two answers.

The formal C answer is still: undefined behavior.

C does not say "same value is okay." Unsynchronized concurrent read/write or write/write to the same object is a data race regardless of whether the bits happen to match. Once you have that, the compiler is allowed to assume the race does not exist.

For this exact snippet in cuda_core/cuda/core/_tensor_bridge.pyx, the practical risk is probably much smaller than the formal one:

  • most likely, both threads just call _resolve_cuda_stream_fn(), both get the same pointer, and both store it;
  • on Windows that may mean an extra LoadLibraryA("torch_cuda.dll") refcount bump;
  • after that, things keep working.

The worst credible consequences are:

  • a thread observes a stale or torn value and calls through a bad pointer or NULL, leading to a crash;
  • compiler optimization based on the "no races exist" assumption produces behavior that is not obviously tied to the source;
  • a thread sees the uninitialized state longer than expected and does duplicate initialization or raises spuriously.

That said, in this particular pattern, on mainstream 64-bit Linux/Windows builds, I would rate the real-world risk as "probably benign most of the time," not "likely to explode." It is still a real bug for a module that claims free-threading compatibility, but it is more of a correctness-contract violation / sanitizer-finding than an obvious imminent crash bug.


I also asked Cursor for a fix, this is what it generated (untested):

diff --git a/cuda_core/cuda/core/_tensor_bridge.pyx b/cuda_core/cuda/core/_tensor_bridge.pyx
index 07eec56537b..THREADSAFE000 100644
--- a/cuda_core/cuda/core/_tensor_bridge.pyx
+++ b/cuda_core/cuda/core/_tensor_bridge.pyx
@@
 ctypedef AOTITorchError (*_get_cuda_stream_fn_t)(int32_t, void**) nogil
 
 cdef extern from *:
     """
     #ifdef _WIN32
     #include <windows.h>
-    static void* _resolve_cuda_stream_fn(void) {
-        HMODULE h = LoadLibraryA("torch_cuda.dll");
-        if (!h) return NULL;
-        return (void*)GetProcAddress(h, "aoti_torch_get_current_cuda_stream");
-    }
+    static INIT_ONCE _cuda_stream_init_once = INIT_ONCE_STATIC_INIT;
+    static void* _cached_cuda_stream_fn = NULL;
+
+    static BOOL CALLBACK _init_cuda_stream_fn(
+            PINIT_ONCE init_once, PVOID param, PVOID* context) {
+        HMODULE h = LoadLibraryA("torch_cuda.dll");
+        if (h) {
+            _cached_cuda_stream_fn =
+                (void*)GetProcAddress(h, "aoti_torch_get_current_cuda_stream");
+        }
+        return TRUE;
+    }
+
+    static void* _resolve_cuda_stream_fn(void) {
+        InitOnceExecuteOnce(&_cuda_stream_init_once, _init_cuda_stream_fn, NULL, NULL);
+        return _cached_cuda_stream_fn;
+    }
     #else
     #include <dlfcn.h>
+    #include <pthread.h>
     #ifndef RTLD_DEFAULT
     #define RTLD_DEFAULT ((void*)0)
     #endif
+    static pthread_once_t _cuda_stream_once = PTHREAD_ONCE_INIT;
+    static void* _cached_cuda_stream_fn = NULL;
+
+    static void _init_cuda_stream_fn(void) {
+        _cached_cuda_stream_fn =
+            dlsym(RTLD_DEFAULT, "aoti_torch_get_current_cuda_stream");
+    }
+
     static void* _resolve_cuda_stream_fn(void) {
-        return dlsym(RTLD_DEFAULT, "aoti_torch_get_current_cuda_stream");
+        pthread_once(&_cuda_stream_once, _init_cuda_stream_fn);
+        return _cached_cuda_stream_fn;
     }
     #endif
     """
     void* _resolve_cuda_stream_fn() nogil
-
-cdef _get_cuda_stream_fn_t _cached_get_cuda_stream = NULL
 
 @@
 cpdef int sync_torch_stream(int32_t device_index,
                             intptr_t consumer_s) except? -1:
 @@
-    global _cached_get_cuda_stream
     cdef void* producer_s
     cdef EventHandle h_event
+    cdef _get_cuda_stream_fn_t get_cuda_stream
 
-    if _cached_get_cuda_stream == NULL:
-        _cached_get_cuda_stream = <_get_cuda_stream_fn_t>_resolve_cuda_stream_fn()
-        if _cached_get_cuda_stream == NULL:
-            raise RuntimeError(
-                "Cannot resolve aoti_torch_get_current_cuda_stream from "
-                "torch_cuda — is CUDA-enabled PyTorch installed?")
-    check_aoti(_cached_get_cuda_stream(device_index, &producer_s),
+    get_cuda_stream = <_get_cuda_stream_fn_t>_resolve_cuda_stream_fn()
+    if get_cuda_stream == NULL:
+        raise RuntimeError(
+            "Cannot resolve aoti_torch_get_current_cuda_stream from "
+            "torch_cuda — is CUDA-enabled PyTorch installed?")
+    check_aoti(get_cuda_stream(device_index, &producer_s),
                b"aoti_torch_get_current_cuda_stream")

if _cached_get_cuda_stream == NULL:
raise RuntimeError(
"Cannot resolve aoti_torch_get_current_cuda_stream from "
"torch_cuda — is CUDA-enabled PyTorch installed?")
check_aoti(_cached_get_cuda_stream(device_index, &producer_s),
b"aoti_torch_get_current_cuda_stream")
if <intptr_t>producer_s != consumer_s:
with nogil:
Expand Down
Loading