forked from NVIDIA/cuda-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_find_nvidia_headers.py
More file actions
60 lines (47 loc) · 2.22 KB
/
test_find_nvidia_headers.py
File metadata and controls
60 lines (47 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Currently these installations are only manually tested:
# conda create -y -n nvshmem python=3.12
# conda activate nvshmem
# conda install -y conda-forge::libnvshmem3 conda-forge::libnvshmem-dev
# wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
# sudo dpkg -i cuda-keyring_1.1-1_all.deb
# sudo apt update
# sudo apt install libnvshmem3-cuda-12 libnvshmem3-dev-cuda-12
# sudo apt install libnvshmem3-cuda-13 libnvshmem3-dev-cuda-13
import functools
import importlib.metadata
import os
import re
import pytest
from cuda.pathfinder import _find_nvidia_header_directory as find_nvidia_header_directory
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import IS_WINDOWS
STRICTNESS = os.environ.get("CUDA_PATHFINDER_TEST_FIND_NVIDIA_HEADERS_STRICTNESS", "see_what_works")
assert STRICTNESS in ("see_what_works", "all_must_work")
@functools.cache
def have_nvidia_nvshmem_package() -> bool:
pattern = re.compile(r"^nvidia-nvshmem-.*$")
return any(
pattern.match(dist.metadata["Name"]) for dist in importlib.metadata.distributions() if "Name" in dist.metadata
)
def test_unknown_libname():
with pytest.raises(RuntimeError, match=r"^UNKNOWN libname='unknown-libname'$"):
find_nvidia_header_directory("unknown-libname")
def test_find_libname_nvshmem(info_summary_append):
hdr_dir = find_nvidia_header_directory("nvshmem")
info_summary_append(f"{hdr_dir=!r}")
if IS_WINDOWS:
assert hdr_dir is None
pytest.skip("nvshmem has no Windows support.")
if hdr_dir:
assert os.path.isdir(hdr_dir)
assert os.path.isfile(os.path.join(hdr_dir, "nvshmem.h"))
if STRICTNESS == "all_must_work" or have_nvidia_nvshmem_package():
assert hdr_dir is not None
if have_nvidia_nvshmem_package():
hdr_dir_parts = hdr_dir.split(os.path.sep)
assert "site-packages" in hdr_dir_parts
elif conda_prefix := os.getenv("CONDA_PREFIX"):
assert hdr_dir.startswith(conda_prefix)
else:
assert hdr_dir.startswith("/usr/include/nvshmem_")