Skip to content

Commit 0ea73b1

Browse files
committed
Rewrite nvjitlink_linux.pyx load_library() to produce detailed error messages.
1 parent f8b3dd5 commit 0ea73b1

1 file changed

Lines changed: 24 additions & 13 deletions

File tree

cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,31 @@ cdef void* load_library(const int driver_ver) except* with gil:
6565
cdef void* handle = NULL;
6666
paths = path_finder.get_cuda_paths()
6767
paths_cudalib_dir = paths["cudalib_dir"]
68-
if (paths_cudalib_dir and
69-
paths_cudalib_dir.info and
70-
os.path.isdir(paths_cudalib_dir.info)):
71-
# TODO(rwgk): Produce the correct so_name in path_finder.py
72-
so_name = os.path.join(paths_cudalib_dir.info, so_basename)
73-
if not os.path.exists(so_name) and so_name.count("/lib64/") == 1:
74-
so_name = so_name.replace("/lib64/", "/lib/")
75-
if os.path.exists(so_name):
68+
if not paths_cudalib_dir:
69+
raise RuntimeError("Failure obtaining paths_cudalib_dir")
70+
if not paths_cudalib_dir.info:
71+
raise RuntimeError("Failure obtaining paths_cudalib_dir.info")
72+
candidate_so_dirs = [paths_cudalib_dir.info]
73+
libs = ["/lib/", "/lib64/"]
74+
for _ in range(2):
75+
alt_dir = libs[0].join(paths_cudalib_dir.info.rsplit(libs[1], 1))
76+
if alt_dir not in candidate_so_dirs:
77+
candidate_so_dirs.append(alt_dir)
78+
libs.reverse()
79+
candidate_so_names = [
80+
os.path.join(so_dirname, so_basename)
81+
for so_dirname in candidate_so_dirs]
82+
error_messages = []
83+
for so_name in candidate_so_names:
84+
if not os.path.exists(so_name):
85+
error_messages.append(f"No such file: {so_name}")
86+
else:
7687
handle = dlopen(so_name.encode(), RTLD_NOW | RTLD_GLOBAL)
77-
if handle == NULL:
78-
err_msg = dlerror()
79-
raise RuntimeError(f'Failed to dlopen {so_name} ({err_msg.decode()})')
80-
return handle
81-
raise RuntimeError(f'Unable to locate {so_basename}')
88+
if handle != NULL:
89+
return handle
90+
err_msg = dlerror().decode(errors="backslashreplace")
91+
error_messages.append(f"Failed to dlopen {so_name}: {err_msg}")
92+
raise RuntimeError(f"Unable to load {so_basename}: {', '.join(error_messages)}")
8293

8394

8495
cdef int _check_or_init_nvjitlink() except -1 nogil:

0 commit comments

Comments
 (0)