Skip to content

Commit f316c6a

Browse files
committed
fixup! feat(core.utils): drop direct driver-version probe in linker branch (nvJitLink)
1 parent 3467c28 commit f316c6a

File tree

2 files changed

+44
-15
lines changed

2 files changed

+44
-15
lines changed

cuda_core/cuda/core/utils/_program_cache.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -645,13 +645,12 @@ def _probe(label: str, fn):
645645
nv_major, nv_minor = nvrtc_ver
646646
_update("nvrtc", f"{nv_major}.{nv_minor}".encode("ascii"))
647647
elif backend == "linker":
648-
# Only the linker (PTX inputs) actually invokes the driver for
649-
# codegen via cuLink, so the driver version belongs only here.
650-
# Keying NVRTC/NVVM on the driver would invalidate caches across
651-
# benign driver upgrades that don't affect compiled bytes.
652-
driver_ver = _probe("driver", _driver_version)
653-
if driver_ver is not None:
654-
_update("driver", str(driver_ver).encode("ascii"))
648+
# Only cuLink (driver-backed linker) goes through the CUDA driver
649+
# for codegen. nvJitLink is a separate library, so a driver upgrade
650+
# under it does not change the compiled bytes -- skip the driver
651+
# version there. ``_linker_backend_and_version`` already returns the
652+
# driver version when the driver backend is active, so the bytes
653+
# are still in the digest via ``linker_version``.
655654
linker = _probe("linker", _linker_backend_and_version)
656655
if linker is not None:
657656
lb_name, lb_version = linker

cuda_core/tests/test_program_cache.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,20 @@ def test_make_program_cache_key_accepts_side_effect_options_for_ptx(option_kw):
732732
_make_key(code=".version 7.0", code_type="ptx", options=_opts(**option_kw)) # no raise
733733

734734

735+
def test_make_program_cache_key_driver_version_does_not_perturb_ptx_under_nvjitlink(monkeypatch):
736+
"""nvJitLink does NOT route PTX compilation through cuLink, so a
737+
changing driver version must not invalidate PTX cache keys when
738+
nvJitLink is the active linker backend."""
739+
from cuda.core.utils import _program_cache
740+
741+
monkeypatch.setattr(_program_cache, "_linker_backend_and_version", lambda: ("nvJitLink", "12030"))
742+
monkeypatch.setattr(_program_cache, "_driver_version", lambda: 13200)
743+
k_a = _make_key(code=".version 7.0", code_type="ptx")
744+
monkeypatch.setattr(_program_cache, "_driver_version", lambda: 13300)
745+
k_b = _make_key(code=".version 7.0", code_type="ptx")
746+
assert k_a == k_b
747+
748+
735749
@pytest.mark.parametrize(
736750
"code_type, code, target_type",
737751
[
@@ -761,11 +775,6 @@ def _broken():
761775
@pytest.mark.parametrize(
762776
"probe_name, code_type, code",
763777
[
764-
# Driver version is only consumed on the linker (PTX) path now; NVRTC
765-
# and NVVM compile bytes don't depend on the driver, so a failing
766-
# driver probe should NOT perturb their cache keys (separately
767-
# verified by test_..._driver_probe_failure_does_not_perturb_non_linker).
768-
pytest.param("_driver_version", "ptx", ".ptx", id="driver"),
769778
pytest.param("_nvrtc_version", "c++", "a", id="nvrtc"),
770779
pytest.param("_cuda_core_version", "c++", "a", id="cuda_core"),
771780
pytest.param("_linker_backend_and_version", "ptx", ".ptx", id="linker"),
@@ -775,7 +784,9 @@ def test_make_program_cache_key_fails_closed_on_probe_failure(probe_name, code_t
775784
"""A failed probe (a) must produce a key that differs from a working
776785
probe (so environments never silently share cache entries), and (b)
777786
must produce a *stable* key across calls -- otherwise the persistent
778-
cache could not be reused in broken environments."""
787+
cache could not be reused in broken environments. ``_driver_version``
788+
is exercised separately because it's only invoked transitively from
789+
``_linker_backend_and_version`` on the cuLink driver path."""
779790
from cuda.core.utils import _program_cache
780791

781792
def _broken():
@@ -789,6 +800,25 @@ def _broken():
789800
assert k_broken1 == k_broken2 # stable: same failure -> same key
790801

791802

803+
def test_make_program_cache_key_driver_probe_failure_taints_ptx_under_cuLink(monkeypatch):
804+
"""When the driver linker is active, _linker_backend_and_version
805+
invokes _driver_version internally; a failing driver probe propagates
806+
through the linker probe and must perturb the PTX key."""
807+
from cuda.core.utils import _program_cache
808+
809+
def _broken():
810+
raise RuntimeError("driver probe failed")
811+
812+
# Force the cuLink driver path so _linker_backend_and_version reads driver_version.
813+
from cuda.core import _linker
814+
815+
monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: True)
816+
k_ok = _make_key(code=".ptx", code_type="ptx")
817+
monkeypatch.setattr(_program_cache, "_driver_version", _broken)
818+
k_broken = _make_key(code=".ptx", code_type="ptx")
819+
assert k_ok != k_broken
820+
821+
792822
# ---------------------------------------------------------------------------
793823
# SQLiteProgramCache -- basic CRUD
794824
# ---------------------------------------------------------------------------
@@ -1085,7 +1115,7 @@ def reader(thread_id: int):
10851115
for t in threads:
10861116
t.join(timeout=30)
10871117
assert not any(t.is_alive() for t in threads)
1088-
assert errors == []
1118+
assert not errors
10891119

10901120

10911121
@needs_sqlite3
@@ -1436,7 +1466,7 @@ def test_filestream_cache_clear_preserves_young_tmp_files(tmp_path):
14361466
# Filenames are hash-like (no extension), so use a file filter rather
14371467
# than a "*.*" glob.
14381468
remaining_entries = [p for p in (root / "entries").rglob("*") if p.is_file()]
1439-
assert remaining_entries == []
1469+
assert not remaining_entries
14401470
assert young_tmp.exists()
14411471
assert not ancient_tmp.exists()
14421472

0 commit comments

Comments
 (0)