Skip to content

Commit 2dc5c8f

Browse files
committed
fixup! feat(core.utils): reject path-backed ObjectCode on read too
1 parent c534df1 commit 2dc5c8f

File tree

2 files changed

+78
-4
lines changed

2 files changed

+78
-4
lines changed

cuda_core/cuda/core/utils/_program_cache.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,18 @@
5757
_IS_WINDOWS = os.name == "nt"
5858

5959

60+
def _is_cacheable_object_code(value: object) -> bool:
61+
"""True iff ``value`` is a bytes-backed :class:`ObjectCode`.
62+
63+
Used by read paths to reject deserialized entries that would fail the
64+
same check writes enforce: path-backed ``ObjectCode`` (where
65+
``value.code`` is a ``str`` path) pickles only the path, so reopening
66+
an older cache or a manually-seeded entry could return stale or
67+
missing on-disk content. Treat such entries as misses and prune.
68+
"""
69+
return isinstance(value, ObjectCode) and not isinstance(value.code, str)
70+
71+
6072
def _require_object_code(value: object) -> ObjectCode:
6173
if not isinstance(value, ObjectCode):
6274
raise TypeError(f"cache values must be ObjectCode instances, got {type(value).__name__}")
@@ -961,7 +973,7 @@ def _load(self, key: object, *, touch_lru: bool) -> ObjectCode | None:
961973
except Exception:
962974
conn.execute("DELETE FROM entries WHERE key = ?", (k,))
963975
return None
964-
if not isinstance(value, ObjectCode):
976+
if not _is_cacheable_object_code(value):
965977
conn.execute("DELETE FROM entries WHERE key = ?", (k,))
966978
return None
967979
if touch_lru:
@@ -1014,7 +1026,7 @@ def __len__(self) -> int:
10141026
except Exception:
10151027
conn.execute("DELETE FROM entries WHERE key = ?", (k,))
10161028
continue
1017-
if not isinstance(value, ObjectCode):
1029+
if not _is_cacheable_object_code(value):
10181030
conn.execute("DELETE FROM entries WHERE key = ?", (k,))
10191031
continue
10201032
count += 1
@@ -1289,7 +1301,7 @@ def __getitem__(self, key: object) -> ObjectCode:
12891301
except Exception:
12901302
_prune_if_stat_unchanged(path, st_before)
12911303
raise KeyError(key) from None
1292-
if not isinstance(value, ObjectCode):
1304+
if not _is_cacheable_object_code(value):
12931305
_prune_if_stat_unchanged(path, st_before)
12941306
raise KeyError(key) from None
12951307
return value
@@ -1360,7 +1372,7 @@ def __len__(self) -> int:
13601372
except Exception:
13611373
_prune_if_stat_unchanged(path, st_before)
13621374
continue
1363-
if not isinstance(value, ObjectCode):
1375+
if not _is_cacheable_object_code(value):
13641376
_prune_if_stat_unchanged(path, st_before)
13651377
continue
13661378
count += 1

cuda_core/tests/test_program_cache.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,6 +1179,41 @@ def test_sqlite_cache_rejects_path_backed_object_code(tmp_path):
11791179
cache[b"k"] = path_backed
11801180

11811181

1182+
@needs_sqlite3
1183+
def test_sqlite_cache_treats_path_backed_payload_as_miss_and_prunes(tmp_path):
1184+
"""An older version, a corrupt/injected entry, or a direct DB write may
1185+
leave a pickled path-backed ObjectCode on disk. Reopen must refuse to
1186+
surface it (pickling stores only the path, so the bytes would be stale
1187+
or missing) and prune the row so subsequent access doesn't re-hit."""
1188+
import pickle
1189+
import sqlite3
1190+
import time as _time
1191+
1192+
from cuda.core._module import ObjectCode
1193+
from cuda.core.utils import SQLiteProgramCache
1194+
1195+
db = tmp_path / "cache.db"
1196+
with SQLiteProgramCache(db):
1197+
pass # materialise schema
1198+
path_backed = ObjectCode.from_cubin(str(tmp_path / "nonexistent.cubin"), name="x")
1199+
payload = pickle.dumps(path_backed)
1200+
now = _time.time()
1201+
conn = sqlite3.connect(db)
1202+
try:
1203+
conn.execute(
1204+
"INSERT INTO entries(key, payload, size_bytes, created_at, accessed_at) VALUES (?, ?, ?, ?, ?)",
1205+
(b"k", payload, len(payload), now, now),
1206+
)
1207+
conn.commit()
1208+
finally:
1209+
conn.close()
1210+
1211+
with SQLiteProgramCache(db) as cache:
1212+
assert cache.get(b"k") is None
1213+
assert b"k" not in cache
1214+
assert len(cache) == 0
1215+
1216+
11821217
@needs_sqlite3
11831218
def test_sqlite_cache_accepts_str_keys(tmp_path):
11841219
from cuda.core.utils import SQLiteProgramCache
@@ -1694,6 +1729,33 @@ def test_filestream_cache_rejects_path_backed_object_code(tmp_path):
16941729
cache[b"k"] = path_backed
16951730

16961731

1732+
def test_filestream_cache_treats_path_backed_payload_as_miss_and_prunes(tmp_path):
1733+
"""Same guarantee as the SQLite backend: a pickled path-backed ObjectCode
1734+
on disk (older version, corruption, direct injection) must not surface
1735+
on reopen, and the file must be pruned so it doesn't keep re-hitting."""
1736+
import pickle
1737+
import time as _time
1738+
1739+
from cuda.core._module import ObjectCode
1740+
from cuda.core.utils import FileStreamProgramCache
1741+
from cuda.core.utils._program_cache import _FILESTREAM_SCHEMA_VERSION
1742+
1743+
root = tmp_path / "fc"
1744+
with FileStreamProgramCache(root) as cache:
1745+
target = cache._path_for_key(b"k")
1746+
path_backed = ObjectCode.from_cubin(str(tmp_path / "nonexistent.cubin"), name="x")
1747+
payload = pickle.dumps(path_backed)
1748+
record = pickle.dumps((_FILESTREAM_SCHEMA_VERSION, b"k", payload, _time.time()))
1749+
target.parent.mkdir(parents=True, exist_ok=True)
1750+
target.write_bytes(record)
1751+
1752+
with FileStreamProgramCache(root) as cache:
1753+
assert cache.get(b"k") is None
1754+
assert b"k" not in cache
1755+
assert len(cache) == 0
1756+
assert not target.exists()
1757+
1758+
16971759
def test_filestream_cache_rejects_negative_size_cap(tmp_path):
16981760
from cuda.core.utils import FileStreamProgramCache
16991761

0 commit comments

Comments
 (0)