@@ -732,6 +732,20 @@ def test_make_program_cache_key_accepts_side_effect_options_for_ptx(option_kw):
732732 _make_key (code = ".version 7.0" , code_type = "ptx" , options = _opts (** option_kw )) # no raise
733733
734734
735+ def test_make_program_cache_key_driver_version_does_not_perturb_ptx_under_nvjitlink (monkeypatch ):
736+ """nvJitLink does NOT route PTX compilation through cuLink, so a
737+ changing driver version must not invalidate PTX cache keys when
738+ nvJitLink is the active linker backend."""
739+ from cuda .core .utils import _program_cache
740+
741+ monkeypatch .setattr (_program_cache , "_linker_backend_and_version" , lambda : ("nvJitLink" , "12030" ))
742+ monkeypatch .setattr (_program_cache , "_driver_version" , lambda : 13200 )
743+ k_a = _make_key (code = ".version 7.0" , code_type = "ptx" )
744+ monkeypatch .setattr (_program_cache , "_driver_version" , lambda : 13300 )
745+ k_b = _make_key (code = ".version 7.0" , code_type = "ptx" )
746+ assert k_a == k_b
747+
748+
735749@pytest .mark .parametrize (
736750 "code_type, code, target_type" ,
737751 [
@@ -761,11 +775,6 @@ def _broken():
761775@pytest .mark .parametrize (
762776 "probe_name, code_type, code" ,
763777 [
764- # Driver version is only consumed on the linker (PTX) path now; NVRTC
765- # and NVVM compile bytes don't depend on the driver, so a failing
766- # driver probe should NOT perturb their cache keys (separately
767- # verified by test_..._driver_probe_failure_does_not_perturb_non_linker).
768- pytest .param ("_driver_version" , "ptx" , ".ptx" , id = "driver" ),
769778 pytest .param ("_nvrtc_version" , "c++" , "a" , id = "nvrtc" ),
770779 pytest .param ("_cuda_core_version" , "c++" , "a" , id = "cuda_core" ),
771780 pytest .param ("_linker_backend_and_version" , "ptx" , ".ptx" , id = "linker" ),
@@ -775,7 +784,9 @@ def test_make_program_cache_key_fails_closed_on_probe_failure(probe_name, code_t
775784 """A failed probe (a) must produce a key that differs from a working
776785 probe (so environments never silently share cache entries), and (b)
777786 must produce a *stable* key across calls -- otherwise the persistent
778- cache could not be reused in broken environments."""
787+ cache could not be reused in broken environments. ``_driver_version``
788+ is exercised separately because it's only invoked transitively from
789+ ``_linker_backend_and_version`` on the cuLink driver path."""
779790 from cuda .core .utils import _program_cache
780791
781792 def _broken ():
@@ -789,6 +800,25 @@ def _broken():
789800 assert k_broken1 == k_broken2 # stable: same failure -> same key
790801
791802
803+ def test_make_program_cache_key_driver_probe_failure_taints_ptx_under_cuLink (monkeypatch ):
804+ """When the driver linker is active, _linker_backend_and_version
805+ invokes _driver_version internally; a failing driver probe propagates
806+ through the linker probe and must perturb the PTX key."""
807+ from cuda .core .utils import _program_cache
808+
809+ def _broken ():
810+ raise RuntimeError ("driver probe failed" )
811+
812+ # Force the cuLink driver path so _linker_backend_and_version reads driver_version.
813+ from cuda .core import _linker
814+
815+ monkeypatch .setattr (_linker , "_decide_nvjitlink_or_driver" , lambda : True )
816+ k_ok = _make_key (code = ".ptx" , code_type = "ptx" )
817+ monkeypatch .setattr (_program_cache , "_driver_version" , _broken )
818+ k_broken = _make_key (code = ".ptx" , code_type = "ptx" )
819+ assert k_ok != k_broken
820+
821+
792822# ---------------------------------------------------------------------------
793823# SQLiteProgramCache -- basic CRUD
794824# ---------------------------------------------------------------------------
@@ -1085,7 +1115,7 @@ def reader(thread_id: int):
10851115 for t in threads :
10861116 t .join (timeout = 30 )
10871117 assert not any (t .is_alive () for t in threads )
1088- assert errors == []
1118+ assert not errors
10891119
10901120
10911121@needs_sqlite3
@@ -1436,7 +1466,7 @@ def test_filestream_cache_clear_preserves_young_tmp_files(tmp_path):
14361466 # Filenames are hash-like (no extension), so use a file filter rather
14371467 # than a "*.*" glob.
14381468 remaining_entries = [p for p in (root / "entries" ).rglob ("*" ) if p .is_file ()]
1439- assert remaining_entries == []
1469+ assert not remaining_entries
14401470 assert young_tmp .exists ()
14411471 assert not ancient_tmp .exists ()
14421472
0 commit comments