Skip to content

Commit e08e89e

Browse files
committed
feat(core): Program.compile(cache=...) convenience wrapper
Adds a ``cache=`` keyword to :meth:`cuda.core.Program.compile` that threads the persistent cache machinery into the high-level compile path. With ``cache=None`` (the default) the call is byte-identical to the un-cached path -- no key derivation, no extra import, no behavior change. When a cache is provided, the wrapper derives a key via :func:`~cuda.core.utils.make_program_cache_key` from the program's source, options, and target type; checks the cache; on hit, returns a fresh ``ObjectCode._init(hit_bytes, target_type, name=self._options.name)``; on miss, runs the underlying compile and stores ``cache[key] = compiled`` (the cache extracts ``bytes(obj.code)``). Two compile-time guards close obvious footguns: * ``name_expressions`` plus ``cache=`` raises ``ValueError``. NVRTC populates ``ObjectCode.symbol_mapping`` from name-expression mangling at compile time, and that mapping isn't carried in the binary the cache stores. Without this guard the first call (miss) would return an ObjectCode with mappings populated, while every subsequent call (hit) would return one without -- silently breaking later ``get_kernel(name_expression)`` lookups that work on the uncached path. Compiles that need name_expressions should run without ``cache=``, or look up mangled symbols by hand from the cached ``ObjectCode``. * Inputs whose compilation effect isn't captured by the key (``include_path``, ``pre_include``, ``pch``, ``use_pch``, ``pch_dir``, NVVM ``use_libdevice=True``, NVRTC ``options.name`` with a directory component, side-effect options like ``create_pch`` / ``time`` / ``fdevice_time_trace``) propagate the ``ValueError`` from ``make_program_cache_key`` -- those callers should use ``make_program_cache_key`` directly with an ``extra_digest`` covering the external content. Supporting refactors: * Unify ``Program``'s source retention into a single ``_code`` field (was split between ``_code`` for NVVM and a separate ``_source`` for c++/ptx). ``_code`` is now always bytes; the cache wrapper decodes back to ``str`` for c++/ptx before passing to ``make_program_cache_key`` (which only accepts bytes for NVVM). * Move the actual compile call into a module-level ``_program_compile_uncached`` so tests can monkeypatch the seam without going through NVRTC. ``Program`` is a ``cdef class``, so its methods cannot be reassigned from Python -- the seam has to live outside the class. * The unified ``_code`` field also exposed a pre-existing bug on the NVVM path: the C pointer was being recomputed from the caller's original ``code`` argument rather than from ``self._code``, which crashed for ``bytearray`` inputs that the field's bytes coercion handled cleanly. Fixed; regression test added in ``test_program.py``. Tests in ``test_program_compile_cache.py`` cover both halves of the contract: the wrapper-level miss/hit/error paths against a recording stub (verifying it's duck-typed and doesn't require subclassing ``ProgramCacheResource``), the rejection paths (name_expressions, extra_digest-required options, side-effect options, NVRTC ``options.name`` with a directory component), and a real NVRTC end-to-end roundtrip using ``FileStreamProgramCache`` across reopen so the bytes match across processes.
1 parent 88b472f commit e08e89e

4 files changed

Lines changed: 429 additions & 8 deletions

File tree

cuda_core/cuda/core/_program.pxd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ cdef class Program:
1717
object _compile_lock # Per-instance lock for compile-time mutation
1818
bint _use_libdevice # Flag for libdevice loading
1919
bint _libdevice_added
20-
bytes _nvrtc_code # Source code for NVRTC retry (PCH auto-resize)
20+
bytes _code # Source code as bytes: used for key derivation and NVRTC PCH retry
21+
str _code_type # Normalised code_type ("c++", "ptx", "nvvm")
2122
str _pch_status # PCH creation outcome after compile

cuda_core/cuda/core/_program.pyx

Lines changed: 100 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,12 @@ cdef class Program:
8585
self._h_nvvm.reset()
8686

8787
def compile(
88-
self, target_type: str, name_expressions: tuple | list = (), logs = None
88+
self,
89+
target_type: str,
90+
name_expressions: tuple | list = (),
91+
logs=None,
92+
*,
93+
cache: "ProgramCacheResource | None" = None,
8994
) -> ObjectCode:
9095
"""Compile the program to the specified target type.
9196

@@ -98,13 +103,82 @@ cdef class Program:
98103
Used for template instantiation and similar cases.
99104
logs : object, optional
100105
Object with a ``write`` method to receive compilation logs.
106+
cache : :class:`~cuda.core.utils.ProgramCacheResource`, optional
107+
If provided, the compiled binary is looked up in ``cache`` via a
108+
key derived from the program's code, options, and ``target_type``.
109+
On a hit the cached bytes are wrapped in a fresh
110+
:class:`~cuda.core.ObjectCode` (with the same ``target_type``
111+
and ``ProgramOptions.name``) and returned without re-compiling;
112+
on a miss the compile output is stored as raw bytes (the cache
113+
extracts ``bytes(object_code.code)``). Passing a non-empty
114+
``name_expressions`` together with ``cache=`` raises
115+
``ValueError``: NVRTC populates
116+
``ObjectCode.symbol_mapping`` at compile time and that mapping
117+
is not carried in the binary the cache stores, so cache hits
118+
would silently miss ``get_kernel(name_expression)`` lookups.
119+
Options that require an ``extra_digest`` (``include_path``,
120+
``pre_include``, ``pch``, ``use_pch``, ``pch_dir``, NVVM
121+
``use_libdevice=True``, or NVRTC ``options.name`` with a
122+
directory component) raise ``ValueError`` via
123+
:func:`~cuda.core.utils.make_program_cache_key`; for those
124+
compiles, use the manual ``make_program_cache_key(...)``
125+
pattern directly.
101126

102127
Returns
103128
-------
104129
:class:`~cuda.core.ObjectCode`
105130
The compiled object code.
106131
"""
107-
return Program_compile(self, target_type, name_expressions, logs)
132+
if cache is None:
133+
return _program_compile_uncached(self, target_type, name_expressions, logs)
134+
135+
# ``name_expressions`` is incompatible with the cache: NVRTC
136+
# populates ``ObjectCode.symbol_mapping`` from name-expression
137+
# mangling at compile time, and that mapping isn't carried in
138+
# the binary bytes the cache stores. Without this guard the
139+
# first call (cache miss) would return an ObjectCode with
140+
# symbol_mapping populated, while every subsequent call (hit)
141+
# would return one without -- silently breaking later
142+
# ``get_kernel(name_expression)`` lookups that work on the
143+
# uncached path. Fail loud here instead.
144+
if name_expressions:
145+
raise ValueError(
146+
"Program.compile(cache=...) does not support name_expressions: "
147+
"ObjectCode.symbol_mapping is populated by NVRTC at compile "
148+
"time and is not preserved across a cache round-trip, so cache "
149+
"hits would silently break get_kernel(name_expression) lookups "
150+
"that the uncached path supports. Compile without cache= when "
151+
"name_expressions are needed, or look up mangled symbols by "
152+
"hand from the cached ObjectCode."
153+
)
154+
155+
# Deferred import to avoid a circular import between _program and
156+
# cuda.core.utils._program_cache (the cache module already imports
157+
# ProgramOptions from this module). Import from the leaf module so
158+
# tests that monkeypatch make_program_cache_key via that path
159+
# intercept reliably.
160+
from cuda.core.utils._program_cache import make_program_cache_key
161+
162+
# ``self._code`` is always stored as bytes (see ``Program_init``),
163+
# but ``make_program_cache_key`` only accepts bytes when
164+
# ``code_type == "nvvm"`` -- c++/ptx must be ``str``. Decode back
165+
# to the original str for the NVRTC/linker paths so the generated
166+
# key matches keys callers build by passing the str source
167+
# directly.
168+
code_for_key = self._code if self._code_type == "nvvm" else self._code.decode("utf-8")
169+
170+
key = make_program_cache_key(
171+
code=code_for_key,
172+
code_type=self._code_type,
173+
options=self._options,
174+
target_type=target_type,
175+
)
176+
hit_bytes = cache.get(key)
177+
if hit_bytes is not None:
178+
return ObjectCode._init(hit_bytes, target_type, name=self._options.name)
179+
compiled = _program_compile_uncached(self, target_type, name_expressions, logs)
180+
cache[key] = compiled
181+
return compiled
108182

109183
@property
110184
def pch_status(self) -> str | None:
@@ -503,6 +577,19 @@ class ProgramOptions:
503577
# Private Classes and Helper Functions
504578
# =============================================================================
505579

580+
581+
def _program_compile_uncached(program, target_type, name_expressions, logs):
582+
"""Run ``Program_compile`` without the cache wrapper.
583+
584+
Module-level Python function so tests can monkeypatch it from
585+
``cuda.core._program`` to avoid invoking NVRTC when exercising the cache
586+
wrapper in :meth:`Program.compile`. ``Program`` itself is a ``cdef class``
587+
and its methods cannot be reassigned from Python, so the seam must live
588+
outside the class.
589+
"""
590+
return Program_compile(program, target_type, name_expressions, logs)
591+
592+
506593
# Module-level state for NVVM lazy loading
507594
_nvvm_module = None
508595
_nvvm_import_attempted = False
@@ -618,6 +705,7 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
618705

619706
self._options = options = check_or_create_options(ProgramOptions, options, "Program options")
620707
code_type = code_type.lower()
708+
self._code_type = code_type
621709
self._compile_lock = threading.Lock()
622710
self._use_libdevice = False
623711
self._libdevice_added = False
@@ -638,16 +726,18 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
638726
HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcCreateProgram(
639727
&nvrtc_prog, code_ptr, name_ptr, 0, NULL, NULL))
640728
self._h_nvrtc = create_nvrtc_program_handle(nvrtc_prog)
641-
self._nvrtc_code = code_bytes
729+
self._code = code_bytes
642730
self._backend = "NVRTC"
643731
self._linker = None
644732

645733
elif code_type == "ptx":
646734
assert_type(code, str)
647735
if options.extra_sources is not None:
648736
raise ValueError("extra_sources is not supported by the PTX backend.")
737+
code_bytes = code.encode()
738+
self._code = code_bytes
649739
self._linker = Linker(
650-
ObjectCode._init(code.encode(), code_type), options=_translate_program_options(options)
740+
ObjectCode._init(code_bytes, code_type), options=_translate_program_options(options)
651741
)
652742
self._backend = self._linker.backend
653743

@@ -657,10 +747,13 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
657747
code = code.encode("utf-8")
658748
elif not isinstance(code, (bytes, bytearray)):
659749
raise TypeError("NVVM IR code must be provided as str, bytes, or bytearray")
750+
self._code = bytes(code) # Coerce bytearray -> bytes so retention type is stable
660751

661-
code_ptr = <const char*>(<bytes>code)
752+
# Use self._code (strictly bytes) for the C pointer so a bytearray
753+
# input doesn't trip the `<bytes>code` cast at runtime.
754+
code_ptr = <const char*>self._code
662755
name_ptr = <const char*>options._name
663-
code_len = len(code)
756+
code_len = len(self._code)
664757

665758
with nogil:
666759
HANDLE_RETURN_NVVM(NULL, cynvvm.nvvmCreateProgram(&nvvm_prog))
@@ -832,7 +925,7 @@ cdef object Program_compile_nvrtc(Program self, str target_type, object name_exp
832925
HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcSetPCHHeapSize(required))
833926

834927
cdef cynvrtc.nvrtcProgram retry_prog
835-
cdef const char* code_ptr = <const char*>self._nvrtc_code
928+
cdef const char* code_ptr = <const char*>self._code
836929
cdef const char* name_ptr = <const char*>self._options._name
837930
with nogil:
838931
HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcCreateProgram(

cuda_core/tests/test_program.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,23 @@ def test_nvvm_compile_invalid_target(nvvm_ir):
434434
program.close()
435435

436436

437+
@nvvm_available
438+
def test_nvvm_accepts_bytearray_input(nvvm_ir):
439+
"""Program(..., 'nvvm') must accept bytearray input.
440+
441+
Regression for a bug where the NVVM init branch retained the coerced
442+
``self._code`` as bytes but still cast the original ``code`` object to
443+
``<bytes>`` for the C pointer -- tripping a runtime type error for
444+
bytearray inputs before nvvmAddModuleToProgram was called.
445+
"""
446+
program = Program(bytearray(nvvm_ir, "utf-8"), "nvvm")
447+
try:
448+
assert program.backend == "NVVM"
449+
assert program.handle is not None
450+
finally:
451+
program.close()
452+
453+
437454
@nvvm_available
438455
def test_nvvm_compile_invalid_ir():
439456
"""Compiling invalid NVVM IR exercises the HANDLE_RETURN_NVVM error path."""

0 commit comments

Comments
 (0)