-
Notifications
You must be signed in to change notification settings - Fork 283
Expand file tree
/
Copy pathfind_static_lib.py
More file actions
167 lines (134 loc) · 5.7 KB
/
find_static_lib.py
File metadata and controls
167 lines (134 loc) · 5.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import functools
import os
from dataclasses import dataclass
from typing import NoReturn, TypedDict
from cuda.pathfinder._utils.env_vars import get_cuda_path_or_home
from cuda.pathfinder._utils.find_sub_dirs import find_sub_dirs_all_sitepackages
from cuda.pathfinder._utils.platform_aware import IS_WINDOWS
class StaticLibNotFoundError(RuntimeError):
"""Raised when a static library cannot be found."""
@dataclass(frozen=True)
class LocatedStaticLib:
"""Information about a located static library."""
name: str
abs_path: str
filename: str
found_via: str
class _StaticLibInfo(TypedDict):
filename: str
ctk_rel_paths: tuple[str, ...]
conda_rel_paths: tuple[str, ...]
site_packages_dirs: tuple[str, ...]
_SUPPORTED_STATIC_LIBS_INFO: dict[str, _StaticLibInfo] = {
"cudadevrt": {
"filename": "cudadevrt.lib" if IS_WINDOWS else "libcudadevrt.a",
"ctk_rel_paths": (os.path.join("lib", "x64"),) if IS_WINDOWS else ("lib64", "lib"),
"conda_rel_paths": ((os.path.join("lib", "x64"), "lib") if IS_WINDOWS else ("lib",)),
"site_packages_dirs": (
("nvidia/cu13/lib/x64", "nvidia/cuda_runtime/lib/x64")
if IS_WINDOWS
else ("nvidia/cu13/lib", "nvidia/cuda_runtime/lib")
),
},
}
SUPPORTED_STATIC_LIBS: tuple[str, ...] = tuple(sorted(_SUPPORTED_STATIC_LIBS_INFO.keys()))
def _no_such_file_in_dir(dir_path: str, filename: str, error_messages: list[str], attachments: list[str]) -> None:
error_messages.append(f"No such file: {os.path.join(dir_path, filename)}")
if os.path.isdir(dir_path):
attachments.append(f' listdir("{dir_path}"):')
for node in sorted(os.listdir(dir_path)):
attachments.append(f" {node}")
else:
attachments.append(f' Directory does not exist: "{dir_path}"')
class _FindStaticLib:
def __init__(self, name: str) -> None:
if name not in _SUPPORTED_STATIC_LIBS_INFO:
raise ValueError(f"Unknown static library: '{name}'. Supported: {', '.join(SUPPORTED_STATIC_LIBS)}")
self.name: str = name
self.config: _StaticLibInfo = _SUPPORTED_STATIC_LIBS_INFO[name]
self.filename: str = self.config["filename"]
self.ctk_rel_paths: tuple[str, ...] = self.config["ctk_rel_paths"]
self.conda_rel_paths: tuple[str, ...] = self.config["conda_rel_paths"]
self.site_packages_dirs: tuple[str, ...] = self.config["site_packages_dirs"]
self.error_messages: list[str] = []
self.attachments: list[str] = []
def try_site_packages(self) -> str | None:
for rel_dir in self.site_packages_dirs:
sub_dir = tuple(rel_dir.split("/"))
for abs_dir in find_sub_dirs_all_sitepackages(sub_dir):
file_path = os.path.join(abs_dir, self.filename)
if os.path.isfile(file_path):
return file_path
return None
def try_with_conda_prefix(self) -> str | None:
conda_prefix = os.environ.get("CONDA_PREFIX")
if not conda_prefix:
return None
anchor = os.path.join(conda_prefix, "Library") if IS_WINDOWS else conda_prefix
for rel_path in self.conda_rel_paths:
file_path = os.path.join(anchor, rel_path, self.filename)
if os.path.isfile(file_path):
return file_path
return None
def try_with_cuda_home(self) -> str | None:
cuda_home = get_cuda_path_or_home()
if cuda_home is None:
self.error_messages.append("CUDA_HOME/CUDA_PATH not set")
return None
for rel_path in self.ctk_rel_paths:
file_path = os.path.join(cuda_home, rel_path, self.filename)
if os.path.isfile(file_path):
return file_path
_no_such_file_in_dir(
os.path.join(cuda_home, self.ctk_rel_paths[0]),
self.filename,
self.error_messages,
self.attachments,
)
return None
def raise_not_found_error(self) -> NoReturn:
err = ", ".join(self.error_messages) if self.error_messages else "No search paths available"
att = "\n".join(self.attachments) if self.attachments else ""
raise StaticLibNotFoundError(f'Failure finding "{self.filename}": {err}\n{att}')
def locate_static_lib(name: str) -> LocatedStaticLib:
"""Locate a static library by name.
Raises:
ValueError: If ``name`` is not a supported static library.
StaticLibNotFoundError: If the static library cannot be found.
"""
finder = _FindStaticLib(name)
abs_path = finder.try_site_packages()
if abs_path is not None:
return LocatedStaticLib(
name=name,
abs_path=abs_path,
filename=finder.filename,
found_via="site-packages",
)
abs_path = finder.try_with_conda_prefix()
if abs_path is not None:
return LocatedStaticLib(
name=name,
abs_path=abs_path,
filename=finder.filename,
found_via="conda",
)
abs_path = finder.try_with_cuda_home()
if abs_path is not None:
return LocatedStaticLib(
name=name,
abs_path=abs_path,
filename=finder.filename,
found_via="CUDA_PATH",
)
finder.raise_not_found_error()
@functools.cache
def find_static_lib(name: str) -> str:
"""Find the absolute path to a static library.
Raises:
ValueError: If ``name`` is not a supported static library.
StaticLibNotFoundError: If the static library cannot be found.
"""
return locate_static_lib(name).abs_path