Skip to content

Commit c6c38e3

Browse files
committed
Use cuda.h for CTK guard-rails metadata.
Replace version.json-based CTK root metadata with cuda.h parsing so compatibility checks use a simpler, more universal toolkit source while preserving wheel-based metadata inference. Made-with: Cursor
1 parent e3b402a commit c6c38e3

2 files changed

Lines changed: 105 additions & 67 deletions

File tree

cuda_pathfinder/cuda/pathfinder/_compatibility_guard_rails.py

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import functools
77
import importlib.metadata
8-
import json
98
import os
109
import re
1110
from collections.abc import Mapping
@@ -46,6 +45,7 @@
4645
QueryDriverCudaVersionError,
4746
query_driver_cuda_version,
4847
)
48+
from cuda.pathfinder._utils.toolkit_info import ReadCudaHeaderVersionError, read_cuda_header_version
4949

5050
ItemKind: TypeAlias = str
5151
PackagedWith: TypeAlias = str
@@ -323,46 +323,62 @@ def _wheel_metadata_for_abs_path(abs_path: str) -> CtkMetadata | None:
323323
return CtkMetadata(ctk_version=ctk_version, ctk_root=None, source=source)
324324

325325

326+
def _normalized_ctk_root_for_cuda_header(cuda_header_path: Path) -> Path:
327+
ctk_root = cuda_header_path.parent.parent
328+
if ctk_root.parent.name == "targets":
329+
return ctk_root.parent.parent
330+
return ctk_root
331+
332+
326333
@functools.cache
327-
def _read_ctk_version(ctk_root: str) -> CtkVersion | None:
328-
version_json_path = os.path.join(ctk_root, "version.json")
329-
if not os.path.isfile(version_json_path):
330-
return None
331-
with open(version_json_path, encoding="utf-8") as fobj:
332-
payload = json.load(fobj)
333-
if not isinstance(payload, dict):
334-
return None
335-
cuda_entry = payload.get("cuda")
336-
if not isinstance(cuda_entry, dict):
334+
def _cuda_header_metadata_for_ctk_root_candidate(ctk_root_candidate: str) -> CtkMetadata | None:
335+
candidate_path = Path(ctk_root_candidate)
336+
header_paths: list[Path] = []
337+
338+
direct_header = candidate_path / "include" / "cuda.h"
339+
if direct_header.is_file():
340+
header_paths.append(direct_header)
341+
342+
targets_dir = candidate_path / "targets"
343+
if targets_dir.is_dir():
344+
header_paths.extend(sorted(path for path in targets_dir.glob("*/include/cuda.h") if path.is_file()))
345+
346+
matches: list[tuple[CtkVersion, Path, Path]] = []
347+
for cuda_header_path in header_paths:
348+
try:
349+
version = read_cuda_header_version(str(cuda_header_path))
350+
except ReadCudaHeaderVersionError:
351+
continue
352+
matches.append(
353+
(
354+
CtkVersion(major=version.major, minor=version.minor),
355+
_normalized_ctk_root_for_cuda_header(cuda_header_path),
356+
cuda_header_path,
357+
)
358+
)
359+
360+
if not matches:
337361
return None
338-
cuda_version = cuda_entry.get("version")
339-
if not isinstance(cuda_version, str):
362+
363+
ctk_version, ctk_root, source_path = matches[0]
364+
if any(other_version != ctk_version for other_version, _other_root, _other_source in matches[1:]):
340365
return None
341-
return _parse_ctk_version(cuda_version)
366+
367+
return CtkMetadata(
368+
ctk_version=ctk_version,
369+
ctk_root=str(ctk_root),
370+
source=f"cuda.h at {source_path}",
371+
)
342372

343373

344-
def _find_enclosing_ctk_root(abs_path: str) -> str | None:
374+
def _ctk_metadata_for_abs_path(abs_path: str) -> CtkMetadata | None:
345375
current = Path(abs_path)
346376
if current.is_file():
347377
current = current.parent
348378
for candidate in (current, *current.parents):
349-
ctk_root = str(candidate)
350-
if _read_ctk_version(ctk_root) is not None:
351-
return ctk_root
352-
return None
353-
354-
355-
def _ctk_metadata_for_abs_path(abs_path: str) -> CtkMetadata | None:
356-
ctk_root = _find_enclosing_ctk_root(abs_path)
357-
if ctk_root is not None:
358-
ctk_version = _read_ctk_version(ctk_root)
359-
if ctk_version is not None:
360-
version_json_path = os.path.join(ctk_root, "version.json")
361-
return CtkMetadata(
362-
ctk_version=ctk_version,
363-
ctk_root=ctk_root,
364-
source=f"version.json at {version_json_path}",
365-
)
379+
ctk_metadata = _cuda_header_metadata_for_ctk_root_candidate(str(candidate))
380+
if ctk_metadata is not None:
381+
return ctk_metadata
366382
return _wheel_metadata_for_abs_path(abs_path)
367383

368384

@@ -468,7 +484,7 @@ def compatibility_check(
468484
status="insufficient_metadata",
469485
message=(
470486
"v1 compatibility checks require either an enclosing CUDA Toolkit root "
471-
"with version.json or wheel metadata that can be traced to an installed "
487+
"with cuda.h or wheel metadata that can be traced to an installed "
472488
f"cuda-toolkit distribution. Could not determine the CTK version for {item.describe()}."
473489
),
474490
)
@@ -545,7 +561,7 @@ def _enforce_ctk_metadata(self, item: ResolvedItem) -> None:
545561
return
546562
raise CompatibilityInsufficientMetadataError(
547563
"v1 compatibility checks require either an enclosing CUDA Toolkit root "
548-
"with version.json or wheel metadata that can be traced to an installed "
564+
"with cuda.h or wheel metadata that can be traced to an installed "
549565
f"cuda-toolkit distribution. Could not determine the CTK version for {item.describe()}."
550566
)
551567

cuda_pathfinder/tests/test_compatibility_guard_rails.py

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import importlib
5-
import json
65
import os
76
from pathlib import Path
87

@@ -36,10 +35,22 @@ def _default_process_wide_guard_rails_mode(monkeypatch):
3635
monkeypatch.delenv(COMPATIBILITY_GUARD_RAILS_ENV_VAR, raising=False)
3736

3837

39-
def _write_version_json(ctk_root: Path, toolkit_version: str) -> None:
40-
ctk_root.mkdir(parents=True, exist_ok=True)
41-
payload = {"cuda": {"version": toolkit_version}}
42-
(ctk_root / "version.json").write_text(json.dumps(payload), encoding="utf-8")
38+
def _write_cuda_h(
39+
ctk_root: Path,
40+
toolkit_version: str,
41+
*,
42+
include_dir_parts: tuple[str, ...] = ("targets", "x86_64-linux", "include"),
43+
) -> None:
44+
parts = toolkit_version.split(".")
45+
if len(parts) < 2:
46+
raise AssertionError(f"Expected at least major.minor in toolkit version, got {toolkit_version!r}")
47+
encoded = int(parts[0]) * 1000 + int(parts[1]) * 10
48+
cuda_h_path = ctk_root.joinpath(*include_dir_parts, "cuda.h")
49+
cuda_h_path.parent.mkdir(parents=True, exist_ok=True)
50+
cuda_h_path.write_text(
51+
f"#ifndef CUDA_H\n#define CUDA_H\n#define CUDA_VERSION {encoded}\n#endif\n",
52+
encoding="utf-8",
53+
)
4354

4455

4556
def _touch(path: Path) -> str:
@@ -76,11 +87,7 @@ def _located_bitcode_lib(name: str, abs_path: str) -> LocatedBitcodeLib:
7687

7788

7889
def _driver_cuda_version(encoded: int) -> DriverCudaVersion:
79-
return DriverCudaVersion(
80-
encoded=encoded,
81-
major=encoded // 1000,
82-
minor=(encoded % 1000) // 10,
83-
)
90+
return DriverCudaVersion.from_encoded(encoded)
8491

8592

8693
class _FakeDistribution:
@@ -111,8 +118,9 @@ def _assert_real_ctk_backed_path(path: str) -> None:
111118
if current.is_file():
112119
current = current.parent
113120
for candidate in (current, *current.parents):
114-
version_json_path = candidate / "version.json"
115-
if version_json_path.is_file():
121+
if (candidate / "include" / "cuda.h").is_file():
122+
return
123+
if any(path.is_file() for path in (candidate / "targets").glob("*/include/cuda.h")):
116124
return
117125
for env_var in ("CUDA_PATH", "CUDA_HOME"):
118126
ctk_root = os.environ.get(env_var)
@@ -122,7 +130,7 @@ def _assert_real_ctk_backed_path(path: str) -> None:
122130
if os.path.commonpath((norm_path, norm_ctk_root)) == norm_ctk_root:
123131
return
124132
raise AssertionError(
125-
"Expected a site-packages path, a path under a CTK root with version.json, "
133+
"Expected a site-packages path, a path under a CTK root with cuda.h, "
126134
f"or a path under CUDA_PATH/CUDA_HOME, got {path!r}"
127135
)
128136

@@ -220,7 +228,7 @@ def fail_raw_fallback(_libname: str) -> LoadedDL:
220228

221229
@pytest.mark.parametrize("env_value", [None, ""])
222230
def test_public_apis_default_to_strict_when_env_var_is_unset_or_empty(monkeypatch, tmp_path, env_value):
223-
lib_path = _touch(tmp_path / "no-version-json" / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
231+
lib_path = _touch(tmp_path / "no-cuda-h" / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
224232

225233
monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path))
226234
monkeypatch.setattr(
@@ -238,12 +246,12 @@ def fail_raw_fallback(_libname: str) -> LoadedDL:
238246
else:
239247
monkeypatch.setenv(COMPATIBILITY_GUARD_RAILS_ENV_VAR, env_value)
240248

241-
with pytest.raises(CompatibilityInsufficientMetadataError, match="version.json"):
249+
with pytest.raises(CompatibilityInsufficientMetadataError, match="cuda.h"):
242250
pathfinder.load_nvidia_dynamic_lib("nvrtc")
243251

244252

245253
def test_public_apis_best_effort_fall_back_on_insufficient_metadata(monkeypatch, tmp_path):
246-
guarded_lib_path = _touch(tmp_path / "no-version-json" / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
254+
guarded_lib_path = _touch(tmp_path / "no-cuda-h" / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
247255
raw_loaded = _loaded_dl("/opt/mock/libnvrtc.so.12", found_via="system-search")
248256

249257
monkeypatch.setenv(COMPATIBILITY_GUARD_RAILS_ENV_VAR, "best_effort")
@@ -292,8 +300,8 @@ def test_public_apis_reject_invalid_guard_rails_mode(monkeypatch):
292300
def test_public_apis_share_process_wide_guard_rails_state(monkeypatch, tmp_path):
293301
lib_root = tmp_path / "cuda-12.8"
294302
hdr_root = tmp_path / "cuda-12.9"
295-
_write_version_json(lib_root, "12.8.20250303")
296-
_write_version_json(hdr_root, "12.9.20250531")
303+
_write_cuda_h(lib_root, "12.8.20250303")
304+
_write_cuda_h(hdr_root, "12.9.20250531")
297305

298306
lib_path = _touch(lib_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
299307
hdr_dir = hdr_root / "targets" / "x86_64-linux" / "include"
@@ -320,7 +328,7 @@ def test_public_apis_share_process_wide_guard_rails_state(monkeypatch, tmp_path)
320328

321329
def test_load_dynamic_lib_then_find_headers_same_ctk_version(monkeypatch, tmp_path):
322330
ctk_root = tmp_path / "cuda-12.9"
323-
_write_version_json(ctk_root, "12.9.20250531")
331+
_write_cuda_h(ctk_root, "12.9.20250531")
324332
lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
325333
hdr_dir = ctk_root / "targets" / "x86_64-linux" / "include"
326334
_touch(hdr_dir / "nvrtc.h")
@@ -344,8 +352,8 @@ def test_load_dynamic_lib_then_find_headers_same_ctk_version(monkeypatch, tmp_pa
344352
def test_exact_ctk_major_minor_match_is_required(monkeypatch, tmp_path):
345353
lib_root = tmp_path / "cuda-12.8"
346354
hdr_root = tmp_path / "cuda-12.9"
347-
_write_version_json(lib_root, "12.8.20250303")
348-
_write_version_json(hdr_root, "12.9.20250531")
355+
_write_cuda_h(lib_root, "12.8.20250303")
356+
_write_cuda_h(hdr_root, "12.9.20250531")
349357

350358
lib_path = _touch(lib_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
351359
hdr_dir = hdr_root / "targets" / "x86_64-linux" / "include"
@@ -367,7 +375,7 @@ def test_exact_ctk_major_minor_match_is_required(monkeypatch, tmp_path):
367375

368376
def test_driver_major_must_not_be_older_than_ctk_major(monkeypatch, tmp_path):
369377
ctk_root = tmp_path / "cuda-13.0"
370-
_write_version_json(ctk_root, "13.0.20251003")
378+
_write_cuda_h(ctk_root, "13.0.20251003")
371379
lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.13")
372380

373381
monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path))
@@ -378,17 +386,31 @@ def test_driver_major_must_not_be_older_than_ctk_major(monkeypatch, tmp_path):
378386
guard_rails.load_nvidia_dynamic_lib("nvrtc")
379387

380388

381-
def test_missing_version_json_raises_insufficient_metadata(monkeypatch, tmp_path):
382-
lib_path = _touch(tmp_path / "no-version-json" / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
389+
def test_missing_cuda_h_raises_insufficient_metadata(monkeypatch, tmp_path):
390+
lib_path = _touch(tmp_path / "no-cuda-h" / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
383391

384392
monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path))
385393

386394
guard_rails = CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000))
387395

388-
with pytest.raises(CompatibilityInsufficientMetadataError, match="version.json"):
396+
with pytest.raises(CompatibilityInsufficientMetadataError, match="cuda.h"):
389397
guard_rails.load_nvidia_dynamic_lib("nvrtc")
390398

391399

400+
def test_windows_style_ctk_root_uses_root_include_cuda_h(monkeypatch, tmp_path):
401+
ctk_root = tmp_path / "cuda-13.2"
402+
_write_cuda_h(ctk_root, "13.2.20251003", include_dir_parts=("include",))
403+
lib_path = _touch(ctk_root / "bin" / "x64" / "nvrtc64_130_0.dll")
404+
405+
monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path))
406+
407+
guard_rails = CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000))
408+
409+
loaded = guard_rails.load_nvidia_dynamic_lib("nvrtc")
410+
411+
assert loaded.abs_path == lib_path
412+
413+
392414
def test_other_packaging_raises_insufficient_metadata(monkeypatch, tmp_path):
393415
abs_path = _touch(tmp_path / "site-packages" / "nvidia" / "nvshmem" / "lib" / "libnvshmem_device.bc")
394416

@@ -407,7 +429,7 @@ def test_other_packaging_raises_insufficient_metadata(monkeypatch, tmp_path):
407429
def test_driver_libs_do_not_lock_ctk_anchor(monkeypatch, tmp_path):
408430
driver_lib_path = _touch(tmp_path / "driver-root" / "libnvidia-ml.so.1")
409431
ctk_root = tmp_path / "cuda-12.9"
410-
_write_version_json(ctk_root, "12.9.20250531")
432+
_write_cuda_h(ctk_root, "12.9.20250531")
411433
ctk_lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
412434

413435
def fake_load_nvidia_dynamic_lib(libname: str) -> LoadedDL:
@@ -432,8 +454,8 @@ def test_driver_libs_do_not_mask_later_ctk_mismatch(monkeypatch, tmp_path):
432454
driver_lib_path = _touch(tmp_path / "driver-root" / "libnvidia-ml.so.1")
433455
lib_root = tmp_path / "cuda-12.8"
434456
hdr_root = tmp_path / "cuda-12.9"
435-
_write_version_json(lib_root, "12.8.20250303")
436-
_write_version_json(hdr_root, "12.9.20250531")
457+
_write_cuda_h(lib_root, "12.8.20250303")
458+
_write_cuda_h(hdr_root, "12.9.20250531")
437459

438460
lib_path = _touch(lib_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
439461
hdr_dir = hdr_root / "targets" / "x86_64-linux" / "include"
@@ -506,7 +528,7 @@ def test_wheel_metadata_accepts_exact_and_range_requirements(monkeypatch, tmp_pa
506528

507529
def test_constraints_accept_string_and_tuple_forms(monkeypatch, tmp_path):
508530
ctk_root = tmp_path / "cuda-12.9"
509-
_write_version_json(ctk_root, "12.9.20250531")
531+
_write_cuda_h(ctk_root, "12.9.20250531")
510532
lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
511533

512534
monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path))
@@ -524,7 +546,7 @@ def test_constraints_accept_string_and_tuple_forms(monkeypatch, tmp_path):
524546

525547
def test_constraint_failure_raises(monkeypatch, tmp_path):
526548
ctk_root = tmp_path / "cuda-12.9"
527-
_write_version_json(ctk_root, "12.9.20250531")
549+
_write_cuda_h(ctk_root, "12.9.20250531")
528550
lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
529551

530552
monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path))
@@ -541,7 +563,7 @@ def test_constraint_failure_raises(monkeypatch, tmp_path):
541563

542564
def test_static_bitcode_and_binary_methods_participate_in_checks(monkeypatch, tmp_path):
543565
ctk_root = tmp_path / "cuda-12.9"
544-
_write_version_json(ctk_root, "12.9.20250531")
566+
_write_cuda_h(ctk_root, "12.9.20250531")
545567

546568
lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
547569
static_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libcudadevrt.a")
@@ -575,7 +597,7 @@ def test_static_bitcode_and_binary_methods_participate_in_checks(monkeypatch, tm
575597

576598
def test_guard_rails_query_driver_cuda_version_by_default(monkeypatch, tmp_path):
577599
ctk_root = tmp_path / "cuda-12.9"
578-
_write_version_json(ctk_root, "12.9.20250531")
600+
_write_cuda_h(ctk_root, "12.9.20250531")
579601
lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
580602

581603
query_calls: list[int] = []
@@ -598,7 +620,7 @@ def fake_query_driver_cuda_version() -> DriverCudaVersion:
598620

599621
def test_guard_rails_wrap_driver_query_failures(monkeypatch, tmp_path):
600622
ctk_root = tmp_path / "cuda-12.9"
601-
_write_version_json(ctk_root, "12.9.20250531")
623+
_write_cuda_h(ctk_root, "12.9.20250531")
602624
lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
603625

604626
monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", lambda _libname: _loaded_dl(lib_path))

0 commit comments

Comments
 (0)