Skip to content

Commit 298888e

Browse files
committed
Make real guard-rails tests derive their CTK line from cuda.h.
This keeps the host-backed compatibility checks aligned with the resolved toolkit layout and skips cleanly when cudart headers or cuda.h are unavailable. Made-with: Cursor
1 parent c6c38e3 commit 298888e

2 files changed

Lines changed: 95 additions & 4 deletions

File tree

cuda_pathfinder/tests/local_helpers.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,25 @@
44
import functools
55
import importlib.metadata
66
import re
7+
from dataclasses import dataclass
8+
from pathlib import Path
9+
10+
import pytest
11+
12+
from cuda.pathfinder._headers.find_nvidia_headers import (
13+
locate_nvidia_header_directory as locate_nvidia_header_directory_raw,
14+
)
15+
from cuda.pathfinder._utils.toolkit_info import CudaToolkitVersion, read_cuda_header_version
16+
17+
18+
@dataclass(frozen=True, slots=True)
19+
class LocatedRealCudaToolkitVersion:
20+
"""Real-host CTK version discovered from ``cuda.h`` next to resolved ``cudart`` headers."""
21+
22+
version: CudaToolkitVersion
23+
cuda_h_path: str
24+
header_dir: str
25+
found_via: str
726

827

928
@functools.cache
@@ -14,3 +33,38 @@ def have_distribution(name_pattern: str) -> bool:
1433
for dist in importlib.metadata.distributions()
1534
if "Name" in dist.metadata
1635
)
36+
37+
38+
@functools.cache
39+
def locate_real_cuda_toolkit_version_from_cuda_h() -> LocatedRealCudaToolkitVersion | None:
40+
"""Return the real-host CTK version from ``cuda.h`` if ``cudart`` headers can be located."""
41+
located = locate_nvidia_header_directory_raw("cudart")
42+
if located is None or located.abs_path is None:
43+
return None
44+
cuda_h_path = Path(located.abs_path) / "cuda.h"
45+
if not cuda_h_path.is_file():
46+
return None
47+
return LocatedRealCudaToolkitVersion(
48+
version=read_cuda_header_version(str(cuda_h_path)),
49+
cuda_h_path=str(cuda_h_path),
50+
header_dir=located.abs_path,
51+
found_via=located.found_via,
52+
)
53+
54+
55+
def require_real_cuda_toolkit_version_from_cuda_h() -> LocatedRealCudaToolkitVersion:
56+
"""Return the real-host CTK version from ``cuda.h`` or skip if it cannot be located."""
57+
located = locate_nvidia_header_directory_raw("cudart")
58+
if located is None or located.abs_path is None:
59+
pytest.skip("Could not locate cudart headers, so could not find cuda.h for a real CTK installation.")
60+
cuda_h_path = Path(located.abs_path) / "cuda.h"
61+
if not cuda_h_path.is_file():
62+
pytest.skip(
63+
f"Located cudart headers via {located.found_via} at {located.abs_path!r}, but could not find cuda.h."
64+
)
65+
return LocatedRealCudaToolkitVersion(
66+
version=read_cuda_header_version(str(cuda_h_path)),
67+
cuda_h_path=str(cuda_h_path),
68+
header_dir=located.abs_path,
69+
found_via=located.found_via,
70+
)

cuda_pathfinder/tests/test_compatibility_guard_rails.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,17 @@
2222
StaticLibNotFoundError,
2323
process_wide_compatibility_guard_rails,
2424
)
25+
from cuda.pathfinder._dynamic_libs.load_nvidia_dynamic_lib import _resolve_system_loaded_abs_path_in_subprocess
26+
from cuda.pathfinder._headers.find_nvidia_headers import (
27+
locate_nvidia_header_directory as locate_nvidia_header_directory_raw,
28+
)
29+
from cuda.pathfinder._utils.env_vars import get_cuda_path_or_home
2530
from cuda.pathfinder._utils.driver_info import DriverCudaVersion, QueryDriverCudaVersionError
31+
from cuda.pathfinder._utils.toolkit_info import read_cuda_header_version
32+
from local_helpers import (
33+
locate_real_cuda_toolkit_version_from_cuda_h,
34+
require_real_cuda_toolkit_version_from_cuda_h,
35+
)
2636

2737
STRICTNESS = os.environ.get("CUDA_PATHFINDER_TEST_COMPATIBILITY_GUARD_RAILS_STRICTNESS", "see_what_works")
2838
assert STRICTNESS in ("see_what_works", "all_must_work")
@@ -35,6 +45,21 @@ def _default_process_wide_guard_rails_mode(monkeypatch):
3545
monkeypatch.delenv(COMPATIBILITY_GUARD_RAILS_ENV_VAR, raising=False)
3646

3747

48+
@pytest.fixture
49+
def clear_real_cuda_h_probe_caches():
50+
locate_real_cuda_toolkit_version_from_cuda_h.cache_clear()
51+
locate_nvidia_header_directory_raw.cache_clear()
52+
_resolve_system_loaded_abs_path_in_subprocess.cache_clear()
53+
get_cuda_path_or_home.cache_clear()
54+
read_cuda_header_version.cache_clear()
55+
yield
56+
locate_real_cuda_toolkit_version_from_cuda_h.cache_clear()
57+
locate_nvidia_header_directory_raw.cache_clear()
58+
_resolve_system_loaded_abs_path_in_subprocess.cache_clear()
59+
get_cuda_path_or_home.cache_clear()
60+
read_cuda_header_version.cache_clear()
61+
62+
3863
def _write_cuda_h(
3964
ctk_root: Path,
4065
toolkit_version: str,
@@ -653,10 +678,16 @@ def test_find_nvidia_header_directory_returns_none_when_unresolved(monkeypatch):
653678
assert guard_rails.find_nvidia_header_directory("nvrtc") is None
654679

655680

681+
@pytest.mark.usefixtures("clear_real_cuda_h_probe_caches")
656682
def test_real_wheel_ctk_items_are_compatible(info_summary_append):
683+
real_ctk = require_real_cuda_toolkit_version_from_cuda_h()
684+
info_summary_append(
685+
f"real cuda.h CTK version={real_ctk.version.major}.{real_ctk.version.minor} "
686+
f"via {real_ctk.found_via} at {real_ctk.cuda_h_path!r}"
687+
)
657688
guard_rails = CompatibilityGuardRails(
658-
ctk_major=13,
659-
ctk_minor=2,
689+
ctk_major=real_ctk.version.major,
690+
ctk_minor=real_ctk.version.minor,
660691
driver_cuda_version=_driver_cuda_version(13000),
661692
)
662693

@@ -691,10 +722,16 @@ def test_real_wheel_ctk_items_are_compatible(info_summary_append):
691722
_assert_real_ctk_backed_path(path)
692723

693724

725+
@pytest.mark.usefixtures("clear_real_cuda_h_probe_caches")
694726
def test_real_wheel_component_version_does_not_override_ctk_line(info_summary_append):
727+
real_ctk = require_real_cuda_toolkit_version_from_cuda_h()
728+
info_summary_append(
729+
f"real cuda.h CTK version={real_ctk.version.major}.{real_ctk.version.minor} "
730+
f"via {real_ctk.found_via} at {real_ctk.cuda_h_path!r}"
731+
)
695732
guard_rails = CompatibilityGuardRails(
696-
ctk_major=13,
697-
ctk_minor=2,
733+
ctk_major=real_ctk.version.major,
734+
ctk_minor=real_ctk.version.minor,
698735
driver_cuda_version=_driver_cuda_version(13000),
699736
)
700737

0 commit comments

Comments
 (0)