Skip to content

Commit 02694c7

Browse files
committed
nvjitlink_linux.pyx load_library() enhancements, mainly to avoid os.path.join(None, "libnvJitLink.so")
1 parent 08c3041 commit 02694c7

1 file changed

Lines changed: 12 additions & 8 deletions

File tree

cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,24 @@ cdef void* __nvJitLinkVersion = NULL
5555

5656

5757
cdef void* load_library(const int driver_ver) except* with gil:
58+
so_basename = "libnvJitLink.so"
5859
cdef void* handle = NULL;
5960
paths = path_finder.get_cuda_paths()
6061
paths_cudalib_dir = paths["cudalib_dir"]
61-
if paths_cudalib_dir:
62+
if (paths_cudalib_dir and
63+
paths_cudalib_dir.info and
64+
os.path.isdir(paths_cudalib_dir.info)):
6265
# TODO(rwgk): Produce the correct so_name in path_finder.py
63-
so_name = os.path.join(paths_cudalib_dir.info, "libnvJitLink.so")
66+
so_name = os.path.join(paths_cudalib_dir.info, so_basename)
6467
if not os.path.exists(so_name) and so_name.count("/lib64/") == 1:
6568
so_name = so_name.replace("/lib64/", "/lib/")
66-
handle = dlopen(so_name.encode(), RTLD_NOW | RTLD_GLOBAL)
67-
if handle == NULL:
68-
err_msg = dlerror()
69-
raise RuntimeError(f'Failed to dlopen {so_name} ({err_msg.decode()})')
70-
return handle
71-
raise RuntimeError('Unable to locate libnvJitLink.so')
69+
if os.path.exists(so_name):
70+
handle = dlopen(so_name.encode(), RTLD_NOW | RTLD_GLOBAL)
71+
if handle == NULL:
72+
err_msg = dlerror()
73+
raise RuntimeError(f'Failed to dlopen {so_name} ({err_msg.decode()})')
74+
return handle
75+
raise RuntimeError('Unable to locate {so_basename}')
7276

7377

7478
cdef int _check_or_init_nvjitlink() except -1 nogil:

0 commit comments

Comments
 (0)