|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +import pytest |
| 5 | + |
| 6 | +from cuda.pathfinder._utils import toolkit_info |
| 7 | + |
| 8 | + |
| 9 | +@pytest.fixture(autouse=True) |
| 10 | +def _clear_cuda_header_version_cache(): |
| 11 | + toolkit_info.read_cuda_header_version.cache_clear() |
| 12 | + yield |
| 13 | + toolkit_info.read_cuda_header_version.cache_clear() |
| 14 | + |
| 15 | + |
| 16 | +def test_parse_cuda_header_version_returns_parsed_dataclass(): |
| 17 | + header_text = """ |
| 18 | + #ifndef CUDA_H |
| 19 | + #define CUDA_H |
| 20 | + #define CUDA_VERSION 13020 |
| 21 | + #endif |
| 22 | + """ |
| 23 | + |
| 24 | + assert toolkit_info.parse_cuda_header_version(header_text) == toolkit_info.CudaToolkitVersion( |
| 25 | + encoded=13020, |
| 26 | + major=13, |
| 27 | + minor=2, |
| 28 | + ) |
| 29 | + |
| 30 | + |
| 31 | +def test_parse_cuda_header_version_returns_none_when_macro_is_missing(): |
| 32 | + header_text = """ |
| 33 | + #ifndef CUDA_H |
| 34 | + #define CUDA_H |
| 35 | + #define CUDA_API_PER_THREAD_DEFAULT_STREAM 1 |
| 36 | + #endif |
| 37 | + """ |
| 38 | + |
| 39 | + assert toolkit_info.parse_cuda_header_version(header_text) is None |
| 40 | + |
| 41 | + |
| 42 | +def test_read_cuda_header_version_reads_file_and_returns_parsed_dataclass(tmp_path): |
| 43 | + cuda_h_path = tmp_path / "cuda.h" |
| 44 | + cuda_h_path.write_text( |
| 45 | + """ |
| 46 | + #ifndef CUDA_H |
| 47 | + #define CUDA_H |
| 48 | + #define CUDA_VERSION 12090 /* CUDA 12.9 */ |
| 49 | + #endif |
| 50 | + """, |
| 51 | + encoding="utf-8", |
| 52 | + ) |
| 53 | + |
| 54 | + assert toolkit_info.read_cuda_header_version(str(cuda_h_path)) == toolkit_info.CudaToolkitVersion( |
| 55 | + encoded=12090, |
| 56 | + major=12, |
| 57 | + minor=9, |
| 58 | + ) |
| 59 | + |
| 60 | + |
| 61 | +def test_read_cuda_header_version_tolerates_non_utf8_bytes(tmp_path): |
| 62 | + cuda_h_path = tmp_path / "cuda.h" |
| 63 | + cuda_h_path.write_bytes( |
| 64 | + b"#ifndef CUDA_H\n" |
| 65 | + b"#define CUDA_H\n" |
| 66 | + b"\xff\xfe invalid bytes in comment or banner\n" |
| 67 | + b"#define CUDA_VERSION 12080\n" |
| 68 | + b"#endif\n" |
| 69 | + ) |
| 70 | + |
| 71 | + assert toolkit_info.read_cuda_header_version(str(cuda_h_path)) == toolkit_info.CudaToolkitVersion( |
| 72 | + encoded=12080, |
| 73 | + major=12, |
| 74 | + minor=8, |
| 75 | + ) |
| 76 | + |
| 77 | + |
| 78 | +def test_read_cuda_header_version_wraps_parse_failures(tmp_path): |
| 79 | + cuda_h_path = tmp_path / "cuda.h" |
| 80 | + cuda_h_path.write_text( |
| 81 | + """ |
| 82 | + #ifndef CUDA_H |
| 83 | + #define CUDA_H |
| 84 | + #endif |
| 85 | + """, |
| 86 | + encoding="utf-8", |
| 87 | + ) |
| 88 | + |
| 89 | + with pytest.raises( |
| 90 | + toolkit_info.ReadCudaHeaderVersionError, |
| 91 | + match="Failed to read the CUDA Toolkit version from cuda.h", |
| 92 | + ) as exc_info: |
| 93 | + toolkit_info.read_cuda_header_version(str(cuda_h_path)) |
| 94 | + |
| 95 | + assert isinstance(exc_info.value.__cause__, RuntimeError) |
| 96 | + assert "does not define CUDA_VERSION" in str(exc_info.value.__cause__) |
0 commit comments