Skip to content

Commit fc963ff

Browse files
rwgkcursoragent
andcommitted
pathfinder: refine guard-rails env-var validation, binary classification, and reset naming
* Defer the platform check in CUDA_PATHFINDER_DRIVER_COMPATIBILITY to after the CUDA_PATHFINDER_COMPATIBILITY_GUARD_RAILS=off short-circuit so users who turn guard rails off entirely are no longer forced to also unset the override on non-Linux platforms. The value-validation RuntimeError still fires unconditionally so typos are caught early. * Move the binary packaged_with mapping next to the binary registry as SUPPORTED_BINARIES_PACKAGED_WITH and reclassify nsys / nsight-sys / ncu / nsight-compute as packaged_with="other" so strict-mode lookups for separately packaged Nsight tools no longer raise misleading "missing CTK metadata" errors. * Rename CompatibilityGuardRails._reset_for_testing to _reset_state and document that production cache_clear callers also drive it; configured driver overrides are intentionally re-applied while lazily-queried values are dropped. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent c3ddf79 commit fc963ff

5 files changed

Lines changed: 50 additions & 7 deletions

File tree

cuda_pathfinder/cuda/pathfinder/_binaries/supported_nvidia_binaries.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,14 @@
4141
"nvprune": ("toolchain_cuda_nvcc",),
4242
}
4343

44+
# Nsight Systems and Nsight Compute ship in their own PyPI/Conda packages
45+
# (`nvidia/nsight_systems`, `nvidia/nsight_compute`) and are not pinned by the
46+
# `cuda-toolkit` distribution, so they cannot participate in CTK-coherence
47+
# checks. They are tagged "other" so the guard rails treat them as separately
48+
# packaged tools rather than reporting them as missing CTK metadata.
49+
SUPPORTED_BINARIES_PACKAGED_WITH = {
50+
name: ("other" if name in {"nsys", "nsight-sys", "ncu", "nsight-compute"} else "ctk")
51+
for name in SITE_PACKAGES_BINDIRS
52+
}
53+
4454
SUPPORTED_BINARIES_ALL = SUPPORTED_BINARIES = tuple(SITE_PACKAGES_BINDIRS.keys())

cuda_pathfinder/cuda/pathfinder/_compatibility_guard_rails.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
find_nvidia_binary_utility as _find_nvidia_binary_utility,
2121
)
2222
from cuda.pathfinder._binaries.supported_nvidia_binaries import (
23-
SUPPORTED_BINARIES_ALL,
2423
SUPPORTED_BINARIES_CTK_COMPANION_TAGS,
24+
SUPPORTED_BINARIES_PACKAGED_WITH,
2525
)
2626
from cuda.pathfinder._dynamic_libs.lib_descriptor import LIB_DESCRIPTORS
2727
from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL
@@ -119,7 +119,7 @@ class DeclaredDynamicLibPipeline:
119119
"nccl_device": "other",
120120
"nvshmem_device": "other",
121121
}
122-
_BINARY_PACKAGED_WITH: dict[str, PackagedWith] = dict.fromkeys(SUPPORTED_BINARIES_ALL, "ctk")
122+
_BINARY_PACKAGED_WITH: dict[str, PackagedWith] = dict(SUPPORTED_BINARIES_PACKAGED_WITH)
123123

124124

125125
class CompatibilityCheckError(RuntimeError):
@@ -971,7 +971,15 @@ def _declare_dynamic_lib_pipeline(
971971
self._declared_dynamic_lib_pipelines.add(pipeline)
972972
self._enforce_declared_dynamic_lib_pipeline_if_ready(pipeline)
973973

974-
def _reset_for_testing(self) -> None:
974+
def _reset_state(self) -> None:
975+
"""Clear remembered items and pipelines while preserving constructor overrides.
976+
977+
Called both from tests and from the public ``cache_clear`` helpers in
978+
``_process_wide_compatibility_guard_rails`` so a fresh search cycle does
979+
not see leftover compatibility state. Driver versions that the caller
980+
explicitly passed to ``__init__`` are intentionally re-applied; only the
981+
lazily-queried values are dropped.
982+
"""
975983
self._driver_cuda_version = self._configured_driver_cuda_version
976984
self._driver_release_version = self._configured_driver_release_version
977985
self._resolved_items.clear()

cuda_pathfinder/cuda/pathfinder/_process_wide_compatibility_guard_rails.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,13 @@ def _compatibility_guard_rails_mode() -> str:
9999

100100

101101
def _driver_compatibility_mode() -> str:
102+
"""Return the configured driver-compatibility mode after validating its value.
103+
104+
The platform-specific restriction for ``assume_forward_compatibility`` is
105+
deferred to ``_enforce_driver_compatibility_platform``: if guard rails are
106+
turned off entirely, an unsupported platform should not raise just because
107+
this env var happens to be set.
108+
"""
102109
value = os.environ.get(_DRIVER_COMPATIBILITY_ENV_VAR)
103110
if not value:
104111
return _DRIVER_COMPATIBILITY_DEFAULT_MODE
@@ -109,11 +116,14 @@ def _driver_compatibility_mode() -> str:
109116
f"Allowed values: {allowed_values}. "
110117
f"Unset or empty defaults to {_DRIVER_COMPATIBILITY_DEFAULT_MODE!r}."
111118
)
112-
if value == "assume_forward_compatibility" and not sys.platform.startswith("linux"):
113-
raise RuntimeError(f"{_DRIVER_COMPATIBILITY_ENV_VAR}={value!r} is only supported on Linux.")
114119
return value
115120

116121

122+
def _enforce_driver_compatibility_platform(driver_compatibility_mode: str) -> None:
123+
if driver_compatibility_mode == "assume_forward_compatibility" and not sys.platform.startswith("linux"):
124+
raise RuntimeError(f"{_DRIVER_COMPATIBILITY_ENV_VAR}={driver_compatibility_mode!r} is only supported on Linux.")
125+
126+
117127
def _driver_compatibility_override_hint() -> str:
118128
return (
119129
"On supported Linux systems that intentionally rely on NVIDIA forward compatibility "
@@ -146,7 +156,7 @@ def _current_process_wide_compatibility_guard_rails() -> _ProcessWideGuardRailsA
146156
def _reset_process_wide_compatibility_guard_rails() -> None:
147157
current = _current_process_wide_compatibility_guard_rails()
148158
if isinstance(current, CompatibilityGuardRails):
149-
current._reset_for_testing()
159+
current._reset_state()
150160
return
151161
public_module = _public_module()
152162
if public_module is None:
@@ -161,6 +171,7 @@ def _try_process_wide_guard_rails_then_fallback(guard_rails_call: Callable[[], _
161171
mode = _compatibility_guard_rails_mode()
162172
if mode == "off":
163173
return raw_call()
174+
_enforce_driver_compatibility_platform(driver_compatibility_mode)
164175
try:
165176
return guard_rails_call()
166177
except CompatibilityInsufficientMetadataError:

cuda_pathfinder/tests/test_compatibility_guard_rails.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,8 @@ def test_resolve_binary_item_covers_every_supported_name(tmp_path, utility_name)
474474
item = compatibility_module._resolve_binary_item(utility_name, abs_path)
475475
assert item.name == utility_name
476476
assert item.kind == "binary"
477-
assert item.packaged_with in ("ctk", "other")
477+
expected_packaged_with = "other" if utility_name in {"nsys", "nsight-sys", "ncu", "nsight-compute"} else "ctk"
478+
assert item.packaged_with == expected_packaged_with
478479

479480

480481
def test_static_bitcode_and_binary_methods_participate_in_checks(monkeypatch, tmp_path):

cuda_pathfinder/tests/test_compatibility_guard_rails_public.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,19 @@ def test_driver_compatibility_override_is_linux_only(monkeypatch):
206206
pathfinder.find_nvidia_binary_utility("nvcc")
207207

208208

209+
def test_driver_compatibility_override_is_not_validated_when_guard_rails_off(monkeypatch):
210+
raw_loaded = _loaded_dl("/opt/mock/libnvrtc.so.12", found_via="system-search")
211+
212+
monkeypatch.setenv(COMPATIBILITY_GUARD_RAILS_ENV_VAR, "off")
213+
monkeypatch.setenv(DRIVER_COMPATIBILITY_ENV_VAR, "assume_forward_compatibility")
214+
monkeypatch.setattr(process_wide_module.sys, "platform", "win32")
215+
monkeypatch.setattr(process_wide_module, "_load_nvidia_dynamic_lib", lambda _libname: raw_loaded)
216+
217+
loaded = pathfinder.load_nvidia_dynamic_lib("nvrtc")
218+
219+
assert loaded is raw_loaded
220+
221+
209222
@pytest.mark.skipif(
210223
not sys.platform.startswith("linux"),
211224
reason="driver forward-compatibility override is Linux-only",

0 commit comments

Comments
 (0)