Skip to content

Commit f4214df

Browse files
committed
fix(core): tighten cache-wrapper guards and fix __contains__ atime bump
Roborev jobs NVIDIA#1756 and NVIDIA#1757 surfaced four medium findings: * (NVIDIA#1756 medium @ _program_cache.py:1390) ``FileStreamProgramCache.__contains__`` routed through ``__getitem__``, which read the full payload and called ``_touch_atime``. Membership probes thus counted as LRU reads, inverting eviction relative to genuine reads. Fixed by making ``__contains__`` a stat-only ``self._path_for_key(key).exists()`` check. * (NVIDIA#1756 medium @ _program_cache.py:440) NVRTC uses ``options.name`` as the source filename and resolves quoted ``#include "x.h"`` directives relative to its directory. The cache cannot observe edits to neighbour headers, so an ``options.name`` with a directory component must require an ``extra_digest`` -- the same treatment ``include_path``/``pre_include`` already get. Added the guard in ``make_program_cache_key`` (rejecting ``"/"`` and ``"\\"`` in ``options.name`` on the NVRTC backend when ``extra_digest`` is ``None``). * (NVIDIA#1757 medium @ _program.pyx:156) Cache hits dropped ``ObjectCode.symbol_mapping`` even when ``name_expressions`` were provided. The first call (miss) returned an ObjectCode WITH mappings; every subsequent call (hit) returned one WITHOUT -- silently breaking later ``get_kernel(name_expression)`` lookups that worked on the uncached path. Fixed by rejecting non-empty ``name_expressions`` in ``Program.compile(cache=...)`` so hit and miss behavior cannot diverge. Compiles that need ``name_expressions`` should run without ``cache=``, or look up mangled symbols by hand from the cached ``ObjectCode``. * (NVIDIA#1757 medium @ _program.pyx:142) The ``Program.compile`` docstring claimed NVRTC ``options.name`` with a directory component is rejected, but the wrapper just delegated to ``make_program_cache_key`` without that helper enforcing it. Now enforced (via the helper, per the previous bullet) and tested end-to-end via the wrapper. Tests cover the new rejections (parametrized over ``/``, ``\\``, absolute paths, parent-relative paths), the ``extra_digest`` escape hatch, the ``name_expressions`` rejection (and that an empty ``name_expressions`` is still accepted), and the ``__contains__`` atime-preservation invariant (hammer membership on a cold key, then write a third entry; the cold key must evict, proving the membership probes did not promote it).
1 parent 2607234 commit f4214df

4 files changed

Lines changed: 183 additions & 46 deletions

File tree

cuda_core/cuda/core/_program.pyx

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -105,22 +105,24 @@ cdef class Program:
105105
Object with a ``write`` method to receive compilation logs.
106106
cache : :class:`~cuda.core.utils.ProgramCacheResource`, optional
107107
If provided, the compiled binary is looked up in ``cache`` via a
108-
key derived from the program's code, options, ``target_type`` and
109-
``name_expressions``. On a hit the cached bytes are wrapped in a
110-
fresh :class:`~cuda.core.ObjectCode` (with the same ``target_type``
108+
key derived from the program's code, options, and ``target_type``.
109+
On a hit the cached bytes are wrapped in a fresh
110+
:class:`~cuda.core.ObjectCode` (with the same ``target_type``
111111
and ``ProgramOptions.name``) and returned without re-compiling;
112112
on a miss the compile output is stored as raw bytes (the cache
113-
extracts ``bytes(object_code.code)``). Note that
114-
``ObjectCode.symbol_mapping`` is not preserved across a cache
115-
round-trip -- callers using ``name_expressions`` who need
116-
``get_kernel(name_expression)`` after a hit must compile fresh
117-
or look the mangled symbol up by hand. Options that require an
118-
``extra_digest`` (``include_path``, ``pre_include``, ``pch``,
119-
``use_pch``, ``pch_dir``, NVVM ``use_libdevice=True``, or NVRTC
120-
``options.name`` with a directory component) raise ``ValueError``
121-
via :func:`~cuda.core.utils.make_program_cache_key`; for those
122-
compiles, use the manual ``make_program_cache_key(...)`` pattern
123-
directly.
113+
extracts ``bytes(object_code.code)``). Passing a non-empty
114+
``name_expressions`` together with ``cache=`` raises
115+
``ValueError``: NVRTC populates
116+
``ObjectCode.symbol_mapping`` at compile time and that mapping
117+
is not carried in the binary the cache stores, so cache hits
118+
would silently miss ``get_kernel(name_expression)`` lookups.
119+
Options that require an ``extra_digest`` (``include_path``,
120+
``pre_include``, ``pch``, ``use_pch``, ``pch_dir``, NVVM
121+
``use_libdevice=True``, or NVRTC ``options.name`` with a
122+
directory component) raise ``ValueError`` via
123+
:func:`~cuda.core.utils.make_program_cache_key`; for those
124+
compiles, use the manual ``make_program_cache_key(...)``
125+
pattern directly.
124126

125127
Returns
126128
-------
@@ -130,6 +132,26 @@ cdef class Program:
130132
if cache is None:
131133
return _program_compile_uncached(self, target_type, name_expressions, logs)
132134

135+
# ``name_expressions`` is incompatible with the cache: NVRTC
136+
# populates ``ObjectCode.symbol_mapping`` from name-expression
137+
# mangling at compile time, and that mapping isn't carried in
138+
# the binary bytes the cache stores. Without this guard the
139+
# first call (cache miss) would return an ObjectCode with
140+
# symbol_mapping populated, while every subsequent call (hit)
141+
# would return one without -- silently breaking later
142+
# ``get_kernel(name_expression)`` lookups that work on the
143+
# uncached path. Fail loud here instead.
144+
if name_expressions:
145+
raise ValueError(
146+
"Program.compile(cache=...) does not support name_expressions: "
147+
"ObjectCode.symbol_mapping is populated by NVRTC at compile "
148+
"time and is not preserved across a cache round-trip, so cache "
149+
"hits would silently break get_kernel(name_expression) lookups "
150+
"that the uncached path supports. Compile without cache= when "
151+
"name_expressions are needed, or look up mangled symbols by "
152+
"hand from the cached ObjectCode."
153+
)
154+
133155
# Deferred import to avoid a circular import between _program and
134156
# cuda.core.utils._program_cache (the cache module already imports
135157
# ProgramOptions from this module). Import from the leaf module so
@@ -150,7 +172,6 @@ cdef class Program:
150172
code_type=self._code_type,
151173
options=self._options,
152174
target_type=target_type,
153-
name_expressions=name_expressions,
154175
)
155176
hit_bytes = cache.get(key)
156177
if hit_bytes is not None:

cuda_core/cuda/core/utils/_program_cache.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -607,13 +607,15 @@ def make_program_cache_key(
607607
symbol explicitly.
608608
609609
Options that read external files (``include_path``, ``pre_include``,
610-
``pch``, ``use_pch``, ``pch_dir``; and ``use_libdevice=True`` on the
611-
NVVM path) require ``extra_digest`` -- fingerprint the bytes the
612-
compiler will pull in and pass that digest so changes to those files
613-
force a cache miss. Options that have compile-time side effects
614-
(``create_pch``, ``time``, ``fdevice_time_trace``) cannot be cached
615-
and raise ``ValueError``; compile directly, or disable the flag, for
616-
those cases.
610+
``pch``, ``use_pch``, ``pch_dir``; ``use_libdevice=True`` on the NVVM
611+
path; and on NVRTC, an ``options.name`` with a directory component,
612+
which NVRTC uses for relative-include resolution) require
613+
``extra_digest`` -- fingerprint the bytes the compiler will pull in
614+
and pass that digest so changes to those files force a cache miss.
615+
Options that have compile-time side effects (``create_pch``,
616+
``time``, ``fdevice_time_trace``) cannot be cached and raise
617+
``ValueError``; compile directly, or disable the flag, for those
618+
cases.
617619
"""
618620
# Mirror Program.compile (_program.pyx Program_init lowercases code_type
619621
# before dispatch); a caller that passes "PTX" or "C++" must get the
@@ -678,6 +680,23 @@ def make_program_cache_key(
678680
f"extra_digest; compute a digest over the header/PCH bytes the "
679681
f"compile will read and pass it as extra_digest=..."
680682
)
683+
# NVRTC uses ``options.name`` as the source filename and resolves
684+
# quoted ``#include "x.h"`` directives relative to the directory
685+
# component of that name. The directory's contents are external
686+
# to anything else the key observes, so a name with a directory
687+
# component requires the same extra_digest treatment as
688+
# ``include_path`` etc.: changes to neighbour headers must
689+
# invalidate the cache, and the cache itself can't read those
690+
# files on the caller's behalf.
691+
options_name = getattr(options, "name", None)
692+
if isinstance(options_name, str) and ("/" in options_name or "\\" in options_name):
693+
raise ValueError(
694+
f"make_program_cache_key() refuses to build a key for options.name="
695+
f"{options_name!r} (NVRTC source-filename with a directory "
696+
f"component) without an extra_digest; NVRTC resolves quoted "
697+
f"#include directives relative to that directory, so a digest "
698+
f"covering the headers it may pull in must be supplied."
699+
)
681700

682701
# PTX compiles go through Linker. When the driver (cuLink) backend is
683702
# selected (nvJitLink unavailable), Program.compile rejects a subset of
@@ -1388,14 +1407,11 @@ def _path_for_key(self, key: object) -> Path:
13881407
# -- mapping API ---------------------------------------------------------
13891408

13901409
def __contains__(self, key: object) -> bool:
1391-
# Route through __getitem__ so corrupt records / schema mismatches /
1392-
# stored-key mismatches are treated as absent (and pruned), matching
1393-
# the semantics of ``cache[key]``.
1394-
try:
1395-
self[key]
1396-
except KeyError:
1397-
return False
1398-
return True
1410+
# Membership is a stat-only check: it must not read the payload (a
1411+
# full file read just to answer ``key in cache`` is wasteful) and
1412+
# must not bump atime (otherwise probing keeps cold entries hot
1413+
# and skews LRU eviction).
1414+
return self._path_for_key(key).exists()
13991415

14001416
def __getitem__(self, key: object) -> bytes:
14011417
path = self._path_for_key(key)

cuda_core/tests/test_program_cache.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,48 @@ def test_make_program_cache_key_rejects_external_content_without_extra_digest(op
945945
_make_key(options=_opts(**option_kw))
946946

947947

948+
@pytest.mark.parametrize(
949+
"name",
950+
[
951+
pytest.param("kernels/foo.cu", id="forward_slash"),
952+
pytest.param("kernels\\foo.cu", id="backslash"),
953+
pytest.param("/abs/foo.cu", id="absolute_unix"),
954+
pytest.param("../parent/foo.cu", id="parent_relative"),
955+
],
956+
)
957+
def test_make_program_cache_key_rejects_nvrtc_name_with_dir_component(name):
958+
"""NVRTC uses ``options.name`` as the source filename and resolves
959+
quoted ``#include "x.h"`` directives relative to its directory. Without
960+
an ``extra_digest`` the cache cannot observe edits to those neighbour
961+
headers, so a stale cached binary could be served. Reject the input so
962+
callers either pass an extra_digest or strip the directory component."""
963+
with pytest.raises(ValueError, match="directory component"):
964+
_make_key(options=_opts(name=name))
965+
966+
967+
@pytest.mark.parametrize(
968+
"name",
969+
[
970+
pytest.param("foo.cu", id="bare_filename"),
971+
pytest.param("default_program", id="default"),
972+
pytest.param("", id="empty"),
973+
],
974+
)
975+
def test_make_program_cache_key_accepts_nvrtc_name_without_dir_component(name):
976+
"""Names without a directory component are fine: NVRTC's relative-include
977+
resolution doesn't reach outside the in-memory program."""
978+
_make_key(options=_opts(name=name)) # Should not raise.
979+
980+
981+
def test_make_program_cache_key_accepts_nvrtc_name_with_dir_when_extra_digest_supplied():
982+
"""``extra_digest`` is the escape hatch: the caller has fingerprinted
983+
whatever the directory contributes, so the guard stands down."""
984+
_make_key(
985+
options=_opts(name="kernels/foo.cu"),
986+
extra_digest=b"caller-fingerprint",
987+
)
988+
989+
948990
@pytest.mark.parametrize(
949991
"option_kw",
950992
[
@@ -1239,6 +1281,31 @@ def test_filestream_cache_persists_across_reopen(tmp_path):
12391281
assert cache[b"k"] == b"persisted"
12401282

12411283

1284+
def test_filestream_cache_contains_does_not_read_or_promote_lru(tmp_path):
1285+
"""``__contains__`` is a stat-only check: it must not read the payload
1286+
(a full file read just to answer membership is wasteful) and must not
1287+
bump atime (otherwise probing keeps cold entries hot and inverts LRU
1288+
eviction relative to genuine reads)."""
1289+
from cuda.core.utils import FileStreamProgramCache
1290+
1291+
cap = 250
1292+
with FileStreamProgramCache(tmp_path / "fc", max_size_bytes=cap) as cache:
1293+
cache[b"a"] = b"a" * 100
1294+
cache[b"b"] = b"b" * 100
1295+
1296+
# Hammer membership for 'a' (cold) -- must NOT promote it.
1297+
for _ in range(50):
1298+
assert b"a" in cache
1299+
1300+
# The next size-cap-triggering write should evict 'a' (the genuinely
1301+
# oldest by atime). If membership had bumped atime, 'b' would evict
1302+
# instead.
1303+
cache[b"c"] = b"c" * 100
1304+
assert b"a" not in cache
1305+
assert b"b" in cache
1306+
assert b"c" in cache
1307+
1308+
12421309
def test_filestream_cache_permission_error_propagates_on_posix(tmp_path, monkeypatch):
12431310
"""On non-Windows, PermissionError from os.replace is a real config error
12441311
and must not be silently swallowed."""

cuda_core/tests/test_program_compile_cache.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -116,20 +116,38 @@ def _explode(_program, *_args, **_kwargs):
116116
assert cache.set_calls == []
117117

118118

119-
def test_name_expressions_affects_cache_key(monkeypatch):
120-
"""Different ``name_expressions`` must produce different cache keys."""
121-
sentinel = _make_sentinel_object_code()
119+
def test_cache_rejects_name_expressions():
120+
"""``name_expressions`` is incompatible with ``cache=``: the cache stores
121+
raw binary bytes, but ``ObjectCode.symbol_mapping`` (populated by
122+
NVRTC name-expression mangling) is not preserved across a cache
123+
round-trip. Without an explicit rejection, the first call (miss)
124+
would return an ObjectCode with mappings while every subsequent call
125+
(hit) would return one without -- silently breaking later
126+
``get_kernel(name_expression)`` lookups."""
122127
program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80"))
123-
monkeypatch.setattr(
124-
_program_module, "_program_compile_uncached", lambda _self, *_args, **_kwargs: sentinel
125-
)
126128
cache = _RecordingCache()
127129

128-
program.compile("cubin", name_expressions=("foo",), cache=cache)
129-
program.compile("cubin", name_expressions=("foo", "bar"), cache=cache)
130+
with pytest.raises(ValueError, match="name_expressions"):
131+
program.compile("cubin", name_expressions=("foo",), cache=cache)
132+
133+
# Wrapper rejects BEFORE touching the cache.
134+
assert cache.get_calls == []
135+
assert cache.set_calls == []
136+
130137

131-
assert len(cache.get_calls) == 2
132-
assert cache.get_calls[0] != cache.get_calls[1]
138+
def test_cache_accepts_empty_name_expressions(monkeypatch):
139+
"""An empty ``name_expressions`` (the default) must NOT be rejected --
140+
it's the no-op case, fully supported by the cache."""
141+
sentinel = _make_sentinel_object_code()
142+
program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80"))
143+
monkeypatch.setattr(_program_module, "_program_compile_uncached", lambda _self, *_args, **_kwargs: sentinel)
144+
cache = _RecordingCache()
145+
146+
# Default empty tuple, explicit empty tuple, and explicit empty list
147+
# all go through.
148+
program.compile("cubin", cache=cache)
149+
program.compile("cubin", name_expressions=(), cache=cache)
150+
program.compile("cubin", name_expressions=[], cache=cache)
133151

134152

135153
def test_cache_raises_for_extra_digest_required_option():
@@ -148,6 +166,25 @@ def test_cache_raises_for_extra_digest_required_option():
148166
assert cache.set_calls == []
149167

150168

169+
def test_cache_raises_for_nvrtc_name_with_dir_component():
170+
"""NVRTC ``options.name`` with a directory component must propagate a
171+
ValueError: NVRTC resolves quoted ``#include`` directives relative to
172+
that directory, so neighbour-header changes wouldn't invalidate the
173+
cache without an extra_digest."""
174+
program = Program(
175+
_KERNEL,
176+
"c++",
177+
ProgramOptions(arch="sm_80", name="kernels/foo.cu"),
178+
)
179+
cache = _RecordingCache()
180+
181+
with pytest.raises(ValueError, match="directory component"):
182+
program.compile("cubin", cache=cache)
183+
184+
assert cache.get_calls == []
185+
assert cache.set_calls == []
186+
187+
151188
def test_cache_raises_for_side_effect_option(tmp_path):
152189
"""Options with compile-time side effects can't be cached."""
153190
program = Program(
@@ -200,9 +237,7 @@ def test_cache_write_exception_propagates(monkeypatch):
200237
"""Exceptions from cache.__setitem__ propagate after compile runs."""
201238
sentinel = _make_sentinel_object_code()
202239
program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80"))
203-
monkeypatch.setattr(
204-
_program_module, "_program_compile_uncached", lambda _self, *_args, **_kwargs: sentinel
205-
)
240+
monkeypatch.setattr(_program_module, "_program_compile_uncached", lambda _self, *_args, **_kwargs: sentinel)
206241
cache = _RecordingCache()
207242
cache.set_side_effect = RuntimeError("disk full")
208243

@@ -217,9 +252,7 @@ def test_no_cache_kwarg_does_not_derive_key(monkeypatch):
217252
"""Without cache=, no cache-module functions run; compile goes straight through."""
218253
sentinel = _make_sentinel_object_code()
219254
program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80"))
220-
monkeypatch.setattr(
221-
_program_module, "_program_compile_uncached", lambda _self, *_args, **_kwargs: sentinel
222-
)
255+
monkeypatch.setattr(_program_module, "_program_compile_uncached", lambda _self, *_args, **_kwargs: sentinel)
223256

224257
# If the implementation accidentally derived a key, it would call
225258
# make_program_cache_key. Replace it with a raising stub to catch that.

0 commit comments

Comments
 (0)