-
Notifications
You must be signed in to change notification settings - Fork 279
Initial version of cuda.pathfinder._find_nvidia_headers for nvshmem
#661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 24 commits
d174f5a
5ed86c5
a98de70
6a93fe7
3ab9643
c1e9385
c3464b3
627f6d0
26f94d1
6adddb0
363b649
2dd448c
d3f97e4
6f4d762
dfa3384
673c38c
80cece3
7c30292
a300419
3ae15e0
e855155
dc4de43
eb2e78a
c90c393
cee717a
a50da30
a185f03
7cfcbfe
510f470
2426260
b74a84c
6ee6529
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import functools | ||
| import glob | ||
| import os | ||
| from typing import Optional, cast | ||
|
|
||
| from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import IS_WINDOWS | ||
| from cuda.pathfinder._utils.conda_env import get_conda_prefix | ||
| from cuda.pathfinder._utils.env_vars_for_include import iter_env_vars_for_include_dirs | ||
| from cuda.pathfinder._utils.find_sub_dirs import find_sub_dirs_all_sitepackages | ||
|
|
||
|
|
||
| @functools.cache | ||
| def find_nvidia_header_directory(libname: str) -> Optional[str]: | ||
| if libname != "nvshmem": | ||
| raise RuntimeError(f"UNKNOWN {libname=}") | ||
|
|
||
| if libname == "nvshmem" and IS_WINDOWS: | ||
| # nvshmem has no Windows support. | ||
| return None | ||
|
|
||
| # Installed from a wheel | ||
| nvidia_sub_dirs = ("nvidia", "nvshmem", "include") | ||
| hdr_dir: str # help mypy | ||
| for hdr_dir in find_sub_dirs_all_sitepackages(nvidia_sub_dirs): | ||
| nvshmem_h_path = os.path.join(hdr_dir, "nvshmem.h") | ||
| if os.path.isfile(nvshmem_h_path): | ||
| return hdr_dir | ||
|
|
||
| conda_prefix = get_conda_prefix() | ||
| if conda_prefix and os.path.isdir(conda_prefix.path): | ||
| hdr_dir = os.path.join(conda_prefix.path, "include") | ||
| if os.path.isdir(hdr_dir): | ||
|
rwgk marked this conversation as resolved.
Outdated
|
||
| nvshmem_h_path = os.path.join(hdr_dir, "nvshmem.h") | ||
| if os.path.isfile(nvshmem_h_path): | ||
| return hdr_dir | ||
|
|
||
| for hdr_dir in sorted(glob.glob("/usr/include/nvshmem_*"), reverse=True): | ||
|
leofang marked this conversation as resolved.
|
||
| if os.path.isdir(hdr_dir): | ||
| nvshmem_h_path = os.path.join(hdr_dir, "nvshmem.h") | ||
| if os.path.isfile(nvshmem_h_path): | ||
| return hdr_dir | ||
|
|
||
| for hdr_dir in iter_env_vars_for_include_dirs(): | ||
| if os.path.isdir(hdr_dir): | ||
| nvshmem_h_path = os.path.join(hdr_dir, "nvshmem.h") | ||
| if os.path.isfile(nvshmem_h_path): | ||
| return cast(str, hdr_dir) # help mypy | ||
|
|
||
| return None | ||
|
rwgk marked this conversation as resolved.
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import functools | ||
| import os | ||
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
| from typing import Literal, Optional | ||
|
|
||
| # https://docs.conda.io/projects/conda-build/en/stable/user-guide/environment-variables.html | ||
|
|
||
| BUILD_STATES = ("RENDER", "BUILD", "TEST") | ||
|
rwgk marked this conversation as resolved.
Outdated
|
||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class CondaPrefix: | ||
| env_state: Literal["RENDER", "BUILD", "TEST", "activated"] | ||
|
rwgk marked this conversation as resolved.
Outdated
|
||
| path: Path | ||
|
|
||
|
|
||
| @functools.cache | ||
| def get_conda_prefix() -> Optional[CondaPrefix]: | ||
| """ | ||
| Return the effective conda prefix. | ||
| - RENDER, BUILD, TEST: inside conda-build (host prefix at $PREFIX) | ||
|
rwgk marked this conversation as resolved.
Outdated
|
||
| - activated: user-activated env ($CONDA_PREFIX) | ||
| - None: neither detected | ||
| """ | ||
| state = os.getenv("CONDA_BUILD_STATE") | ||
| if state: | ||
| if state in BUILD_STATES: | ||
| p = os.getenv("PREFIX") | ||
|
rwgk marked this conversation as resolved.
Outdated
|
||
| if p: | ||
| return CondaPrefix(state, Path(p)) # type: ignore[arg-type] | ||
| return None | ||
|
|
||
| cp = os.getenv("CONDA_PREFIX") | ||
| if cp: | ||
| return CondaPrefix("activated", Path(cp)) | ||
|
|
||
| return None | ||
|
rwgk marked this conversation as resolved.
Outdated
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import os | ||
| import sys | ||
| from collections.abc import Iterable | ||
|
|
||
| IS_WINDOWS = sys.platform == "win32" | ||
|
|
||
| # GCC/Clang-style include vars | ||
| GCC_VNAMES = ("CPATH", "C_INCLUDE_PATH", "CPLUS_INCLUDE_PATH") | ||
|
|
||
| # MSVC: INCLUDE is the canonical header search variable | ||
| MSVC_GCC_VNAMES = ("INCLUDE",) | ||
|
|
||
| VNAMES: tuple[str, ...] = MSVC_GCC_VNAMES + GCC_VNAMES if IS_WINDOWS else GCC_VNAMES | ||
|
|
||
|
|
||
| def iter_env_vars_for_include_dirs() -> Iterable[str]: | ||
| for vname in VNAMES: | ||
| v = os.getenv(vname) | ||
| if v: | ||
| for d in v.split(os.pathsep): | ||
| if d: | ||
| yield d |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # Currently these installations are only manually tested: | ||
|
|
||
| # conda create -y -n nvshmem python=3.12 | ||
| # conda activate nvshmem | ||
| # conda install -y conda-forge::libnvshmem3 conda-forge::libnvshmem-dev | ||
|
|
||
| # wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb | ||
| # sudo dpkg -i cuda-keyring_1.1-1_all.deb | ||
| # sudo apt update | ||
| # sudo apt install libnvshmem3-cuda-12 libnvshmem3-dev-cuda-12 | ||
| # sudo apt install libnvshmem3-cuda-13 libnvshmem3-dev-cuda-13 | ||
|
|
||
| import functools | ||
| import importlib.metadata | ||
| import os | ||
| import re | ||
|
|
||
| import pytest | ||
|
|
||
| from cuda.pathfinder import _find_nvidia_header_directory as find_nvidia_header_directory | ||
| from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import IS_WINDOWS | ||
|
|
||
| STRICTNESS = os.environ.get("CUDA_PATHFINDER_TEST_FIND_NVIDIA_HEADERS_STRICTNESS", "see_what_works") | ||
| assert STRICTNESS in ("see_what_works", "all_must_work") | ||
|
|
||
|
|
||
| @functools.cache | ||
| def have_nvidia_nvshmem_package() -> bool: | ||
| pattern = re.compile(r"^nvidia-nvshmem-.*$") | ||
| return any( | ||
| pattern.match(dist.metadata["Name"]) for dist in importlib.metadata.distributions() if "Name" in dist.metadata | ||
| ) | ||
|
|
||
|
|
||
| def test_unknown_libname(): | ||
| with pytest.raises(RuntimeError, match=r"^UNKNOWN libname='unknown-libname'$"): | ||
| find_nvidia_header_directory("unknown-libname") | ||
|
|
||
|
|
||
| def test_find_libname_nvshmem(info_summary_append): | ||
| hdr_dir = find_nvidia_header_directory("nvshmem") | ||
| info_summary_append(f"{hdr_dir=!r}") | ||
| if IS_WINDOWS: | ||
| assert hdr_dir is None | ||
| pytest.skip("nvshmem has no Windows support.") | ||
| if STRICTNESS == "all_must_work" or have_nvidia_nvshmem_package(): | ||
| assert hdr_dir is not None | ||
| if have_nvidia_nvshmem_package(): | ||
| hdr_dir_parts = hdr_dir.split(os.path.sep) | ||
| assert any( | ||
| sub_dir in hdr_dir_parts | ||
| for sub_dir in ( | ||
| "site-packages", # pip install | ||
| "dist-packages", # apt install | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to do anything for conda packages here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is gated by While at it, I added additional test (low-hanging fruits): commit b74a84c Interactive testing with conda and Without wheel or conda ( I also did negative tests for both, by using bad expected paths, and they fail as expected. |
||
| ) | ||
| ), hdr_dir | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import pytest | ||
|
|
||
| from cuda.pathfinder._utils.conda_env import ( | ||
| BUILD_STATES, | ||
| CondaPrefix, | ||
| get_conda_prefix, | ||
| ) | ||
|
|
||
| # Auto-clean environment & cache before every test ----------------------------- | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def _clean_env_and_cache(monkeypatch): | ||
| # Remove any possibly inherited variables from the test runner environment | ||
| for k in ("CONDA_BUILD_STATE", "PREFIX", "CONDA_PREFIX"): | ||
| monkeypatch.delenv(k, raising=False) | ||
| # Clear the cached result between tests | ||
| get_conda_prefix.cache_clear() | ||
| return | ||
| # (No teardown needed; monkeypatch auto-reverts) | ||
|
|
||
|
|
||
| # Tests ----------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def test_returns_none_when_no_relevant_env_vars(): | ||
| assert get_conda_prefix() is None | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("state", BUILD_STATES) | ||
| def test_build_state_returns_prefix_when_present(state, monkeypatch, tmp_path): | ||
| monkeypatch.setenv("CONDA_BUILD_STATE", state) | ||
| monkeypatch.setenv("PREFIX", str(tmp_path)) | ||
| res = get_conda_prefix() | ||
| assert isinstance(res, CondaPrefix) | ||
| assert res.env_state == state | ||
| assert res.path == tmp_path | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("state", BUILD_STATES) | ||
| def test_build_state_requires_prefix_otherwise_none(state, monkeypatch): | ||
| monkeypatch.setenv("CONDA_BUILD_STATE", state) | ||
| # No PREFIX set | ||
| assert get_conda_prefix() is None | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("state", BUILD_STATES) | ||
| def test_build_state_with_empty_prefix_returns_none(state, monkeypatch): | ||
| monkeypatch.setenv("CONDA_BUILD_STATE", state) | ||
| monkeypatch.setenv("PREFIX", "") | ||
| assert get_conda_prefix() is None | ||
|
|
||
|
|
||
| def test_activated_env_returns_conda_prefix(monkeypatch, tmp_path): | ||
| monkeypatch.setenv("CONDA_PREFIX", str(tmp_path)) | ||
| res = get_conda_prefix() | ||
| assert isinstance(res, CondaPrefix) | ||
| assert res.env_state == "activated" | ||
| assert res.path == tmp_path | ||
|
|
||
|
|
||
| def test_activated_env_ignores_empty_conda_prefix(monkeypatch): | ||
| monkeypatch.setenv("CONDA_PREFIX", "") | ||
| assert get_conda_prefix() is None | ||
|
|
||
|
|
||
| def test_build_state_wins_over_activated_when_valid(monkeypatch, tmp_path): | ||
| build_p = tmp_path / "host" | ||
| user_p = tmp_path / "user" | ||
| monkeypatch.setenv("CONDA_BUILD_STATE", "TEST") | ||
| monkeypatch.setenv("PREFIX", str(build_p)) | ||
| monkeypatch.setenv("CONDA_PREFIX", str(user_p)) | ||
| res = get_conda_prefix() | ||
| assert res | ||
| assert res.env_state == "TEST" | ||
| assert res.path == build_p | ||
|
|
||
|
|
||
| def test_unknown_build_state_returns_none_even_if_conda_prefix_set(monkeypatch, tmp_path): | ||
| # Any non-empty CONDA_BUILD_STATE that is not recognized -> None | ||
| monkeypatch.setenv("CONDA_BUILD_STATE", "SOMETHING_ELSE") | ||
| monkeypatch.setenv("CONDA_PREFIX", str(tmp_path)) | ||
| assert get_conda_prefix() is None | ||
|
|
||
|
|
||
| def test_empty_build_state_treated_as_absent_and_falls_back_to_activated(monkeypatch, tmp_path): | ||
| # Empty string is falsy -> treated like "not set" -> activated path | ||
| monkeypatch.setenv("CONDA_BUILD_STATE", "") | ||
| monkeypatch.setenv("CONDA_PREFIX", str(tmp_path)) | ||
| res = get_conda_prefix() | ||
| assert res | ||
| assert res.env_state == "activated" | ||
| assert res.path == tmp_path | ||
|
|
||
|
|
||
| def test_have_cache(monkeypatch, tmp_path): | ||
| monkeypatch.setenv("CONDA_PREFIX", str(tmp_path)) | ||
| res = get_conda_prefix() | ||
| assert res | ||
| assert res.path == tmp_path | ||
| res2 = get_conda_prefix() | ||
| assert res2 is res |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should do it in a follow up, but I think we should really move towards using an
importlibbased resolution method instead of walking the sitepackages ourselves.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we discussed and agreed walking the paths is acceptable since on the PyPI side they are predictable? Any reason to prefer
importlib?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point, tx, I created issue #949 to track this.
(I wrote this code before @ZzEeKkAa pointed me to
importlibwhile working on PR #864)Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines crossed here. Let's review under #949 when we get to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
864?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, sorry, copy-paste mishap. Corrected (949).