Skip to content

Commit 0eab1ab

Browse files
committed
Use the shared pathfinder driver query in compatibility checks.
Route `WithCompatibilityChecks` through `query_driver_cuda_version()` so the wrapper reuses the public driver-info helper and preserves a compatibility-layer error when the implicit driver query fails. Made-with: Cursor
1 parent 0c3803e commit 0eab1ab

2 files changed

Lines changed: 39 additions & 31 deletions

File tree

cuda_pathfinder/cuda/pathfinder/_compatibility_checks.py

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33

44
from __future__ import annotations
55

6-
import ctypes
76
import functools
87
import importlib.metadata
98
import json
109
import os
1110
import re
12-
from collections.abc import Callable, Mapping
11+
from collections.abc import Mapping
1312
from dataclasses import dataclass
1413
from pathlib import Path
1514
from typing import TypeAlias, cast
@@ -42,7 +41,10 @@
4241
from cuda.pathfinder._static_libs.find_static_lib import (
4342
locate_static_lib as _locate_static_lib,
4443
)
45-
from cuda.pathfinder._utils.platform_aware import IS_WINDOWS
44+
from cuda.pathfinder._utils.driver_info import (
45+
QueryDriverCudaVersionError,
46+
query_driver_cuda_version,
47+
)
4648

4749
ItemKind: TypeAlias = str
4850
PackagedWith: TypeAlias = str
@@ -429,30 +431,6 @@ def compatibility_check(driver_version: int, item1: ResolvedItem, item2: Resolve
429431
)
430432

431433

432-
def _query_driver_version() -> int:
433-
loaded_cuda = _load_nvidia_dynamic_lib("cuda")
434-
if loaded_cuda.abs_path is None:
435-
raise CompatibilityCheckError('Could not determine an absolute path for the driver library "cuda".')
436-
if IS_WINDOWS:
437-
loader_cls_obj = vars(ctypes).get("WinDLL")
438-
if loader_cls_obj is None:
439-
raise CompatibilityCheckError("ctypes.WinDLL is unavailable on this platform.")
440-
loader_cls = cast(Callable[[str], ctypes.CDLL], loader_cls_obj)
441-
else:
442-
loader_cls = ctypes.CDLL
443-
driver_lib = loader_cls(loaded_cuda.abs_path)
444-
cu_driver_get_version = driver_lib.cuDriverGetVersion
445-
cu_driver_get_version.argtypes = [ctypes.POINTER(ctypes.c_int)]
446-
cu_driver_get_version.restype = ctypes.c_int
447-
version = ctypes.c_int()
448-
status = cu_driver_get_version(ctypes.byref(version))
449-
if status != 0:
450-
raise CompatibilityCheckError(
451-
f"Failed to query CUDA driver version via cuDriverGetVersion() (status={status})."
452-
)
453-
return version.value
454-
455-
456434
class WithCompatibilityChecks:
457435
"""Resolve CUDA artifacts while enforcing minimal v1 compatibility guard rails."""
458436

@@ -470,7 +448,12 @@ def __init__(
470448

471449
def _get_driver_version(self) -> int:
472450
if self._driver_version is None:
473-
self._driver_version = _query_driver_version()
451+
try:
452+
self._driver_version = query_driver_cuda_version().encoded
453+
except QueryDriverCudaVersionError as exc:
454+
raise CompatibilityCheckError(
455+
"Failed to query the CUDA driver version needed for compatibility checks."
456+
) from exc
474457
return self._driver_version
475458

476459
def _enforce_supported_packaging(self, item: ResolvedItem) -> None:

cuda_pathfinder/tests/test_with_compatibility_checks.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
BitcodeLibNotFoundError,
1313
CompatibilityCheckError,
1414
CompatibilityInsufficientMetadataError,
15+
DriverCudaVersion,
1516
DynamicLibNotFoundError,
1617
LoadedDL,
1718
LocatedBitcodeLib,
1819
LocatedHeaderDir,
1920
LocatedStaticLib,
21+
QueryDriverCudaVersionError,
2022
StaticLibNotFoundError,
2123
WithCompatibilityChecks,
2224
)
@@ -252,11 +254,11 @@ def test_wrapper_queries_driver_version_by_default(monkeypatch, tmp_path):
252254

253255
monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path))
254256

255-
def fake_query_driver_version() -> int:
257+
def fake_query_driver_cuda_version() -> DriverCudaVersion:
256258
query_calls.append(1)
257-
return 13000
259+
return DriverCudaVersion(encoded=13000, major=13, minor=0)
258260

259-
monkeypatch.setattr(compatibility_module, "_query_driver_version", fake_query_driver_version)
261+
monkeypatch.setattr(compatibility_module, "query_driver_cuda_version", fake_query_driver_cuda_version)
260262

261263
pfchecks = WithCompatibilityChecks()
262264

@@ -266,6 +268,29 @@ def fake_query_driver_version() -> int:
266268
assert len(query_calls) == 1
267269

268270

271+
def test_wrapper_wraps_driver_query_failures(monkeypatch, tmp_path):
272+
ctk_root = tmp_path / "cuda-12.9"
273+
_write_version_json(ctk_root, "12.9.20250531")
274+
lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
275+
276+
monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path))
277+
278+
def fail_query_driver_cuda_version() -> DriverCudaVersion:
279+
raise QueryDriverCudaVersionError("driver query failed")
280+
281+
monkeypatch.setattr(compatibility_module, "query_driver_cuda_version", fail_query_driver_cuda_version)
282+
283+
pfchecks = WithCompatibilityChecks()
284+
285+
with pytest.raises(
286+
CompatibilityCheckError,
287+
match="Failed to query the CUDA driver version needed for compatibility checks",
288+
) as exc_info:
289+
pfchecks.load_nvidia_dynamic_lib("nvrtc")
290+
291+
assert isinstance(exc_info.value.__cause__, QueryDriverCudaVersionError)
292+
293+
269294
def test_find_nvidia_header_directory_returns_none_when_unresolved(monkeypatch):
270295
monkeypatch.setattr(
271296
compatibility_module,

0 commit comments

Comments
 (0)