Skip to content

Commit b622613

Browse files
committed
Accept wheel metadata version ranges in strict guard rails.
Infer the CUDA Toolkit line from both wildcard-pinned and range-based cuda-toolkit requirements so strict process-wide guard rails keep working for editable wheel installs used by nvrtc and nvJitLink. Made-with: Cursor
1 parent e2a0909 commit b622613

2 files changed

Lines changed: 150 additions & 10 deletions

File tree

cuda_pathfinder/cuda/pathfinder/_compatibility_guard_rails.py

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,8 @@
5353
ConstraintArg: TypeAlias = int | str | tuple[str, int] | None
5454

5555
_CTK_VERSION_RE = re.compile(r"^(?P<major>\d+)\.(?P<minor>\d+)")
56-
_REQUIRES_DIST_RE = re.compile(
57-
r"^\s*(?P<name>[A-Za-z0-9_.-]+)\s*==\s*(?P<version>[0-9][A-Za-z0-9.+-]*?)(?:\.\*)?(?:\s*;|$)"
58-
)
56+
_REQUIRES_DIST_RE = re.compile(r"^\s*(?P<name>[A-Za-z0-9_.-]+)\s*(?P<specifier_text>[^;]*)(?:\s*;|$)")
57+
_VERSION_SPECIFIER_RE = re.compile(r"^\s*(?P<operator>==|<=|>=|<|>)\s*(?P<version>[0-9][A-Za-z0-9.+-]*?(?:\.\*)?)\s*$")
5958

6059
_STATIC_LIBS_PACKAGED_WITH: dict[str, PackagedWith] = {
6160
"cudadevrt": "ctk",
@@ -113,6 +112,12 @@ def __str__(self) -> str:
113112
return f"{self.operator}{self.value}"
114113

115114

115+
@dataclass(frozen=True, slots=True)
116+
class VersionSpecifier:
117+
operator: ConstraintOperator
118+
version: str
119+
120+
116121
@dataclass(frozen=True, slots=True)
117122
class ResolvedItem:
118123
name: str
@@ -185,6 +190,63 @@ def _distribution_name(dist: importlib.metadata.Distribution) -> str | None:
185190
return metadata.get("Name")
186191

187192

193+
def _release_version_parts(version: str) -> tuple[int, ...] | None:
194+
match = re.match(r"^\d+(?:\.\d+)*", version)
195+
if match is None:
196+
return None
197+
return tuple(int(part) for part in match.group(0).split("."))
198+
199+
200+
def _compare_release_versions(lhs: tuple[int, ...], rhs: tuple[int, ...]) -> int:
201+
max_len = max(len(lhs), len(rhs))
202+
lhs_padded = lhs + (0,) * (max_len - len(lhs))
203+
rhs_padded = rhs + (0,) * (max_len - len(rhs))
204+
if lhs_padded < rhs_padded:
205+
return -1
206+
if lhs_padded > rhs_padded:
207+
return 1
208+
return 0
209+
210+
211+
def _parse_version_specifiers(specifier_text: str) -> tuple[VersionSpecifier, ...]:
212+
stripped = specifier_text.strip()
213+
if not stripped:
214+
return ()
215+
parsed: list[VersionSpecifier] = []
216+
for raw_clause in stripped.split(","):
217+
match = _VERSION_SPECIFIER_RE.match(raw_clause)
218+
if match is None:
219+
return ()
220+
parsed.append(VersionSpecifier(operator=match.group("operator"), version=match.group("version")))
221+
return tuple(parsed)
222+
223+
224+
def _version_satisfies_specifiers(version: str, specifiers: tuple[VersionSpecifier, ...]) -> bool:
225+
if not specifiers:
226+
return False
227+
for specifier in specifiers:
228+
if specifier.operator == "==":
229+
prefix = specifier.version.removesuffix(".*")
230+
if version == prefix or version.startswith(prefix + "."):
231+
continue
232+
return False
233+
candidate_parts = _release_version_parts(version)
234+
required_parts = _release_version_parts(specifier.version)
235+
if candidate_parts is None or required_parts is None:
236+
return False
237+
comparison = _compare_release_versions(candidate_parts, required_parts)
238+
if specifier.operator == "<" and comparison < 0:
239+
continue
240+
if specifier.operator == "<=" and comparison <= 0:
241+
continue
242+
if specifier.operator == ">" and comparison > 0:
243+
continue
244+
if specifier.operator == ">=" and comparison >= 0:
245+
continue
246+
return False
247+
return True
248+
249+
188250
@functools.cache
189251
def _owned_distribution_candidates(abs_path: str) -> tuple[tuple[str, str], ...]:
190252
normalized_abs_path = os.path.normpath(os.path.abspath(abs_path))
@@ -201,27 +263,42 @@ def _owned_distribution_candidates(abs_path: str) -> tuple[tuple[str, str], ...]
201263

202264

203265
@functools.cache
204-
def _cuda_toolkit_requirement_maps() -> tuple[tuple[str, CtkVersion, dict[str, tuple[str, ...]]], ...]:
205-
results: list[tuple[str, CtkVersion, dict[str, tuple[str, ...]]]] = []
266+
def _cuda_toolkit_requirement_maps() -> tuple[
267+
tuple[str, CtkVersion, dict[str, tuple[tuple[VersionSpecifier, ...], ...]]], ...
268+
]:
269+
results: list[tuple[str, CtkVersion, dict[str, tuple[tuple[VersionSpecifier, ...], ...]]]] = []
206270
for dist in importlib.metadata.distributions():
207271
dist_name = _distribution_name(dist)
208272
if _normalize_distribution_name(dist_name or "") != "cuda-toolkit":
209273
continue
210274
ctk_version = _parse_ctk_version(dist.version)
211275
if ctk_version is None:
212276
continue
213-
requirement_map: dict[str, set[str]] = {}
277+
requirement_map: dict[str, set[tuple[VersionSpecifier, ...]]] = {}
214278
for requirement in dist.requires or ():
215279
match = _REQUIRES_DIST_RE.match(requirement)
216280
if match is None:
217281
continue
218282
req_name = _normalize_distribution_name(match.group("name"))
219-
requirement_map.setdefault(req_name, set()).add(match.group("version"))
283+
parsed_specifiers = _parse_version_specifiers(match.group("specifier_text"))
284+
if not parsed_specifiers:
285+
continue
286+
requirement_map.setdefault(req_name, set()).add(parsed_specifiers)
220287
results.append(
221288
(
222289
dist.version,
223290
ctk_version,
224-
{name: tuple(sorted(prefixes)) for name, prefixes in requirement_map.items()},
291+
{
292+
name: tuple(
293+
sorted(
294+
specifier_sets,
295+
key=lambda specifiers: tuple(
296+
(specifier.operator, specifier.version) for specifier in specifiers
297+
),
298+
)
299+
)
300+
for name, specifier_sets in requirement_map.items()
301+
},
225302
)
226303
)
227304
return tuple(results)
@@ -232,9 +309,9 @@ def _wheel_metadata_for_abs_path(abs_path: str) -> CtkMetadata | None:
232309
for owner_name, owner_version in _owned_distribution_candidates(abs_path):
233310
normalized_owner_name = _normalize_distribution_name(owner_name)
234311
for toolkit_dist_version, ctk_version, requirement_map in _cuda_toolkit_requirement_maps():
235-
requirement_prefixes = requirement_map.get(normalized_owner_name, ())
312+
requirement_specifier_sets = requirement_map.get(normalized_owner_name, ())
236313
if not any(
237-
owner_version == prefix or owner_version.startswith(prefix + ".") for prefix in requirement_prefixes
314+
_version_satisfies_specifiers(owner_version, specifiers) for specifiers in requirement_specifier_sets
238315
):
239316
continue
240317
matched_versions[ctk_version] = (

cuda_pathfinder/tests/test_compatibility_guard_rails.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,26 @@ def _driver_cuda_version(encoded: int) -> DriverCudaVersion:
8383
)
8484

8585

86+
class _FakeDistribution:
87+
def __init__(
88+
self,
89+
*,
90+
name: str,
91+
version: str,
92+
root: Path,
93+
files: tuple[str, ...] = (),
94+
requires: tuple[str, ...] = (),
95+
) -> None:
96+
self.metadata = {"Name": name}
97+
self.version = version
98+
self.files = tuple(Path(file) for file in files)
99+
self.requires = list(requires)
100+
self._root = root
101+
102+
def locate_file(self, file: Path) -> Path:
103+
return self._root / file
104+
105+
86106
def _assert_real_ctk_backed_path(path: str) -> None:
87107
norm_path = os.path.normpath(os.path.abspath(path))
88108
if "site-packages" in Path(norm_path).parts:
@@ -441,6 +461,49 @@ def fake_load_nvidia_dynamic_lib(libname: str) -> LoadedDL:
441461
guard_rails.find_nvidia_header_directory("nvrtc")
442462

443463

464+
@pytest.mark.parametrize(
465+
"requirement",
466+
(
467+
"nvidia-nvjitlink == 13.2.78.*; extra == 'nvjitlink'",
468+
"nvidia-nvjitlink<14,>=13.2.78; extra == 'nvjitlink'",
469+
),
470+
)
471+
def test_wheel_metadata_accepts_exact_and_range_requirements(monkeypatch, tmp_path, requirement):
472+
site_packages = tmp_path / "site-packages"
473+
lib_path = _touch(site_packages / "nvidia" / "cu13" / "lib" / "libnvJitLink.so.13")
474+
owner_dist = _FakeDistribution(
475+
name="nvidia-nvjitlink",
476+
version="13.2.78",
477+
root=site_packages,
478+
files=("nvidia/cu13/lib/libnvJitLink.so.13",),
479+
)
480+
cuda_toolkit_dist = _FakeDistribution(
481+
name="cuda-toolkit",
482+
version="13.2.1",
483+
root=site_packages,
484+
requires=(requirement,),
485+
)
486+
487+
compatibility_module._owned_distribution_candidates.cache_clear()
488+
compatibility_module._cuda_toolkit_requirement_maps.cache_clear()
489+
try:
490+
monkeypatch.setattr(
491+
compatibility_module.importlib.metadata,
492+
"distributions",
493+
lambda: (owner_dist, cuda_toolkit_dist),
494+
)
495+
496+
metadata = compatibility_module._wheel_metadata_for_abs_path(lib_path)
497+
finally:
498+
compatibility_module._owned_distribution_candidates.cache_clear()
499+
compatibility_module._cuda_toolkit_requirement_maps.cache_clear()
500+
501+
assert metadata is not None
502+
assert metadata.ctk_version.major == 13
503+
assert metadata.ctk_version.minor == 2
504+
assert metadata.source == "wheel metadata via nvidia-nvjitlink==13.2.78 pinned by cuda-toolkit==13.2.1"
505+
506+
444507
def test_constraints_accept_string_and_tuple_forms(monkeypatch, tmp_path):
445508
ctk_root = tmp_path / "cuda-12.9"
446509
_write_version_json(ctk_root, "12.9.20250531")

0 commit comments

Comments
 (0)