Skip to content

Commit 0c3803e

Browse files
committed
Merge branch 'pathfinder_driver_info' into pathfinder_with_compatibility_checks_v0.
Bring the shared CUDA driver version query helper, public exports, docs, and tests into the compatibility-check branch before wiring the wrapper to that API. Made-with: Cursor
2 parents 9961248 + 44e3ba1 commit 0c3803e

5 files changed

Lines changed: 221 additions & 0 deletions

File tree

cuda_pathfinder/cuda/pathfinder/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@
6868
from cuda.pathfinder._static_libs.find_static_lib import (
6969
locate_static_lib as locate_static_lib,
7070
)
71+
from cuda.pathfinder._utils.driver_info import DriverCudaVersion as DriverCudaVersion
72+
from cuda.pathfinder._utils.driver_info import QueryDriverCudaVersionError as QueryDriverCudaVersionError
73+
from cuda.pathfinder._utils.driver_info import query_driver_cuda_version as query_driver_cuda_version
7174
from cuda.pathfinder._utils.env_vars import get_cuda_path_or_home as get_cuda_path_or_home
7275

7376
from cuda.pathfinder._version import __version__ # isort: skip
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
import ctypes
7+
import functools
8+
from collections.abc import Callable
9+
from dataclasses import dataclass
10+
11+
from cuda.pathfinder._dynamic_libs.load_nvidia_dynamic_lib import (
12+
load_nvidia_dynamic_lib as _load_nvidia_dynamic_lib,
13+
)
14+
from cuda.pathfinder._utils.platform_aware import IS_WINDOWS
15+
16+
17+
class QueryDriverCudaVersionError(RuntimeError):
18+
"""Raised when ``query_driver_cuda_version()`` cannot determine the CUDA driver version."""
19+
20+
21+
@dataclass(frozen=True, slots=True)
22+
class DriverCudaVersion:
23+
"""
24+
CUDA-facing driver version reported by ``cuDriverGetVersion()``.
25+
26+
The name ``DriverCudaVersion`` is intentionally specific: this dataclass
27+
models the version shown as ``CUDA Version`` in ``nvidia-smi``, not the
28+
graphics driver release shown as ``Driver Version``.
29+
30+
Example ``nvidia-smi`` output::
31+
32+
+---------------------------------------------------------------------+
33+
| NVIDIA-SMI 595.58.03 Driver Version: 595.58.03 CUDA Version: 13.2 |
34+
+---------------------------------------------------------------------+
35+
36+
For the example above, ``DriverCudaVersion(encoded=13020, major=13,
37+
minor=2)`` corresponds to ``CUDA Version: 13.2``. It does not correspond
38+
to ``Driver Version: 595.58.03``.
39+
"""
40+
41+
encoded: int
42+
major: int
43+
minor: int
44+
45+
46+
@functools.cache
47+
def query_driver_cuda_version() -> DriverCudaVersion:
48+
"""Return the CUDA driver version parsed into its major/minor components."""
49+
try:
50+
encoded = _query_driver_cuda_version_int()
51+
return DriverCudaVersion(
52+
encoded=encoded,
53+
major=encoded // 1000,
54+
minor=(encoded % 1000) // 10,
55+
)
56+
except Exception as exc:
57+
raise QueryDriverCudaVersionError("Failed to query the CUDA driver version.") from exc
58+
59+
60+
def _query_driver_cuda_version_int() -> int:
61+
"""Return the encoded CUDA driver version from ``cuDriverGetVersion()``."""
62+
loaded_cuda = _load_nvidia_dynamic_lib("cuda")
63+
if IS_WINDOWS:
64+
# `ctypes.WinDLL` exists on Windows at runtime. The ignore is only for
65+
# Linux mypy runs, where the platform stubs do not define that attribute.
66+
loader_cls: Callable[[str], ctypes.CDLL] = ctypes.WinDLL # type: ignore[attr-defined]
67+
else:
68+
loader_cls = ctypes.CDLL
69+
driver_lib = loader_cls(loaded_cuda.abs_path)
70+
cu_driver_get_version = driver_lib.cuDriverGetVersion
71+
cu_driver_get_version.argtypes = [ctypes.POINTER(ctypes.c_int)]
72+
cu_driver_get_version.restype = ctypes.c_int
73+
version = ctypes.c_int()
74+
status = cu_driver_get_version(ctypes.byref(version))
75+
if status != 0:
76+
raise RuntimeError(f"Failed to query CUDA driver version via cuDriverGetVersion() (status={status}).")
77+
return version.value

cuda_pathfinder/docs/source/api.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ CUDA bitcode and static libraries.
2222
CompatibilityCheckError
2323
CompatibilityInsufficientMetadataError
2424

25+
DriverCudaVersion
26+
QueryDriverCudaVersionError
27+
query_driver_cuda_version
28+
2529
SUPPORTED_NVIDIA_LIBNAMES
2630
load_nvidia_dynamic_lib
2731
LoadedDL

cuda_pathfinder/tests/test_driver_lib_loading.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
_load_lib_no_cache,
2626
)
2727
from cuda.pathfinder._dynamic_libs.subprocess_protocol import STATUS_NOT_FOUND, parse_dynamic_lib_subprocess_payload
28+
from cuda.pathfinder._utils import driver_info
2829
from cuda.pathfinder._utils.platform_aware import IS_WINDOWS, quote_for_shell
2930

3031
STRICTNESS = os.environ.get("CUDA_PATHFINDER_TEST_LOAD_NVIDIA_DYNAMIC_LIB_STRICTNESS", "see_what_works")
@@ -157,3 +158,23 @@ def raise_child_process_failed():
157158
assert abs_path is not None
158159
info_summary_append(f"abs_path={quote_for_shell(abs_path)}")
159160
assert os.path.isfile(abs_path)
161+
162+
163+
def test_real_query_driver_cuda_version(info_summary_append):
164+
driver_info._load_nvidia_dynamic_lib.cache_clear()
165+
driver_info.query_driver_cuda_version.cache_clear()
166+
try:
167+
version = driver_info.query_driver_cuda_version()
168+
except driver_info.QueryDriverCudaVersionError as exc:
169+
if STRICTNESS == "all_must_work":
170+
raise
171+
info_summary_append(f"driver version unavailable: {exc.__class__.__name__}: {exc}")
172+
return
173+
finally:
174+
driver_info._load_nvidia_dynamic_lib.cache_clear()
175+
driver_info.query_driver_cuda_version.cache_clear()
176+
177+
info_summary_append(f"driver_version={version.major}.{version.minor} (encoded={version.encoded})")
178+
assert version.encoded > 0
179+
assert version.major == version.encoded // 1000
180+
assert version.minor == (version.encoded % 1000) // 10
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import ctypes
5+
6+
import pytest
7+
8+
from cuda.pathfinder import (
9+
DriverCudaVersion as PublicDriverCudaVersion,
10+
)
11+
from cuda.pathfinder import (
12+
QueryDriverCudaVersionError as PublicQueryDriverCudaVersionError,
13+
)
14+
from cuda.pathfinder import (
15+
query_driver_cuda_version as public_query_driver_cuda_version,
16+
)
17+
from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL
18+
from cuda.pathfinder._utils import driver_info
19+
20+
21+
@pytest.fixture(autouse=True)
22+
def _clear_driver_cuda_version_query_cache():
23+
driver_info.query_driver_cuda_version.cache_clear()
24+
yield
25+
driver_info.query_driver_cuda_version.cache_clear()
26+
27+
28+
class _FakeCuDriverGetVersion:
29+
def __init__(self, *, status: int, version: int):
30+
self.argtypes = None
31+
self.restype = None
32+
self._status = status
33+
self._version = version
34+
35+
def __call__(self, version_ptr) -> int:
36+
ctypes.cast(version_ptr, ctypes.POINTER(ctypes.c_int)).contents.value = self._version
37+
return self._status
38+
39+
40+
class _FakeDriverLib:
41+
def __init__(self, *, status: int, version: int):
42+
self.cuDriverGetVersion = _FakeCuDriverGetVersion(status=status, version=version)
43+
44+
45+
def _loaded_cuda(abs_path: str) -> LoadedDL:
46+
return LoadedDL(
47+
abs_path=abs_path,
48+
was_already_loaded_from_elsewhere=False,
49+
_handle_uint=0xBEEF,
50+
found_via="system-search",
51+
)
52+
53+
54+
def test_driver_cuda_version_public_api_exports():
55+
assert PublicDriverCudaVersion is driver_info.DriverCudaVersion
56+
assert PublicQueryDriverCudaVersionError is driver_info.QueryDriverCudaVersionError
57+
assert public_query_driver_cuda_version is driver_info.query_driver_cuda_version
58+
59+
60+
def test_query_driver_cuda_version_uses_windll_on_windows(monkeypatch):
61+
fake_driver_lib = _FakeDriverLib(status=0, version=12080)
62+
loaded_paths: list[str] = []
63+
64+
monkeypatch.setattr(driver_info, "IS_WINDOWS", True)
65+
monkeypatch.setattr(
66+
driver_info,
67+
"_load_nvidia_dynamic_lib",
68+
lambda _libname: _loaded_cuda(r"C:\Windows\System32\nvcuda.dll"),
69+
)
70+
71+
def fake_windll(abs_path: str):
72+
loaded_paths.append(abs_path)
73+
return fake_driver_lib
74+
75+
monkeypatch.setattr(driver_info.ctypes, "WinDLL", fake_windll, raising=False)
76+
77+
assert driver_info._query_driver_cuda_version_int() == 12080
78+
assert loaded_paths == [r"C:\Windows\System32\nvcuda.dll"]
79+
80+
81+
def test_query_driver_cuda_version_returns_parsed_dataclass(monkeypatch):
82+
monkeypatch.setattr(driver_info, "_query_driver_cuda_version_int", lambda: 12080)
83+
84+
assert driver_info.query_driver_cuda_version() == driver_info.DriverCudaVersion(
85+
encoded=12080,
86+
major=12,
87+
minor=8,
88+
)
89+
90+
91+
def test_query_driver_cuda_version_wraps_internal_failures(monkeypatch):
92+
root_cause = RuntimeError("low-level query failed")
93+
94+
def fail_query_driver_cuda_version_int() -> int:
95+
raise root_cause
96+
97+
monkeypatch.setattr(driver_info, "_query_driver_cuda_version_int", fail_query_driver_cuda_version_int)
98+
99+
with pytest.raises(
100+
driver_info.QueryDriverCudaVersionError,
101+
match="Failed to query the CUDA driver version",
102+
) as exc_info:
103+
driver_info.query_driver_cuda_version()
104+
105+
assert exc_info.value.__cause__ is root_cause
106+
107+
108+
def test_query_driver_cuda_version_int_raises_when_cuda_call_fails(monkeypatch):
109+
fake_driver_lib = _FakeDriverLib(status=1, version=0)
110+
111+
monkeypatch.setattr(driver_info, "IS_WINDOWS", False)
112+
monkeypatch.setattr(driver_info, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_cuda("/usr/lib/libcuda.so.1"))
113+
monkeypatch.setattr(driver_info.ctypes, "CDLL", lambda _abs_path: fake_driver_lib)
114+
115+
with pytest.raises(RuntimeError, match=r"cuDriverGetVersion\(\) \(status=1\)"):
116+
driver_info._query_driver_cuda_version_int()

0 commit comments

Comments
 (0)