Skip to content

Commit e2a0909

Browse files
committed
Allow strict guard rails for driver libraries.
Treat driver-packaged libraries as compatibility-neutral so strict mode can load NVML and other driver libs without a raw fallback, while CTK-backed artifacts remain the only items that establish and enforce the process-wide CTK anchor. Made-with: Cursor
1 parent 0b15665 commit e2a0909

2 files changed

Lines changed: 93 additions & 4 deletions

File tree

cuda_pathfinder/cuda/pathfinder/_compatibility_guard_rails.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,8 @@ def _enforce_supported_packaging(self, item: ResolvedItem) -> None:
459459
return
460460
raise CompatibilityInsufficientMetadataError(
461461
"v1 compatibility checks only give definitive answers for "
462-
f"packaged_with='ctk' items. {item.describe()} is packaged_with={item.packaged_with!r}."
462+
f"packaged_with='ctk' items, plus compatibility-neutral driver libraries. "
463+
f"{item.describe()} is packaged_with={item.packaged_with!r}."
463464
)
464465

465466
def _enforce_ctk_metadata(self, item: ResolvedItem) -> None:
@@ -485,9 +486,10 @@ def _enforce_constraints(self, item: ResolvedItem) -> None:
485486
)
486487

487488
def _anchor_item(self) -> ResolvedItem | None:
488-
if not self._resolved_items:
489-
return None
490-
return self._resolved_items[0]
489+
for item in self._resolved_items:
490+
if item.packaged_with == "ctk":
491+
return item
492+
return None
491493

492494
def _remember(self, item: ResolvedItem) -> None:
493495
if item not in self._resolved_items:
@@ -498,6 +500,12 @@ def _reset_for_testing(self) -> None:
498500
self._resolved_items.clear()
499501

500502
def _register_and_check(self, item: ResolvedItem) -> None:
503+
# Driver libraries come from the installed display driver rather than a
504+
# CUDA Toolkit line, so they do not need CTK metadata and must not lock
505+
# the process-wide CTK anchor.
506+
if item.packaged_with == "driver":
507+
self._remember(item)
508+
return
501509
self._enforce_supported_packaging(item)
502510
self._enforce_ctk_metadata(item)
503511
self._enforce_constraints(item)

cuda_pathfinder/tests/test_compatibility_guard_rails.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,30 @@ def test_public_apis_route_through_process_wide_guard_rails(
174174
assert fake_guard_rails.calls == [(guard_rails_method_name, args)]
175175

176176

177+
def test_public_driver_libs_are_allowed_in_strict_mode(monkeypatch, tmp_path):
178+
driver_lib_path = _touch(tmp_path / "driver-root" / "libnvidia-ml.so.1")
179+
180+
monkeypatch.setattr(
181+
compatibility_module,
182+
"_load_nvidia_dynamic_lib",
183+
lambda _libname: _loaded_dl(driver_lib_path, found_via="system-search"),
184+
)
185+
monkeypatch.setattr(
186+
pathfinder,
187+
"process_wide_compatibility_guard_rails",
188+
CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000)),
189+
)
190+
191+
def fail_raw_fallback(_libname: str) -> LoadedDL:
192+
pytest.fail("strict mode must not fall back to raw loading")
193+
194+
monkeypatch.setattr(process_wide_module, "_load_nvidia_dynamic_lib", fail_raw_fallback)
195+
196+
loaded = pathfinder.load_nvidia_dynamic_lib("nvml")
197+
198+
assert loaded.abs_path == driver_lib_path
199+
200+
177201
@pytest.mark.parametrize("env_value", [None, ""])
178202
def test_public_apis_default_to_strict_when_env_var_is_unset_or_empty(monkeypatch, tmp_path, env_value):
179203
lib_path = _touch(tmp_path / "no-version-json" / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
@@ -360,6 +384,63 @@ def test_other_packaging_raises_insufficient_metadata(monkeypatch, tmp_path):
360384
guard_rails.find_bitcode_lib("nvshmem_device")
361385

362386

387+
def test_driver_libs_do_not_lock_ctk_anchor(monkeypatch, tmp_path):
388+
driver_lib_path = _touch(tmp_path / "driver-root" / "libnvidia-ml.so.1")
389+
ctk_root = tmp_path / "cuda-12.9"
390+
_write_version_json(ctk_root, "12.9.20250531")
391+
ctk_lib_path = _touch(ctk_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
392+
393+
def fake_load_nvidia_dynamic_lib(libname: str) -> LoadedDL:
394+
if libname == "nvml":
395+
return _loaded_dl(driver_lib_path, found_via="system-search")
396+
if libname == "nvrtc":
397+
return _loaded_dl(ctk_lib_path)
398+
raise AssertionError(f"Unexpected libname: {libname!r}")
399+
400+
monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", fake_load_nvidia_dynamic_lib)
401+
402+
guard_rails = CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000))
403+
404+
driver_loaded = guard_rails.load_nvidia_dynamic_lib("nvml")
405+
ctk_loaded = guard_rails.load_nvidia_dynamic_lib("nvrtc")
406+
407+
assert driver_loaded.abs_path == driver_lib_path
408+
assert ctk_loaded.abs_path == ctk_lib_path
409+
410+
411+
def test_driver_libs_do_not_mask_later_ctk_mismatch(monkeypatch, tmp_path):
412+
driver_lib_path = _touch(tmp_path / "driver-root" / "libnvidia-ml.so.1")
413+
lib_root = tmp_path / "cuda-12.8"
414+
hdr_root = tmp_path / "cuda-12.9"
415+
_write_version_json(lib_root, "12.8.20250303")
416+
_write_version_json(hdr_root, "12.9.20250531")
417+
418+
lib_path = _touch(lib_root / "targets" / "x86_64-linux" / "lib" / "libnvrtc.so.12")
419+
hdr_dir = hdr_root / "targets" / "x86_64-linux" / "include"
420+
_touch(hdr_dir / "nvrtc.h")
421+
422+
def fake_load_nvidia_dynamic_lib(libname: str) -> LoadedDL:
423+
if libname == "nvml":
424+
return _loaded_dl(driver_lib_path, found_via="system-search")
425+
if libname == "nvrtc":
426+
return _loaded_dl(lib_path)
427+
raise AssertionError(f"Unexpected libname: {libname!r}")
428+
429+
monkeypatch.setattr(compatibility_module, "_load_nvidia_dynamic_lib", fake_load_nvidia_dynamic_lib)
430+
monkeypatch.setattr(
431+
compatibility_module,
432+
"_locate_nvidia_header_directory",
433+
lambda _libname: LocatedHeaderDir(abs_path=str(hdr_dir), found_via="CUDA_PATH"),
434+
)
435+
436+
guard_rails = CompatibilityGuardRails(driver_cuda_version=_driver_cuda_version(13000))
437+
guard_rails.load_nvidia_dynamic_lib("nvml")
438+
guard_rails.load_nvidia_dynamic_lib("nvrtc")
439+
440+
with pytest.raises(CompatibilityCheckError, match="exact CTK major.minor match"):
441+
guard_rails.find_nvidia_header_directory("nvrtc")
442+
443+
363444
def test_constraints_accept_string_and_tuple_forms(monkeypatch, tmp_path):
364445
ctk_root = tmp_path / "cuda-12.9"
365446
_write_version_json(ctk_root, "12.9.20250531")

0 commit comments

Comments
 (0)