forked from NVIDIA/cuda-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfind_nvidia_headers.py
More file actions
150 lines (112 loc) · 5.12 KB
/
find_nvidia_headers.py
File metadata and controls
150 lines (112 loc) · 5.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import functools
import glob
import os
from typing import Optional
from cuda.pathfinder._headers import supported_nvidia_headers
from cuda.pathfinder._utils.env_vars import get_cuda_home_or_path
from cuda.pathfinder._utils.find_sub_dirs import find_sub_dirs_all_sitepackages
from cuda.pathfinder._utils.platform_aware import IS_WINDOWS
def _abs_norm(path: Optional[str]) -> Optional[str]:
if path:
return os.path.normpath(os.path.abspath(path))
return None
def _joined_isfile(dirpath: str, basename: str) -> bool:
return os.path.isfile(os.path.join(dirpath, basename))
def _find_nvshmem_header_directory() -> Optional[str]:
if IS_WINDOWS:
# nvshmem has no Windows support.
return None
# Installed from a wheel
nvidia_sub_dirs = ("nvidia", "nvshmem", "include")
hdr_dir: str # help mypy
for hdr_dir in find_sub_dirs_all_sitepackages(nvidia_sub_dirs):
if _joined_isfile(hdr_dir, "nvshmem.h"):
return hdr_dir
conda_prefix = os.environ.get("CONDA_PREFIX")
if conda_prefix and os.path.isdir(conda_prefix):
hdr_dir = os.path.join(conda_prefix, "include")
if _joined_isfile(hdr_dir, "nvshmem.h"):
return hdr_dir
for hdr_dir in sorted(glob.glob("/usr/include/nvshmem_*"), reverse=True):
if _joined_isfile(hdr_dir, "nvshmem.h"):
return hdr_dir
return None
def _find_based_on_ctk_layout(libname: str, h_basename: str, anchor_point: str) -> Optional[str]:
parts = [anchor_point]
if libname == "nvvm":
parts.append(libname)
parts.append("include")
idir = os.path.join(*parts)
if libname == "cccl":
cdir = os.path.join(idir, "cccl") # CTK 13
if _joined_isfile(cdir, h_basename):
return cdir
if _joined_isfile(idir, h_basename):
return idir
return None
def _find_based_on_conda_layout(libname: str, h_basename: str, conda_prefix: str) -> Optional[str]:
if IS_WINDOWS:
anchor_point = os.path.join(conda_prefix, "Library")
if not os.path.isdir(anchor_point):
return None
else:
targets_include_path = glob.glob(os.path.join(conda_prefix, "targets", "*", "include"))
if not targets_include_path:
return None
if len(targets_include_path) != 1:
# Conda does not support multiple architectures.
# QUESTION(PR#956): Do we want to issue a warning?
return None
anchor_point = os.path.dirname(targets_include_path[0])
return _find_based_on_ctk_layout(libname, h_basename, anchor_point)
def _find_ctk_header_directory(libname: str) -> Optional[str]:
h_basename = supported_nvidia_headers.SUPPORTED_HEADERS_CTK[libname]
candidate_dirs = supported_nvidia_headers.SUPPORTED_SITE_PACKAGE_HEADER_DIRS_CTK[libname]
# Installed from a wheel
for cdir in candidate_dirs:
hdr_dir: str # help mypy
for hdr_dir in find_sub_dirs_all_sitepackages(tuple(cdir.split("/"))):
if _joined_isfile(hdr_dir, h_basename):
return hdr_dir
conda_prefix = os.environ.get("CONDA_PREFIX")
if conda_prefix: # noqa: SIM102
if result := _find_based_on_conda_layout(libname, h_basename, conda_prefix):
return result
cuda_home = get_cuda_home_or_path()
if cuda_home: # noqa: SIM102
if result := _find_based_on_ctk_layout(libname, h_basename, cuda_home):
return result
return None
@functools.cache
def find_nvidia_header_directory(libname: str) -> Optional[str]:
"""Locate the header directory for a supported NVIDIA library.
Args:
libname (str): The short name of the library whose headers are needed
(e.g., ``"nvrtc"``, ``"cusolver"``, ``"nvshmem"``).
Returns:
str or None: Absolute path to the discovered header directory, or ``None``
if the headers cannot be found.
Raises:
RuntimeError: If ``libname`` is not in the supported set.
Search order:
1. **NVIDIA Python wheels**
- Scan installed distributions (``site-packages``) for header layouts
shipped in NVIDIA wheels (e.g., ``cuda-toolkit[nvrtc]``).
2. **Conda environments**
- Check Conda-style installation prefixes, which use platform-specific
include directory layouts.
3. **CUDA Toolkit environment variables**
- Use ``CUDA_HOME`` or ``CUDA_PATH`` (in that order).
Notes:
- The ``SUPPORTED_HEADERS_CTK`` dictionary maps each supported CUDA Toolkit
(CTK) library to the name of its canonical header (e.g., ``"cublas" →
"cublas.h"``). This is used to verify that the located directory is valid.
- The only supported non-CTK library at present is ``nvshmem``.
"""
if libname == "nvshmem":
return _abs_norm(_find_nvshmem_header_directory())
if libname in supported_nvidia_headers.SUPPORTED_HEADERS_CTK:
return _abs_norm(_find_ctk_header_directory(libname))
raise RuntimeError(f"UNKNOWN {libname=}")