Skip to content

Commit e3b402a

Browse files
committed
Share encoded CUDA version decoding logic.
Centralize encoded CUDA version parsing and validation so toolkit and driver version helpers stay aligned and cuda.h parsing gets consistent string conversion and error reporting. Made-with: Cursor
1 parent f7e81ed commit e3b402a

4 files changed

Lines changed: 99 additions & 21 deletions

File tree

cuda_pathfinder/cuda/pathfinder/_utils/driver_info.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,21 @@
77
import functools
88
from collections.abc import Callable
99
from dataclasses import dataclass
10+
from typing import cast
1011

1112
from cuda.pathfinder._dynamic_libs.load_nvidia_dynamic_lib import (
1213
load_nvidia_dynamic_lib as _load_nvidia_dynamic_lib,
1314
)
1415
from cuda.pathfinder._utils.platform_aware import IS_WINDOWS
16+
from cuda.pathfinder._utils.toolkit_info import EncodedCudaVersion
1517

1618

1719
class QueryDriverCudaVersionError(RuntimeError):
1820
"""Raised when ``query_driver_cuda_version()`` cannot determine the CUDA driver version."""
1921

2022

2123
@dataclass(frozen=True, slots=True)
22-
class DriverCudaVersion:
24+
class DriverCudaVersion(EncodedCudaVersion):
2325
"""
2426
CUDA-facing driver version reported by ``cuDriverGetVersion()``.
2527
@@ -41,21 +43,13 @@ class DriverCudaVersion:
4143
to ``Driver Version: 595.58.03``.
4244
"""
4345

44-
encoded: int
45-
major: int
46-
minor: int
47-
4846

4947
@functools.cache
5048
def query_driver_cuda_version() -> DriverCudaVersion:
5149
"""Return the CUDA driver version parsed into its major/minor components."""
5250
try:
5351
encoded = _query_driver_cuda_version_int()
54-
return DriverCudaVersion(
55-
encoded=encoded,
56-
major=encoded // 1000,
57-
minor=(encoded % 1000) // 10,
58-
)
52+
return cast(DriverCudaVersion, DriverCudaVersion.from_encoded(encoded))
5953
except Exception as exc:
6054
raise QueryDriverCudaVersionError("Failed to query the CUDA driver version.") from exc
6155

cuda_pathfinder/cuda/pathfinder/_utils/toolkit_info.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,34 +7,63 @@
77
import re
88
from dataclasses import dataclass
99
from pathlib import Path
10+
from typing import TypeVar
1011

1112
_CUDA_VERSION_RE = re.compile(r"^\s*#\s*define\s+CUDA_VERSION\s+(?P<encoded>\d+)\b", re.MULTILINE)
13+
EncodedCudaVersionT = TypeVar("EncodedCudaVersionT", bound="EncodedCudaVersion")
14+
15+
16+
@dataclass(frozen=True, slots=True)
17+
class EncodedCudaVersion:
18+
"""CUDA major/minor version represented in CUDA's integer ``encoded`` form."""
19+
20+
encoded: int
21+
major: int
22+
minor: int
23+
24+
@classmethod
25+
def from_encoded(cls: type[EncodedCudaVersionT], encoded: int | str) -> EncodedCudaVersionT:
26+
if isinstance(encoded, str):
27+
try:
28+
encoded_int = int(encoded)
29+
except ValueError as exc:
30+
raise ValueError(
31+
f"{cls.__name__}.from_encoded() expected an integer or decimal string, got {encoded!r}."
32+
) from exc
33+
elif isinstance(encoded, int):
34+
encoded_int = encoded
35+
else:
36+
raise TypeError(
37+
f"{cls.__name__}.from_encoded() expected an integer or decimal string, got {type(encoded).__name__}."
38+
)
39+
if encoded_int < 0:
40+
raise ValueError(
41+
f"{cls.__name__}.from_encoded() expected a non-negative encoded CUDA version, got {encoded_int}."
42+
)
43+
# CUDA encodes versions as major * 1000 + minor * 10. The least-significant
44+
# decimal is ignored here: it is 0 in all CUDA releases and is not a patch version.
45+
return cls(
46+
encoded=encoded_int,
47+
major=encoded_int // 1000,
48+
minor=(encoded_int % 1000) // 10,
49+
)
1250

1351

1452
class ReadCudaHeaderVersionError(RuntimeError):
1553
"""Raised when ``read_cuda_header_version()`` cannot determine the CTK version from ``cuda.h``."""
1654

1755

1856
@dataclass(frozen=True, slots=True)
19-
class CudaToolkitVersion:
57+
class CudaToolkitVersion(EncodedCudaVersion):
2058
"""CUDA Toolkit version encoded by the ``CUDA_VERSION`` macro in ``cuda.h``."""
2159

22-
encoded: int
23-
major: int
24-
minor: int
25-
2660

2761
def parse_cuda_header_version(header_text: str) -> CudaToolkitVersion | None:
2862
"""Parse the CUDA Toolkit major/minor version from ``cuda.h`` text."""
2963
match = _CUDA_VERSION_RE.search(header_text)
3064
if match is None:
3165
return None
32-
encoded = int(match.group("encoded"))
33-
return CudaToolkitVersion(
34-
encoded=encoded,
35-
major=encoded // 1000,
36-
minor=(encoded % 1000) // 10,
37-
)
66+
return CudaToolkitVersion.from_encoded(match.group("encoded"))
3867

3968

4069
@functools.cache

cuda_pathfinder/tests/test_utils_driver_info.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,17 @@ def test_query_driver_cuda_version_returns_parsed_dataclass(monkeypatch):
7373
)
7474

7575

76+
def test_driver_cuda_version_from_encoded_returns_subclass_instance():
77+
version = driver_info.DriverCudaVersion.from_encoded(12080)
78+
79+
assert version == driver_info.DriverCudaVersion(
80+
encoded=12080,
81+
major=12,
82+
minor=8,
83+
)
84+
assert type(version) is driver_info.DriverCudaVersion
85+
86+
7687
def test_query_driver_cuda_version_wraps_internal_failures(monkeypatch):
7788
root_cause = RuntimeError("low-level query failed")
7889

cuda_pathfinder/tests/test_utils_toolkit_info.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,39 @@ def _clear_cuda_header_version_cache():
1313
toolkit_info.read_cuda_header_version.cache_clear()
1414

1515

16+
def test_encoded_cuda_version_from_encoded_decodes_major_minor():
17+
assert toolkit_info.EncodedCudaVersion.from_encoded(13020) == toolkit_info.EncodedCudaVersion(
18+
encoded=13020,
19+
major=13,
20+
minor=2,
21+
)
22+
23+
24+
def test_encoded_cuda_version_from_encoded_accepts_decimal_string():
25+
assert toolkit_info.EncodedCudaVersion.from_encoded("13020") == toolkit_info.EncodedCudaVersion(
26+
encoded=13020,
27+
major=13,
28+
minor=2,
29+
)
30+
31+
32+
def test_encoded_cuda_version_from_encoded_raises_helpful_error_for_invalid_string():
33+
with pytest.raises(
34+
ValueError,
35+
match=r"EncodedCudaVersion\.from_encoded\(\) expected an integer or decimal string, got '13\.2'",
36+
):
37+
toolkit_info.EncodedCudaVersion.from_encoded("13.2")
38+
39+
40+
@pytest.mark.parametrize("encoded", [-1, "-1"])
41+
def test_encoded_cuda_version_from_encoded_rejects_negative_values(encoded):
42+
with pytest.raises(
43+
ValueError,
44+
match=r"EncodedCudaVersion\.from_encoded\(\) expected a non-negative encoded CUDA version, got -1",
45+
):
46+
toolkit_info.EncodedCudaVersion.from_encoded(encoded)
47+
48+
1649
def test_parse_cuda_header_version_returns_parsed_dataclass():
1750
header_text = """
1851
#ifndef CUDA_H
@@ -28,6 +61,17 @@ def test_parse_cuda_header_version_returns_parsed_dataclass():
2861
)
2962

3063

64+
def test_cuda_toolkit_version_from_encoded_returns_subclass_instance():
65+
version = toolkit_info.CudaToolkitVersion.from_encoded(12090)
66+
67+
assert version == toolkit_info.CudaToolkitVersion(
68+
encoded=12090,
69+
major=12,
70+
minor=9,
71+
)
72+
assert type(version) is toolkit_info.CudaToolkitVersion
73+
74+
3175
def test_parse_cuda_header_version_returns_none_when_macro_is_missing():
3276
header_text = """
3377
#ifndef CUDA_H

0 commit comments

Comments
 (0)