Skip to content

Commit e4a4849

Browse files
committed
Factor out load_nvidia_dynamic_library() from _internal/nvjitlink_linux.pyx, nvvm_linux.pyx
1 parent 667d3ed commit e4a4849

4 files changed

Lines changed: 22 additions & 15 deletions

File tree

cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import os
88

9-
from libc.stdint cimport intptr_t
9+
from libc.stdint cimport intptr_t, uintptr_t
1010

1111
from .utils import FunctionNotFoundError, NotSupportedError
1212

@@ -54,13 +54,9 @@ cdef void* __nvJitLinkGetInfoLog = NULL
5454
cdef void* __nvJitLinkVersion = NULL
5555

5656

57-
cdef void* load_library(const int driver_ver) except* with gil:
58-
so_name = path_finder.find_nvidia_dynamic_library("nvJitLink")
59-
cdef void* handle = dlopen(so_name.encode(), RTLD_NOW | RTLD_GLOBAL)
60-
if handle != NULL:
61-
return handle
62-
err_msg = dlerror().decode(errors="backslashreplace")
63-
raise RuntimeError(f"Failed to dlopen {so_name}: {err_msg}")
57+
cdef void* load_library(int driver_ver) except* with gil:
58+
cdef uintptr_t handle = path_finder.load_nvidia_dynamic_library("nvJitLink")
59+
return <void*>handle
6460

6561

6662
cdef int _check_or_init_nvjitlink() except -1 nogil:

cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#
55
# This code was automatically generated across versions from 11.0.3 to 12.8.0. Do not modify it directly.
66

7-
from libc.stdint cimport intptr_t
7+
from libc.stdint cimport intptr_t, uintptr_t
88

99
from .utils import FunctionNotFoundError, NotSupportedError
1010

@@ -51,12 +51,8 @@ cdef void* __nvvmGetProgramLog = NULL
5151

5252

5353
cdef void* load_library(const int driver_ver) except* with gil:
54-
so_name = path_finder.find_nvidia_dynamic_library("nvvm")
55-
cdef void* handle = dlopen(so_name.encode(), RTLD_NOW | RTLD_GLOBAL)
56-
if handle != NULL:
57-
return handle
58-
err_msg = dlerror().decode(errors="backslashreplace")
59-
raise RuntimeError(f"Failed to dlopen {so_name}: {err_msg}")
54+
cdef uintptr_t handle = path_finder.load_nvidia_dynamic_library("nvvm")
55+
return <void*>handle
6056

6157

6258
cdef int _check_or_init_nvvm() except -1 nogil:
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import ctypes
2+
import os
3+
4+
from .find_nvidia_dynamic_library import find_nvidia_dynamic_library
5+
6+
7+
def load_nvidia_dynamic_library(name: str) -> int:
8+
path = find_nvidia_dynamic_library(name)
9+
try:
10+
handle = ctypes.CDLL(path, mode=os.RTLD_NOW | os.RTLD_GLOBAL)
11+
return handle._handle # This is the actual `void*` value as an int
12+
except OSError as e:
13+
raise RuntimeError(f"Failed to dlopen {path}: {e}") from e

cuda_bindings/cuda/bindings/path_finder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
get_system_ctk,
1818
)
1919
from cuda.bindings._path_finder_utils.find_nvidia_dynamic_library import find_nvidia_dynamic_library
20+
from cuda.bindings._path_finder_utils.load_nvidia_dynamic_library import load_nvidia_dynamic_library
2021

2122
__all__ = [
2223
"find_nvidia_dynamic_library",
24+
"load_nvidia_dynamic_library",
2325
"get_conda_ctk",
2426
"get_conda_include_dir",
2527
"get_cuda_home",

0 commit comments

Comments
 (0)