Skip to content

Commit 772451b

Browse files
committed
Finalize the pathfinder CUDA driver version query API.
Expose DriverCudaVersion, QueryDriverCudaVersionError, and query_driver_cuda_version publicly, and align the internal naming, caching, docs, and test coverage around the CUDA-specific driver version query. Made-with: Cursor
1 parent 1369c17 commit 772451b

File tree

5 files changed

+60
-18
lines changed

5 files changed

+60
-18
lines changed

cuda_pathfinder/cuda/pathfinder/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
from cuda.pathfinder._static_libs.find_static_lib import (
6060
locate_static_lib as locate_static_lib,
6161
)
62+
from cuda.pathfinder._utils.driver_info import DriverCudaVersion as DriverCudaVersion
63+
from cuda.pathfinder._utils.driver_info import QueryDriverCudaVersionError as QueryDriverCudaVersionError
64+
from cuda.pathfinder._utils.driver_info import query_driver_cuda_version as query_driver_cuda_version
6265
from cuda.pathfinder._utils.env_vars import get_cuda_path_or_home as get_cuda_path_or_home
6366

6467
from cuda.pathfinder._version import __version__ # isort: skip

cuda_pathfinder/cuda/pathfinder/_utils/driver_info.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
import ctypes
7+
import functools
78
from collections.abc import Callable
89
from dataclasses import dataclass
910

@@ -13,6 +14,10 @@
1314
from cuda.pathfinder._utils.platform_aware import IS_WINDOWS
1415

1516

17+
class QueryDriverCudaVersionError(RuntimeError):
18+
"""Raised when ``query_driver_cuda_version()`` cannot determine the CUDA driver version."""
19+
20+
1621
@dataclass(frozen=True, slots=True)
1722
class DriverCudaVersion:
1823
"""
@@ -38,17 +43,21 @@ class DriverCudaVersion:
3843
minor: int
3944

4045

41-
def query_driver_version() -> DriverCudaVersion:
46+
@functools.cache
47+
def query_driver_cuda_version() -> DriverCudaVersion:
4248
"""Return the CUDA driver version parsed into its major/minor components."""
43-
encoded = _query_driver_version_int()
44-
return DriverCudaVersion(
45-
encoded=encoded,
46-
major=encoded // 1000,
47-
minor=(encoded % 1000) // 10,
48-
)
49+
try:
50+
encoded = _query_driver_cuda_version_int()
51+
return DriverCudaVersion(
52+
encoded=encoded,
53+
major=encoded // 1000,
54+
minor=(encoded % 1000) // 10,
55+
)
56+
except Exception as exc:
57+
raise QueryDriverCudaVersionError("Failed to query the CUDA driver version.") from exc
4958

5059

51-
def _query_driver_version_int() -> int:
60+
def _query_driver_cuda_version_int() -> int:
5261
"""Return the encoded CUDA driver version from ``cuDriverGetVersion()``."""
5362
loaded_cuda = _load_nvidia_dynamic_lib("cuda")
5463
if IS_WINDOWS:

cuda_pathfinder/docs/source/api.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ CUDA bitcode and static libraries.
1818

1919
get_cuda_path_or_home
2020

21+
DriverCudaVersion
22+
QueryDriverCudaVersionError
23+
query_driver_cuda_version
24+
2125
SUPPORTED_NVIDIA_LIBNAMES
2226
load_nvidia_dynamic_lib
2327
LoadedDL

cuda_pathfinder/tests/test_driver_lib_loading.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,17 +160,19 @@ def raise_child_process_failed():
160160
assert os.path.isfile(abs_path)
161161

162162

163-
def test_real_query_driver_version(info_summary_append):
163+
def test_real_query_driver_cuda_version(info_summary_append):
164164
driver_info._load_nvidia_dynamic_lib.cache_clear()
165+
driver_info.query_driver_cuda_version.cache_clear()
165166
try:
166-
version = driver_info.query_driver_version()
167-
except Exception as exc:
167+
version = driver_info.query_driver_cuda_version()
168+
except driver_info.QueryDriverCudaVersionError as exc:
168169
if STRICTNESS == "all_must_work":
169170
raise
170171
info_summary_append(f"driver version unavailable: {exc.__class__.__name__}: {exc}")
171172
return
172173
finally:
173174
driver_info._load_nvidia_dynamic_lib.cache_clear()
175+
driver_info.query_driver_cuda_version.cache_clear()
174176

175177
info_summary_append(f"driver_version={version.major}.{version.minor} (encoded={version.encoded})")
176178
assert version.encoded > 0

cuda_pathfinder/tests/test_utils_driver_info.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99
from cuda.pathfinder._utils import driver_info
1010

1111

12+
@pytest.fixture(autouse=True)
13+
def _clear_driver_cuda_version_query_cache():
14+
driver_info.query_driver_cuda_version.cache_clear()
15+
yield
16+
driver_info.query_driver_cuda_version.cache_clear()
17+
18+
1219
class _FakeCuDriverGetVersion:
1320
def __init__(self, *, status: int, version: int):
1421
self.argtypes = None
@@ -35,7 +42,7 @@ def _loaded_cuda(abs_path: str) -> LoadedDL:
3542
)
3643

3744

38-
def test_query_driver_version_uses_windll_on_windows(monkeypatch):
45+
def test_query_driver_cuda_version_uses_windll_on_windows(monkeypatch):
3946
fake_driver_lib = _FakeDriverLib(status=0, version=12080)
4047
loaded_paths: list[str] = []
4148

@@ -52,26 +59,43 @@ def fake_windll(abs_path: str):
5259

5360
monkeypatch.setattr(driver_info.ctypes, "WinDLL", fake_windll, raising=False)
5461

55-
assert driver_info._query_driver_version_int() == 12080
62+
assert driver_info._query_driver_cuda_version_int() == 12080
5663
assert loaded_paths == [r"C:\Windows\System32\nvcuda.dll"]
5764

5865

59-
def test_query_driver_version_returns_parsed_dataclass(monkeypatch):
60-
monkeypatch.setattr(driver_info, "_query_driver_version_int", lambda: 12080)
66+
def test_query_driver_cuda_version_returns_parsed_dataclass(monkeypatch):
67+
monkeypatch.setattr(driver_info, "_query_driver_cuda_version_int", lambda: 12080)
6168

62-
assert driver_info.query_driver_version() == driver_info.DriverCudaVersion(
69+
assert driver_info.query_driver_cuda_version() == driver_info.DriverCudaVersion(
6370
encoded=12080,
6471
major=12,
6572
minor=8,
6673
)
6774

6875

69-
def test_query_driver_version_int_raises_when_cuda_call_fails(monkeypatch):
76+
def test_query_driver_cuda_version_wraps_internal_failures(monkeypatch):
77+
root_cause = RuntimeError("low-level query failed")
78+
79+
def fail_query_driver_cuda_version_int() -> int:
80+
raise root_cause
81+
82+
monkeypatch.setattr(driver_info, "_query_driver_cuda_version_int", fail_query_driver_cuda_version_int)
83+
84+
with pytest.raises(
85+
driver_info.QueryDriverCudaVersionError,
86+
match="Failed to query the CUDA driver version",
87+
) as exc_info:
88+
driver_info.query_driver_cuda_version()
89+
90+
assert exc_info.value.__cause__ is root_cause
91+
92+
93+
def test_query_driver_cuda_version_int_raises_when_cuda_call_fails(monkeypatch):
7094
fake_driver_lib = _FakeDriverLib(status=1, version=0)
7195

7296
monkeypatch.setattr(driver_info, "IS_WINDOWS", False)
7397
monkeypatch.setattr(driver_info, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_cuda("/usr/lib/libcuda.so.1"))
7498
monkeypatch.setattr(driver_info.ctypes, "CDLL", lambda _abs_path: fake_driver_lib)
7599

76100
with pytest.raises(RuntimeError, match=r"cuDriverGetVersion\(\) \(status=1\)"):
77-
driver_info._query_driver_version_int()
101+
driver_info._query_driver_cuda_version_int()

0 commit comments

Comments
 (0)