Skip to content

Commit e0b46cb

Browse files
committed
fix(core): two robustness fixes from roborev review
Roborev jobs #1770 and #1771: * MEDIUM @ _program_cache.py:1235 -- ``__setitem__`` calls ``tempfile.mkstemp(dir=self._tmp)`` without ensuring ``tmp/`` still exists. If something deletes it after ``__init__`` (operators clearing by hand, another process's wipe), every subsequent write crashes with FileNotFoundError even though we could trivially recreate it. Fixed: ``self._tmp.mkdir(parents=True, exist_ok=True)`` before mkstemp -- same defensive recreation we already do for ``entries/<2-char>/`` shard subdirs. * LOW @ _program.pyx:177 -- the uncached NVRTC path warns when the active driver can't load freshly-generated PTX. Cache hits skipped this warning, so ``Program.compile("ptx", cache=cache)`` could silently hand back PTX that won't actually load on the active driver. Loadability is a property of the driver, not of how the bytes were produced -- mirror the warning before returning a cached hit so cached and uncached calls behave the same on loadability-vs-driver mismatches. Tests added: * ``test_filestream_cache_recreates_tmp_dir_if_missing`` -- ``shutil.rmtree(root/"tmp")`` between two writes; the second write must still succeed and recreate the dir. * ``test_cache_hit_emits_ptx_loadability_warning_when_driver_too_old`` -- monkeypatch ``_can_load_generated_ptx`` to False on a preseeded cache; cache hit must emit the RuntimeWarning. * ``test_cache_hit_no_ptx_warning_when_driver_supports_it`` -- inverse: warning must NOT fire when the driver can load it. Roborev #1770 also raised a MEDIUM about ``extra_sources`` being hashed in caller-provided order. NVVM module linking is order-dependent in the general case (overlapping symbols, weak definitions, etc.), so canonicalizing the order would be unsound without proving the input subset NVVM actually treats as order-insensitive. Keeping the order-sensitive hash is the safer default; documented as part of ``_NvvmBackend.hash_extra_payload``'s comment in the strategy class.
1 parent e08e89e commit e0b46cb

4 files changed

Lines changed: 91 additions & 0 deletions

File tree

cuda_core/cuda/core/_program.pyx

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,23 @@ cdef class Program:
175175
)
176176
hit_bytes = cache.get(key)
177177
if hit_bytes is not None:
178+
# The uncached NVRTC path warns when the active driver can't
179+
# load freshly-generated PTX; that loadability is a property
180+
# of the driver, not of how the bytes were produced, so the
181+
# warning applies equally to cached PTX. Mirror it here so a
182+
# cache hit doesn't silently hide an incompatibility that the
183+
# uncached call would have surfaced.
184+
if (
185+
self._backend == "NVRTC"
186+
and target_type == "ptx"
187+
and not _can_load_generated_ptx()
188+
):
189+
warn(
190+
"The CUDA driver version is older than the backend version. "
191+
"The generated ptx will not be loadable by the current driver.",
192+
stacklevel=2,
193+
category=RuntimeWarning,
194+
)
178195
return ObjectCode._init(hit_bytes, target_type, name=self._options.name)
179196
compiled = _program_compile_uncached(self, target_type, name_expressions, logs)
180197
cache[key] = compiled

cuda_core/cuda/core/utils/_program_cache.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,6 +1513,12 @@ def __setitem__(self, key: object, value: bytes | bytearray | memoryview | Objec
15131513
data = _extract_bytes(value)
15141514
target = self._path_for_key(key)
15151515
target.parent.mkdir(parents=True, exist_ok=True)
1516+
# Re-create ``tmp/`` if something deleted it after ``__init__``
1517+
# (operators clearing the cache by hand, ``rm -rf cache_dir/tmp``,
1518+
# another process's overzealous wipe). Cheap and idempotent;
1519+
# without it, every subsequent write would crash with
1520+
# FileNotFoundError even though we could trivially recover.
1521+
self._tmp.mkdir(parents=True, exist_ok=True)
15161522

15171523
fd, tmp_name = tempfile.mkstemp(prefix="entry-", dir=self._tmp)
15181524
tmp_path = Path(tmp_name)

cuda_core/tests/test_program_cache.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,6 +1576,27 @@ def test_filestream_cache_uses_default_dir_when_path_omitted(tmp_path, monkeypat
15761576
assert (tmp_path / "default-fc" / "entries").is_dir()
15771577

15781578

1579+
def test_filestream_cache_recreates_tmp_dir_if_missing(tmp_path):
1580+
"""If ``tmp/`` is deleted out from under a live cache (operator clearing
1581+
by hand, another process's wipe, etc.), the next write must recreate
1582+
it rather than crash with ``FileNotFoundError``. The cache already
1583+
recreates the per-shard ``entries/<2-char>/`` directory; ``tmp/``
1584+
deserves the same treatment since ``mkstemp`` writes there."""
1585+
import shutil
1586+
1587+
from cuda.core.utils import FileStreamProgramCache
1588+
1589+
root = tmp_path / "fc"
1590+
with FileStreamProgramCache(root) as cache:
1591+
cache[b"k"] = b"first"
1592+
# Nuke tmp/ between writes; the second write must still succeed.
1593+
shutil.rmtree(root / "tmp")
1594+
cache[b"k2"] = b"second"
1595+
assert cache[b"k"] == b"first"
1596+
assert cache[b"k2"] == b"second"
1597+
assert (root / "tmp").is_dir()
1598+
1599+
15791600
def test_filestream_cache_sweeps_stale_tmp_files_on_open(tmp_path):
15801601
"""A crashed writer can leave files in ``tmp/``; the next ``open`` must
15811602
sweep ones older than the staleness threshold so disk usage doesn't

cuda_core/tests/test_program_compile_cache.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,53 @@ def _explode(_program, *_args, **_kwargs):
116116
assert cache.set_calls == []
117117

118118

119+
def test_cache_hit_emits_ptx_loadability_warning_when_driver_too_old(monkeypatch):
120+
"""The uncached NVRTC path warns when the active driver can't load
121+
freshly-generated PTX. That loadability is a property of the driver,
122+
not of how the bytes were produced, so a cache hit on the same
123+
(NVRTC + ptx target_type) compile must emit the same warning. Without
124+
this mirror, ``compile("ptx", cache=cache)`` would silently hand back
125+
PTX that won't actually load on the active driver."""
126+
options = ProgramOptions(arch="sm_80", name="warn_program")
127+
program = Program(_KERNEL, "c++", options)
128+
key = make_program_cache_key(code=_KERNEL, code_type="c++", options=options, target_type="ptx")
129+
130+
monkeypatch.setattr(_program_module, "_can_load_generated_ptx", lambda: False)
131+
132+
def _explode(_program, *_args, **_kwargs):
133+
raise AssertionError("_program_compile_uncached must not be called on cache hit")
134+
135+
monkeypatch.setattr(_program_module, "_program_compile_uncached", _explode)
136+
cache = _RecordingCache(preseed={key: _SENTINEL_BYTES})
137+
138+
with pytest.warns(RuntimeWarning, match="driver"):
139+
result = program.compile("ptx", cache=cache)
140+
141+
assert bytes(result.code) == _SENTINEL_BYTES
142+
143+
144+
def test_cache_hit_no_ptx_warning_when_driver_supports_it(monkeypatch):
145+
"""Inverse of the warning test: when the driver can load the PTX, no
146+
warning is emitted on cache hit (the wrapper must not be over-eager)."""
147+
import warnings
148+
149+
options = ProgramOptions(arch="sm_80", name="quiet_program")
150+
program = Program(_KERNEL, "c++", options)
151+
key = make_program_cache_key(code=_KERNEL, code_type="c++", options=options, target_type="ptx")
152+
153+
monkeypatch.setattr(_program_module, "_can_load_generated_ptx", lambda: True)
154+
155+
def _explode(_program, *_args, **_kwargs):
156+
raise AssertionError("_program_compile_uncached must not be called on cache hit")
157+
158+
monkeypatch.setattr(_program_module, "_program_compile_uncached", _explode)
159+
cache = _RecordingCache(preseed={key: _SENTINEL_BYTES})
160+
161+
with warnings.catch_warnings():
162+
warnings.simplefilter("error") # any warning becomes an exception
163+
program.compile("ptx", cache=cache)
164+
165+
119166
def test_cache_rejects_name_expressions():
120167
"""``name_expressions`` is incompatible with ``cache=``: the cache stores
121168
raw binary bytes, but ``ObjectCode.symbol_mapping`` (populated by

0 commit comments

Comments
 (0)