Skip to content

Commit 9bd35c0

Browse files
committed
feat(core.utils): add SQLiteProgramCache backend
Persistent program cache backed by a single sqlite3 file in WAL mode. Keys are arbitrary bytes (str keys are UTF-8 encoded); values are ObjectCode instances serialised with pickle. Corrupt entries are treated as cache misses and pruned on read. A max_size_bytes cap, when supplied, triggers LRU eviction on writes -- the tracking infrastructure is in place even though dedicated eviction tests land in a follow-up commit. Part of issue #178.
1 parent 2cb74f7 commit 9bd35c0

File tree

3 files changed

+334
-1
lines changed

3 files changed

+334
-1
lines changed

cuda_core/cuda/core/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@
88
)
99
from cuda.core.utils._program_cache import (
1010
ProgramCacheResource, # noqa: F401
11+
SQLiteProgramCache, # noqa: F401
1112
make_program_cache_key, # noqa: F401
1213
)

cuda_core/cuda/core/utils/_program_cache.py

Lines changed: 208 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919

2020
import abc
2121
import hashlib
22-
from typing import Hashable, Sequence
22+
import os
23+
import pickle
24+
import sqlite3
25+
import time
26+
from pathlib import Path
27+
from typing import Hashable, Iterable, Optional, Sequence
2328

2429
from cuda.core._module import ObjectCode
2530
from cuda.core._program import ProgramOptions
@@ -31,10 +36,30 @@
3136

3237
__all__ = [
3338
"ProgramCacheResource",
39+
"SQLiteProgramCache",
3440
"make_program_cache_key",
3541
]
3642

3743

44+
_PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
45+
46+
47+
def _require_object_code(value: object) -> ObjectCode:
48+
if not isinstance(value, ObjectCode):
49+
raise TypeError(
50+
f"cache values must be ObjectCode instances, got {type(value).__name__}"
51+
)
52+
return value
53+
54+
55+
def _as_key_bytes(key: object) -> bytes:
56+
if isinstance(key, (bytes, bytearray)):
57+
return bytes(key)
58+
if isinstance(key, str):
59+
return key.encode("utf-8")
60+
raise TypeError(f"cache keys must be bytes or str, got {type(key).__name__}")
61+
62+
3863
# ---------------------------------------------------------------------------
3964
# Abstract base class
4065
# ---------------------------------------------------------------------------
@@ -224,3 +249,185 @@ def _update(label: str, payload: bytes) -> None:
224249
_update("name", n.encode("utf-8"))
225250

226251
return hasher.digest()
252+
253+
254+
# ---------------------------------------------------------------------------
255+
# SQLite backend
256+
# ---------------------------------------------------------------------------
257+
258+
259+
_SQLITE_SCHEMA_VERSION = "1"
260+
261+
262+
class SQLiteProgramCache(ProgramCacheResource):
263+
"""Persistent program cache backed by a single sqlite3 database file.
264+
265+
Suitable for single-process workflows. Multiple processes *can* share the
266+
file (sqlite3 WAL mode serialises writes), but
267+
:class:`FileStreamProgramCache` is the recommended choice for concurrent
268+
workers.
269+
270+
Parameters
271+
----------
272+
path:
273+
Filesystem path to the sqlite3 database. The parent directory is
274+
created if missing.
275+
max_size_bytes:
276+
Optional size cap in bytes. When the sum of stored payload sizes
277+
exceeds the cap, the least-recently-used entries are evicted until
278+
the total is at or below the cap. ``None`` means unbounded.
279+
"""
280+
281+
def __init__(
282+
self,
283+
path: str | os.PathLike,
284+
*,
285+
max_size_bytes: Optional[int] = None,
286+
) -> None:
287+
if max_size_bytes is not None and max_size_bytes < 0:
288+
raise ValueError("max_size_bytes must be non-negative or None")
289+
self._path = Path(path)
290+
self._path.parent.mkdir(parents=True, exist_ok=True)
291+
self._max_size_bytes = max_size_bytes
292+
self._conn: Optional[sqlite3.Connection] = None
293+
self._open()
294+
295+
# -- lifecycle -----------------------------------------------------------
296+
297+
def _open(self) -> None:
298+
# ``isolation_level=None`` puts the connection in autocommit mode so
299+
# each statement is its own transaction; ``check_same_thread=False``
300+
# lets a cache be created in one thread and used from another (writes
301+
# are still serialised by sqlite's own lock).
302+
self._conn = sqlite3.connect(
303+
self._path,
304+
isolation_level=None,
305+
check_same_thread=False,
306+
timeout=5.0,
307+
)
308+
self._conn.execute("PRAGMA journal_mode=WAL")
309+
self._conn.execute("PRAGMA synchronous=NORMAL")
310+
self._conn.execute("PRAGMA foreign_keys=ON")
311+
self._conn.execute("PRAGMA busy_timeout=5000")
312+
self._conn.executescript(
313+
"""
314+
CREATE TABLE IF NOT EXISTS schema_meta (
315+
key TEXT PRIMARY KEY,
316+
value TEXT NOT NULL
317+
);
318+
CREATE TABLE IF NOT EXISTS entries (
319+
key BLOB PRIMARY KEY,
320+
payload BLOB NOT NULL,
321+
size_bytes INTEGER NOT NULL,
322+
created_at REAL NOT NULL,
323+
accessed_at REAL NOT NULL
324+
);
325+
CREATE INDEX IF NOT EXISTS idx_accessed_at
326+
ON entries(accessed_at);
327+
"""
328+
)
329+
self._conn.execute(
330+
"INSERT OR IGNORE INTO schema_meta(key, value) VALUES (?, ?)",
331+
("schema_version", _SQLITE_SCHEMA_VERSION),
332+
)
333+
334+
def close(self) -> None:
335+
if self._conn is not None:
336+
try:
337+
self._conn.close()
338+
finally:
339+
self._conn = None
340+
341+
def _require_open(self) -> sqlite3.Connection:
342+
if self._conn is None:
343+
raise RuntimeError("SQLiteProgramCache is closed")
344+
return self._conn
345+
346+
# -- mapping API ---------------------------------------------------------
347+
348+
def __contains__(self, key: object) -> bool:
349+
k = _as_key_bytes(key)
350+
row = self._require_open().execute(
351+
"SELECT 1 FROM entries WHERE key = ?", (k,)
352+
).fetchone()
353+
return row is not None
354+
355+
def __getitem__(self, key: object) -> ObjectCode:
356+
k = _as_key_bytes(key)
357+
conn = self._require_open()
358+
row = conn.execute(
359+
"SELECT payload FROM entries WHERE key = ?", (k,)
360+
).fetchone()
361+
if row is None:
362+
raise KeyError(key)
363+
payload = row[0]
364+
try:
365+
value = pickle.loads(payload)
366+
except Exception:
367+
# Corrupt entry -- delete and treat as a miss.
368+
conn.execute("DELETE FROM entries WHERE key = ?", (k,))
369+
raise KeyError(key)
370+
if not isinstance(value, ObjectCode):
371+
conn.execute("DELETE FROM entries WHERE key = ?", (k,))
372+
raise KeyError(key)
373+
conn.execute(
374+
"UPDATE entries SET accessed_at = ? WHERE key = ?",
375+
(time.time(), k),
376+
)
377+
return value
378+
379+
def __setitem__(self, key: object, value: object) -> None:
380+
obj = _require_object_code(value)
381+
k = _as_key_bytes(key)
382+
payload = pickle.dumps(obj, protocol=_PICKLE_PROTOCOL)
383+
now = time.time()
384+
conn = self._require_open()
385+
conn.execute(
386+
"""
387+
INSERT INTO entries(key, payload, size_bytes, created_at, accessed_at)
388+
VALUES (?, ?, ?, ?, ?)
389+
ON CONFLICT(key) DO UPDATE SET
390+
payload = excluded.payload,
391+
size_bytes = excluded.size_bytes,
392+
accessed_at = excluded.accessed_at
393+
""",
394+
(k, payload, len(payload), now, now),
395+
)
396+
self._enforce_size_cap()
397+
398+
def __delitem__(self, key: object) -> None:
399+
k = _as_key_bytes(key)
400+
conn = self._require_open()
401+
cur = conn.execute("DELETE FROM entries WHERE key = ?", (k,))
402+
if cur.rowcount == 0:
403+
raise KeyError(key)
404+
405+
def __len__(self) -> int:
406+
(n,) = self._require_open().execute(
407+
"SELECT COUNT(*) FROM entries"
408+
).fetchone()
409+
return int(n)
410+
411+
def clear(self) -> None:
412+
self._require_open().execute("DELETE FROM entries")
413+
414+
# -- eviction ------------------------------------------------------------
415+
416+
def _enforce_size_cap(self) -> None:
417+
if self._max_size_bytes is None:
418+
return
419+
conn = self._require_open()
420+
(total,) = conn.execute(
421+
"SELECT COALESCE(SUM(size_bytes), 0) FROM entries"
422+
).fetchone()
423+
if total <= self._max_size_bytes:
424+
return
425+
# Delete oldest (least-recently-used) until at or under the cap.
426+
rows: Iterable[tuple[bytes, int]] = conn.execute(
427+
"SELECT key, size_bytes FROM entries ORDER BY accessed_at ASC"
428+
).fetchall()
429+
for k, sz in rows:
430+
if total <= self._max_size_bytes:
431+
return
432+
conn.execute("DELETE FROM entries WHERE key = ?", (k,))
433+
total -= sz

cuda_core/tests/test_program_cache.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,128 @@ def test_make_program_cache_key_rejects_non_str_bytes_code():
251251
make_program_cache_key(
252252
code=12345, code_type="c++", options=_opts(), target_type="cubin"
253253
)
254+
255+
256+
# ---------------------------------------------------------------------------
257+
# SQLiteProgramCache -- basic CRUD
258+
# ---------------------------------------------------------------------------
259+
260+
261+
def _fake_object_code(payload: bytes = b"fake-cubin", name: str = "unit"):
262+
"""Build an ObjectCode without touching the driver."""
263+
from cuda.core._module import ObjectCode
264+
265+
return ObjectCode._init(payload, "cubin", name=name)
266+
267+
268+
def test_sqlite_cache_empty_on_create(tmp_path):
269+
from cuda.core.utils import SQLiteProgramCache
270+
271+
db = tmp_path / "cache.db"
272+
with SQLiteProgramCache(db) as cache:
273+
assert len(cache) == 0
274+
assert b"nope" not in cache
275+
with pytest.raises(KeyError):
276+
cache[b"nope"]
277+
assert cache.get(b"nope") is None
278+
279+
280+
def test_sqlite_cache_set_get_roundtrip(tmp_path):
281+
from cuda.core.utils import SQLiteProgramCache
282+
283+
db = tmp_path / "cache.db"
284+
with SQLiteProgramCache(db) as cache:
285+
key = b"k1"
286+
cache[key] = _fake_object_code(b"bytes-1", name="a")
287+
288+
assert key in cache
289+
assert len(cache) == 1
290+
got = cache[key]
291+
assert bytes(got._module) == b"bytes-1"
292+
assert got._name == "a"
293+
assert got._code_type == "cubin"
294+
295+
296+
def test_sqlite_cache_overwrite_same_key(tmp_path):
297+
from cuda.core.utils import SQLiteProgramCache
298+
299+
db = tmp_path / "cache.db"
300+
with SQLiteProgramCache(db) as cache:
301+
cache[b"k"] = _fake_object_code(b"v1")
302+
cache[b"k"] = _fake_object_code(b"v2")
303+
assert len(cache) == 1
304+
assert bytes(cache[b"k"]._module) == b"v2"
305+
306+
307+
def test_sqlite_cache_delete(tmp_path):
308+
from cuda.core.utils import SQLiteProgramCache
309+
310+
db = tmp_path / "cache.db"
311+
with SQLiteProgramCache(db) as cache:
312+
cache[b"k"] = _fake_object_code()
313+
del cache[b"k"]
314+
assert b"k" not in cache
315+
assert len(cache) == 0
316+
with pytest.raises(KeyError):
317+
del cache[b"k"]
318+
319+
320+
def test_sqlite_cache_clear(tmp_path):
321+
from cuda.core.utils import SQLiteProgramCache
322+
323+
db = tmp_path / "cache.db"
324+
with SQLiteProgramCache(db) as cache:
325+
cache[b"a"] = _fake_object_code(b"1")
326+
cache[b"b"] = _fake_object_code(b"2")
327+
cache.clear()
328+
assert len(cache) == 0
329+
330+
331+
def test_sqlite_cache_persists_across_open(tmp_path):
332+
from cuda.core.utils import SQLiteProgramCache
333+
334+
db = tmp_path / "cache.db"
335+
with SQLiteProgramCache(db) as cache:
336+
cache[b"k"] = _fake_object_code(b"persisted")
337+
with SQLiteProgramCache(db) as cache:
338+
assert bytes(cache[b"k"]._module) == b"persisted"
339+
340+
341+
def test_sqlite_cache_corruption_is_reported_as_miss(tmp_path):
342+
import sqlite3
343+
344+
from cuda.core.utils import SQLiteProgramCache
345+
346+
db = tmp_path / "cache.db"
347+
with SQLiteProgramCache(db) as cache:
348+
cache[b"k"] = _fake_object_code(b"ok")
349+
# Overwrite the payload with garbage directly in the DB.
350+
with sqlite3.connect(db) as conn:
351+
conn.execute(
352+
"UPDATE entries SET payload = ? WHERE key = ?",
353+
(b"\x00\x01garbage", b"k"),
354+
)
355+
conn.commit()
356+
with SQLiteProgramCache(db) as cache:
357+
with pytest.raises(KeyError):
358+
cache[b"k"]
359+
assert b"k" not in cache # corrupt entry was pruned
360+
361+
362+
def test_sqlite_cache_rejects_non_object_code(tmp_path):
363+
from cuda.core.utils import SQLiteProgramCache
364+
365+
with SQLiteProgramCache(tmp_path / "cache.db") as cache:
366+
with pytest.raises(TypeError, match="ObjectCode"):
367+
cache[b"k"] = b"not an ObjectCode"
368+
369+
370+
def test_sqlite_cache_accepts_str_keys(tmp_path):
371+
from cuda.core.utils import SQLiteProgramCache
372+
373+
db = tmp_path / "cache.db"
374+
with SQLiteProgramCache(db) as cache:
375+
cache["str-key"] = _fake_object_code(b"v")
376+
assert "str-key" in cache
377+
# Same bytes representation so the corresponding bytes key also hits.
378+
assert b"str-key" in cache

0 commit comments

Comments
 (0)