Skip to content

Commit 945506e

Browse files
committed
fixup! feat(core.utils): validate name_expressions; sweep stale tmp files
1 parent 8552ca9 commit 945506e

File tree

2 files changed

+134
-1
lines changed

2 files changed

+134
-1
lines changed

cuda_core/cuda/core/utils/_program_cache.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,10 +597,15 @@ def make_program_cache_key(
597597
# (_program.pyx:759). Returning a cached ObjectCode whose mapping-key
598598
# type differs from what the caller's later ``get_kernel`` passes would
599599
# silently miss -- so treat ``"foo"`` and ``b"foo"`` as distinct here.
600+
# Reject anything other than str/bytes/bytearray up front; Program.compile
601+
# would fail at compile time anyway, and persisting a key for an invalid
602+
# input is just a foot-gun.
600603
def _tag_name(n):
601604
if isinstance(n, (bytes, bytearray)):
602605
return b"b:" + bytes(n)
603-
return b"s:" + str(n).encode("utf-8")
606+
if isinstance(n, str):
607+
return b"s:" + n.encode("utf-8")
608+
raise TypeError(f"name_expressions elements must be str, bytes, or bytearray; got {type(n).__name__}")
604609

605610
names = tuple(sorted(_tag_name(n) for n in name_expressions))
606611

@@ -988,6 +993,11 @@ def _enforce_size_cap(self) -> None:
988993
_ENTRIES_SUBDIR = "entries"
989994
_TMP_SUBDIR = "tmp"
990995
_SCHEMA_FILE = "SCHEMA_VERSION"
996+
# Temp files older than this are assumed to belong to a crashed writer and
997+
# are eligible for cleanup. Picked large enough that no real ``os.replace``
998+
# write should still be in flight (writes are bounded by mkstemp + write +
999+
# fsync + replace, all fast on healthy disks).
1000+
_TMP_STALE_AGE_SECONDS = 3600
9911001

9921002

9931003
_SHARING_VIOLATION_WINERRORS = (32, 33) # ERROR_SHARING_VIOLATION, ERROR_LOCK_VIOLATION
@@ -1114,6 +1124,10 @@ def __init__(
11141124
with contextlib.suppress(FileNotFoundError):
11151125
entry.unlink()
11161126
self._schema_path.write_text(expected)
1127+
# Opportunistic startup sweep of orphaned temp files left by any
1128+
# crashed writers. Age-based so concurrent in-flight writes from
1129+
# other processes are preserved.
1130+
self._sweep_stale_tmp_files()
11171131

11181132
# -- key-to-path helpers -------------------------------------------------
11191133

@@ -1219,6 +1233,15 @@ def clear(self) -> None:
12191233
for path in list(self._iter_entry_paths()):
12201234
with contextlib.suppress(FileNotFoundError):
12211235
path.unlink()
1236+
# The user explicitly asked to wipe this cache, so also drop every
1237+
# temp file we can see (whether stale or in flight from this process).
1238+
# Other processes' in-flight writes will still complete to ``entries``
1239+
# via ``os.replace``, but their staging files are intentionally gone.
1240+
if self._tmp.exists():
1241+
for tmp in list(self._tmp.iterdir()):
1242+
if tmp.is_file():
1243+
with contextlib.suppress(FileNotFoundError):
1244+
tmp.unlink()
12221245
# Remove empty subdirs (best-effort; concurrent writers may re-create).
12231246
if self._entries.exists():
12241247
for sub in sorted(self._entries.iterdir(), reverse=True):
@@ -1238,18 +1261,51 @@ def _iter_entry_paths(self) -> Iterable[Path]:
12381261
if entry.is_file():
12391262
yield entry
12401263

1264+
def _sweep_stale_tmp_files(self) -> None:
1265+
"""Remove temp files left behind by crashed writers.
1266+
1267+
Age threshold is conservative (``_TMP_STALE_AGE_SECONDS``) so an
1268+
in-flight write from another process is not interrupted. Best
1269+
effort: a missing file or a permission failure is ignored.
1270+
"""
1271+
if not self._tmp.exists():
1272+
return
1273+
cutoff = time.time() - _TMP_STALE_AGE_SECONDS
1274+
for tmp in self._tmp.iterdir():
1275+
if not tmp.is_file():
1276+
continue
1277+
try:
1278+
if tmp.stat().st_mtime < cutoff:
1279+
tmp.unlink()
1280+
except (FileNotFoundError, PermissionError):
1281+
continue
1282+
12411283
def _enforce_size_cap(self) -> None:
12421284
if self._max_size_bytes is None:
12431285
return
1286+
# Sweep stale temp files first so a long-dead writer's leftovers
1287+
# don't drag the apparent size up and force needless eviction.
1288+
self._sweep_stale_tmp_files()
12441289
entries = []
12451290
total = 0
1291+
# Count both committed entries AND surviving temp files: temp files
1292+
# occupy disk too, even if they're young. Without this the soft cap
1293+
# silently undercounts in-flight writes.
12461294
for path in self._iter_entry_paths():
12471295
try:
12481296
st = path.stat()
12491297
except FileNotFoundError:
12501298
continue
12511299
entries.append((st.st_mtime, st.st_size, path))
12521300
total += st.st_size
1301+
if self._tmp.exists():
1302+
for tmp in self._tmp.iterdir():
1303+
if not tmp.is_file():
1304+
continue
1305+
try:
1306+
total += tmp.stat().st_size
1307+
except FileNotFoundError:
1308+
continue
12531309
if total <= self._max_size_bytes:
12541310
return
12551311
entries.sort() # oldest mtime first

cuda_core/tests/test_program_cache.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,14 @@ def test_make_program_cache_key_name_expressions_order_insensitive():
232232
assert _make_key(name_expressions=("f", "g")) == _make_key(name_expressions=("g", "f"))
233233

234234

235+
@pytest.mark.parametrize("bad", [123, 1.5, object(), None])
236+
def test_make_program_cache_key_rejects_invalid_name_expressions_element(bad):
237+
"""Program.compile only forwards str/bytes name_expressions to NVRTC;
238+
persisting a key for an invalid input is just a foot-gun. Reject up front."""
239+
with pytest.raises(TypeError, match="name_expressions"):
240+
_make_key(name_expressions=("ok", bad))
241+
242+
235243
def test_make_program_cache_key_name_expressions_str_bytes_distinct():
236244
"""``Program.compile`` records the *original* Python object as the key in
237245
``ObjectCode.symbol_mapping``. Returning a cached ObjectCode whose
@@ -1376,6 +1384,75 @@ def test_filestream_cache_rejects_negative_size_cap(tmp_path):
13761384
FileStreamProgramCache(tmp_path / "fc", max_size_bytes=-1)
13771385

13781386

1387+
def test_filestream_cache_sweeps_stale_tmp_files_on_open(tmp_path):
1388+
"""A crashed writer can leave files in ``tmp/``; the next ``open`` must
1389+
sweep ones older than the staleness threshold so disk usage doesn't
1390+
grow without bound."""
1391+
import os as _os
1392+
1393+
from cuda.core.utils import FileStreamProgramCache, _program_cache
1394+
1395+
root = tmp_path / "fc"
1396+
# Create the cache directory layout, then plant two temp files: one
1397+
# young (must be preserved as it could be an in-flight write) and one
1398+
# ancient (must be swept).
1399+
with FileStreamProgramCache(root):
1400+
pass
1401+
young = root / "tmp" / "entry-young"
1402+
young.write_bytes(b"in-flight")
1403+
ancient = root / "tmp" / "entry-ancient"
1404+
ancient.write_bytes(b"crashed-writer-leftover")
1405+
ancient_mtime = time.time() - _program_cache._TMP_STALE_AGE_SECONDS - 60
1406+
_os.utime(ancient, (ancient_mtime, ancient_mtime))
1407+
1408+
with FileStreamProgramCache(root):
1409+
# Reopen triggers _sweep_stale_tmp_files.
1410+
assert young.exists(), "young temp file must not be swept"
1411+
assert not ancient.exists(), "ancient temp file should have been swept"
1412+
1413+
1414+
def test_filestream_cache_clear_drops_all_tmp_files(tmp_path):
1415+
"""clear() is an explicit user wipe, so it removes every temp file too
1416+
-- including young ones (other processes' in-flight writes will still
1417+
complete to entries/, just without their staging file)."""
1418+
from cuda.core.utils import FileStreamProgramCache
1419+
1420+
root = tmp_path / "fc"
1421+
with FileStreamProgramCache(root) as cache:
1422+
cache[b"k"] = _fake_object_code(b"v")
1423+
young_tmp = root / "tmp" / "entry-young"
1424+
young_tmp.write_bytes(b"in-flight")
1425+
1426+
with FileStreamProgramCache(root) as cache:
1427+
cache.clear()
1428+
assert not young_tmp.exists()
1429+
1430+
1431+
def test_filestream_cache_size_cap_counts_tmp_files(tmp_path):
1432+
"""Surviving temp files occupy disk too; the soft cap must include
1433+
them, otherwise an attacker (or a flurry of crashed writers) could
1434+
inflate disk usage well past max_size_bytes."""
1435+
from cuda.core.utils import FileStreamProgramCache
1436+
1437+
cap = 4000
1438+
root = tmp_path / "fc"
1439+
with FileStreamProgramCache(root, max_size_bytes=cap) as cache:
1440+
cache[b"a"] = _fake_object_code(b"A" * 1500, name="a")
1441+
time.sleep(0.02)
1442+
cache[b"b"] = _fake_object_code(b"B" * 1500, name="b")
1443+
# Plant a young temp file that pushes total over the cap.
1444+
young_tmp = root / "tmp" / "entry-leftover"
1445+
young_tmp.write_bytes(b"X" * 2500)
1446+
1447+
with FileStreamProgramCache(root, max_size_bytes=cap) as cache:
1448+
# New write triggers _enforce_size_cap; 'a' must be evicted because
1449+
# the temp file's bytes count toward the cap now.
1450+
time.sleep(0.02)
1451+
cache[b"c"] = _fake_object_code(b"C" * 200, name="c")
1452+
assert b"a" not in cache
1453+
assert b"c" in cache
1454+
1455+
13791456
def test_filestream_cache_handles_long_keys(tmp_path):
13801457
"""Arbitrary-length keys must not overflow per-component filename limits.
13811458
The filename is a fixed-length hash; the original key is verified from

0 commit comments

Comments
 (0)