Skip to content

Commit 8900cd2

Browse files
committed
Make real guard-rails tests query the driver version.
This keeps the host-backed compatibility checks aligned with the actual driver state instead of relying on a fixed encoded version in the real-environment tests. Made-with: Cursor
1 parent 298888e commit 8900cd2

2 files changed

Lines changed: 30 additions & 5 deletions

File tree

cuda_pathfinder/tests/local_helpers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from cuda.pathfinder._headers.find_nvidia_headers import (
1313
locate_nvidia_header_directory as locate_nvidia_header_directory_raw,
1414
)
15+
from cuda.pathfinder._utils import driver_info
1516
from cuda.pathfinder._utils.toolkit_info import CudaToolkitVersion, read_cuda_header_version
1617

1718

@@ -68,3 +69,11 @@ def require_real_cuda_toolkit_version_from_cuda_h() -> LocatedRealCudaToolkitVer
6869
header_dir=located.abs_path,
6970
found_via=located.found_via,
7071
)
72+
73+
74+
def require_real_driver_cuda_version() -> driver_info.DriverCudaVersion:
75+
"""Return the real-host CUDA driver version or skip if it cannot be queried."""
76+
try:
77+
return driver_info.query_driver_cuda_version()
78+
except driver_info.QueryDriverCudaVersionError as exc:
79+
pytest.skip(f"Could not query the CUDA driver version for a real driver installation: {exc}")

cuda_pathfinder/tests/test_compatibility_guard_rails.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
locate_nvidia_header_directory as locate_nvidia_header_directory_raw,
2828
)
2929
from cuda.pathfinder._utils.env_vars import get_cuda_path_or_home
30+
from cuda.pathfinder._utils import driver_info
3031
from cuda.pathfinder._utils.driver_info import DriverCudaVersion, QueryDriverCudaVersionError
3132
from cuda.pathfinder._utils.toolkit_info import read_cuda_header_version
3233
from local_helpers import (
3334
locate_real_cuda_toolkit_version_from_cuda_h,
35+
require_real_driver_cuda_version,
3436
require_real_cuda_toolkit_version_from_cuda_h,
3537
)
3638

@@ -46,18 +48,22 @@ def _default_process_wide_guard_rails_mode(monkeypatch):
4648

4749

4850
@pytest.fixture
49-
def clear_real_cuda_h_probe_caches():
51+
def clear_real_host_probe_caches():
5052
locate_real_cuda_toolkit_version_from_cuda_h.cache_clear()
5153
locate_nvidia_header_directory_raw.cache_clear()
5254
_resolve_system_loaded_abs_path_in_subprocess.cache_clear()
5355
get_cuda_path_or_home.cache_clear()
5456
read_cuda_header_version.cache_clear()
57+
driver_info._load_nvidia_dynamic_lib.cache_clear()
58+
driver_info.query_driver_cuda_version.cache_clear()
5559
yield
5660
locate_real_cuda_toolkit_version_from_cuda_h.cache_clear()
5761
locate_nvidia_header_directory_raw.cache_clear()
5862
_resolve_system_loaded_abs_path_in_subprocess.cache_clear()
5963
get_cuda_path_or_home.cache_clear()
6064
read_cuda_header_version.cache_clear()
65+
driver_info._load_nvidia_dynamic_lib.cache_clear()
66+
driver_info.query_driver_cuda_version.cache_clear()
6167

6268

6369
def _write_cuda_h(
@@ -678,17 +684,22 @@ def test_find_nvidia_header_directory_returns_none_when_unresolved(monkeypatch):
678684
assert guard_rails.find_nvidia_header_directory("nvrtc") is None
679685

680686

681-
@pytest.mark.usefixtures("clear_real_cuda_h_probe_caches")
687+
@pytest.mark.usefixtures("clear_real_host_probe_caches")
682688
def test_real_wheel_ctk_items_are_compatible(info_summary_append):
683689
real_ctk = require_real_cuda_toolkit_version_from_cuda_h()
690+
real_driver = require_real_driver_cuda_version()
684691
info_summary_append(
685692
f"real cuda.h CTK version={real_ctk.version.major}.{real_ctk.version.minor} "
686693
f"via {real_ctk.found_via} at {real_ctk.cuda_h_path!r}"
687694
)
695+
info_summary_append(
696+
"real driver CUDA version="
697+
f"{real_driver.major}.{real_driver.minor} (encoded={real_driver.encoded})"
698+
)
688699
guard_rails = CompatibilityGuardRails(
689700
ctk_major=real_ctk.version.major,
690701
ctk_minor=real_ctk.version.minor,
691-
driver_cuda_version=_driver_cuda_version(13000),
702+
driver_cuda_version=real_driver,
692703
)
693704

694705
try:
@@ -722,17 +733,22 @@ def test_real_wheel_ctk_items_are_compatible(info_summary_append):
722733
_assert_real_ctk_backed_path(path)
723734

724735

725-
@pytest.mark.usefixtures("clear_real_cuda_h_probe_caches")
736+
@pytest.mark.usefixtures("clear_real_host_probe_caches")
726737
def test_real_wheel_component_version_does_not_override_ctk_line(info_summary_append):
727738
real_ctk = require_real_cuda_toolkit_version_from_cuda_h()
739+
real_driver = require_real_driver_cuda_version()
728740
info_summary_append(
729741
f"real cuda.h CTK version={real_ctk.version.major}.{real_ctk.version.minor} "
730742
f"via {real_ctk.found_via} at {real_ctk.cuda_h_path!r}"
731743
)
744+
info_summary_append(
745+
"real driver CUDA version="
746+
f"{real_driver.major}.{real_driver.minor} (encoded={real_driver.encoded})"
747+
)
732748
guard_rails = CompatibilityGuardRails(
733749
ctk_major=real_ctk.version.major,
734750
ctk_minor=real_ctk.version.minor,
735-
driver_cuda_version=_driver_cuda_version(13000),
751+
driver_cuda_version=real_driver,
736752
)
737753

738754
try:

0 commit comments

Comments
 (0)