Skip to content

Commit 113d831

Browse files
committed
fixup! feat(core.utils): scope driver-version salt to linker; document cache semantics
1 parent dbc6348 commit 113d831

File tree

2 files changed

+62
-4
lines changed

2 files changed

+62
-4
lines changed

cuda_core/cuda/core/utils/_program_cache.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -634,15 +634,19 @@ def _probe(label: str, fn):
634634
cuda_core_ver = _probe("cuda_core", _cuda_core_version)
635635
if cuda_core_ver is not None:
636636
_update("cuda_core", cuda_core_ver.encode("ascii"))
637-
driver_ver = _probe("driver", _driver_version)
638-
if driver_ver is not None:
639-
_update("driver", str(driver_ver).encode("ascii"))
640637
if backend == "nvrtc":
641638
nvrtc_ver = _probe("nvrtc", _nvrtc_version)
642639
if nvrtc_ver is not None:
643640
nv_major, nv_minor = nvrtc_ver
644641
_update("nvrtc", f"{nv_major}.{nv_minor}".encode("ascii"))
645642
elif backend == "linker":
643+
# Only the linker (PTX inputs) actually invokes the driver for
644+
# codegen via cuLink, so the driver version belongs only here.
645+
# Keying NVRTC/NVVM on the driver would invalidate caches across
646+
# benign driver upgrades that don't affect compiled bytes.
647+
driver_ver = _probe("driver", _driver_version)
648+
if driver_ver is not None:
649+
_update("driver", str(driver_ver).encode("ascii"))
646650
linker = _probe("linker", _linker_backend_and_version)
647651
if linker is not None:
648652
lb_name, lb_version = linker
@@ -1046,6 +1050,30 @@ class FileStreamProgramCache(ProgramCacheResource):
10461050
partially-written entry. There is no cross-process LRU tracking; size
10471051
enforcement is best-effort by file mtime.
10481052
1053+
.. note:: **Best-effort writes.**
1054+
1055+
On Windows, ``os.replace`` raises ``PermissionError`` (winerror
1056+
32 / 33) when another process holds the target file open. This
1057+
backend retries with bounded backoff (~185 ms) and, if still
1058+
failing, drops the cache write silently and returns success-shaped
1059+
control flow. The next call will see no entry and recompile. POSIX
1060+
and other ``PermissionError`` codes propagate.
1061+
1062+
.. note:: **Atomic for readers, not crash-durable.**
1063+
1064+
Each entry's temp file is ``fsync``-ed before ``os.replace``, but
1065+
the containing directory is **not** ``fsync``-ed. A host crash
1066+
between write and the next directory commit may lose recently
1067+
added entries; surviving entries remain consistent.
1068+
1069+
.. note:: **Cross-version sharing.**
1070+
1071+
``_FILESTREAM_SCHEMA_VERSION`` guards on-disk format changes: a
1072+
cache written by an incompatible version is wiped on open. Within
1073+
a single schema version, the cache is safe to share across
1074+
``cuda.core`` patch releases because every entry's key already
1075+
encodes ``cuda_core``, NVRTC, NVVM, driver, and linker fingerprints.
1076+
10491077
Parameters
10501078
----------
10511079
path:

cuda_core/tests/test_program_cache.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,10 +724,40 @@ def test_make_program_cache_key_accepts_side_effect_options_for_ptx(option_kw):
724724
_make_key(code=".version 7.0", code_type="ptx", options=_opts(**option_kw)) # no raise
725725

726726

727+
@pytest.mark.parametrize(
728+
"code_type, code, target_type",
729+
[
730+
pytest.param("c++", "a", "cubin", id="nvrtc"),
731+
pytest.param("nvvm", "abc", "ptx", id="nvvm"),
732+
],
733+
)
734+
def test_make_program_cache_key_driver_probe_failure_does_not_perturb_non_linker(
735+
code_type, code, target_type, monkeypatch
736+
):
737+
"""The driver version is only consumed on the linker (PTX) path because
738+
cuLink runs through the driver. NVRTC and NVVM produce identical bytes
739+
regardless of the driver version, so a failed driver probe must NOT
740+
perturb their cache keys -- otherwise driver upgrades would invalidate
741+
perfectly good caches."""
742+
from cuda.core.utils import _program_cache
743+
744+
def _broken():
745+
raise RuntimeError("driver probe failed")
746+
747+
k_ok = _make_key(code=code, code_type=code_type, target_type=target_type)
748+
monkeypatch.setattr(_program_cache, "_driver_version", _broken)
749+
k_broken = _make_key(code=code, code_type=code_type, target_type=target_type)
750+
assert k_ok == k_broken
751+
752+
727753
@pytest.mark.parametrize(
728754
"probe_name, code_type, code",
729755
[
730-
pytest.param("_driver_version", "c++", "a", id="driver"),
756+
# Driver version is only consumed on the linker (PTX) path now; NVRTC
757+
# and NVVM compile bytes don't depend on the driver, so a failing
758+
# driver probe should NOT perturb their cache keys (separately
759+
# verified by test_..._driver_probe_failure_does_not_perturb_non_linker).
760+
pytest.param("_driver_version", "ptx", ".ptx", id="driver"),
731761
pytest.param("_nvrtc_version", "c++", "a", id="nvrtc"),
732762
pytest.param("_cuda_core_version", "c++", "a", id="cuda_core"),
733763
pytest.param("_linker_backend_and_version", "ptx", ".ptx", id="linker"),

0 commit comments

Comments
 (0)