Skip to content

Commit b1d011c

Browse files
authored
Handle both Windows conda static-lib layouts. (#2015)
Look for cudadevrt under both Library/lib/x64 and Library/lib so CUDA 12 conda environments resolve the real static library instead of falling through to a misleading CUDA_PATH error. Made-with: Cursor
1 parent 98df790 commit b1d011c

2 files changed

Lines changed: 34 additions & 7 deletions

File tree

cuda_pathfinder/cuda/pathfinder/_static_libs/find_static_lib.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ class LocatedStaticLib:
2828
class _StaticLibInfo(TypedDict):
2929
filename: str
3030
ctk_rel_paths: tuple[str, ...]
31-
conda_rel_path: str
31+
conda_rel_paths: tuple[str, ...]
3232
site_packages_dirs: tuple[str, ...]
3333

3434

3535
_SUPPORTED_STATIC_LIBS_INFO: dict[str, _StaticLibInfo] = {
3636
"cudadevrt": {
3737
"filename": "cudadevrt.lib" if IS_WINDOWS else "libcudadevrt.a",
3838
"ctk_rel_paths": (os.path.join("lib", "x64"),) if IS_WINDOWS else ("lib64", "lib"),
39-
"conda_rel_path": os.path.join("lib", "x64") if IS_WINDOWS else "lib",
39+
"conda_rel_paths": ((os.path.join("lib", "x64"), "lib") if IS_WINDOWS else ("lib",)),
4040
"site_packages_dirs": (
4141
("nvidia/cu13/lib/x64", "nvidia/cuda_runtime/lib/x64")
4242
if IS_WINDOWS
@@ -66,7 +66,7 @@ def __init__(self, name: str) -> None:
6666
self.config: _StaticLibInfo = _SUPPORTED_STATIC_LIBS_INFO[name]
6767
self.filename: str = self.config["filename"]
6868
self.ctk_rel_paths: tuple[str, ...] = self.config["ctk_rel_paths"]
69-
self.conda_rel_path: str = self.config["conda_rel_path"]
69+
self.conda_rel_paths: tuple[str, ...] = self.config["conda_rel_paths"]
7070
self.site_packages_dirs: tuple[str, ...] = self.config["site_packages_dirs"]
7171
self.error_messages: list[str] = []
7272
self.attachments: list[str] = []
@@ -86,9 +86,10 @@ def try_with_conda_prefix(self) -> str | None:
8686
return None
8787

8888
anchor = os.path.join(conda_prefix, "Library") if IS_WINDOWS else conda_prefix
89-
file_path = os.path.join(anchor, self.conda_rel_path, self.filename)
90-
if os.path.isfile(file_path):
91-
return file_path
89+
for rel_path in self.conda_rel_paths:
90+
file_path = os.path.join(anchor, rel_path, self.filename)
91+
if os.path.isfile(file_path):
92+
return file_path
9293
return None
9394

9495
def try_with_cuda_home(self) -> str | None:

cuda_pathfinder/tests/test_find_static_lib.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_locate_static_lib(info_summary_append, libname):
7878
@pytest.mark.usefixtures("clear_find_static_lib_cache")
7979
def test_locate_static_lib_search_order(monkeypatch, tmp_path):
8080
filename = CUDADEVRT_INFO["filename"]
81-
conda_rel_path = CUDADEVRT_INFO["conda_rel_path"]
81+
conda_rel_path = CUDADEVRT_INFO["conda_rel_paths"][0]
8282

8383
site_pkg_rel = CUDADEVRT_INFO["site_packages_dirs"][0]
8484
site_packages_lib_dir = tmp_path / "site-packages" / Path(site_pkg_rel.replace("/", os.sep))
@@ -117,6 +117,32 @@ def test_locate_static_lib_search_order(monkeypatch, tmp_path):
117117
assert located_lib.found_via == "CUDA_PATH"
118118

119119

120+
@pytest.mark.usefixtures("clear_find_static_lib_cache")
121+
def test_locate_static_lib_conda_rel_path_fallback(monkeypatch, tmp_path):
122+
filename = CUDADEVRT_INFO["filename"]
123+
conda_rel_paths = CUDADEVRT_INFO["conda_rel_paths"]
124+
if len(conda_rel_paths) == 1:
125+
monkeypatch.setitem(CUDADEVRT_INFO, "conda_rel_paths", ("missing-first", conda_rel_paths[0]))
126+
conda_rel_paths = CUDADEVRT_INFO["conda_rel_paths"]
127+
128+
conda_prefix = tmp_path / "conda-prefix"
129+
conda_lib_dir = _conda_anchor(conda_prefix) / Path(conda_rel_paths[1])
130+
conda_path = _make_static_lib_file(conda_lib_dir, filename)
131+
132+
monkeypatch.setattr(
133+
find_static_lib_module,
134+
"find_sub_dirs_all_sitepackages",
135+
lambda _sub_dir: [],
136+
)
137+
monkeypatch.setenv("CONDA_PREFIX", str(conda_prefix))
138+
monkeypatch.delenv("CUDA_HOME", raising=False)
139+
monkeypatch.delenv("CUDA_PATH", raising=False)
140+
141+
located_lib = locate_static_lib("cudadevrt")
142+
assert located_lib.abs_path == conda_path
143+
assert located_lib.found_via == "conda"
144+
145+
120146
@pytest.mark.usefixtures("clear_find_static_lib_cache")
121147
def test_find_static_lib_not_found_error_includes_cuda_home_directory_listing(monkeypatch, tmp_path):
122148
filename = CUDADEVRT_INFO["filename"]

0 commit comments

Comments
 (0)