Skip to content

Commit 22e72df

Browse files
leofangclaude
andcommitted
Fix tensor bridge DLL import failure on Windows
aoti_torch_get_current_cuda_stream lives in torch_cuda.dll, not torch_cpu.dll. The stub import library pointed at the wrong DLL, causing "The specified procedure could not be found" on Windows. - Move aoti_torch_get_current_cuda_stream from aoti_shim.def (torch_cpu.dll) to new aoti_shim_cuda.def (torch_cuda.dll) - Update build_hooks.py to generate stub libs for both DLLs via a loop - Add torch_cuda.dll to delvewheel exclude list Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6d4e82d commit 22e72df

4 files changed

Lines changed: 19 additions & 12 deletions

File tree

cuda_core/build_hooks.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -183,19 +183,22 @@ def get_sources(mod_name):
183183
# related to free-threading builds.
184184
extra_compile_args += ["-DCYTHON_TRACE_NOGIL=1", "-DCYTHON_USE_SYS_MONITORING=0"]
185185

186-
# On Windows, _tensor_bridge.pyx needs a stub import library so the MSVC
187-
# linker can resolve the AOTI symbols (they live in torch_cpu.dll at
188-
# runtime). We generate the .lib from a .def file at build time.
186+
# On Windows, _tensor_bridge.pyx needs stub import libraries so the MSVC
187+
# linker can resolve the AOTI symbols at link time. At runtime the symbols
188+
# resolve from the actual DLLs loaded by 'import torch'.
189+
# - aoti_shim.def -> torch_cpu.dll (dtype, device, tensor metadata)
190+
# - aoti_shim_cuda.def -> torch_cuda.dll (CUDA stream access)
189191
_aoti_extra_link_args = []
190192
if sys.platform == "win32":
191-
_def_file = os.path.join("cuda", "core", "_include", "aoti_shim.def")
192-
_lib_file = os.path.join("build", "aoti_shim.lib")
193193
os.makedirs("build", exist_ok=True)
194-
subprocess.check_call( # noqa: S603
195-
["lib", f"/DEF:{_def_file}", f"/OUT:{_lib_file}", "/MACHINE:X64"], # noqa: S607
196-
stdout=subprocess.DEVNULL,
197-
)
198-
_aoti_extra_link_args = [_lib_file]
194+
for def_name in ("aoti_shim", "aoti_shim_cuda"):
195+
def_file = os.path.join("cuda", "core", "_include", f"{def_name}.def")
196+
lib_file = os.path.join("build", f"{def_name}.lib")
197+
subprocess.check_call( # noqa: S603
198+
["lib", f"/DEF:{def_file}", f"/OUT:{lib_file}", "/MACHINE:X64"], # noqa: S607
199+
stdout=subprocess.DEVNULL,
200+
)
201+
_aoti_extra_link_args.append(lib_file)
199202

200203
def get_extra_link_args(mod_name):
201204
if mod_name == "_tensor_bridge" and _aoti_extra_link_args:

cuda_core/cuda/core/_include/aoti_shim.def

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,3 @@ EXPORTS
3434
aoti_torch_get_device_index
3535
aoti_torch_device_type_cpu
3636
aoti_torch_device_type_cuda
37-
aoti_torch_get_current_cuda_stream
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
; Stub import library for CUDA-specific AOTI symbols (torch_cuda.dll).
2+
; See aoti_shim.def for the torch_cpu.dll counterpart.
3+
LIBRARY torch_cuda.dll
4+
EXPORTS
5+
aoti_torch_get_current_cuda_stream

cuda_core/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,4 @@ archs = "native"
117117
[tool.cibuildwheel.windows]
118118
archs = "AMD64"
119119
before-build = "pip install delvewheel"
120-
repair-wheel-command = "delvewheel repair --namespace-pkg cuda --exclude \"torch_cpu.dll;torch_python.dll\" -w {dest_dir} {wheel}"
120+
repair-wheel-command = "delvewheel repair --namespace-pkg cuda --exclude \"torch_cpu.dll;torch_python.dll;torch_cuda.dll\" -w {dest_dir} {wheel}"

0 commit comments

Comments
 (0)