Skip to content

Commit f9d90f9

Browse files
committed
feat(core.utils): add FileStreamProgramCache backend
Persistent program cache backed by a directory of entry files, one per key hash. Writes stage into a tmp/ subdirectory and promote via os.replace so concurrent readers never observe a torn file. Corrupt entries are treated as cache misses and pruned. The max_size_bytes cap is enforced opportunistically on writes by oldest mtime; this is deliberately best-effort for multi-process use. Use SQLiteProgramCache for strict LRU semantics within a single process. Part of issue #178.
1 parent 76567d7 commit f9d90f9

File tree

3 files changed

+312
-0
lines changed

3 files changed

+312
-0
lines changed

cuda_core/cuda/core/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
args_viewable_as_strided_memory, # noqa: F401
88
)
99
from cuda.core.utils._program_cache import (
10+
FileStreamProgramCache, # noqa: F401
1011
ProgramCacheResource, # noqa: F401
1112
SQLiteProgramCache, # noqa: F401
1213
make_program_cache_key, # noqa: F401

cuda_core/cuda/core/utils/_program_cache.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import os
2323
import pickle
2424
import sqlite3
25+
import tempfile
2526
import time
2627
from pathlib import Path
2728
from typing import Hashable, Iterable, Optional, Sequence
@@ -35,6 +36,7 @@
3536
)
3637

3738
__all__ = [
39+
"FileStreamProgramCache",
3840
"ProgramCacheResource",
3941
"SQLiteProgramCache",
4042
"make_program_cache_key",
@@ -431,3 +433,187 @@ def _enforce_size_cap(self) -> None:
431433
return
432434
conn.execute("DELETE FROM entries WHERE key = ?", (k,))
433435
total -= sz
436+
437+
438+
# ---------------------------------------------------------------------------
439+
# FileStream backend
440+
# ---------------------------------------------------------------------------
441+
442+
443+
_FILESTREAM_SCHEMA_VERSION = 1
444+
_ENTRIES_SUBDIR = "entries"
445+
_TMP_SUBDIR = "tmp"
446+
_SCHEMA_FILE = "SCHEMA_VERSION"
447+
448+
449+
class FileStreamProgramCache(ProgramCacheResource):
450+
"""Persistent program cache backed by a directory of atomic files.
451+
452+
Designed for multi-process use: writes stage a temporary file and then
453+
:func:`os.replace` it into place, so concurrent readers never observe a
454+
partially-written entry. There is no cross-process LRU tracking; size
455+
enforcement is best-effort by file mtime.
456+
457+
Parameters
458+
----------
459+
path:
460+
Directory that owns the cache. Created if missing.
461+
max_size_bytes:
462+
Optional soft cap on total on-disk size. Enforced opportunistically
463+
on writes; concurrent writers may briefly exceed it.
464+
"""
465+
466+
def __init__(
467+
self,
468+
path: str | os.PathLike,
469+
*,
470+
max_size_bytes: Optional[int] = None,
471+
) -> None:
472+
if max_size_bytes is not None and max_size_bytes < 0:
473+
raise ValueError("max_size_bytes must be non-negative or None")
474+
self._root = Path(path)
475+
self._entries = self._root / _ENTRIES_SUBDIR
476+
self._tmp = self._root / _TMP_SUBDIR
477+
self._schema_path = self._root / _SCHEMA_FILE
478+
self._max_size_bytes = max_size_bytes
479+
self._root.mkdir(parents=True, exist_ok=True)
480+
self._entries.mkdir(exist_ok=True)
481+
self._tmp.mkdir(exist_ok=True)
482+
if not self._schema_path.exists():
483+
self._schema_path.write_text(str(_FILESTREAM_SCHEMA_VERSION))
484+
485+
# -- key-to-path helpers -------------------------------------------------
486+
487+
def _path_for_key(self, key: object) -> Path:
488+
k = _as_key_bytes(key)
489+
hex_ = k.hex() if k else "empty"
490+
if len(hex_) < 3:
491+
hex_ = hex_.rjust(3, "0")
492+
return self._entries / hex_[:2] / hex_[2:]
493+
494+
# -- mapping API ---------------------------------------------------------
495+
496+
def __contains__(self, key: object) -> bool:
497+
return self._path_for_key(key).exists()
498+
499+
def __getitem__(self, key: object) -> ObjectCode:
500+
path = self._path_for_key(key)
501+
try:
502+
data = path.read_bytes()
503+
except FileNotFoundError:
504+
raise KeyError(key)
505+
k = _as_key_bytes(key)
506+
try:
507+
record = pickle.loads(data)
508+
schema, stored_key, payload, _created_at = record
509+
if schema != _FILESTREAM_SCHEMA_VERSION:
510+
raise ValueError(f"unknown schema {schema}")
511+
if stored_key != k:
512+
raise ValueError("key mismatch")
513+
value = pickle.loads(payload)
514+
except Exception:
515+
# Corrupt entry -- delete and treat as a miss.
516+
try:
517+
path.unlink()
518+
except FileNotFoundError:
519+
pass
520+
raise KeyError(key)
521+
if not isinstance(value, ObjectCode):
522+
try:
523+
path.unlink()
524+
except FileNotFoundError:
525+
pass
526+
raise KeyError(key)
527+
return value
528+
529+
def __setitem__(self, key: object, value: object) -> None:
530+
obj = _require_object_code(value)
531+
k = _as_key_bytes(key)
532+
payload = pickle.dumps(obj, protocol=_PICKLE_PROTOCOL)
533+
record = pickle.dumps(
534+
(_FILESTREAM_SCHEMA_VERSION, k, payload, time.time()),
535+
protocol=_PICKLE_PROTOCOL,
536+
)
537+
538+
target = self._path_for_key(key)
539+
target.parent.mkdir(parents=True, exist_ok=True)
540+
541+
fd, tmp_name = tempfile.mkstemp(prefix="entry-", dir=self._tmp)
542+
tmp_path = Path(tmp_name)
543+
try:
544+
with os.fdopen(fd, "wb") as fh:
545+
fh.write(record)
546+
fh.flush()
547+
os.fsync(fh.fileno())
548+
os.replace(tmp_path, target)
549+
except BaseException:
550+
try:
551+
tmp_path.unlink()
552+
except FileNotFoundError:
553+
pass
554+
raise
555+
self._enforce_size_cap()
556+
557+
def __delitem__(self, key: object) -> None:
558+
path = self._path_for_key(key)
559+
try:
560+
path.unlink()
561+
except FileNotFoundError:
562+
raise KeyError(key)
563+
564+
def __len__(self) -> int:
565+
count = 0
566+
for _ in self._iter_entry_paths():
567+
count += 1
568+
return count
569+
570+
def clear(self) -> None:
571+
for path in list(self._iter_entry_paths()):
572+
try:
573+
path.unlink()
574+
except FileNotFoundError:
575+
pass
576+
# Remove empty subdirs (best-effort; concurrent writers may re-create).
577+
if self._entries.exists():
578+
for sub in sorted(self._entries.iterdir(), reverse=True):
579+
if sub.is_dir():
580+
try:
581+
sub.rmdir()
582+
except OSError:
583+
pass
584+
585+
# -- internals -----------------------------------------------------------
586+
587+
def _iter_entry_paths(self) -> Iterable[Path]:
588+
if not self._entries.exists():
589+
return
590+
for sub in self._entries.iterdir():
591+
if not sub.is_dir():
592+
continue
593+
for entry in sub.iterdir():
594+
if entry.is_file():
595+
yield entry
596+
597+
def _enforce_size_cap(self) -> None:
598+
if self._max_size_bytes is None:
599+
return
600+
entries = []
601+
total = 0
602+
for path in self._iter_entry_paths():
603+
try:
604+
st = path.stat()
605+
except FileNotFoundError:
606+
continue
607+
entries.append((st.st_mtime, st.st_size, path))
608+
total += st.st_size
609+
if total <= self._max_size_bytes:
610+
return
611+
entries.sort() # oldest mtime first
612+
for _mtime, size, path in entries:
613+
if total <= self._max_size_bytes:
614+
return
615+
try:
616+
path.unlink()
617+
total -= size
618+
except FileNotFoundError:
619+
pass

cuda_core/tests/test_program_cache.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,3 +434,128 @@ def test_sqlite_cache_unbounded_by_default(tmp_path):
434434
for i in range(25):
435435
cache[f"k{i}".encode()] = _fake_object_code(b"X" * 1024, name=f"n{i}")
436436
assert len(cache) == 25
437+
438+
439+
# ---------------------------------------------------------------------------
440+
# FileStreamProgramCache -- single-process CRUD
441+
# ---------------------------------------------------------------------------
442+
443+
444+
def test_filestream_cache_empty_on_create(tmp_path):
445+
from cuda.core.utils import FileStreamProgramCache
446+
447+
with FileStreamProgramCache(tmp_path / "fc") as cache:
448+
assert len(cache) == 0
449+
assert b"nope" not in cache
450+
with pytest.raises(KeyError):
451+
cache[b"nope"]
452+
453+
454+
def test_filestream_cache_roundtrip(tmp_path):
455+
from cuda.core.utils import FileStreamProgramCache
456+
457+
with FileStreamProgramCache(tmp_path / "fc") as cache:
458+
cache[b"k1"] = _fake_object_code(b"v1", name="x")
459+
assert b"k1" in cache
460+
got = cache[b"k1"]
461+
assert bytes(got._module) == b"v1"
462+
assert got._name == "x"
463+
assert got._code_type == "cubin"
464+
465+
466+
def test_filestream_cache_delete(tmp_path):
467+
from cuda.core.utils import FileStreamProgramCache
468+
469+
with FileStreamProgramCache(tmp_path / "fc") as cache:
470+
cache[b"k"] = _fake_object_code()
471+
del cache[b"k"]
472+
assert b"k" not in cache
473+
with pytest.raises(KeyError):
474+
del cache[b"k"]
475+
476+
477+
def test_filestream_cache_len_counts_all(tmp_path):
478+
from cuda.core.utils import FileStreamProgramCache
479+
480+
with FileStreamProgramCache(tmp_path / "fc") as cache:
481+
cache[b"a"] = _fake_object_code(b"1")
482+
cache[b"b"] = _fake_object_code(b"2")
483+
cache[b"c"] = _fake_object_code(b"3")
484+
assert len(cache) == 3
485+
486+
487+
def test_filestream_cache_clear(tmp_path):
488+
from cuda.core.utils import FileStreamProgramCache
489+
490+
root = tmp_path / "fc"
491+
with FileStreamProgramCache(root) as cache:
492+
cache[b"a"] = _fake_object_code()
493+
cache.clear()
494+
assert len(cache) == 0
495+
496+
497+
def test_filestream_cache_persists_across_reopen(tmp_path):
498+
from cuda.core.utils import FileStreamProgramCache
499+
500+
root = tmp_path / "fc"
501+
with FileStreamProgramCache(root) as cache:
502+
cache[b"k"] = _fake_object_code(b"persisted")
503+
with FileStreamProgramCache(root) as cache:
504+
assert bytes(cache[b"k"]._module) == b"persisted"
505+
506+
507+
def test_filestream_cache_atomic_no_half_written_file(tmp_path, monkeypatch):
508+
# Simulate a crash during write: patch os.replace to raise.
509+
import os as _os
510+
511+
from cuda.core.utils import FileStreamProgramCache
512+
513+
with FileStreamProgramCache(tmp_path / "fc") as cache:
514+
def _boom(src, dst):
515+
raise RuntimeError("crash during replace")
516+
517+
monkeypatch.setattr(_os, "replace", _boom)
518+
with pytest.raises(RuntimeError, match="crash"):
519+
cache[b"k"] = _fake_object_code(b"v")
520+
monkeypatch.undo()
521+
assert b"k" not in cache
522+
523+
524+
def test_filestream_cache_corruption_is_reported_as_miss(tmp_path):
525+
from cuda.core.utils import FileStreamProgramCache
526+
527+
root = tmp_path / "fc"
528+
with FileStreamProgramCache(root) as cache:
529+
cache[b"k"] = _fake_object_code(b"ok")
530+
path = cache._path_for_key(b"k")
531+
532+
# Corrupt the file on disk.
533+
path.write_bytes(b"\x00not-a-pickle")
534+
with FileStreamProgramCache(root) as cache:
535+
with pytest.raises(KeyError):
536+
cache[b"k"]
537+
assert b"k" not in cache
538+
539+
540+
def test_filestream_cache_rejects_non_object_code(tmp_path):
541+
from cuda.core.utils import FileStreamProgramCache
542+
543+
with FileStreamProgramCache(tmp_path / "fc") as cache:
544+
with pytest.raises(TypeError, match="ObjectCode"):
545+
cache[b"k"] = b"not an ObjectCode"
546+
547+
548+
def test_filestream_cache_rejects_negative_size_cap(tmp_path):
549+
from cuda.core.utils import FileStreamProgramCache
550+
551+
with pytest.raises(ValueError, match="non-negative"):
552+
FileStreamProgramCache(tmp_path / "fc", max_size_bytes=-1)
553+
554+
555+
def test_filestream_cache_accepts_str_keys(tmp_path):
556+
from cuda.core.utils import FileStreamProgramCache
557+
558+
with FileStreamProgramCache(tmp_path / "fc") as cache:
559+
cache["my-key"] = _fake_object_code(b"v")
560+
assert "my-key" in cache
561+
assert b"my-key" in cache

0 commit comments

Comments
 (0)