11# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2- #
32# SPDX-License-Identifier: Apache-2.0
43"""Tests for the ``Program.compile(cache=...)`` convenience integration."""
54
1211from cuda .core ._module import ObjectCode
1312from cuda .core .utils import (
1413 FileStreamProgramCache ,
15- InMemoryProgramCache ,
16- SQLiteProgramCache ,
1714 make_program_cache_key ,
1815)
1916
20- try :
21- import sqlite3 # noqa: F401
22-
23- _has_sqlite3 = True
24- except ImportError :
25- _has_sqlite3 = False
26-
27- needs_sqlite3 = pytest .mark .skipif (not _has_sqlite3 , reason = "libsqlite3 not available" )
28-
2917
3018class _RecordingCache :
31- """Minimal recording stub for the two-method cache protocol.
19+ """Minimal recording stub for the bytes-in / bytes-out cache protocol.
20+
21+ Mirrors :class:`FileStreamProgramCache`'s contract: ``__setitem__``
22+ accepts bytes-like or :class:`ObjectCode` (extracts bytes), and
23+ ``get`` returns the stored bytes (or ``None``).
3224
3325 Intentionally does NOT subclass ``ProgramCacheResource`` -- the wrapper
3426 should be duck-typed, so we test the duck-typed surface directly.
3527 """
3628
3729 def __init__ (self , preseed = None ):
38- self ._store = dict (preseed or {})
39- self .get_calls = []
40- self .set_calls = []
41- self .get_side_effect = None
42- self .set_side_effect = None
30+ self ._store : dict [bytes , bytes ] = {}
31+ for k , v in (preseed or {}).items ():
32+ self ._store [k ] = self ._extract (v )
33+ self .get_calls : list [bytes ] = []
34+ self .set_calls : list [tuple [bytes , bytes ]] = []
35+ self .get_side_effect : BaseException | None = None
36+ self .set_side_effect : BaseException | None = None
37+
38+ @staticmethod
39+ def _extract (value ) -> bytes :
40+ if isinstance (value , ObjectCode ):
41+ return bytes (value .code )
42+ if isinstance (value , (bytes , bytearray , memoryview )):
43+ return bytes (value )
44+ raise TypeError (f"unexpected value type: { type (value ).__name__ } " )
4345
4446 def get (self , key , default = None ):
4547 self .get_calls .append (key )
@@ -48,22 +50,24 @@ def get(self, key, default=None):
4850 return self ._store .get (key , default )
4951
5052 def __setitem__ (self , key , value ):
51- self .set_calls .append ((key , value ))
53+ data = self ._extract (value )
54+ self .set_calls .append ((key , data ))
5255 if self .set_side_effect is not None :
5356 raise self .set_side_effect
54- self ._store [key ] = value
57+ self ._store [key ] = data
5558
5659
5760_KERNEL = 'extern "C" __global__ void k() {}'
61+ _SENTINEL_BYTES = b"sentinel-cubin-bytes"
5862
5963
6064def _make_sentinel_object_code ():
6165 """Construct a cache-safe ``ObjectCode`` that doesn't require compilation."""
62- return ObjectCode ._init (b"sentinel-cubin-bytes" , "cubin" , name = "sentinel" )
66+ return ObjectCode ._init (_SENTINEL_BYTES , "cubin" , name = "sentinel" )
6367
6468
6569def test_cache_miss_runs_compile_then_stores (monkeypatch ):
66- """On cache miss: get(key) once, _program_compile_uncached once, __setitem__(key, obj) once."""
70+ """On cache miss: get(key) once, _program_compile_uncached once, __setitem__ once."""
6771 sentinel = _make_sentinel_object_code ()
6872 program = Program (_KERNEL , "c++" , ProgramOptions (arch = "sm_80" ))
6973
@@ -75,16 +79,19 @@ def _return_sentinel(_program, *_args, **_kwargs):
7579
7680 result = program .compile ("cubin" , cache = cache )
7781
82+ # On miss the wrapper returns the freshly-compiled ObjectCode unchanged.
7883 assert result is sentinel
7984 assert len (cache .get_calls ) == 1
8085 assert len (cache .set_calls ) == 1
81- assert cache .set_calls [0 ][1 ] is sentinel
86+ # The cache stored the binary bytes extracted from the ObjectCode.
87+ assert cache .set_calls [0 ][1 ] == _SENTINEL_BYTES
8288
8389
84- def test_cache_hit_short_circuits_compile (monkeypatch ):
85- """On cache hit: get(key) returns, _program_compile_uncached is not called, no __setitem__."""
86- sentinel = _make_sentinel_object_code ()
87- options = ProgramOptions (arch = "sm_80" )
90+ def test_cache_hit_returns_object_code_reconstructed_from_bytes (monkeypatch ):
91+ """On hit: get(key) returns bytes, the wrapper rebuilds an ObjectCode with
92+ the same code_type and ProgramOptions.name. _program_compile_uncached is
93+ NOT called and there is no __setitem__."""
94+ options = ProgramOptions (arch = "sm_80" , name = "my_program" )
8895 program = Program (_KERNEL , "c++" , options )
8996 key = make_program_cache_key (
9097 code = _KERNEL ,
@@ -97,11 +104,14 @@ def _explode(_program, *_args, **_kwargs):
97104 raise AssertionError ("_program_compile_uncached must not be called on cache hit" )
98105
99106 monkeypatch .setattr (_program_module , "_program_compile_uncached" , _explode )
100- cache = _RecordingCache (preseed = {key : sentinel })
107+ cache = _RecordingCache (preseed = {key : _SENTINEL_BYTES })
101108
102109 result = program .compile ("cubin" , cache = cache )
103110
104- assert result is sentinel
111+ assert isinstance (result , ObjectCode )
112+ assert bytes (result .code ) == _SENTINEL_BYTES
113+ assert result .code_type == "cubin"
114+ assert result .name == "my_program"
105115 assert cache .get_calls == [key ]
106116 assert cache .set_calls == []
107117
@@ -110,7 +120,9 @@ def test_name_expressions_affects_cache_key(monkeypatch):
110120 """Different ``name_expressions`` must produce different cache keys."""
111121 sentinel = _make_sentinel_object_code ()
112122 program = Program (_KERNEL , "c++" , ProgramOptions (arch = "sm_80" ))
113- monkeypatch .setattr (_program_module , "_program_compile_uncached" , lambda _self , * _args , ** _kwargs : sentinel )
123+ monkeypatch .setattr (
124+ _program_module , "_program_compile_uncached" , lambda _self , * _args , ** _kwargs : sentinel
125+ )
114126 cache = _RecordingCache ()
115127
116128 program .compile ("cubin" , name_expressions = ("foo" ,), cache = cache )
@@ -188,7 +200,9 @@ def test_cache_write_exception_propagates(monkeypatch):
188200 """Exceptions from cache.__setitem__ propagate after compile runs."""
189201 sentinel = _make_sentinel_object_code ()
190202 program = Program (_KERNEL , "c++" , ProgramOptions (arch = "sm_80" ))
191- monkeypatch .setattr (_program_module , "_program_compile_uncached" , lambda _self , * _args , ** _kwargs : sentinel )
203+ monkeypatch .setattr (
204+ _program_module , "_program_compile_uncached" , lambda _self , * _args , ** _kwargs : sentinel
205+ )
192206 cache = _RecordingCache ()
193207 cache .set_side_effect = RuntimeError ("disk full" )
194208
@@ -203,7 +217,9 @@ def test_no_cache_kwarg_does_not_derive_key(monkeypatch):
203217 """Without cache=, no cache-module functions run; compile goes straight through."""
204218 sentinel = _make_sentinel_object_code ()
205219 program = Program (_KERNEL , "c++" , ProgramOptions (arch = "sm_80" ))
206- monkeypatch .setattr (_program_module , "_program_compile_uncached" , lambda _self , * _args , ** _kwargs : sentinel )
220+ monkeypatch .setattr (
221+ _program_module , "_program_compile_uncached" , lambda _self , * _args , ** _kwargs : sentinel
222+ )
207223
208224 # If the implementation accidentally derived a key, it would call
209225 # make_program_cache_key. Replace it with a raising stub to catch that.
@@ -219,35 +235,9 @@ def _cache_path_must_not_run(*_args, **_kwargs):
219235 assert result is sentinel
220236
221237
222- def test_inmemory_hit_returns_same_instance (init_cuda ):
223- """InMemoryProgramCache stores by reference; a hit returns the same object."""
224- cache = InMemoryProgramCache ()
225- program = Program (_KERNEL , "c++" , ProgramOptions (arch = "sm_80" ))
226-
227- first = program .compile ("cubin" , cache = cache )
228- second = program .compile ("cubin" , cache = cache )
229-
230- assert second is first
231-
232-
233- @needs_sqlite3
234- def test_sqlite_hit_returns_byte_equal_object_code (init_cuda , tmp_path ):
235- """SQLiteProgramCache pickle-roundtrips; a hit returns byte-identical code."""
236- program = Program (_KERNEL , "c++" , ProgramOptions (arch = "sm_80" ))
237- db_path = tmp_path / "cache.sqlite"
238-
239- with SQLiteProgramCache (db_path ) as cache :
240- first = program .compile ("cubin" , cache = cache )
241-
242- with SQLiteProgramCache (db_path ) as cache :
243- second = program .compile ("cubin" , cache = cache )
244-
245- assert second .code == first .code
246- assert second .code_type == first .code_type
247-
248-
249238def test_filestream_hit_returns_byte_equal_object_code (init_cuda , tmp_path ):
250- """FileStreamProgramCache pickle-roundtrips; a hit returns byte-identical code."""
239+ """End-to-end: real compile, FileStreamProgramCache roundtrip, second
240+ compile returns an ObjectCode whose bytes match the first compile."""
251241 program = Program (_KERNEL , "c++" , ProgramOptions (arch = "sm_80" ))
252242 cache_dir = tmp_path / "fc"
253243
@@ -257,42 +247,31 @@ def test_filestream_hit_returns_byte_equal_object_code(init_cuda, tmp_path):
257247 with FileStreamProgramCache (cache_dir ) as cache :
258248 second = program .compile ("cubin" , cache = cache )
259249
260- assert second .code == first .code
250+ assert bytes ( second .code ) == bytes ( first .code )
261251 assert second .code_type == first .code_type
262252
263253
264- @pytest .mark .parametrize (
265- "backend" ,
266- [
267- pytest .param ("sqlite" , marks = needs_sqlite3 ),
268- "filestream" ,
269- ],
270- )
271- def test_cache_kwarg_roundtrip_across_reopen (init_cuda , tmp_path , backend , monkeypatch ):
254+ def test_cache_kwarg_roundtrip_across_reopen (init_cuda , tmp_path , monkeypatch ):
272255 """Compile with cache= in one 'session', reopen and fetch via cache= again."""
273- program_init_args = (_KERNEL , "c++" , ProgramOptions (arch = "sm_80" , name = "cached_kernel" ))
274-
275- if backend == "sqlite" :
276- cache_path = tmp_path / "cache.sqlite"
277- cache_cls = SQLiteProgramCache
278- else :
279- cache_path = tmp_path / "fc"
280- cache_cls = FileStreamProgramCache
256+ init_args = (_KERNEL , "c++" , ProgramOptions (arch = "sm_80" , name = "cached_kernel" ))
257+ cache_path = tmp_path / "fc"
281258
282- with cache_cls (cache_path ) as cache :
283- program = Program (* program_init_args )
259+ with FileStreamProgramCache (cache_path ) as cache :
260+ program = Program (* init_args )
284261 first = program .compile ("cubin" , cache = cache )
285262
286263 # Fresh process / fresh Program and cache-handle -- same cache path.
287- with cache_cls (cache_path ) as cache :
288- program = Program (* program_init_args )
264+ with FileStreamProgramCache (cache_path ) as cache :
265+ program = Program (* init_args )
289266
290- # Monkeypatch the module-level compile seam so a cache miss would raise --
291- # the only way this succeeds is if the reopen finds the prior entry.
267+ # If the reopened cache misses, the wrapper would fall back to
268+ # _program_compile_uncached -- replace it with a raising stub so
269+ # the test can only succeed via a hit.
292270 def _must_not_recompile (_program , * _args , ** _kwargs ):
293271 raise AssertionError ("cache miss: reopened cache didn't serve entry" )
294272
295273 monkeypatch .setattr (_program_module , "_program_compile_uncached" , _must_not_recompile )
296274 second = program .compile ("cubin" , cache = cache )
297275
298- assert second .code == first .code
276+ assert bytes (second .code ) == bytes (first .code )
277+ assert second .name == "cached_kernel"
0 commit comments