Skip to content

Commit 55f4d47

Browse files
committed
fixup! feat(core.utils): NVVM fingerprint, Windows replace retry, driver-linker validation
1 parent 47b47da commit 55f4d47

File tree

2 files changed

+163
-26
lines changed

2 files changed

+163
-26
lines changed

cuda_core/cuda/core/utils/_program_cache.py

Lines changed: 94 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ def _linker_option_fingerprint(options: ProgramOptions) -> list[bytes]:
219219
return [f"{name}={getattr(options, name, None)!r}".encode() for name in _LINKER_RELEVANT_FIELDS]
220220

221221

222+
# LinkerOptions fields that the cuLink (driver) backend rejects outright
223+
# (_linker.pyx _prepare_driver_options). nvJitLink accepts all of them.
224+
_DRIVER_LINKER_UNSUPPORTED_FIELDS = ("time", "ptxas_options", "split_compile", "split_compile_extended")
225+
226+
222227
def _driver_version() -> int:
223228
return int(_handle_return(_driver.cuDriverGetVersion()))
224229

@@ -247,6 +252,22 @@ def _linker_backend_and_version() -> tuple[str, str]:
247252
return ("nvJitLink", str(nvjitlink.version()))
248253

249254

255+
def _nvvm_fingerprint() -> str:
256+
"""Stable identifier for the loaded NVVM toolchain.
257+
258+
NVVM lacks a direct version API (nvbugs 5312315), but ``ir_version()``
259+
reports the IR major/minor/debug pair the toolchain emits -- enough to
260+
keep pre-/post-upgrade caches separate. Paired with the driver and
261+
cuda-core versions already in the digest, this is a practical substitute
262+
for a true libNVVM version.
263+
"""
264+
from cuda.core._program import _get_nvvm_module
265+
266+
nvvm = _get_nvvm_module()
267+
major, minor, debug_major, debug_minor = nvvm.ir_version()
268+
return f"ir={major}.{minor}.{debug_major}.{debug_minor}"
269+
270+
250271
def _cuda_core_version() -> str:
251272
from cuda.core._version import __version__
252273

@@ -389,6 +410,28 @@ def make_program_cache_key(
389410
f"compile will read and pass it as extra_digest=..."
390411
)
391412

413+
# PTX compiles go through Linker. When the driver (cuLink) backend is
414+
# selected (nvJitLink unavailable), Program.compile rejects a subset of
415+
# options that nvJitLink would accept; reject them here too so we never
416+
# store a key for a compilation that can't succeed in this environment.
417+
# If the probe fails we can't tell which backend will run, so skip -- the
418+
# failed-probe branch below already taints the key.
419+
if backend == "linker":
420+
try:
421+
from cuda.core._linker import _decide_nvjitlink_or_driver
422+
423+
use_driver_linker = _decide_nvjitlink_or_driver()
424+
except Exception:
425+
use_driver_linker = None
426+
if use_driver_linker is True:
427+
unsupported = [name for name in _DRIVER_LINKER_UNSUPPORTED_FIELDS if _option_is_set(options, name)]
428+
if unsupported:
429+
raise ValueError(
430+
f"the cuLink driver linker does not support these options: "
431+
f"{', '.join(unsupported)}; Program.compile() would reject this "
432+
f"configuration before producing an ObjectCode."
433+
)
434+
392435
if isinstance(code, str):
393436
code_bytes = code.encode("utf-8")
394437
elif isinstance(code, (bytes, bytearray)):
@@ -431,17 +474,22 @@ def _update(label: str, payload: bytes) -> None:
431474
hasher.update(payload)
432475

433476
def _probe(label: str, fn):
434-
"""Run an environment probe; on failure, hash the exception's class
435-
and message under a ``*_probe_failed`` label. That label differs
436-
from the success label (``driver``/``nvrtc``/...), so a broken env
437-
never collides with a working one; and because the digest is
438-
derived from the *stable* exception signature -- not a random
439-
per-process marker -- two processes with the same failure produce
440-
the same key and can reuse on-disk cache entries."""
477+
"""Run an environment probe; on failure, hash the exception's
478+
CLASS NAME (not its message) under a ``*_probe_failed`` label.
479+
480+
Using only the class name keeps the digest stable across repeated
481+
calls within one process (e.g. NVVM's loader reports different
482+
messages on first vs. cached-failure attempts) AND across processes
483+
that hit the same failure mode. The ``_probe_failed`` label differs
484+
from the success labels (``driver``/``nvrtc``/...), so a broken env
485+
never collides with a working one -- the cache "fails closed"
486+
between broken and working environments while staying persistent
487+
within either.
488+
"""
441489
try:
442490
return fn()
443491
except Exception as exc:
444-
_update(f"{label}_probe_failed", f"{type(exc).__name__}:{exc}".encode())
492+
_update(f"{label}_probe_failed", type(exc).__name__.encode())
445493
return None
446494

447495
_update("schema", str(_KEY_SCHEMA_VERSION).encode("ascii"))
@@ -463,8 +511,9 @@ def _probe(label: str, fn):
463511
_update("linker_backend", lb_name.encode("ascii"))
464512
_update("linker_version", lb_version.encode("ascii"))
465513
else:
466-
# NVVM lacks a direct version API; proxy via driver + cuda-core above.
467-
_update("nvvm", b"proxied-by-driver-and-cuda-core-version")
514+
nvvm_fp = _probe("nvvm", _nvvm_fingerprint)
515+
if nvvm_fp is not None:
516+
_update("nvvm", nvvm_fp.encode("ascii"))
468517
_update("code_type", code_type.encode("ascii"))
469518
_update("target_type", target_type.encode("ascii"))
470519
_update("code", code_bytes)
@@ -796,6 +845,34 @@ def _enforce_size_cap(self) -> None:
796845
_SCHEMA_FILE = "SCHEMA_VERSION"
797846

798847

848+
_SHARING_VIOLATION_WINERRORS = (32, 33) # ERROR_SHARING_VIOLATION, ERROR_LOCK_VIOLATION
849+
_REPLACE_RETRY_DELAYS = (0.0, 0.005, 0.010, 0.020, 0.050, 0.100) # ~185ms budget
850+
851+
852+
def _replace_with_sharing_retry(tmp_path: Path, target: Path) -> bool:
853+
"""Atomic rename with Windows-specific retry on sharing/lock violations.
854+
855+
Returns True on success. Returns False only after the retry budget is
856+
exhausted on Windows with a genuine sharing violation -- the caller then
857+
treats the cache write as dropped. Any other ``PermissionError`` (ACLs,
858+
read-only dir, unexpected winerror, or any POSIX failure) propagates.
859+
"""
860+
for i, delay in enumerate(_REPLACE_RETRY_DELAYS):
861+
if delay:
862+
time.sleep(delay)
863+
try:
864+
os.replace(tmp_path, target)
865+
return True
866+
except PermissionError as exc:
867+
if not _IS_WINDOWS or getattr(exc, "winerror", None) not in _SHARING_VIOLATION_WINERRORS:
868+
raise
869+
# Windows sharing violation; loop and try again unless this was the
870+
# last attempt, in which case fall through and return False.
871+
if i == len(_REPLACE_RETRY_DELAYS) - 1:
872+
return False
873+
return False
874+
875+
799876
def _prune_if_stat_unchanged(path: Path, st_before: os.stat_result) -> None:
800877
"""Unlink ``path`` iff its stat still matches ``st_before``.
801878
@@ -938,24 +1015,15 @@ def __setitem__(self, key: object, value: object) -> None:
9381015
fh.write(record)
9391016
fh.flush()
9401017
os.fsync(fh.fileno())
941-
# Narrow PermissionError suppression to os.replace only. Earlier
942-
# failures (mkdir / mkstemp / fdopen / write / fsync) indicate a
943-
# real configuration problem and must propagate.
944-
try:
945-
os.replace(tmp_path, target)
946-
except PermissionError as exc:
1018+
# Retry os.replace under Windows sharing/lock violations; only
1019+
# give up (and drop the cache write) after a bounded backoff, so
1020+
# transient contention is not turned into a silent miss.
1021+
# Non-sharing PermissionErrors and all POSIX PermissionErrors
1022+
# propagate immediately (real config problem).
1023+
if not _replace_with_sharing_retry(tmp_path, target):
9471024
with contextlib.suppress(FileNotFoundError):
9481025
tmp_path.unlink()
949-
# Windows raises PermissionError from os.replace specifically
950-
# when the target is held open by another process (winerror
951-
# 32 = ERROR_SHARING_VIOLATION, 33 = ERROR_LOCK_VIOLATION);
952-
# swallow those as a cache miss. Any other winerror (ACL
953-
# issues, read-only dir, etc.) is a real config problem and
954-
# must propagate. POSIX has no such sharing-violation case
955-
# and always propagates.
956-
if _IS_WINDOWS and getattr(exc, "winerror", None) in (32, 33):
957-
return
958-
raise
1026+
return
9591027
except BaseException:
9601028
with contextlib.suppress(FileNotFoundError):
9611029
tmp_path.unlink()

cuda_core/tests/test_program_cache.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,75 @@ def test_make_program_cache_key_ignores_name_expressions_for_non_nvrtc(code_type
340340
assert k_none == k_with
341341

342342

343+
def test_make_program_cache_key_nvvm_probe_changes_key(monkeypatch):
344+
"""NVVM keys must reflect the NVVM toolchain identity (IR version)
345+
so an upgraded libNVVM does not silently reuse pre-upgrade entries."""
346+
from cuda.core.utils import _program_cache
347+
348+
monkeypatch.setattr(_program_cache, "_nvvm_fingerprint", lambda: "ir=1.8.3.0")
349+
k1 = _make_key(code="abc", code_type="nvvm", target_type="ptx")
350+
monkeypatch.setattr(_program_cache, "_nvvm_fingerprint", lambda: "ir=2.0.3.0")
351+
k2 = _make_key(code="abc", code_type="nvvm", target_type="ptx")
352+
assert k1 != k2
353+
354+
355+
@pytest.mark.parametrize(
356+
"option_kw",
357+
[
358+
pytest.param({"time": True}, id="time"),
359+
pytest.param({"ptxas_options": "-v"}, id="ptxas_options"),
360+
pytest.param({"split_compile": 0}, id="split_compile"),
361+
],
362+
)
363+
def test_make_program_cache_key_ptx_rejects_driver_linker_unsupported(option_kw, monkeypatch):
364+
"""When the driver (cuLink) linker backend is selected, options that
365+
``_prepare_driver_options`` rejects must also be rejected at key time
366+
so we never cache a compilation that would fail."""
367+
from cuda.core import _linker
368+
369+
monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: True) # driver
370+
with pytest.raises(ValueError, match="driver linker"):
371+
_make_key(code=".version 7.0", code_type="ptx", options=_opts(**option_kw))
372+
373+
374+
def test_make_program_cache_key_ptx_accepts_driver_linker_unsupported_with_nvjitlink(monkeypatch):
375+
"""Under nvJitLink those same options are valid and must not be
376+
rejected at key time."""
377+
from cuda.core import _linker
378+
379+
monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: False) # nvJitLink
380+
# Should not raise.
381+
_make_key(code=".version 7.0", code_type="ptx", options=_opts(time=True))
382+
383+
384+
def test_filestream_cache_replace_retries_on_sharing_violation(tmp_path, monkeypatch):
385+
"""Under Windows sharing/lock violations, os.replace is retried with a
386+
bounded backoff; a transient violation that clears within the budget
387+
must still produce a successful cache write."""
388+
import os as _os
389+
390+
from cuda.core.utils import FileStreamProgramCache, _program_cache
391+
392+
monkeypatch.setattr(_program_cache, "_IS_WINDOWS", True)
393+
394+
real_replace = _os.replace
395+
calls = {"n": 0}
396+
397+
def _flaky_replace(src, dst):
398+
calls["n"] += 1
399+
if calls["n"] < 3:
400+
exc = PermissionError("sharing violation")
401+
exc.winerror = 32
402+
raise exc
403+
return real_replace(src, dst)
404+
405+
with FileStreamProgramCache(tmp_path / "fc") as cache:
406+
monkeypatch.setattr(_os, "replace", _flaky_replace)
407+
cache[b"k"] = _fake_object_code(b"v") # succeeds on third attempt
408+
assert calls["n"] == 3
409+
assert bytes(cache[b"k"].code) == b"v"
410+
411+
343412
@pytest.mark.parametrize(
344413
"option_kw",
345414
[

0 commit comments

Comments
 (0)