Skip to content

Commit a140f5b

Browse files
committed
Generalize load_nvidia_dynamic_library.py to also work under Windows.
1 parent e4a4849 commit a140f5b

3 files changed

Lines changed: 24 additions & 48 deletions

File tree

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#
55
# This code was automatically generated across versions from 12.0.1 to 12.6.2. Do not modify it directly.
66

7-
from libc.stdint cimport intptr_t
7+
from libc.stdint cimport intptr_t, uintptr_t
88

99
from .utils import FunctionNotFoundError, NotSupportedError
1010

@@ -14,7 +14,6 @@ import os
1414
import site
1515

1616
import win32api
17-
import pywintypes
1817

1918

2019
###############################################################################
@@ -44,25 +43,8 @@ cdef void* __nvJitLinkVersion = NULL
4443

4544

4645
cdef load_library(const int driver_ver):
47-
cdef str dll_path = path_finder.find_nvidia_dynamic_library("nvJitLink")
48-
cdef str dll_name = os.path.basename(dll_path)
49-
cdef intptr_t handle = 0
50-
51-
# Check if already loaded
52-
try:
53-
handle = win32api.GetModuleHandle(dll_name)
54-
except pywintypes.error:
55-
pass
56-
else:
57-
return handle
58-
59-
# Not already loaded; load it
60-
try:
61-
handle = win32api.LoadLibrary(dll_path)
62-
except pywintypes.error as e:
63-
raise RuntimeError(f"Failed to load NVVM DLL at {dll_path}: {e}")
64-
65-
return handle
46+
cdef uintptr_t handle = path_finder.load_nvidia_dynamic_library("nvJitLink")
47+
return <void*>handle
6648

6749

6850
cdef int _check_or_init_nvjitlink() except -1 nogil:

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#
55
# This code was automatically generated across versions from 11.0.3 to 12.8.0. Do not modify it directly.
66

7-
from libc.stdint cimport intptr_t
7+
from libc.stdint cimport intptr_t, uintptr_t
88

99
from .utils import FunctionNotFoundError, NotSupportedError
1010

@@ -14,7 +14,6 @@ import os
1414
import site
1515

1616
import win32api
17-
import pywintypes
1817

1918

2019
###############################################################################
@@ -42,25 +41,8 @@ cdef void* __nvvmGetProgramLog = NULL
4241

4342

4443
cdef load_library(const int driver_ver):
45-
cdef str dll_path = path_finder.find_nvidia_dynamic_library("nvvm")
46-
cdef str dll_name = os.path.basename(dll_path)
47-
cdef intptr_t handle = 0
48-
49-
# Check if already loaded
50-
try:
51-
handle = win32api.GetModuleHandle(dll_name)
52-
except pywintypes.error:
53-
pass
54-
else:
55-
return handle
56-
57-
# Not already loaded; load it
58-
try:
59-
handle = win32api.LoadLibrary(dll_path)
60-
except pywintypes.error as e:
61-
raise RuntimeError(f"Failed to load NVVM DLL at {dll_path}: {e}")
62-
63-
return handle
44+
cdef uintptr_t handle = path_finder.load_nvidia_dynamic_library("nvvm")
45+
return <void*>handle
6446

6547

6648
cdef int _check_or_init_nvvm() except -1 nogil:
Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
11
import ctypes
2+
import functools
23
import os
4+
import sys
35

46
from .find_nvidia_dynamic_library import find_nvidia_dynamic_library
57

68

9+
@functools.cache
710
def load_nvidia_dynamic_library(name: str) -> int:
8-
path = find_nvidia_dynamic_library(name)
9-
try:
10-
handle = ctypes.CDLL(path, mode=os.RTLD_NOW | os.RTLD_GLOBAL)
11-
return handle._handle # This is the actual `void*` value as an int
12-
except OSError as e:
13-
raise RuntimeError(f"Failed to dlopen {path}: {e}") from e
11+
dl_path = find_nvidia_dynamic_library(name)
12+
if sys.platform == "win32":
13+
try:
14+
handle = ctypes.windll.kernel32.LoadLibraryW(dl_path)
15+
if not handle:
16+
raise ctypes.WinError(ctypes.get_last_error())
17+
except Exception as e:
18+
raise RuntimeError(f"Failed to load DLL at {dl_path}: {e}") from e
19+
return handle
20+
else:
21+
try:
22+
handle = ctypes.CDLL(dl_path, mode=os.RTLD_NOW | os.RTLD_GLOBAL)
23+
return handle._handle # Raw void* as int
24+
except OSError as e:
25+
raise RuntimeError(f"Failed to dlopen {dl_path}: {e}") from e

0 commit comments

Comments
 (0)