Skip to content

Commit 45cde8c

Browse files
committed
fixup! feat(core.utils): couple key schema with backend schema; document pickle compat; add usage example
1 parent 5da111b commit 45cde8c

File tree

2 files changed

+112
-13
lines changed

2 files changed

+112
-13
lines changed

cuda_core/cuda/core/utils/_program_cache.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,34 @@ def make_program_cache_key(
508508
If ``extra_digest`` is ``None`` while ``options`` sets any option whose
509509
compilation effect depends on external file content that the key
510510
cannot otherwise observe.
511+
512+
Examples
513+
--------
514+
Wiring a cache around :class:`~cuda.core.Program` compile::
515+
516+
from cuda.core import Program, ProgramOptions
517+
from cuda.core.utils import FileStreamProgramCache, make_program_cache_key
518+
519+
source = "extern \"C\" __global__ void k(int *a){ *a = 1; }"
520+
options = ProgramOptions(arch="sm_80")
521+
522+
with FileStreamProgramCache("/var/cache/myapp/cuda") as cache:
523+
key = make_program_cache_key(
524+
code=source, code_type="c++", options=options, target_type="cubin"
525+
)
526+
obj = cache.get(key)
527+
if obj is None:
528+
obj = Program(source, "c++", options=options).compile("cubin")
529+
cache[key] = obj
530+
531+
Options that read external files (``include_path``, ``pre_include``,
532+
``pch``, ``use_pch``, ``pch_dir``; and ``use_libdevice=True`` on the
533+
NVVM path) require ``extra_digest`` -- fingerprint the bytes the
534+
compiler will pull in and pass that digest so changes to those files
535+
force a cache miss. Options that have compile-time side effects
536+
(``create_pch``, ``time``, ``fdevice_time_trace``) cannot be cached
537+
and raise ``ValueError``; compile directly, or disable the flag, for
538+
those cases.
511539
"""
512540
# Mirror Program.compile (_program.pyx Program_init lowercases code_type
513541
# before dispatch); a caller that passes "PTX" or "C++" must get the
@@ -764,7 +792,12 @@ def _probe(label: str, fn):
764792
# ---------------------------------------------------------------------------
765793

766794

767-
_SQLITE_SCHEMA_VERSION = "1"
795+
# Composite of (backend-storage schema, key schema): a bump in either one
796+
# forces a wipe-on-open. This keeps ``_KEY_SCHEMA_VERSION`` bumps from
797+
# leaving unreachable entries on disk (they would be orphaned forever,
798+
# since the new hash never collides with the old one).
799+
_SQLITE_BACKEND_SCHEMA = "1"
800+
_SQLITE_SCHEMA_VERSION = f"{_SQLITE_BACKEND_SCHEMA}.{_KEY_SCHEMA_VERSION}"
768801

769802

770803
class SQLiteProgramCache(ProgramCacheResource):
@@ -1066,7 +1099,9 @@ def _enforce_size_cap(self) -> None:
10661099
# ---------------------------------------------------------------------------
10671100

10681101

1069-
_FILESTREAM_SCHEMA_VERSION = 2
1102+
# Composite of (backend-storage schema, key schema) -- see SQLite comment.
1103+
_FILESTREAM_BACKEND_SCHEMA = 2
1104+
_FILESTREAM_SCHEMA_VERSION = f"{_FILESTREAM_BACKEND_SCHEMA}.{_KEY_SCHEMA_VERSION}"
10701105
_ENTRIES_SUBDIR = "entries"
10711106
_TMP_SUBDIR = "tmp"
10721107
_SCHEMA_FILE = "SCHEMA_VERSION"
@@ -1197,17 +1232,25 @@ class FileStreamProgramCache(ProgramCacheResource):
11971232
11981233
.. note:: **Cross-version sharing.**
11991234
1200-
``_FILESTREAM_SCHEMA_VERSION`` guards on-disk format changes: a
1201-
cache written by an incompatible version is wiped on open. Within
1202-
a single schema version, the cache is safe to share across
1203-
``cuda.core`` patch releases because every entry's key encodes
1204-
the relevant backend/compiler/runtime fingerprints for its
1205-
compilation path (NVRTC entries pin the NVRTC version, NVVM
1206-
entries pin the libNVVM library and IR versions, PTX/linker
1207-
entries pin the chosen linker backend and its version -- and,
1208-
when the cuLink/driver backend is selected, the driver version
1209-
too; nvJitLink-backed PTX entries are deliberately driver-version
1210-
independent).
1235+
``_FILESTREAM_SCHEMA_VERSION`` encodes both the on-disk storage
1236+
format and the key-schema version, so a cache written by an
1237+
incompatible version is wiped on open (bumping either
1238+
``_KEY_SCHEMA_VERSION`` or ``_FILESTREAM_BACKEND_SCHEMA`` forces
1239+
cleanup instead of leaving orphaned entries on disk).
1240+
1241+
Within a single schema version the cache is safe to share across
1242+
``cuda.core`` patch releases on a best-effort basis: every entry's
1243+
key encodes the relevant backend/compiler/runtime fingerprints for
1244+
its compilation path (NVRTC entries pin the NVRTC version, NVVM
1245+
entries pin the libNVVM library and IR versions, PTX/linker entries
1246+
pin the chosen linker backend and its version -- and, when the
1247+
cuLink/driver backend is selected, the driver version too;
1248+
nvJitLink-backed PTX entries are deliberately driver-version
1249+
independent). Entries are stored as pickled :class:`ObjectCode`,
1250+
so the sharing guarantee also assumes ``ObjectCode``'s pickle
1251+
representation stays compatible across the patch releases in
1252+
question; a change to its ``__reduce__`` protocol would require
1253+
bumping the schema version.
12111254
12121255
Parameters
12131256
----------

cuda_core/tests/test_program_cache.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,6 +1481,37 @@ def test_sqlite_cache_wipes_on_schema_mismatch(tmp_path):
14811481
assert b"k" not in cache
14821482

14831483

1484+
@needs_sqlite3
1485+
def test_sqlite_cache_schema_version_encodes_key_schema(tmp_path, monkeypatch):
1486+
"""Bumping ``_KEY_SCHEMA_VERSION`` must invalidate existing on-disk
1487+
entries even when the backend storage layout is unchanged -- otherwise
1488+
old rows linger unreachable after a key-hash format change."""
1489+
import sqlite3
1490+
1491+
from cuda.core.utils import SQLiteProgramCache, _program_cache
1492+
1493+
db = tmp_path / "cache.db"
1494+
with SQLiteProgramCache(db) as cache:
1495+
cache[b"k"] = _fake_object_code(b"old-payload")
1496+
1497+
# Bump just the key-schema version; the backend version stays the same.
1498+
monkeypatch.setattr(_program_cache, "_KEY_SCHEMA_VERSION", _program_cache._KEY_SCHEMA_VERSION + 1)
1499+
monkeypatch.setattr(
1500+
_program_cache,
1501+
"_SQLITE_SCHEMA_VERSION",
1502+
f"{_program_cache._SQLITE_BACKEND_SCHEMA}.{_program_cache._KEY_SCHEMA_VERSION}",
1503+
)
1504+
1505+
with SQLiteProgramCache(db) as cache:
1506+
assert len(cache) == 0
1507+
assert b"k" not in cache
1508+
1509+
# The old row must be physically gone, not just invisible.
1510+
with sqlite3.connect(db) as conn:
1511+
row_count = conn.execute("SELECT COUNT(*) FROM entries").fetchone()[0]
1512+
assert row_count == 0
1513+
1514+
14841515
@needs_sqlite3
14851516
def test_sqlite_cache_drops_tables_on_schema_mismatch(tmp_path):
14861517
"""On a schema mismatch the cache must DROP the old tables, not just
@@ -2040,6 +2071,31 @@ def test_filestream_cache_wipes_on_schema_mismatch(tmp_path):
20402071
assert (root / "SCHEMA_VERSION").read_text().strip() != "0"
20412072

20422073

2074+
def test_filestream_cache_schema_version_encodes_key_schema(tmp_path, monkeypatch):
2075+
"""As with the SQLite backend, bumping ``_KEY_SCHEMA_VERSION`` alone
2076+
must invalidate the on-disk cache so orphaned entries from the old
2077+
key-hash format do not linger after an upgrade."""
2078+
from cuda.core.utils import FileStreamProgramCache, _program_cache
2079+
2080+
root = tmp_path / "fc"
2081+
with FileStreamProgramCache(root) as cache:
2082+
cache[b"k"] = _fake_object_code(b"old-payload")
2083+
path = cache._path_for_key(b"k")
2084+
assert path.exists()
2085+
2086+
monkeypatch.setattr(_program_cache, "_KEY_SCHEMA_VERSION", _program_cache._KEY_SCHEMA_VERSION + 1)
2087+
monkeypatch.setattr(
2088+
_program_cache,
2089+
"_FILESTREAM_SCHEMA_VERSION",
2090+
f"{_program_cache._FILESTREAM_BACKEND_SCHEMA}.{_program_cache._KEY_SCHEMA_VERSION}",
2091+
)
2092+
2093+
with FileStreamProgramCache(root) as cache:
2094+
assert len(cache) == 0
2095+
assert b"k" not in cache
2096+
assert not path.exists()
2097+
2098+
20432099
# ---------------------------------------------------------------------------
20442100
# End-to-end: real NVRTC compilation through persistent cache
20452101
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)