Skip to content

Commit f7e81ed

Browse files
committed
Add cuda.h toolkit version parser.
Introduce a small toolkit-info utility that reads the CUDA_VERSION macro from cuda.h so follow-up guard-rails changes can infer CTK major.minor from toolkit headers without depending on version.json. Made-with: Cursor
1 parent 3bf0e98 commit f7e81ed

2 files changed

Lines changed: 148 additions & 0 deletions

File tree

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
import functools
7+
import re
8+
from dataclasses import dataclass
9+
from pathlib import Path
10+
11+
_CUDA_VERSION_RE = re.compile(r"^\s*#\s*define\s+CUDA_VERSION\s+(?P<encoded>\d+)\b", re.MULTILINE)
12+
13+
14+
class ReadCudaHeaderVersionError(RuntimeError):
15+
"""Raised when ``read_cuda_header_version()`` cannot determine the CTK version from ``cuda.h``."""
16+
17+
18+
@dataclass(frozen=True, slots=True)
19+
class CudaToolkitVersion:
20+
"""CUDA Toolkit version encoded by the ``CUDA_VERSION`` macro in ``cuda.h``."""
21+
22+
encoded: int
23+
major: int
24+
minor: int
25+
26+
27+
def parse_cuda_header_version(header_text: str) -> CudaToolkitVersion | None:
28+
"""Parse the CUDA Toolkit major/minor version from ``cuda.h`` text."""
29+
match = _CUDA_VERSION_RE.search(header_text)
30+
if match is None:
31+
return None
32+
encoded = int(match.group("encoded"))
33+
return CudaToolkitVersion(
34+
encoded=encoded,
35+
major=encoded // 1000,
36+
minor=(encoded % 1000) // 10,
37+
)
38+
39+
40+
@functools.cache
41+
def read_cuda_header_version(cuda_header_path: str) -> CudaToolkitVersion:
42+
"""Read and parse the CUDA Toolkit major/minor version from ``cuda.h``."""
43+
try:
44+
header_text = Path(cuda_header_path).read_text(encoding="utf-8", errors="replace")
45+
version = parse_cuda_header_version(header_text)
46+
if version is None:
47+
raise RuntimeError(f"{cuda_header_path!r} does not define CUDA_VERSION.")
48+
return version
49+
except Exception as exc:
50+
raise ReadCudaHeaderVersionError(
51+
f"Failed to read the CUDA Toolkit version from cuda.h at {cuda_header_path!r}."
52+
) from exc
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import pytest
5+
6+
from cuda.pathfinder._utils import toolkit_info
7+
8+
9+
@pytest.fixture(autouse=True)
10+
def _clear_cuda_header_version_cache():
11+
toolkit_info.read_cuda_header_version.cache_clear()
12+
yield
13+
toolkit_info.read_cuda_header_version.cache_clear()
14+
15+
16+
def test_parse_cuda_header_version_returns_parsed_dataclass():
17+
header_text = """
18+
#ifndef CUDA_H
19+
#define CUDA_H
20+
#define CUDA_VERSION 13020
21+
#endif
22+
"""
23+
24+
assert toolkit_info.parse_cuda_header_version(header_text) == toolkit_info.CudaToolkitVersion(
25+
encoded=13020,
26+
major=13,
27+
minor=2,
28+
)
29+
30+
31+
def test_parse_cuda_header_version_returns_none_when_macro_is_missing():
32+
header_text = """
33+
#ifndef CUDA_H
34+
#define CUDA_H
35+
#define CUDA_API_PER_THREAD_DEFAULT_STREAM 1
36+
#endif
37+
"""
38+
39+
assert toolkit_info.parse_cuda_header_version(header_text) is None
40+
41+
42+
def test_read_cuda_header_version_reads_file_and_returns_parsed_dataclass(tmp_path):
43+
cuda_h_path = tmp_path / "cuda.h"
44+
cuda_h_path.write_text(
45+
"""
46+
#ifndef CUDA_H
47+
#define CUDA_H
48+
#define CUDA_VERSION 12090 /* CUDA 12.9 */
49+
#endif
50+
""",
51+
encoding="utf-8",
52+
)
53+
54+
assert toolkit_info.read_cuda_header_version(str(cuda_h_path)) == toolkit_info.CudaToolkitVersion(
55+
encoded=12090,
56+
major=12,
57+
minor=9,
58+
)
59+
60+
61+
def test_read_cuda_header_version_tolerates_non_utf8_bytes(tmp_path):
62+
cuda_h_path = tmp_path / "cuda.h"
63+
cuda_h_path.write_bytes(
64+
b"#ifndef CUDA_H\n"
65+
b"#define CUDA_H\n"
66+
b"\xff\xfe invalid bytes in comment or banner\n"
67+
b"#define CUDA_VERSION 12080\n"
68+
b"#endif\n"
69+
)
70+
71+
assert toolkit_info.read_cuda_header_version(str(cuda_h_path)) == toolkit_info.CudaToolkitVersion(
72+
encoded=12080,
73+
major=12,
74+
minor=8,
75+
)
76+
77+
78+
def test_read_cuda_header_version_wraps_parse_failures(tmp_path):
79+
cuda_h_path = tmp_path / "cuda.h"
80+
cuda_h_path.write_text(
81+
"""
82+
#ifndef CUDA_H
83+
#define CUDA_H
84+
#endif
85+
""",
86+
encoding="utf-8",
87+
)
88+
89+
with pytest.raises(
90+
toolkit_info.ReadCudaHeaderVersionError,
91+
match="Failed to read the CUDA Toolkit version from cuda.h",
92+
) as exc_info:
93+
toolkit_info.read_cuda_header_version(str(cuda_h_path))
94+
95+
assert isinstance(exc_info.value.__cause__, RuntimeError)
96+
assert "does not define CUDA_VERSION" in str(exc_info.value.__cause__)

0 commit comments

Comments
 (0)