Skip to content

Commit f1fbab7

Browse files
committed
fixup! adapt Program.compile(cache=...) wrapper + tests for bytes API
Salvage the high-level wrapper from the diverged origin branch and adapt to the raw-bytes cache API: * Wrapper rebuilds an ObjectCode on hit via ObjectCode._init(hit_bytes, target_type, name=self._options.name) since the cache returns bytes, not an ObjectCode. * Documented in the Program.compile docstring that symbol_mapping is not preserved across a cache round-trip. * Test stub _RecordingCache mirrors the FileStream contract: __setitem__ accepts ObjectCode and stores its bytes, get returns bytes (or None). * Drop SQLite/InMemory imports and tests left over from the cherry-pick (those backends are gone in this branch). * Fix import order in _program_cache.py (cherry-pick interleaved platformdirs between pathlib and typing imports).
1 parent ca31ede commit f1fbab7

3 files changed

Lines changed: 78 additions & 93 deletions

File tree

cuda_core/cuda/core/_program.pyx

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,17 @@ cdef class Program:
104104
logs : object, optional
105105
Object with a ``write`` method to receive compilation logs.
106106
cache : :class:`~cuda.core.utils.ProgramCacheResource`, optional
107-
If provided, the compiled :class:`~cuda.core.ObjectCode` is looked
108-
up in ``cache`` via a key derived from the program's code, options,
109-
``target_type`` and ``name_expressions``. On a hit the cached
110-
``ObjectCode`` is returned without re-compiling; on a miss the
111-
fresh compile result is stored. Options that require an
107+
If provided, the compiled binary is looked up in ``cache`` via a
108+
key derived from the program's code, options, ``target_type`` and
109+
``name_expressions``. On a hit the cached bytes are wrapped in a
110+
fresh :class:`~cuda.core.ObjectCode` (with the same ``target_type``
111+
and ``ProgramOptions.name``) and returned without re-compiling;
112+
on a miss the compile output is stored as raw bytes (the cache
113+
extracts ``bytes(object_code.code)``). Note that
114+
``ObjectCode.symbol_mapping`` is not preserved across a cache
115+
round-trip -- callers using ``name_expressions`` who need
116+
``get_kernel(name_expression)`` after a hit must compile fresh
117+
or look the mangled symbol up by hand. Options that require an
112118
``extra_digest`` (``include_path``, ``pre_include``, ``pch``,
113119
``use_pch``, ``pch_dir``, NVVM ``use_libdevice=True``, or NVRTC
114120
``options.name`` with a directory component) raise ``ValueError``
@@ -146,9 +152,9 @@ cdef class Program:
146152
target_type=target_type,
147153
name_expressions=name_expressions,
148154
)
149-
hit = cache.get(key)
150-
if hit is not None:
151-
return hit
155+
hit_bytes = cache.get(key)
156+
if hit_bytes is not None:
157+
return ObjectCode._init(hit_bytes, target_type, name=self._options.name)
152158
compiled = _program_compile_uncached(self, target_type, name_expressions, logs)
153159
cache[key] = compiled
154160
return compiled

cuda_core/cuda/core/utils/_program_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
import threading
2424
import time
2525
from pathlib import Path
26+
from typing import Iterable, Sequence
2627

2728
import platformdirs
28-
from typing import Iterable, Sequence
2929

3030
from cuda.core._module import ObjectCode
3131
from cuda.core._program import ProgramOptions

cuda_core/tests/test_program_compile_cache.py

Lines changed: 63 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2-
#
32
# SPDX-License-Identifier: Apache-2.0
43
"""Tests for the ``Program.compile(cache=...)`` convenience integration."""
54

@@ -12,34 +11,37 @@
1211
from cuda.core._module import ObjectCode
1312
from cuda.core.utils import (
1413
FileStreamProgramCache,
15-
InMemoryProgramCache,
16-
SQLiteProgramCache,
1714
make_program_cache_key,
1815
)
1916

20-
try:
21-
import sqlite3 # noqa: F401
22-
23-
_has_sqlite3 = True
24-
except ImportError:
25-
_has_sqlite3 = False
26-
27-
needs_sqlite3 = pytest.mark.skipif(not _has_sqlite3, reason="libsqlite3 not available")
28-
2917

3018
class _RecordingCache:
31-
"""Minimal recording stub for the two-method cache protocol.
19+
"""Minimal recording stub for the bytes-in / bytes-out cache protocol.
20+
21+
Mirrors :class:`FileStreamProgramCache`'s contract: ``__setitem__``
22+
accepts bytes-like or :class:`ObjectCode` (extracts bytes), and
23+
``get`` returns the stored bytes (or ``None``).
3224
3325
Intentionally does NOT subclass ``ProgramCacheResource`` -- the wrapper
3426
should be duck-typed, so we test the duck-typed surface directly.
3527
"""
3628

3729
def __init__(self, preseed=None):
38-
self._store = dict(preseed or {})
39-
self.get_calls = []
40-
self.set_calls = []
41-
self.get_side_effect = None
42-
self.set_side_effect = None
30+
self._store: dict[bytes, bytes] = {}
31+
for k, v in (preseed or {}).items():
32+
self._store[k] = self._extract(v)
33+
self.get_calls: list[bytes] = []
34+
self.set_calls: list[tuple[bytes, bytes]] = []
35+
self.get_side_effect: BaseException | None = None
36+
self.set_side_effect: BaseException | None = None
37+
38+
@staticmethod
39+
def _extract(value) -> bytes:
40+
if isinstance(value, ObjectCode):
41+
return bytes(value.code)
42+
if isinstance(value, (bytes, bytearray, memoryview)):
43+
return bytes(value)
44+
raise TypeError(f"unexpected value type: {type(value).__name__}")
4345

4446
def get(self, key, default=None):
4547
self.get_calls.append(key)
@@ -48,22 +50,24 @@ def get(self, key, default=None):
4850
return self._store.get(key, default)
4951

5052
def __setitem__(self, key, value):
51-
self.set_calls.append((key, value))
53+
data = self._extract(value)
54+
self.set_calls.append((key, data))
5255
if self.set_side_effect is not None:
5356
raise self.set_side_effect
54-
self._store[key] = value
57+
self._store[key] = data
5558

5659

5760
_KERNEL = 'extern "C" __global__ void k() {}'
61+
_SENTINEL_BYTES = b"sentinel-cubin-bytes"
5862

5963

6064
def _make_sentinel_object_code():
6165
"""Construct a cache-safe ``ObjectCode`` that doesn't require compilation."""
62-
return ObjectCode._init(b"sentinel-cubin-bytes", "cubin", name="sentinel")
66+
return ObjectCode._init(_SENTINEL_BYTES, "cubin", name="sentinel")
6367

6468

6569
def test_cache_miss_runs_compile_then_stores(monkeypatch):
66-
"""On cache miss: get(key) once, _program_compile_uncached once, __setitem__(key, obj) once."""
70+
"""On cache miss: get(key) once, _program_compile_uncached once, __setitem__ once."""
6771
sentinel = _make_sentinel_object_code()
6872
program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80"))
6973

@@ -75,16 +79,19 @@ def _return_sentinel(_program, *_args, **_kwargs):
7579

7680
result = program.compile("cubin", cache=cache)
7781

82+
# On miss the wrapper returns the freshly-compiled ObjectCode unchanged.
7883
assert result is sentinel
7984
assert len(cache.get_calls) == 1
8085
assert len(cache.set_calls) == 1
81-
assert cache.set_calls[0][1] is sentinel
86+
# The cache stored the binary bytes extracted from the ObjectCode.
87+
assert cache.set_calls[0][1] == _SENTINEL_BYTES
8288

8389

84-
def test_cache_hit_short_circuits_compile(monkeypatch):
85-
"""On cache hit: get(key) returns, _program_compile_uncached is not called, no __setitem__."""
86-
sentinel = _make_sentinel_object_code()
87-
options = ProgramOptions(arch="sm_80")
90+
def test_cache_hit_returns_object_code_reconstructed_from_bytes(monkeypatch):
91+
"""On hit: get(key) returns bytes, the wrapper rebuilds an ObjectCode with
92+
the same code_type and ProgramOptions.name. _program_compile_uncached is
93+
NOT called and there is no __setitem__."""
94+
options = ProgramOptions(arch="sm_80", name="my_program")
8895
program = Program(_KERNEL, "c++", options)
8996
key = make_program_cache_key(
9097
code=_KERNEL,
@@ -97,11 +104,14 @@ def _explode(_program, *_args, **_kwargs):
97104
raise AssertionError("_program_compile_uncached must not be called on cache hit")
98105

99106
monkeypatch.setattr(_program_module, "_program_compile_uncached", _explode)
100-
cache = _RecordingCache(preseed={key: sentinel})
107+
cache = _RecordingCache(preseed={key: _SENTINEL_BYTES})
101108

102109
result = program.compile("cubin", cache=cache)
103110

104-
assert result is sentinel
111+
assert isinstance(result, ObjectCode)
112+
assert bytes(result.code) == _SENTINEL_BYTES
113+
assert result.code_type == "cubin"
114+
assert result.name == "my_program"
105115
assert cache.get_calls == [key]
106116
assert cache.set_calls == []
107117

@@ -110,7 +120,9 @@ def test_name_expressions_affects_cache_key(monkeypatch):
110120
"""Different ``name_expressions`` must produce different cache keys."""
111121
sentinel = _make_sentinel_object_code()
112122
program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80"))
113-
monkeypatch.setattr(_program_module, "_program_compile_uncached", lambda _self, *_args, **_kwargs: sentinel)
123+
monkeypatch.setattr(
124+
_program_module, "_program_compile_uncached", lambda _self, *_args, **_kwargs: sentinel
125+
)
114126
cache = _RecordingCache()
115127

116128
program.compile("cubin", name_expressions=("foo",), cache=cache)
@@ -188,7 +200,9 @@ def test_cache_write_exception_propagates(monkeypatch):
188200
"""Exceptions from cache.__setitem__ propagate after compile runs."""
189201
sentinel = _make_sentinel_object_code()
190202
program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80"))
191-
monkeypatch.setattr(_program_module, "_program_compile_uncached", lambda _self, *_args, **_kwargs: sentinel)
203+
monkeypatch.setattr(
204+
_program_module, "_program_compile_uncached", lambda _self, *_args, **_kwargs: sentinel
205+
)
192206
cache = _RecordingCache()
193207
cache.set_side_effect = RuntimeError("disk full")
194208

@@ -203,7 +217,9 @@ def test_no_cache_kwarg_does_not_derive_key(monkeypatch):
203217
"""Without cache=, no cache-module functions run; compile goes straight through."""
204218
sentinel = _make_sentinel_object_code()
205219
program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80"))
206-
monkeypatch.setattr(_program_module, "_program_compile_uncached", lambda _self, *_args, **_kwargs: sentinel)
220+
monkeypatch.setattr(
221+
_program_module, "_program_compile_uncached", lambda _self, *_args, **_kwargs: sentinel
222+
)
207223

208224
# If the implementation accidentally derived a key, it would call
209225
# make_program_cache_key. Replace it with a raising stub to catch that.
@@ -219,35 +235,9 @@ def _cache_path_must_not_run(*_args, **_kwargs):
219235
assert result is sentinel
220236

221237

222-
def test_inmemory_hit_returns_same_instance(init_cuda):
223-
"""InMemoryProgramCache stores by reference; a hit returns the same object."""
224-
cache = InMemoryProgramCache()
225-
program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80"))
226-
227-
first = program.compile("cubin", cache=cache)
228-
second = program.compile("cubin", cache=cache)
229-
230-
assert second is first
231-
232-
233-
@needs_sqlite3
234-
def test_sqlite_hit_returns_byte_equal_object_code(init_cuda, tmp_path):
235-
"""SQLiteProgramCache pickle-roundtrips; a hit returns byte-identical code."""
236-
program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80"))
237-
db_path = tmp_path / "cache.sqlite"
238-
239-
with SQLiteProgramCache(db_path) as cache:
240-
first = program.compile("cubin", cache=cache)
241-
242-
with SQLiteProgramCache(db_path) as cache:
243-
second = program.compile("cubin", cache=cache)
244-
245-
assert second.code == first.code
246-
assert second.code_type == first.code_type
247-
248-
249238
def test_filestream_hit_returns_byte_equal_object_code(init_cuda, tmp_path):
250-
"""FileStreamProgramCache pickle-roundtrips; a hit returns byte-identical code."""
239+
"""End-to-end: real compile, FileStreamProgramCache roundtrip, second
240+
compile returns an ObjectCode whose bytes match the first compile."""
251241
program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80"))
252242
cache_dir = tmp_path / "fc"
253243

@@ -257,42 +247,31 @@ def test_filestream_hit_returns_byte_equal_object_code(init_cuda, tmp_path):
257247
with FileStreamProgramCache(cache_dir) as cache:
258248
second = program.compile("cubin", cache=cache)
259249

260-
assert second.code == first.code
250+
assert bytes(second.code) == bytes(first.code)
261251
assert second.code_type == first.code_type
262252

263253

264-
@pytest.mark.parametrize(
265-
"backend",
266-
[
267-
pytest.param("sqlite", marks=needs_sqlite3),
268-
"filestream",
269-
],
270-
)
271-
def test_cache_kwarg_roundtrip_across_reopen(init_cuda, tmp_path, backend, monkeypatch):
254+
def test_cache_kwarg_roundtrip_across_reopen(init_cuda, tmp_path, monkeypatch):
272255
"""Compile with cache= in one 'session', reopen and fetch via cache= again."""
273-
program_init_args = (_KERNEL, "c++", ProgramOptions(arch="sm_80", name="cached_kernel"))
274-
275-
if backend == "sqlite":
276-
cache_path = tmp_path / "cache.sqlite"
277-
cache_cls = SQLiteProgramCache
278-
else:
279-
cache_path = tmp_path / "fc"
280-
cache_cls = FileStreamProgramCache
256+
init_args = (_KERNEL, "c++", ProgramOptions(arch="sm_80", name="cached_kernel"))
257+
cache_path = tmp_path / "fc"
281258

282-
with cache_cls(cache_path) as cache:
283-
program = Program(*program_init_args)
259+
with FileStreamProgramCache(cache_path) as cache:
260+
program = Program(*init_args)
284261
first = program.compile("cubin", cache=cache)
285262

286263
# Fresh process / fresh Program and cache-handle -- same cache path.
287-
with cache_cls(cache_path) as cache:
288-
program = Program(*program_init_args)
264+
with FileStreamProgramCache(cache_path) as cache:
265+
program = Program(*init_args)
289266

290-
# Monkeypatch the module-level compile seam so a cache miss would raise --
291-
# the only way this succeeds is if the reopen finds the prior entry.
267+
# If the reopened cache misses, the wrapper would fall back to
268+
# _program_compile_uncached -- replace it with a raising stub so
269+
# the test can only succeed via a hit.
292270
def _must_not_recompile(_program, *_args, **_kwargs):
293271
raise AssertionError("cache miss: reopened cache didn't serve entry")
294272

295273
monkeypatch.setattr(_program_module, "_program_compile_uncached", _must_not_recompile)
296274
second = program.compile("cubin", cache=cache)
297275

298-
assert second.code == first.code
276+
assert bytes(second.code) == bytes(first.code)
277+
assert second.name == "cached_kernel"

0 commit comments

Comments
 (0)