Skip to content

Commit 2cb0d68

Browse files
Introduce LocatedHeaderDir to report the header discovery method (#1536)
* initial * resolve a few tests * precommit * address reviews * test updates * address reviews * address reviews
1 parent c94557c commit 2cb0d68

3 files changed

Lines changed: 90 additions & 23 deletions

File tree

cuda_pathfinder/cuda/pathfinder/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import (
1010
SUPPORTED_LIBNAMES as SUPPORTED_NVIDIA_LIBNAMES, # noqa: F401
1111
)
12+
from cuda.pathfinder._headers.find_nvidia_headers import LocatedHeaderDir as LocatedHeaderDir
1213
from cuda.pathfinder._headers.find_nvidia_headers import find_nvidia_header_directory as find_nvidia_header_directory
14+
from cuda.pathfinder._headers.find_nvidia_headers import (
15+
locate_nvidia_header_directory as locate_nvidia_header_directory,
16+
)
1317
from cuda.pathfinder._headers.supported_nvidia_headers import SUPPORTED_HEADERS_CTK as _SUPPORTED_HEADERS_CTK
1418

1519
from cuda.pathfinder._version import __version__ # isort: skip # noqa: F401

cuda_pathfinder/cuda/pathfinder/_headers/find_nvidia_headers.py

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,28 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from __future__ import annotations
5+
46
import functools
57
import glob
68
import os
9+
from dataclasses import dataclass
710

811
from cuda.pathfinder._headers import supported_nvidia_headers
912
from cuda.pathfinder._utils.env_vars import get_cuda_home_or_path
1013
from cuda.pathfinder._utils.find_sub_dirs import find_sub_dirs_all_sitepackages
1114
from cuda.pathfinder._utils.platform_aware import IS_WINDOWS
1215

1316

17+
@dataclass
18+
class LocatedHeaderDir:
19+
abs_path: str | None
20+
found_via: str
21+
22+
def __post_init__(self) -> None:
23+
self.abs_path = _abs_norm(self.abs_path)
24+
25+
1426
def _abs_norm(path: str | None) -> str | None:
1527
if path:
1628
return os.path.normpath(os.path.abspath(path))
@@ -21,16 +33,16 @@ def _joined_isfile(dirpath: str, basename: str) -> bool:
2133
return os.path.isfile(os.path.join(dirpath, basename))
2234

2335

24-
def _find_under_site_packages(sub_dir: str, h_basename: str) -> str | None:
36+
def _locate_under_site_packages(sub_dir: str, h_basename: str) -> LocatedHeaderDir | None:
2537
# Installed from a wheel
2638
hdr_dir: str # help mypy
2739
for hdr_dir in find_sub_dirs_all_sitepackages(tuple(sub_dir.split("/"))):
2840
if _joined_isfile(hdr_dir, h_basename):
29-
return hdr_dir
41+
return LocatedHeaderDir(abs_path=hdr_dir, found_via="site-packages")
3042
return None
3143

3244

33-
def _find_based_on_ctk_layout(libname: str, h_basename: str, anchor_point: str) -> str | None:
45+
def _locate_based_on_ctk_layout(libname: str, h_basename: str, anchor_point: str) -> str | None:
3446
parts = [anchor_point]
3547
if libname == "nvvm":
3648
parts.append(libname)
@@ -52,7 +64,7 @@ def _find_based_on_ctk_layout(libname: str, h_basename: str, anchor_point: str)
5264
return None
5365

5466

55-
def _find_based_on_conda_layout(libname: str, h_basename: str, ctk_layout: bool) -> str | None:
67+
def _find_based_on_conda_layout(libname: str, h_basename: str, ctk_layout: bool) -> LocatedHeaderDir | None:
5668
conda_prefix = os.environ.get("CONDA_PREFIX")
5769
if not conda_prefix:
5870
return None
@@ -73,39 +85,43 @@ def _find_based_on_conda_layout(libname: str, h_basename: str, ctk_layout: bool)
7385
else:
7486
include_path = os.path.join(conda_prefix, "include")
7587
anchor_point = os.path.dirname(include_path)
76-
return _find_based_on_ctk_layout(libname, h_basename, anchor_point)
88+
found_header_path = _locate_based_on_ctk_layout(libname, h_basename, anchor_point)
89+
if found_header_path:
90+
return LocatedHeaderDir(abs_path=found_header_path, found_via="conda")
91+
return None
7792

7893

79-
def _find_ctk_header_directory(libname: str) -> str | None:
94+
def _find_ctk_header_directory(libname: str) -> LocatedHeaderDir | None:
8095
h_basename = supported_nvidia_headers.SUPPORTED_HEADERS_CTK[libname]
8196
candidate_dirs = supported_nvidia_headers.SUPPORTED_SITE_PACKAGE_HEADER_DIRS_CTK[libname]
8297

8398
for cdir in candidate_dirs:
84-
if hdr_dir := _find_under_site_packages(cdir, h_basename):
99+
if hdr_dir := _locate_under_site_packages(cdir, h_basename):
85100
return hdr_dir
86101

87102
if hdr_dir := _find_based_on_conda_layout(libname, h_basename, True):
88103
return hdr_dir
89104

90105
cuda_home = get_cuda_home_or_path()
91106
if cuda_home: # noqa: SIM102
92-
if result := _find_based_on_ctk_layout(libname, h_basename, cuda_home):
93-
return result
107+
if result := _locate_based_on_ctk_layout(libname, h_basename, cuda_home):
108+
return LocatedHeaderDir(abs_path=result, found_via="CUDA_HOME")
94109

95110
return None
96111

97112

98113
@functools.cache
99-
def find_nvidia_header_directory(libname: str) -> str | None:
114+
def locate_nvidia_header_directory(libname: str) -> LocatedHeaderDir | None:
100115
"""Locate the header directory for a supported NVIDIA library.
101116
102117
Args:
103118
libname (str): The short name of the library whose headers are needed
104119
(e.g., ``"nvrtc"``, ``"cusolver"``, ``"nvshmem"``).
105120
106121
Returns:
107-
str or None: Absolute path to the discovered header directory, or ``None``
108-
if the headers cannot be found.
122+
LocatedHeaderDir or None: A LocatedHeaderDir object containing the absolute path
123+
to the discovered header directory and information about where it was found,
124+
or ``None`` if the headers cannot be found.
109125
110126
Raises:
111127
RuntimeError: If ``libname`` is not in the supported set.
@@ -127,25 +143,59 @@ def find_nvidia_header_directory(libname: str) -> str | None:
127143
"""
128144

129145
if libname in supported_nvidia_headers.SUPPORTED_HEADERS_CTK:
130-
return _abs_norm(_find_ctk_header_directory(libname))
146+
return _find_ctk_header_directory(libname)
131147

132148
h_basename = supported_nvidia_headers.SUPPORTED_HEADERS_NON_CTK.get(libname)
133149
if h_basename is None:
134150
raise RuntimeError(f"UNKNOWN {libname=}")
135151

136152
candidate_dirs = supported_nvidia_headers.SUPPORTED_SITE_PACKAGE_HEADER_DIRS_NON_CTK.get(libname, [])
137-
hdr_dir: str | None # help mypy
153+
138154
for cdir in candidate_dirs:
139-
if hdr_dir := _find_under_site_packages(cdir, h_basename):
140-
return _abs_norm(hdr_dir)
155+
if found_hdr := _locate_under_site_packages(cdir, h_basename):
156+
return found_hdr
141157

142-
if hdr_dir := _find_based_on_conda_layout(libname, h_basename, False):
143-
return _abs_norm(hdr_dir)
158+
if found_hdr := _find_based_on_conda_layout(libname, h_basename, False):
159+
return found_hdr
144160

161+
# Fall back to system install directories
145162
candidate_dirs = supported_nvidia_headers.SUPPORTED_INSTALL_DIRS_NON_CTK.get(libname, [])
146163
for cdir in candidate_dirs:
147164
for hdr_dir in sorted(glob.glob(cdir), reverse=True):
148165
if _joined_isfile(hdr_dir, h_basename):
149-
return _abs_norm(hdr_dir)
150-
166+
# For system installs, we don't have a clear found_via, so use "system"
167+
return LocatedHeaderDir(abs_path=hdr_dir, found_via="supported_install_dir")
151168
return None
169+
170+
171+
def find_nvidia_header_directory(libname: str) -> str | None:
172+
"""Locate the header directory for a supported NVIDIA library.
173+
174+
Args:
175+
libname (str): The short name of the library whose headers are needed
176+
(e.g., ``"nvrtc"``, ``"cusolver"``, ``"nvshmem"``).
177+
178+
Returns:
179+
str or None: Absolute path to the discovered header directory, or ``None``
180+
if the headers cannot be found.
181+
182+
Raises:
183+
RuntimeError: If ``libname`` is not in the supported set.
184+
185+
Search order:
186+
1. **NVIDIA Python wheels**
187+
188+
- Scan installed distributions (``site-packages``) for header layouts
189+
shipped in NVIDIA wheels (e.g., ``cuda-toolkit[nvrtc]``).
190+
191+
2. **Conda environments**
192+
193+
- Check Conda-style installation prefixes, which use platform-specific
194+
include directory layouts.
195+
196+
3. **CUDA Toolkit environment variables**
197+
198+
- Use ``CUDA_HOME`` or ``CUDA_PATH`` (in that order).
199+
"""
200+
found = locate_nvidia_header_directory(libname)
201+
return found.abs_path if found else None

cuda_pathfinder/tests/test_find_nvidia_headers.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import pytest
2121

22-
from cuda.pathfinder import find_nvidia_header_directory
22+
from cuda.pathfinder import LocatedHeaderDir, find_nvidia_header_directory, locate_nvidia_header_directory
2323
from cuda.pathfinder._headers.supported_nvidia_headers import (
2424
SUPPORTED_HEADERS_CTK,
2525
SUPPORTED_HEADERS_CTK_ALL,
@@ -44,6 +44,11 @@ def test_unknown_libname():
4444
find_nvidia_header_directory("unknown-libname")
4545

4646

47+
def _located_hdr_dir_asserts(located_hdr_dir):
48+
assert isinstance(located_hdr_dir, LocatedHeaderDir)
49+
assert located_hdr_dir.found_via in ("site-packages", "conda", "CUDA_HOME", "supported_install_dir")
50+
51+
4752
def test_non_ctk_importlib_metadata_distributions_names():
4853
# Ensure the dict keys above stay in sync with supported_nvidia_headers
4954
assert sorted(NON_CTK_IMPORTLIB_METADATA_DISTRIBUTIONS_NAMES) == sorted(SUPPORTED_HEADERS_NON_CTK_ALL)
@@ -58,10 +63,14 @@ def have_distribution_for(libname: str) -> bool:
5863

5964

6065
@pytest.mark.parametrize("libname", SUPPORTED_HEADERS_NON_CTK.keys())
61-
def test_find_non_ctk_headers(info_summary_append, libname):
66+
def test_locate_non_ctk_headers(info_summary_append, libname):
6267
hdr_dir = find_nvidia_header_directory(libname)
68+
located_hdr_dir = locate_nvidia_header_directory(libname)
69+
assert hdr_dir is None if not located_hdr_dir else hdr_dir == located_hdr_dir.abs_path
70+
6371
info_summary_append(f"{hdr_dir=!r}")
6472
if hdr_dir:
73+
_located_hdr_dir_asserts(located_hdr_dir)
6574
assert os.path.isdir(hdr_dir)
6675
assert os.path.isfile(os.path.join(hdr_dir, SUPPORTED_HEADERS_NON_CTK[libname]))
6776
if have_distribution_for(libname):
@@ -88,10 +97,14 @@ def test_supported_headers_site_packages_ctk_consistency():
8897

8998

9099
@pytest.mark.parametrize("libname", SUPPORTED_HEADERS_CTK.keys())
91-
def test_find_ctk_headers(info_summary_append, libname):
100+
def test_locate_ctk_headers(info_summary_append, libname):
92101
hdr_dir = find_nvidia_header_directory(libname)
102+
located_hdr_dir = locate_nvidia_header_directory(libname)
103+
assert hdr_dir is None if not located_hdr_dir else hdr_dir == located_hdr_dir.abs_path
104+
93105
info_summary_append(f"{hdr_dir=!r}")
94106
if hdr_dir:
107+
_located_hdr_dir_asserts(located_hdr_dir)
95108
assert os.path.isdir(hdr_dir)
96109
h_filename = SUPPORTED_HEADERS_CTK[libname]
97110
assert os.path.isfile(os.path.join(hdr_dir, h_filename))

0 commit comments

Comments
 (0)