Skip to content

Commit 10ac5b5

Browse files
rwgkcursoragent
andcommitted
pathfinder: add display-driver release metadata and NVML query helper
Introduce ``DriverReleaseVersion`` and ``query_driver_release_version()`` in ``cuda.pathfinder._utils.driver_info`` so callers can read the display-driver release shown as ``Driver Version`` in ``nvidia-smi`` (for example ``595.58.03``, branch ``595``). The helper queries ``nvmlSystemGetDriverVersion()`` via ``ctypes`` against the NVML library loaded through pathfinder's existing dynamic-lib loader, and parses the result into a frozen dataclass exposing both the raw text and the branch number used by NVIDIA's published minor-version compatibility tables. Tests cover the dataclass parser, the cache-clear lifecycle, the ``QueryDriverReleaseVersionError`` wrapping, and a fake-NVML end-to-end path that asserts the shutdown call always runs. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent e2da47a commit 10ac5b5

2 files changed

Lines changed: 213 additions & 0 deletions

File tree

cuda_pathfinder/cuda/pathfinder/_utils/driver_info.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import ctypes
77
import functools
8+
import re
89
from collections.abc import Callable
910
from dataclasses import dataclass
1011
from typing import cast
@@ -15,11 +16,19 @@
1516
from cuda.pathfinder._utils.platform_aware import IS_WINDOWS
1617
from cuda.pathfinder._utils.toolkit_info import EncodedCudaVersion
1718

19+
_NVML_SUCCESS = 0
20+
_NVML_SYSTEM_DRIVER_VERSION_BUFFER_LENGTH = 80
21+
_DRIVER_RELEASE_VERSION_RE = re.compile(r"^\d+(?:\.\d+){1,2}$")
22+
1823

1924
class QueryDriverCudaVersionError(RuntimeError):
2025
"""Raised when ``query_driver_cuda_version()`` cannot determine the CUDA driver version."""
2126

2227

28+
class QueryDriverReleaseVersionError(RuntimeError):
29+
"""Raised when ``query_driver_release_version()`` cannot determine the display-driver release version."""
30+
31+
2332
@dataclass(frozen=True, slots=True)
2433
class DriverCudaVersion(EncodedCudaVersion):
2534
"""
@@ -44,6 +53,37 @@ class DriverCudaVersion(EncodedCudaVersion):
4453
"""
4554

4655

56+
@dataclass(frozen=True, slots=True)
57+
class DriverReleaseVersion:
58+
"""
59+
Display-driver release version shown as ``Driver Version`` in ``nvidia-smi``.
60+
61+
Example ``nvidia-smi`` output::
62+
63+
+---------------------------------------------------------------------+
64+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
65+
+---------------------------------------------------------------------+
66+
67+
For the example above, ``DriverReleaseVersion(text="595.58.03",
68+
components=(595, 58, 3), branch=595)`` corresponds to ``Driver Version:
69+
595.58.03``. The ``branch`` field is the first numeric component because
70+
NVIDIA's compatibility docs publish minimum display-driver requirements in
71+
branch form such as ``>= 580`` for CUDA 13.x minor-version compatibility.
72+
"""
73+
74+
text: str
75+
components: tuple[int, ...]
76+
branch: int
77+
78+
@classmethod
79+
def from_text(cls, text: str) -> DriverReleaseVersion:
80+
normalized_text = text.strip()
81+
if not _DRIVER_RELEASE_VERSION_RE.fullmatch(normalized_text):
82+
raise ValueError(f"Invalid driver release version text: {text!r}")
83+
components = tuple(int(component) for component in normalized_text.split("."))
84+
return cls(text=normalized_text, components=components, branch=components[0])
85+
86+
4787
@functools.cache
4888
def query_driver_cuda_version() -> DriverCudaVersion:
4989
"""Return the CUDA driver version parsed into its major/minor components."""
@@ -54,6 +94,15 @@ def query_driver_cuda_version() -> DriverCudaVersion:
5494
raise QueryDriverCudaVersionError("Failed to query the CUDA driver version.") from exc
5595

5696

97+
@functools.cache
98+
def query_driver_release_version() -> DriverReleaseVersion:
99+
"""Return the display-driver release version parsed into branch/components."""
100+
try:
101+
return DriverReleaseVersion.from_text(_query_driver_release_version_text())
102+
except Exception as exc:
103+
raise QueryDriverReleaseVersionError("Failed to query the display-driver release version.") from exc
104+
105+
57106
def _query_driver_cuda_version_int() -> int:
58107
"""Return the encoded CUDA driver version from ``cuDriverGetVersion()``."""
59108
loaded_cuda = _load_nvidia_dynamic_lib("cuda")
@@ -72,3 +121,44 @@ def _query_driver_cuda_version_int() -> int:
72121
if status != 0:
73122
raise RuntimeError(f"Failed to query CUDA driver version via cuDriverGetVersion() (status={status}).")
74123
return version.value
124+
125+
126+
def _query_driver_release_version_text() -> str:
127+
"""Return the display-driver release version from ``nvmlSystemGetDriverVersion()``."""
128+
loaded_nvml = _load_nvidia_dynamic_lib("nvml")
129+
nvml_lib = ctypes.CDLL(loaded_nvml.abs_path)
130+
131+
nvml_init_v2 = nvml_lib.nvmlInit_v2
132+
nvml_init_v2.argtypes = []
133+
nvml_init_v2.restype = ctypes.c_int
134+
135+
nvml_system_get_driver_version = nvml_lib.nvmlSystemGetDriverVersion
136+
nvml_system_get_driver_version.argtypes = [ctypes.POINTER(ctypes.c_char), ctypes.c_uint]
137+
nvml_system_get_driver_version.restype = ctypes.c_int
138+
139+
nvml_shutdown = nvml_lib.nvmlShutdown
140+
nvml_shutdown.argtypes = []
141+
nvml_shutdown.restype = ctypes.c_int
142+
143+
init_status = nvml_init_v2()
144+
if init_status != _NVML_SUCCESS:
145+
raise RuntimeError(f"Failed to initialize NVML via nvmlInit_v2() (status={init_status}).")
146+
147+
try:
148+
version_buffer = ctypes.create_string_buffer(_NVML_SYSTEM_DRIVER_VERSION_BUFFER_LENGTH)
149+
status = nvml_system_get_driver_version(version_buffer, _NVML_SYSTEM_DRIVER_VERSION_BUFFER_LENGTH)
150+
if status != _NVML_SUCCESS:
151+
raise RuntimeError(
152+
f"Failed to query driver release version via nvmlSystemGetDriverVersion() (status={status})."
153+
)
154+
release_version = version_buffer.value.decode()
155+
except BaseException as exc:
156+
shutdown_status = nvml_shutdown()
157+
if shutdown_status != _NVML_SUCCESS:
158+
raise RuntimeError(f"Failed to shut down NVML via nvmlShutdown() (status={shutdown_status}).") from exc
159+
raise
160+
161+
shutdown_status = nvml_shutdown()
162+
if shutdown_status != _NVML_SUCCESS:
163+
raise RuntimeError(f"Failed to shut down NVML via nvmlShutdown() (status={shutdown_status}).")
164+
return release_version

cuda_pathfinder/tests/test_utils_driver_info.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
@pytest.fixture(autouse=True)
1313
def _clear_driver_cuda_version_query_cache():
1414
driver_info.query_driver_cuda_version.cache_clear()
15+
driver_info.query_driver_release_version.cache_clear()
1516
yield
1617
driver_info.query_driver_cuda_version.cache_clear()
18+
driver_info.query_driver_release_version.cache_clear()
1719

1820

1921
class _FakeCuDriverGetVersion:
@@ -33,6 +35,47 @@ def __init__(self, *, status: int, version: int):
3335
self.cuDriverGetVersion = _FakeCuDriverGetVersion(status=status, version=version)
3436

3537

38+
class _FakeNvmlFunction:
39+
def __init__(self, func):
40+
self.argtypes = None
41+
self.restype = None
42+
self._func = func
43+
44+
def __call__(self, *args):
45+
return self._func(*args)
46+
47+
48+
class _FakeNvmlLib:
49+
def __init__(
50+
self,
51+
*,
52+
init_status: int = 0,
53+
driver_release_version: str = "595.58.03",
54+
query_status: int = 0,
55+
shutdown_statuses: tuple[int, ...] = (0,),
56+
):
57+
self.shutdown_calls = 0
58+
remaining_shutdown_statuses = list(shutdown_statuses)
59+
60+
self.nvmlInit_v2 = _FakeNvmlFunction(lambda: init_status)
61+
62+
def nvml_system_get_driver_version(version_buffer, _buffer_length) -> int:
63+
if query_status != 0:
64+
return query_status
65+
version_buffer.value = driver_release_version.encode()
66+
return 0
67+
68+
self.nvmlSystemGetDriverVersion = _FakeNvmlFunction(nvml_system_get_driver_version)
69+
70+
def nvml_shutdown() -> int:
71+
self.shutdown_calls += 1
72+
if remaining_shutdown_statuses:
73+
return remaining_shutdown_statuses.pop(0)
74+
return 0
75+
76+
self.nvmlShutdown = _FakeNvmlFunction(nvml_shutdown)
77+
78+
3679
def _loaded_cuda(abs_path: str) -> LoadedDL:
3780
return LoadedDL(
3881
abs_path=abs_path,
@@ -42,6 +85,86 @@ def _loaded_cuda(abs_path: str) -> LoadedDL:
4285
)
4386

4487

88+
def _loaded_nvml(abs_path: str) -> LoadedDL:
89+
return LoadedDL(
90+
abs_path=abs_path,
91+
was_already_loaded_from_elsewhere=False,
92+
_handle_uint=0xCAFE,
93+
found_via="system-search",
94+
)
95+
96+
97+
def test_driver_release_version_from_text_parses_branch():
98+
assert driver_info.DriverReleaseVersion.from_text("595.58.03") == driver_info.DriverReleaseVersion(
99+
text="595.58.03",
100+
components=(595, 58, 3),
101+
branch=595,
102+
)
103+
104+
105+
def test_query_driver_release_version_returns_parsed_dataclass(monkeypatch):
106+
monkeypatch.setattr(driver_info, "_query_driver_release_version_text", lambda: "595.58.03")
107+
108+
assert driver_info.query_driver_release_version() == driver_info.DriverReleaseVersion(
109+
text="595.58.03",
110+
components=(595, 58, 3),
111+
branch=595,
112+
)
113+
114+
115+
def test_query_driver_release_version_wraps_internal_failures(monkeypatch):
116+
root_cause = RuntimeError("low-level release query failed")
117+
118+
def fail_query_driver_release_version_text() -> str:
119+
raise root_cause
120+
121+
monkeypatch.setattr(driver_info, "_query_driver_release_version_text", fail_query_driver_release_version_text)
122+
123+
with pytest.raises(
124+
driver_info.QueryDriverReleaseVersionError,
125+
match="Failed to query the display-driver release version",
126+
) as exc_info:
127+
driver_info.query_driver_release_version()
128+
129+
assert exc_info.value.__cause__ is root_cause
130+
131+
132+
def test_query_driver_release_version_text_uses_nvml(monkeypatch):
133+
fake_nvml_lib = _FakeNvmlLib(driver_release_version="595.58.03")
134+
loaded_paths: list[str] = []
135+
136+
monkeypatch.setattr(
137+
driver_info,
138+
"_load_nvidia_dynamic_lib",
139+
lambda _libname: _loaded_nvml("/usr/lib/libnvidia-ml.so.1"),
140+
)
141+
142+
def fake_cdll(abs_path: str):
143+
loaded_paths.append(abs_path)
144+
return fake_nvml_lib
145+
146+
monkeypatch.setattr(driver_info.ctypes, "CDLL", fake_cdll)
147+
148+
assert driver_info._query_driver_release_version_text() == "595.58.03"
149+
assert loaded_paths == ["/usr/lib/libnvidia-ml.so.1"]
150+
assert fake_nvml_lib.shutdown_calls == 1
151+
152+
153+
def test_query_driver_release_version_text_raises_when_nvml_call_fails(monkeypatch):
154+
fake_nvml_lib = _FakeNvmlLib(query_status=1)
155+
156+
monkeypatch.setattr(
157+
driver_info,
158+
"_load_nvidia_dynamic_lib",
159+
lambda _libname: _loaded_nvml("/usr/lib/libnvidia-ml.so.1"),
160+
)
161+
monkeypatch.setattr(driver_info.ctypes, "CDLL", lambda _abs_path: fake_nvml_lib)
162+
163+
with pytest.raises(RuntimeError, match=r"nvmlSystemGetDriverVersion\(\) \(status=1\)"):
164+
driver_info._query_driver_release_version_text()
165+
assert fake_nvml_lib.shutdown_calls == 1
166+
167+
45168
def test_query_driver_cuda_version_uses_windll_on_windows(monkeypatch):
46169
fake_driver_lib = _FakeDriverLib(status=0, version=12080)
47170
loaded_paths: list[str] = []

0 commit comments

Comments
 (0)