11# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: Apache-2.0
33
4+ from __future__ import annotations
5+
46import functools
57import glob
68import os
9+ from dataclasses import dataclass
710
811from cuda .pathfinder ._headers import supported_nvidia_headers
912from cuda .pathfinder ._utils .env_vars import get_cuda_home_or_path
1013from cuda .pathfinder ._utils .find_sub_dirs import find_sub_dirs_all_sitepackages
1114from cuda .pathfinder ._utils .platform_aware import IS_WINDOWS
1215
1316
17+ @dataclass
18+ class LocatedHeaderDir :
19+ abs_path : str | None
20+ found_via : str
21+
22+ def __post_init__ (self ) -> None :
23+ self .abs_path = _abs_norm (self .abs_path )
24+
25+
1426def _abs_norm (path : str | None ) -> str | None :
1527 if path :
1628 return os .path .normpath (os .path .abspath (path ))
@@ -21,16 +33,16 @@ def _joined_isfile(dirpath: str, basename: str) -> bool:
2133 return os .path .isfile (os .path .join (dirpath , basename ))
2234
2335
24- def _find_under_site_packages (sub_dir : str , h_basename : str ) -> str | None :
36+ def _locate_under_site_packages (sub_dir : str , h_basename : str ) -> LocatedHeaderDir | None :
2537 # Installed from a wheel
2638 hdr_dir : str # help mypy
2739 for hdr_dir in find_sub_dirs_all_sitepackages (tuple (sub_dir .split ("/" ))):
2840 if _joined_isfile (hdr_dir , h_basename ):
29- return hdr_dir
41+ return LocatedHeaderDir ( abs_path = hdr_dir , found_via = "site-packages" )
3042 return None
3143
3244
33- def _find_based_on_ctk_layout (libname : str , h_basename : str , anchor_point : str ) -> str | None :
45+ def _locate_based_on_ctk_layout (libname : str , h_basename : str , anchor_point : str ) -> str | None :
3446 parts = [anchor_point ]
3547 if libname == "nvvm" :
3648 parts .append (libname )
@@ -52,7 +64,7 @@ def _find_based_on_ctk_layout(libname: str, h_basename: str, anchor_point: str)
5264 return None
5365
5466
55- def _find_based_on_conda_layout (libname : str , h_basename : str , ctk_layout : bool ) -> str | None :
67+ def _find_based_on_conda_layout (libname : str , h_basename : str , ctk_layout : bool ) -> LocatedHeaderDir | None :
5668 conda_prefix = os .environ .get ("CONDA_PREFIX" )
5769 if not conda_prefix :
5870 return None
@@ -73,39 +85,43 @@ def _find_based_on_conda_layout(libname: str, h_basename: str, ctk_layout: bool)
7385 else :
7486 include_path = os .path .join (conda_prefix , "include" )
7587 anchor_point = os .path .dirname (include_path )
76- return _find_based_on_ctk_layout (libname , h_basename , anchor_point )
88+ found_header_path = _locate_based_on_ctk_layout (libname , h_basename , anchor_point )
89+ if found_header_path :
90+ return LocatedHeaderDir (abs_path = found_header_path , found_via = "conda" )
91+ return None
7792
7893
79- def _find_ctk_header_directory (libname : str ) -> str | None :
94+ def _find_ctk_header_directory (libname : str ) -> LocatedHeaderDir | None :
8095 h_basename = supported_nvidia_headers .SUPPORTED_HEADERS_CTK [libname ]
8196 candidate_dirs = supported_nvidia_headers .SUPPORTED_SITE_PACKAGE_HEADER_DIRS_CTK [libname ]
8297
8398 for cdir in candidate_dirs :
84- if hdr_dir := _find_under_site_packages (cdir , h_basename ):
99+ if hdr_dir := _locate_under_site_packages (cdir , h_basename ):
85100 return hdr_dir
86101
87102 if hdr_dir := _find_based_on_conda_layout (libname , h_basename , True ):
88103 return hdr_dir
89104
90105 cuda_home = get_cuda_home_or_path ()
91106 if cuda_home : # noqa: SIM102
92- if result := _find_based_on_ctk_layout (libname , h_basename , cuda_home ):
93- return result
107+ if result := _locate_based_on_ctk_layout (libname , h_basename , cuda_home ):
108+ return LocatedHeaderDir ( abs_path = result , found_via = "CUDA_HOME" )
94109
95110 return None
96111
97112
98113@functools .cache
99- def find_nvidia_header_directory (libname : str ) -> str | None :
114+ def locate_nvidia_header_directory (libname : str ) -> LocatedHeaderDir | None :
100115 """Locate the header directory for a supported NVIDIA library.
101116
102117 Args:
103118 libname (str): The short name of the library whose headers are needed
104119 (e.g., ``"nvrtc"``, ``"cusolver"``, ``"nvshmem"``).
105120
106121 Returns:
107- str or None: Absolute path to the discovered header directory, or ``None``
108- if the headers cannot be found.
122+ LocatedHeaderDir or None: A LocatedHeaderDir object containing the absolute path
123+ to the discovered header directory and information about where it was found,
124+ or ``None`` if the headers cannot be found.
109125
110126 Raises:
111127 RuntimeError: If ``libname`` is not in the supported set.
@@ -127,25 +143,59 @@ def find_nvidia_header_directory(libname: str) -> str | None:
127143 """
128144
129145 if libname in supported_nvidia_headers .SUPPORTED_HEADERS_CTK :
130- return _abs_norm ( _find_ctk_header_directory (libname ) )
146+ return _find_ctk_header_directory (libname )
131147
132148 h_basename = supported_nvidia_headers .SUPPORTED_HEADERS_NON_CTK .get (libname )
133149 if h_basename is None :
134150 raise RuntimeError (f"UNKNOWN { libname = } " )
135151
136152 candidate_dirs = supported_nvidia_headers .SUPPORTED_SITE_PACKAGE_HEADER_DIRS_NON_CTK .get (libname , [])
137- hdr_dir : str | None # help mypy
153+
138154 for cdir in candidate_dirs :
139- if hdr_dir := _find_under_site_packages (cdir , h_basename ):
140- return _abs_norm ( hdr_dir )
155+ if found_hdr := _locate_under_site_packages (cdir , h_basename ):
156+ return found_hdr
141157
142- if hdr_dir := _find_based_on_conda_layout (libname , h_basename , False ):
143- return _abs_norm ( hdr_dir )
158+ if found_hdr := _find_based_on_conda_layout (libname , h_basename , False ):
159+ return found_hdr
144160
161+ # Fall back to system install directories
145162 candidate_dirs = supported_nvidia_headers .SUPPORTED_INSTALL_DIRS_NON_CTK .get (libname , [])
146163 for cdir in candidate_dirs :
147164 for hdr_dir in sorted (glob .glob (cdir ), reverse = True ):
148165 if _joined_isfile (hdr_dir , h_basename ):
149- return _abs_norm ( hdr_dir )
150-
166+ # For system installs, we don't have a clear found_via, so use "system"
167+ return LocatedHeaderDir ( abs_path = hdr_dir , found_via = "supported_install_dir" )
151168 return None
169+
170+
171+ def find_nvidia_header_directory (libname : str ) -> str | None :
172+ """Locate the header directory for a supported NVIDIA library.
173+
174+ Args:
175+ libname (str): The short name of the library whose headers are needed
176+ (e.g., ``"nvrtc"``, ``"cusolver"``, ``"nvshmem"``).
177+
178+ Returns:
179+ str or None: Absolute path to the discovered header directory, or ``None``
180+ if the headers cannot be found.
181+
182+ Raises:
183+ RuntimeError: If ``libname`` is not in the supported set.
184+
185+ Search order:
186+ 1. **NVIDIA Python wheels**
187+
188+ - Scan installed distributions (``site-packages``) for header layouts
189+ shipped in NVIDIA wheels (e.g., ``cuda-toolkit[nvrtc]``).
190+
191+ 2. **Conda environments**
192+
193+ - Check Conda-style installation prefixes, which use platform-specific
194+ include directory layouts.
195+
196+ 3. **CUDA Toolkit environment variables**
197+
198+ - Use ``CUDA_HOME`` or ``CUDA_PATH`` (in that order).
199+ """
200+ found = locate_nvidia_header_directory (libname )
201+ return found .abs_path if found else None
0 commit comments