diff --git a/cuda_core/cuda/core/_program.pxd b/cuda_core/cuda/core/_program.pxd index b2feeba860d..cea430c3f20 100644 --- a/cuda_core/cuda/core/_program.pxd +++ b/cuda_core/cuda/core/_program.pxd @@ -17,5 +17,6 @@ cdef class Program: object _compile_lock # Per-instance lock for compile-time mutation bint _use_libdevice # Flag for libdevice loading bint _libdevice_added - bytes _nvrtc_code # Source code for NVRTC retry (PCH auto-resize) + bytes _code # Source code as bytes: used for key derivation and NVRTC PCH retry + str _code_type # Normalised code_type ("c++", "ptx", "nvvm") str _pch_status # PCH creation outcome after compile diff --git a/cuda_core/cuda/core/_program.pyx b/cuda_core/cuda/core/_program.pyx index 194ef6da53f..04a3b35cecb 100644 --- a/cuda_core/cuda/core/_program.pyx +++ b/cuda_core/cuda/core/_program.pyx @@ -85,7 +85,12 @@ cdef class Program: self._h_nvvm.reset() def compile( - self, target_type: str, name_expressions: tuple | list = (), logs = None + self, + target_type: str, + name_expressions: tuple | list = (), + logs=None, + *, + cache: "ProgramCacheResource | None" = None, ) -> ObjectCode: """Compile the program to the specified target type. @@ -98,13 +103,61 @@ cdef class Program: Used for template instantiation and similar cases. logs : object, optional Object with a ``write`` method to receive compilation logs. + cache : :class:`~cuda.core.utils.ProgramCacheResource`, optional + If provided, the compiled binary is looked up in ``cache`` via a + key derived from the program's code, options, ``target_type`` and + ``name_expressions``. On a hit the cached bytes are wrapped in a + fresh :class:`~cuda.core.ObjectCode` (with the same ``target_type`` + and ``ProgramOptions.name``) and returned without re-compiling; + on a miss the compile output is stored as raw bytes (the cache + extracts ``bytes(object_code.code)``). Note that + ``ObjectCode.symbol_mapping`` is not preserved across a cache + round-trip -- callers using ``name_expressions`` who need + ``get_kernel(name_expression)`` after a hit must compile fresh + or look the mangled symbol up by hand. Options that require an + ``extra_digest`` (``include_path``, ``pre_include``, ``pch``, + ``use_pch``, ``pch_dir``, NVVM ``use_libdevice=True``, or NVRTC + ``options.name`` with a directory component) raise ``ValueError`` + via :func:`~cuda.core.utils.make_program_cache_key`; for those + compiles, use the manual ``make_program_cache_key(...)`` pattern + directly. Returns ------- :class:`~cuda.core.ObjectCode` The compiled object code. """ - return Program_compile(self, target_type, name_expressions, logs) + if cache is None: + return _program_compile_uncached(self, target_type, name_expressions, logs) + + # Deferred import to avoid a circular import between _program and + # cuda.core.utils._program_cache (the cache module already imports + # ProgramOptions from this module). Import from the leaf module so + # tests that monkeypatch make_program_cache_key via that path + # intercept reliably. + from cuda.core.utils._program_cache import make_program_cache_key + + # ``self._code`` is always stored as bytes (see ``Program_init``), + # but ``make_program_cache_key`` only accepts bytes when + # ``code_type == "nvvm"`` -- c++/ptx must be ``str``. Decode back + # to the original str for the NVRTC/linker paths so the generated + # key matches keys callers build by passing the str source + # directly. + code_for_key = self._code if self._code_type == "nvvm" else self._code.decode("utf-8") + + key = make_program_cache_key( + code=code_for_key, + code_type=self._code_type, + options=self._options, + target_type=target_type, + name_expressions=name_expressions, + ) + hit_bytes = cache.get(key) + if hit_bytes is not None: + return ObjectCode._init(hit_bytes, target_type, name=self._options.name) + compiled = _program_compile_uncached(self, target_type, name_expressions, logs) + cache[key] = compiled + return compiled @property def pch_status(self) -> str | None: @@ -503,6 +556,19 @@ class ProgramOptions: # Private Classes and Helper Functions # ============================================================================= + +def _program_compile_uncached(program, target_type, name_expressions, logs): + """Run ``Program_compile`` without the cache wrapper. + + Module-level Python function so tests can monkeypatch it from + ``cuda.core._program`` to avoid invoking NVRTC when exercising the cache + wrapper in :meth:`Program.compile`. ``Program`` itself is a ``cdef class`` + and its methods cannot be reassigned from Python, so the seam must live + outside the class. + """ + return Program_compile(program, target_type, name_expressions, logs) + + # Module-level state for NVVM lazy loading _nvvm_module = None _nvvm_import_attempted = False @@ -618,6 +684,7 @@ cdef inline int Program_init(Program self, object code, str code_type, object op self._options = options = check_or_create_options(ProgramOptions, options, "Program options") code_type = code_type.lower() + self._code_type = code_type self._compile_lock = threading.Lock() self._use_libdevice = False self._libdevice_added = False @@ -638,7 +705,7 @@ cdef inline int Program_init(Program self, object code, str code_type, object op HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcCreateProgram( &nvrtc_prog, code_ptr, name_ptr, 0, NULL, NULL)) self._h_nvrtc = create_nvrtc_program_handle(nvrtc_prog) - self._nvrtc_code = code_bytes + self._code = code_bytes self._backend = "NVRTC" self._linker = None @@ -646,8 +713,10 @@ cdef inline int Program_init(Program self, object code, str code_type, object op assert_type(code, str) if options.extra_sources is not None: raise ValueError("extra_sources is not supported by the PTX backend.") + code_bytes = code.encode() + self._code = code_bytes self._linker = Linker( - ObjectCode._init(code.encode(), code_type), options=_translate_program_options(options) + ObjectCode._init(code_bytes, code_type), options=_translate_program_options(options) ) self._backend = self._linker.backend @@ -657,10 +726,13 @@ cdef inline int Program_init(Program self, object code, str code_type, object op code = code.encode("utf-8") elif not isinstance(code, (bytes, bytearray)): raise TypeError("NVVM IR code must be provided as str, bytes, or bytearray") + self._code = bytes(code) # Coerce bytearray -> bytes so retention type is stable - code_ptr = (code) + # Use self._code (strictly bytes) for the C pointer so a bytearray + # input doesn't trip the `code` cast at runtime. + code_ptr = self._code name_ptr = options._name - code_len = len(code) + code_len = len(self._code) with nogil: HANDLE_RETURN_NVVM(NULL, cynvvm.nvvmCreateProgram(&nvvm_prog)) @@ -832,7 +904,7 @@ cdef object Program_compile_nvrtc(Program self, str target_type, object name_exp HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcSetPCHHeapSize(required)) cdef cynvrtc.nvrtcProgram retry_prog - cdef const char* code_ptr = self._nvrtc_code + cdef const char* code_ptr = self._code cdef const char* name_ptr = self._options._name with nogil: HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcCreateProgram( diff --git a/cuda_core/cuda/core/utils.py b/cuda_core/cuda/core/utils.py deleted file mode 100644 index f15d9242778..00000000000 --- a/cuda_core/cuda/core/utils.py +++ /dev/null @@ -1,8 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from cuda.core._memoryview import ( - StridedMemoryView, # noqa: F401 - args_viewable_as_strided_memory, # noqa: F401 -) diff --git a/cuda_core/cuda/core/utils/__init__.py b/cuda_core/cuda/core/utils/__init__.py new file mode 100644 index 00000000000..c5d560d3466 --- /dev/null +++ b/cuda_core/cuda/core/utils/__init__.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from cuda.core._memoryview import ( + StridedMemoryView, + args_viewable_as_strided_memory, +) + +__all__ = [ + "FileStreamProgramCache", + "InMemoryProgramCache", + "ProgramCacheResource", + "StridedMemoryView", + "args_viewable_as_strided_memory", + "make_program_cache_key", +] + +# Lazily expose the program-cache APIs so ``from cuda.core.utils import +# StridedMemoryView`` stays lightweight -- the cache backend pulls in driver, +# NVRTC, and module-load machinery that memoryview-only consumers do not need. +_LAZY_CACHE_ATTRS = frozenset( + { + "FileStreamProgramCache", + "InMemoryProgramCache", + "ProgramCacheResource", + "make_program_cache_key", + } +) + + +def __getattr__(name): + if name in _LAZY_CACHE_ATTRS: + from cuda.core.utils import _program_cache + + value = getattr(_program_cache, name) + globals()[name] = value # cache for subsequent accesses + return value + raise AttributeError(f"module 'cuda.core.utils' has no attribute {name!r}") + + +def __dir__(): + # Merge the lazy public API with the real module namespace so REPL and + # introspection tools still surface ``__file__``, ``__spec__``, etc. + return sorted(set(globals()) | set(__all__)) diff --git a/cuda_core/cuda/core/utils/_program_cache.py b/cuda_core/cuda/core/utils/_program_cache.py new file mode 100644 index 00000000000..015087b54d5 --- /dev/null +++ b/cuda_core/cuda/core/utils/_program_cache.py @@ -0,0 +1,1519 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Persistent program cache for cuda.core. + +Provides :class:`FileStreamProgramCache` -- a directory of atomically-written +entry files, safe across concurrent processes via :func:`os.replace`. Each +entry on disk is the raw compiled binary (cubin / PTX / LTO-IR / etc.) with +no pickle, JSON, or metadata wrapping, so cache files are directly consumable +by external NVIDIA tools (``cuobjdump``, ``nvdisasm``, ``cuda-gdb``, ...). +""" + +from __future__ import annotations + +import abc +import collections.abc +import contextlib +import errno +import hashlib +import os +import tempfile +import threading +import time +from pathlib import Path +from typing import Iterable, Sequence + +import platformdirs + +from cuda.core._module import ObjectCode +from cuda.core._program import ProgramOptions +from cuda.core._utils.cuda_utils import ( + driver as _driver, +) +from cuda.core._utils.cuda_utils import ( + handle_return as _handle_return, +) +from cuda.core._utils.cuda_utils import ( + nvrtc as _nvrtc, +) + +__all__ = [ + "FileStreamProgramCache", + "InMemoryProgramCache", + "ProgramCacheResource", + "make_program_cache_key", +] + + +# Exposed as a module-level flag so tests can toggle it without monkeypatching +# ``os.name`` itself (pathlib reads ``os.name`` at instantiation time). +_IS_WINDOWS = os.name == "nt" + + +def _extract_bytes(value: object) -> bytes: + """Return the raw binary bytes to store on disk. + + Accepts ``bytes``, ``bytearray``, ``memoryview``, or any + :class:`ObjectCode`. Path-backed ``ObjectCode`` (created via + ``ObjectCode.from_cubin('/path')`` etc.) is read from the filesystem + at write time so the cached entry is the binary content itself, not + a path that could later be moved or modified. + """ + if isinstance(value, (bytes, bytearray, memoryview)): + return bytes(value) + if isinstance(value, ObjectCode): + code = value.code + if isinstance(code, str): + return Path(code).read_bytes() + return bytes(code) + raise TypeError(f"cache values must be bytes-like or ObjectCode, got {type(value).__name__}") + + +def _as_key_bytes(key: object) -> bytes: + if isinstance(key, (bytes, bytearray)): + return bytes(key) + if isinstance(key, str): + return key.encode("utf-8") + raise TypeError(f"cache keys must be bytes or str, got {type(key).__name__}") + + +# --------------------------------------------------------------------------- +# Abstract base class +# --------------------------------------------------------------------------- + + +class ProgramCacheResource(abc.ABC): + """Abstract base class for compiled-program caches. + + Concrete implementations store and retrieve **raw binary bytes** keyed + by ``bytes`` or ``str``. A ``str`` key is encoded as UTF-8 before + being used, so ``"k"`` and ``b"k"`` refer to the same entry. A typical + key is produced by :func:`make_program_cache_key`, which returns + ``bytes``. + + The values written are the compiled program bytes themselves -- + cubin, PTX, LTO-IR, etc. Reads return raw bytes so cache files + remain consumable by external NVIDIA tools (``cuobjdump``, + ``nvdisasm``, ``cuda-gdb``, ...). + + Most callers don't interact with this object directly. The + recommended usage is :meth:`cuda.core.Program.compile`'s ``cache=`` + keyword, which derives the key, returns a fresh + :class:`~cuda.core.ObjectCode` on hit, and stores the compile + result on miss:: + + with FileStreamProgramCache() as cache: + obj = program.compile("cubin", cache=cache) + + The escape hatch -- only needed when the compile inputs require an + ``extra_digest`` (header / PCH content fingerprints, NVVM + libdevice) -- is to call :func:`make_program_cache_key` yourself + and use the cache as a plain ``bytes`` mapping:: + + from cuda.core._module import ObjectCode + + key = make_program_cache_key( + code=source, + code_type="c++", + options=options, + target_type="cubin", + extra_digest=header_fingerprint(), + ) + data = cache.get(key) + if data is None: + obj = program.compile("cubin") + cache[key] = obj # extracts bytes(obj.code) + else: + obj = ObjectCode._init(data, "cubin") + + The cache layer does no payload validation; bytes go in and come + back out unchanged. Symbol-mapping metadata that + :class:`~cuda.core.ObjectCode` carries when produced with NVRTC + name expressions is **not** preserved across a cache round-trip -- + the binary alone is stored. Callers that need ``symbol_mapping`` + for ``get_kernel(name_expression)`` should compile fresh, or look + the mangled symbol up by hand. + """ + + @abc.abstractmethod + def __getitem__(self, key: bytes | str) -> bytes: + """Retrieve the cached binary bytes. + + Raises + ------ + KeyError + If ``key`` is not in the cache. + """ + + @abc.abstractmethod + def __setitem__(self, key: bytes | str, value: bytes | bytearray | memoryview | ObjectCode) -> None: + """Store ``value`` under ``key``. + + Path-backed :class:`~cuda.core.ObjectCode` is read from disk at + write time so the cached entry holds the bytes, not a path. + """ + + @abc.abstractmethod + def __contains__(self, key: bytes | str) -> bool: + """Return ``True`` if ``key`` is in the cache.""" + + @abc.abstractmethod + def __delitem__(self, key: bytes | str) -> None: + """Remove the entry associated with ``key``. + + Raises + ------ + KeyError + If ``key`` is not in the cache. + """ + + @abc.abstractmethod + def __len__(self) -> int: + """Return the number of entries currently in the cache.""" + + @abc.abstractmethod + def clear(self) -> None: + """Remove every entry from the cache.""" + + def get(self, key: bytes | str, default: bytes | None = None) -> bytes | None: + """Return ``self[key]`` or ``default`` if absent.""" + try: + return self[key] + except KeyError: + return default + + def update( + self, + items: ( + collections.abc.Mapping[bytes | str, bytes | bytearray | memoryview | ObjectCode] + | collections.abc.Iterable[tuple[bytes | str, bytes | bytearray | memoryview | ObjectCode]] + ), + /, + ) -> None: + """Bulk ``__setitem__``. + + Accepts a mapping or an iterable of ``(key, value)`` pairs. Each + write goes through ``__setitem__`` so backend-specific value + coercion (e.g. extracting bytes from an :class:`~cuda.core.ObjectCode`) + and size-cap enforcement run on every entry. Not transactional -- + a failure mid-iteration leaves earlier writes committed. + """ + if isinstance(items, collections.abc.Mapping): + items = items.items() + for key, value in items: + self[key] = value + + def close(self) -> None: # noqa: B027 + """Release backend resources. No-op by default.""" + + def __enter__(self) -> ProgramCacheResource: + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.close() + + +# --------------------------------------------------------------------------- +# Key construction +# --------------------------------------------------------------------------- + + +# Bump when the key schema changes in a way that invalidates existing caches. +_KEY_SCHEMA_VERSION = 1 + +_VALID_CODE_TYPES = frozenset({"c++", "ptx", "nvvm"}) +_VALID_TARGET_TYPES = frozenset({"ptx", "cubin", "ltoir"}) + +# code_type -> allowed target_type set, mirroring Program.compile's +# SUPPORTED_TARGETS matrix in _program.pyx. +_SUPPORTED_TARGETS_BY_CODE_TYPE = { + "c++": frozenset({"ptx", "cubin", "ltoir"}), + "ptx": frozenset({"cubin", "ptx"}), + "nvvm": frozenset({"ptx", "ltoir"}), +} + + +def _backend_for_code_type(code_type: str) -> str: + if code_type == "nvvm": + return "nvvm" + if code_type == "ptx": + # Program routes PTX through Linker, not NVRTC. + return "linker" + return "nvrtc" + + +# ProgramOptions fields that reach the Linker via _translate_program_options +# (see cuda_core/cuda/core/_program.pyx). All other fields on ProgramOptions +# are NVRTC-only and must NOT perturb a PTX cache key: a PTX compile with a +# shared ProgramOptions that happens to set include_path/pch/frandom_seed +# would otherwise miss the cache unnecessarily. +_LINKER_RELEVANT_FIELDS = ( + "name", + "arch", + "max_register_count", + "time", + "link_time_optimization", + "debug", + "lineinfo", + "ftz", + "prec_div", + "prec_sqrt", + "fma", + "split_compile", + "ptxas_options", + "no_cache", +) + + +# Map each linker-relevant ProgramOptions field to the gate the Linker uses +# to turn it into a flag (see ``_prepare_nvjitlink_options`` and +# ``_prepare_driver_options`` in _linker.pyx). Collapsing inputs through +# these gates means semantically-equivalent configurations +# (``debug=False`` vs ``None``, ``time=True`` vs ``time="path"``) hash to +# the same cache key instead of forcing spurious misses. +def _gate_presence(v): + return v is not None + + +def _gate_truthy(v): + return bool(v) + + +def _gate_is_true(v): + return v is True + + +def _gate_tristate_bool(v): + return None if v is None else bool(v) + + +def _gate_identity(v): + return v + + +def _gate_ptxas_options(v): + # ``_prepare_nvjitlink_options`` emits one ``-Xptxas=`` per element, and + # treats ``str`` as a single-element sequence. Canonicalize to a tuple so + # ``"-v"`` / ``["-v"]`` / ``("-v",)`` all hash the same. An empty sequence + # emits no flags, so collapse it to ``None`` too. + if v is None: + return None + if isinstance(v, str): + return ("-Xptxas=" + v,) + if isinstance(v, collections.abc.Sequence): + if len(v) == 0: + return None + return tuple(f"-Xptxas={s}" for s in v) + return v + + +_LINKER_FIELD_GATES = { + "name": _gate_identity, + "arch": _gate_identity, + "max_register_count": _gate_identity, + "time": _gate_presence, # linker emits ``-time`` iff value is not None + "link_time_optimization": _gate_truthy, + "debug": _gate_truthy, + "lineinfo": _gate_truthy, + "ftz": _gate_tristate_bool, + "prec_div": _gate_tristate_bool, + "prec_sqrt": _gate_tristate_bool, + "fma": _gate_tristate_bool, + "split_compile": _gate_identity, + "ptxas_options": _gate_ptxas_options, + "no_cache": _gate_is_true, +} + + +# LinkerOptions fields the ``cuLink`` driver backend silently ignores +# (emits only a DeprecationWarning; no actual flag reaches the compiler). +# When the driver backend is active, collapse them to a single sentinel in +# the fingerprint so nvJitLink<->driver parity of ``ObjectCode`` doesn't +# cause cache misses from otherwise-equivalent configurations. +_DRIVER_IGNORED_LINKER_FIELDS = frozenset({"ftz", "prec_div", "prec_sqrt", "fma"}) + + +def _linker_option_fingerprint(options: ProgramOptions, *, use_driver_linker: bool | None) -> list[bytes]: + """Backend-aware fingerprint of ProgramOptions fields consumed by the Linker. + + Each field passes through the gate the Linker itself uses so equivalent + inputs (e.g. ``debug=False`` / ``None``) hash to the same bytes. When + the driver (cuLink) linker backend is in use, fields it silently + ignores collapse to one sentinel so those options don't perturb the + key on driver-backed hosts either. ``use_driver_linker=None`` means we + couldn't probe the backend; we don't collapse driver-ignored fields in + that case, to stay conservative. + """ + parts = [] + driver_ignored = use_driver_linker is True + for name in _LINKER_RELEVANT_FIELDS: + if driver_ignored and name in _DRIVER_IGNORED_LINKER_FIELDS: + parts.append(f"{name}=".encode()) + continue + gated = _LINKER_FIELD_GATES[name](getattr(options, name, None)) + parts.append(f"{name}={gated!r}".encode()) + return parts + + +# ProgramOptions fields that map to LinkerOptions fields the cuLink (driver) +# backend rejects outright (see _prepare_driver_options in _linker.pyx). +# ``split_compile_extended`` exists on LinkerOptions but is not exposed via +# ProgramOptions / _translate_program_options, so it cannot reach the driver +# linker from the cache path and is omitted here. +_DRIVER_LINKER_UNSUPPORTED_FIELDS = ("time", "ptxas_options", "split_compile") + + +def _driver_version() -> int: + return int(_handle_return(_driver.cuDriverGetVersion())) + + +def _nvrtc_version() -> tuple[int, int]: + major, minor = _handle_return(_nvrtc.nvrtcVersion()) + return int(major), int(minor) + + +def _linker_backend_and_version() -> tuple[str, str]: + """Return ``(backend, version)`` for the linker used on PTX inputs. + + Raises any underlying probe exception. ``make_program_cache_key`` catches + and mixes the exception's class name into the digest, so the same probe + failure produces the same key across processes -- the cache stays + persistent in broken environments, while never sharing a key with a + working probe (``_probe_failed`` label vs. ``driver``/``nvrtc``/...). + + nvJitLink version lookup goes through ``sys.modules`` first so we hit the + same module ``_decide_nvjitlink_or_driver()`` already loaded. That keeps + fingerprinting aligned with whichever ``cuda.bindings.nvjitlink`` import + path the linker actually uses. + """ + import sys + + from cuda.core._linker import _decide_nvjitlink_or_driver + + use_driver = _decide_nvjitlink_or_driver() + if use_driver: + return ("driver", str(_driver_version())) + nvjitlink = sys.modules.get("cuda.bindings.nvjitlink") + if nvjitlink is None: + from cuda.bindings import nvjitlink + + return ("nvJitLink", str(nvjitlink.version())) + + +def _nvvm_fingerprint() -> str: + """Stable identifier for the loaded NVVM toolchain. + + Combines the libNVVM library version (``module.version()``) with the IR + version reported by ``module.ir_version()``. The library version is the + primary invalidation lever: a libNVVM patch upgrade can change codegen + while keeping the same IR major/minor, so keying only on the IR pair + would silently reuse stale entries. Paired with cuda-core, the IR pair + adds defence in depth without making the key any less stable. + + Both calls go through ``_get_nvvm_module()`` so this fingerprint follows + the same availability / cuda-bindings-version gate that real NVVM + compilation does -- if NVVM is unusable at compile time, the probe + fails the same way and ``_probe`` mixes the failure label into the key. + """ + from cuda.core._program import _get_nvvm_module + + module = _get_nvvm_module() + lib_major, lib_minor = module.version() + major, minor, debug_major, debug_minor = module.ir_version() + return f"lib={lib_major}.{lib_minor};ir={major}.{minor}.{debug_major}.{debug_minor}" + + +# ProgramOptions fields that reference external files whose *contents* the +# cache key cannot observe without reading the filesystem. Callers that set +# any of these must supply an ``extra_digest`` covering the dependency surface +# (e.g. a hash over all reachable headers / PCH bytes). +_EXTERNAL_CONTENT_OPTIONS = ( + "include_path", + "pre_include", + "pch", + "use_pch", + "pch_dir", +) + +# ProgramOptions fields whose compilation effect is not captured in the +# returned ``ObjectCode`` -- they produce a filesystem artifact as a side +# effect. A cache hit skips compilation, so that artifact would never be +# written. Reject these outright: the persistent cache is for pure ObjectCode +# reuse, not for replaying compile-time side effects. +# * create_pch -- writes a PCH file (NVRTC). +# * time -- writes NVRTC timing info to a file. +# * fdevice_time_trace -- writes a device-compilation time trace file (NVRTC). +# These are all NVRTC-specific; the Linker's ``-time`` logs to the info log +# (not a file) and NVVM explicitly rejects all three at compile time. The +# side-effect guard is therefore gated on ``backend == "nvrtc"`` below. +_SIDE_EFFECT_OPTIONS = ("create_pch", "time", "fdevice_time_trace") + + +# ProgramOptions fields gated by plain truthiness in ``_program.pyx`` (the +# compiler writes the flag only when the value is truthy). +_BOOLEAN_OPTION_FIELDS = frozenset({"pch"}) + +# Fields whose compiler emission requires ``isinstance(value, str)`` or a +# non-empty sequence; anything else (``False``, ``int``, ``None``, ``[]``) +# is silently ignored at compile time. +_STR_OR_SEQUENCE_OPTION_FIELDS = frozenset({"include_path", "pre_include"}) + + +def _option_is_set(options: ProgramOptions, name: str) -> bool: + """Match how ``_program.pyx`` gates option emission, per field shape. + + - Boolean flags (``pch``): truthy only. + - str-or-sequence fields (``include_path``, ``pre_include``): ``str`` + (including empty) or a non-empty ``collections.abc.Sequence`` (list, + tuple, range, user subclass, ...); everything else (``False``, ``int``, + empty sequence, ``None``) is ignored by the compiler and must not + trigger a cache-time guard. + - Path/string-shaped fields (``create_pch``, ``time``, + ``fdevice_time_trace``, ``use_pch``, ``pch_dir``): ``is not None`` -- + the compiler emits ``--flag=`` for any non-None value, so + ``False`` / ``""`` / ``0`` must still count as set. + """ + value = getattr(options, name, None) + if value is None: + return False + if name in _BOOLEAN_OPTION_FIELDS: + return bool(value) + if name in _STR_OR_SEQUENCE_OPTION_FIELDS: + # Mirror ``_prepare_nvrtc_options_impl``: it checks ``isinstance(v, str)`` + # first, then ``is_sequence(v)`` (which is ``isinstance(v, Sequence)``). + # We therefore accept any ``collections.abc.Sequence`` (range, deque, + # user subclass, etc.), not just list/tuple. + if isinstance(value, str): + return True + if isinstance(value, collections.abc.Sequence): + return len(value) > 0 + return False + return True + + +def make_program_cache_key( + *, + code: str | bytes, + code_type: str, + options: ProgramOptions, + target_type: str, + name_expressions: Sequence[str | bytes | bytearray] = (), + extra_digest: bytes | None = None, +) -> bytes: + """Build a stable cache key from compile inputs. + + Parameters + ---------- + code: + Source text. ``str`` is encoded as UTF-8. + code_type: + One of ``"c++"``, ``"ptx"``, ``"nvvm"``. + options: + A :class:`cuda.core.ProgramOptions`. Its ``arch`` must be set (the + default ``ProgramOptions.__post_init__`` populates it from the current + device). + target_type: + One of ``"ptx"``, ``"cubin"``, ``"ltoir"``. + name_expressions: + Optional iterable of mangled-name lookups. Order is not significant. + Elements may be ``str`` or ``bytes``; ``"foo"`` and ``b"foo"`` produce + distinct keys because ``Program.compile`` records the original Python + object as the ``ObjectCode.symbol_mapping`` key, and ``get_kernel`` + lookups must use the same type the cache key recorded. ``bytearray`` + is rejected because ``Program.compile`` stores each element as a + dict key and ``bytearray`` is unhashable. + extra_digest: + Caller-supplied bytes mixed into the key. Required whenever + :class:`cuda.core.ProgramOptions` sets any option that pulls in + external file content (``include_path``, ``pre_include``, ``pch``, + ``use_pch``, ``pch_dir``) -- the cache cannot read + those files on the caller's behalf, so the caller must fingerprint + the header / PCH surface and pass it here. Callers may pass this for + other inputs too (embedded kernels, generated sources, etc.). + + Returns + ------- + bytes + A 32-byte blake2b digest suitable for use as a cache key. + + Raises + ------ + ValueError + If ``options`` sets an option with compile-time side effects (such as + ``create_pch``) -- a cache hit skips compilation, so the side effect + would not occur. + ValueError + If ``extra_digest`` is ``None`` while ``options`` sets any option whose + compilation effect depends on external file content that the key + cannot otherwise observe. + + Examples + -------- + For most workflows you should not call ``make_program_cache_key`` + yourself -- pass ``cache=`` to :meth:`cuda.core.Program.compile`, + which derives the key, returns the cached + :class:`~cuda.core.ObjectCode` on hit, and stores the compile + result on miss:: + + from cuda.core import Program, ProgramOptions + from cuda.core.utils import FileStreamProgramCache + + source = 'extern "C" __global__ void k(int *a){ *a = 1; }' + options = ProgramOptions(arch="sm_80") + + with FileStreamProgramCache() as cache: + obj = Program(source, "c++", options=options).compile("cubin", cache=cache) + + Call ``make_program_cache_key`` directly when the compile inputs + require an ``extra_digest`` (the cache cannot read external file + content on the caller's behalf) -- ``Program.compile(cache=...)`` + refuses those inputs with a ``ValueError`` pointing here:: + + from cuda.core._module import ObjectCode + from cuda.core.utils import FileStreamProgramCache, make_program_cache_key + + with FileStreamProgramCache() as cache: + key = make_program_cache_key( + code=source, + code_type="c++", + options=options, + target_type="cubin", + extra_digest=fingerprint_headers(options.include_path), + ) + data = cache.get(key) + if data is None: + obj = Program(source, "c++", options=options).compile("cubin") + cache[key] = obj # extracts bytes(obj.code) + else: + obj = ObjectCode._init(data, "cubin") + + The cache stores raw binary bytes -- cubin / PTX / LTO-IR with no + pickle, JSON, or framing -- so entry files are directly consumable + by external NVIDIA tools (``cuobjdump``, ``nvdisasm``, ...). Note + that an :class:`~cuda.core.ObjectCode` round-tripped through the + cache loses ``symbol_mapping``: callers that compile with + ``name_expressions`` and rely on ``get_kernel(name_expression)`` + after a cache hit must either compile fresh or look up the mangled + symbol explicitly. + + Options that read external files (``include_path``, ``pre_include``, + ``pch``, ``use_pch``, ``pch_dir``; and ``use_libdevice=True`` on the + NVVM path) require ``extra_digest`` -- fingerprint the bytes the + compiler will pull in and pass that digest so changes to those files + force a cache miss. Options that have compile-time side effects + (``create_pch``, ``time``, ``fdevice_time_trace``) cannot be cached + and raise ``ValueError``; compile directly, or disable the flag, for + those cases. + """ + # Mirror Program.compile (_program.pyx Program_init lowercases code_type + # before dispatch); a caller that passes "PTX" or "C++" must get the + # same routing and the same cache key as the lowercase form. + code_type = code_type.lower() if isinstance(code_type, str) else code_type + if code_type not in _VALID_CODE_TYPES: + raise ValueError(f"code_type={code_type!r} is not supported (must be one of {sorted(_VALID_CODE_TYPES)})") + if target_type not in _VALID_TARGET_TYPES: + raise ValueError(f"target_type={target_type!r} is not supported (must be one of {sorted(_VALID_TARGET_TYPES)})") + supported_for_code = _SUPPORTED_TARGETS_BY_CODE_TYPE[code_type] + if target_type not in supported_for_code: + raise ValueError( + f"target_type={target_type!r} is not valid for code_type={code_type!r}" + f" (supported: {sorted(supported_for_code)}). Program.compile() rejects" + f" this combination, so caching a key for it is meaningless." + ) + + backend = _backend_for_code_type(code_type) + + # Side-effect options are NVRTC-specific: ``time``/``fdevice_time_trace`` + # write artifacts via NVRTC, ``create_pch`` writes via NVRTC. The linker + # (PTX inputs) uses ``-time`` only to log to the info log (not a file), + # and NVVM explicitly rejects all three at compile time anyway, so the + # guard is only meaningful for the NVRTC path. + if backend == "nvrtc": + side_effects = [name for name in _SIDE_EFFECT_OPTIONS if _option_is_set(options, name)] + if side_effects: + raise ValueError( + f"make_program_cache_key() refuses to build a key for options that " + f"have compile-time side effects ({', '.join(side_effects)}); a " + f"cache hit skips compilation, so the side effect would not occur. " + f"Disable the option, or compile directly without the cache." + ) + + # NVVM with ``use_libdevice=True`` reads external libdevice bitcode at + # compile time (see Program_init in _program.pyx). The file is resolved + # from the active toolkit, so a changed CUDA_HOME / libdevice upgrade + # changes the linked output without touching any key input the cache can + # observe. Require the caller to supply an ``extra_digest`` that + # fingerprints the libdevice bytes (or simply disable use_libdevice for + # caching-sensitive workflows). + if backend == "nvvm" and extra_digest is None and getattr(options, "use_libdevice", None): + raise ValueError( + "make_program_cache_key() refuses to build an NVVM key with " + "use_libdevice=True and no extra_digest: the linked libdevice " + "bitcode can change out from under a cached ObjectCode. Pass an " + "extra_digest that fingerprints the libdevice file you intend " + "to link against, or disable use_libdevice." + ) + + # External-content options are NVRTC-only. ``Program.compile`` for PTX + # inputs runs ``_translate_program_options``, which drops + # ``include_path``/``pre_include``/``pch``/``use_pch``/``pch_dir`` + # entirely, and NVVM explicitly rejects them. Only NVRTC actually reads + # those external files, so gate the guard on the NVRTC backend. + if backend == "nvrtc" and extra_digest is None: + external = [name for name in _EXTERNAL_CONTENT_OPTIONS if _option_is_set(options, name)] + if external: + raise ValueError( + f"make_program_cache_key() refuses to build a key for options that " + f"pull in external file content ({', '.join(external)}) without an " + f"extra_digest; compute a digest over the header/PCH bytes the " + f"compile will read and pass it as extra_digest=..." + ) + + # PTX compiles go through Linker. When the driver (cuLink) backend is + # selected (nvJitLink unavailable), Program.compile rejects a subset of + # options that nvJitLink would accept; reject them here too so we never + # store a key for a compilation that can't succeed in this environment. + # If the probe fails we can't tell which backend will run, so skip -- the + # failed-probe branch below already taints the key. + use_driver_linker: bool | None = None + if backend == "linker": + try: + from cuda.core._linker import _decide_nvjitlink_or_driver + + use_driver_linker = _decide_nvjitlink_or_driver() + except Exception: + use_driver_linker = None + if use_driver_linker is True: + # Mirror ``_prepare_driver_options``'s exact gate: the driver + # linker checks ``is not None`` for these fields, so ``time=False`` + # or ``ptxas_options=[]`` is still a rejection. Do NOT use the + # truthiness-based ``_option_is_set`` helper here. + unsupported = [ + name for name in _DRIVER_LINKER_UNSUPPORTED_FIELDS if getattr(options, name, None) is not None + ] + if unsupported: + raise ValueError( + f"the cuLink driver linker does not support these options: " + f"{', '.join(unsupported)}; Program.compile() would reject this " + f"configuration before producing an ObjectCode." + ) + + if isinstance(code, str): + code_bytes = code.encode("utf-8") + elif isinstance(code, (bytes, bytearray)): + # Program() only accepts bytes-like ``code`` for the NVVM backend + # (_program.pyx Program_init); c++/ptx require ``str``. Mirror that + # so the cache helper doesn't mint keys for inputs the real compile + # would reject. + if backend != "nvvm": + raise TypeError( + f"code must be str for code_type={code_type!r}; bytes/bytearray are only accepted for code_type='nvvm'." + ) + code_bytes = bytes(code) + else: + raise TypeError(f"code must be str or bytes, got {type(code).__name__}") + + # For PTX inputs the Linker path reads only a subset of ProgramOptions + # (see _translate_program_options in _program.pyx); fingerprint just those + # fields so shared ProgramOptions carrying NVRTC-only flags + # (include_path, pch_*, frandom_seed, ...) don't force spurious cache + # misses on PTX. For nvrtc/nvvm backends, ProgramOptions.as_bytes gives + # the real compile-time flag surface. + if backend == "linker": + option_bytes = _linker_option_fingerprint(options, use_driver_linker=use_driver_linker) + else: + option_bytes = options.as_bytes(backend, target_type) + + # Preserve the original type of each name expression in the key: though + # ``name_expressions`` is only consumed (and only meaningful) on the + # NVRTC compile path; Program.compile silently ignores it for PTX/NVVM. + # Validation + tagging is therefore gated on the NVRTC backend so the + # cache helper doesn't reject inputs the real compile would accept. + # NVRTC tagging notes: ``"foo"`` and ``b"foo"`` get distinct tags + # because Program.compile records the original Python object as the + # ObjectCode.symbol_mapping key (_program.pyx:759), so a cached + # ObjectCode whose mapping-key type differs from what the caller's + # later ``get_kernel`` passes would silently miss. Reject + # ``bytearray`` because Program.compile also uses the raw element as a + # dict key -- bytearray is unhashable, so a cache miss would compile + # then crash in ``symbol_mapping[n] = ...``. Accepting it here would + # let the cache serve hits for inputs the uncached path can't handle. + if backend == "nvrtc": + + def _tag_name(n): + if isinstance(n, bytes): + return b"b:" + n + if isinstance(n, str): + return b"s:" + n.encode("utf-8") + if isinstance(n, bytearray): + raise TypeError( + "name_expressions elements must be str or bytes; " + "bytearray is not accepted because Program.compile uses " + "each element as a dict key and bytearray is unhashable." + ) + raise TypeError(f"name_expressions elements must be str or bytes; got {type(n).__name__}") + + names = tuple(sorted(_tag_name(n) for n in name_expressions)) + else: + names = () + + hasher = hashlib.blake2b(digest_size=32) + + def _update(label: str, payload: bytes) -> None: + hasher.update(label.encode("ascii")) + hasher.update(len(payload).to_bytes(8, "big")) + hasher.update(payload) + + def _probe(label: str, fn): + """Run an environment probe; on failure, hash the exception's + CLASS NAME (not its message) under a ``*_probe_failed`` label. + + Using only the class name keeps the digest stable across repeated + calls within one process (e.g. NVVM's loader reports different + messages on first vs. cached-failure attempts) AND across processes + that hit the same failure mode. The ``_probe_failed`` label differs + from the success labels (``driver``/``nvrtc``/...), so a broken env + never collides with a working one -- the cache "fails closed" + between broken and working environments while staying persistent + within either. + """ + try: + return fn() + except Exception as exc: + _update(f"{label}_probe_failed", type(exc).__name__.encode()) + return None + + _update("schema", str(_KEY_SCHEMA_VERSION).encode("ascii")) + if backend == "nvrtc": + nvrtc_ver = _probe("nvrtc", _nvrtc_version) + if nvrtc_ver is not None: + nv_major, nv_minor = nvrtc_ver + _update("nvrtc", f"{nv_major}.{nv_minor}".encode("ascii")) + elif backend == "linker": + # Only cuLink (driver-backed linker) goes through the CUDA driver + # for codegen. nvJitLink is a separate library, so a driver upgrade + # under it does not change the compiled bytes -- skip the driver + # version there. ``_linker_backend_and_version`` already returns the + # driver version when the driver backend is active, so the bytes + # are still in the digest via ``linker_version``. + linker = _probe("linker", _linker_backend_and_version) + if linker is not None: + lb_name, lb_version = linker + _update("linker_backend", lb_name.encode("ascii")) + _update("linker_version", lb_version.encode("ascii")) + else: + nvvm_fp = _probe("nvvm", _nvvm_fingerprint) + if nvvm_fp is not None: + _update("nvvm", nvvm_fp.encode("ascii")) + _update("code_type", code_type.encode("ascii")) + _update("target_type", target_type.encode("ascii")) + _update("code", code_bytes) + _update("option_count", str(len(option_bytes)).encode("ascii")) + for opt in option_bytes: + _update("option", bytes(opt)) + # Only NVRTC consumes ``name_expressions``; Program.compile ignores them + # on the NVVM and PTX/linker paths, so folding them into the key there + # would force spurious cache misses. + if backend == "nvrtc": + _update("names_count", str(len(names)).encode("ascii")) + for n in names: + _update("name", n) + + # ``extra_sources`` is NVVM-only -- ``Program`` raises for non-NVVM + # backends (_program.pyx). Reject up front so callers get the same + # error from the cache key path as from a real compile, and only hash + # for backend == "nvvm". + extra_sources = getattr(options, "extra_sources", None) + if extra_sources is not None and backend != "nvvm": + raise ValueError( + f"extra_sources is only valid for code_type='nvvm'; Program() rejects it for code_type={code_type!r}." + ) + if extra_sources: + _update("extra_sources_count", str(len(extra_sources)).encode("ascii")) + for item in extra_sources: + # extra_sources is a sequence of (name, source) tuples. + if isinstance(item, (tuple, list)) and len(item) == 2: + name, src = item + _update("extra_source_name", str(name).encode("utf-8")) + if isinstance(src, str): + _update("extra_source_code", src.encode("utf-8")) + elif isinstance(src, (bytes, bytearray)): + _update("extra_source_code", bytes(src)) + else: + _update("extra_source_code", str(src).encode("utf-8")) + else: + # Fallback for unexpected format. + _update("extra_source", str(item).encode("utf-8")) + # ``use_libdevice`` is only consumed on the NVVM compile path + # (_program.pyx Program_init); NVRTC and PTX/linker ignore it, so + # folding it into the key there would force spurious misses. On NVVM, + # Program_init gates it on truthiness -- False and None match. + if backend == "nvvm" and getattr(options, "use_libdevice", None): + _update("use_libdevice", b"1") + + # Program.compile() propagates options.name onto the returned ObjectCode, + # so two compiles identical in everything but name produce ObjectCodes + # that differ in their public ``name`` attribute. The key must reflect + # that or a cache hit could hand back an entry with the wrong name. + options_name = getattr(options, "name", None) + if options_name is not None: + _update("options_name", str(options_name).encode("utf-8")) + + if extra_digest is not None: + _update("extra_digest", bytes(extra_digest)) + + return hasher.digest() + + +# --------------------------------------------------------------------------- +# In-memory backend +# --------------------------------------------------------------------------- + + +class InMemoryProgramCache(ProgramCacheResource): + """In-memory program cache with LRU eviction. + + Suitable for single-process workflows that want to avoid disk I/O -- + a typical application compiles its kernels once per process and + looks them up many times. Entries live only for the lifetime of + the process; use :class:`FileStreamProgramCache` when the cache + should persist across runs. + + Like :class:`FileStreamProgramCache`, this backend is bytes-in / + bytes-out: ``__setitem__`` accepts ``bytes``, ``bytearray``, + ``memoryview``, or any :class:`~cuda.core.ObjectCode` (path-backed + too -- the file is read at write time so the cached entry holds the + binary content, not a path). ``__getitem__`` returns ``bytes``. + + Parameters + ---------- + max_size_bytes: + Optional cap on the sum of stored payload sizes. When exceeded, + LRU eviction runs until the total fits. ``None`` means + unbounded. The size-only bound mirrors + :class:`FileStreamProgramCache`. + + Notes + ----- + Recency is updated on :meth:`__getitem__`; :meth:`__contains__` is + read-only and does not shift LRU order, matching + :class:`FileStreamProgramCache`. + + Thread safety: a :class:`threading.RLock` serialises every method, + so the cache can be shared across threads without external + locking. + """ + + def __init__( + self, + *, + max_size_bytes: int | None = None, + ) -> None: + if max_size_bytes is not None and max_size_bytes < 0: + raise ValueError("max_size_bytes must be non-negative or None") + self._max_size_bytes = max_size_bytes + # Key insertion order encodes LRU order: oldest first, newest last. + # Each value is ``(payload_bytes, payload_size)``; caching the size + # avoids recomputing ``len(data)`` on every eviction pass. + self._entries: collections.OrderedDict[bytes, tuple[bytes, int]] = collections.OrderedDict() + self._total_bytes = 0 + # Reentrant so helper methods that also take the lock can nest + # without deadlocking. + self._lock = threading.RLock() + + def __getitem__(self, key: object) -> bytes: + k = _as_key_bytes(key) + with self._lock: + try: + data, _size = self._entries[k] + except KeyError: + raise KeyError(key) from None + # Touch LRU: a real read promotes the entry to "most recent" + # so eviction prefers genuinely cold entries. + self._entries.move_to_end(k) + return data + + def __setitem__(self, key: object, value: bytes | bytearray | memoryview | ObjectCode) -> None: + data = _extract_bytes(value) + size = len(data) + k = _as_key_bytes(key) + with self._lock: + existing = self._entries.pop(k, None) + if existing is not None: + self._total_bytes -= existing[1] + self._entries[k] = (data, size) + self._total_bytes += size + self._evict_to_caps() + + def __contains__(self, key: object) -> bool: + # Validate the key (mirror FileStream's behaviour: a non-str, + # non-bytes key is a programming error and should surface, not + # quietly report "not present"). + k = _as_key_bytes(key) + with self._lock: + return k in self._entries + + def __delitem__(self, key: object) -> None: + k = _as_key_bytes(key) + with self._lock: + try: + _data, size = self._entries.pop(k) + except KeyError: + raise KeyError(key) from None + self._total_bytes -= size + + def __len__(self) -> int: + with self._lock: + return len(self._entries) + + def clear(self) -> None: + with self._lock: + self._entries.clear() + self._total_bytes = 0 + + # -- eviction ------------------------------------------------------------ + + def _evict_to_caps(self) -> None: + """Evict oldest entries until the size cap is satisfied. + + Called from ``__setitem__`` after an insert/update. Pops from + the front of the OrderedDict (oldest first). If the + just-inserted entry on its own exceeds ``max_size_bytes``, the + loop will evict it too -- mirroring + :class:`FileStreamProgramCache` (a write that cannot fit does + not survive its own size-cap pass). + """ + if self._max_size_bytes is None: + return + while self._entries and self._total_bytes > self._max_size_bytes: + _k, (_data, size) = self._entries.popitem(last=False) + self._total_bytes -= size + + +# --------------------------------------------------------------------------- +# FileStream backend +# --------------------------------------------------------------------------- + + +# Composite of (on-disk-format version, key schema version): a bump in either +# one forces wipe-on-open. ``_KEY_SCHEMA_VERSION`` participates so that +# changes to the cache-key encoding can never leave orphaned entries on +# disk -- the new key would never collide with the old hash, but the file +# would otherwise sit forever counting against the size cap. +# +# Bumped from 2 -> 3 when entries switched from pickled records to raw +# binary; old caches are auto-wiped on first open by the new code. +_FILESTREAM_BACKEND_SCHEMA = 3 +_FILESTREAM_SCHEMA_VERSION = f"{_FILESTREAM_BACKEND_SCHEMA}.{_KEY_SCHEMA_VERSION}" +_ENTRIES_SUBDIR = "entries" +_TMP_SUBDIR = "tmp" +_SCHEMA_FILE = "SCHEMA_VERSION" +# Temp files older than this are assumed to belong to a crashed writer and +# are eligible for cleanup. Picked large enough that no real ``os.replace`` +# write should still be in flight (writes are bounded by mkstemp + write + +# fsync + replace, all fast on healthy disks). +_TMP_STALE_AGE_SECONDS = 3600 + + +_SHARING_VIOLATION_WINERRORS = (5, 32, 33) # ERROR_ACCESS_DENIED, ERROR_SHARING_VIOLATION, ERROR_LOCK_VIOLATION +_REPLACE_RETRY_DELAYS = (0.0, 0.005, 0.010, 0.020, 0.050, 0.100) # ~185ms budget + + +def _default_cache_dir() -> Path: + """OS-conventional default location for the file-stream cache. + + Resolves to the user-cache root for the calling user, with a + ``program-cache`` leaf so future tooling can place sibling caches + under the same ``cuda-python`` vendor directory: + + * Linux / *BSD: ``$XDG_CACHE_HOME/cuda-python/program-cache`` + (default ``~/.cache/cuda-python/program-cache``). + * macOS: ``~/Library/Caches/cuda-python/program-cache``. + * Windows: ``%LOCALAPPDATA%\\cuda-python\\program-cache`` + (Windows uses local AppData -- caches don't roam). + + Delegates to :mod:`platformdirs`, which encodes the per-platform + rules canonically (and tracks corner cases like ``FOLDERID_*`` on + Windows that an ad-hoc reimplementation would miss). + ``opinion=False`` suppresses the extra ``Cache`` component + platformdirs would otherwise insert on Windows, keeping the layout + identical across platforms (``/cuda-python/program-cache``). + """ + return platformdirs.user_cache_path("cuda-python", appauthor=False, opinion=False) / "program-cache" + + +def _replace_with_sharing_retry(tmp_path: Path, target: Path) -> bool: + """Atomic rename with Windows-specific retry on sharing/lock violations. + + Returns True on success. Returns False only after the retry budget is + exhausted on Windows with a genuine sharing violation -- the caller then + treats the cache write as dropped. Any other ``PermissionError`` (ACLs, + read-only dir, unexpected winerror, or any POSIX failure) propagates. + + ``ERROR_ACCESS_DENIED`` (winerror 5) is treated as a sharing violation + because Windows surfaces it when a file is held open without + ``FILE_SHARE_WRITE`` (Python's default for ``open(p, "wb")``) or while + a previous unlink is in ``PENDING_DELETE`` -- both are transient. + """ + for i, delay in enumerate(_REPLACE_RETRY_DELAYS): + if delay: + time.sleep(delay) + try: + os.replace(tmp_path, target) + return True + except PermissionError as exc: + if not _IS_WINDOWS or getattr(exc, "winerror", None) not in _SHARING_VIOLATION_WINERRORS: + raise + # Windows sharing violation; loop and try again unless this was the + # last attempt, in which case fall through and return False. + if i == len(_REPLACE_RETRY_DELAYS) - 1: + return False + return False + + +def _stat_and_read_with_sharing_retry(path: Path) -> tuple[os.stat_result, bytes]: + """Snapshot stat and read bytes, retrying briefly on Windows transient + sharing-violation ``PermissionError``. + + Reads race the rewriter's ``os.replace``: on Windows, the destination + can be momentarily inaccessible (winerror 5/32/33) while the rename + completes. Mirroring ``_replace_with_sharing_retry``'s budget keeps + transient contention from being mistaken for a real read failure. + + Raises ``FileNotFoundError`` on miss or after exhausting the Windows + sharing-retry budget. Non-Windows ``PermissionError`` propagates. + + On Windows, EACCES (errno 13) is treated as transient too: ``io.open`` + sometimes surfaces a pending-delete or share-mode mismatch as bare + EACCES with no ``winerror`` attribute, indistinguishable here from + a true sharing violation. Real ACL problems on a path the cache owns + would surface consistently; the bounded retry budget keeps the cost + of treating them as transient negligible. + """ + last_exc: BaseException | None = None + for delay in _REPLACE_RETRY_DELAYS: + if delay: + time.sleep(delay) + try: + return path.stat(), path.read_bytes() + except FileNotFoundError: + raise + except PermissionError as exc: + if not _IS_WINDOWS: + raise + winerror = getattr(exc, "winerror", None) + if winerror not in _SHARING_VIOLATION_WINERRORS and exc.errno != errno.EACCES: + raise + last_exc = exc + raise FileNotFoundError(path) from last_exc + + +_UTIME_SUPPORTS_FD = os.utime in os.supports_fd + + +def _touch_atime(path: Path, st_before: os.stat_result) -> None: + """Bump ``path``'s atime to "now", preserving its mtime, iff the + file's stat still matches ``st_before``. + + Eviction sorts by ``st_atime`` so reads must reliably refresh atime + regardless of OS or filesystem default behavior: + + * Linux ``relatime`` (default) only updates atime when the existing + atime is older than mtime, which would skew LRU once an entry has + been read once. + * NTFS on Windows Vista+ disables atime updates by default + (``NtfsDisableLastAccessUpdate``) and most modern installations + keep that off, so a bare read never bumps atime. + * ``noatime``-mounted filesystems disable updates entirely. + + Calling ``os.utime`` with explicit times bypasses all of the above + and writes atime directly. The stat-guard is critical: if another + process ``os.replace``-d a fresh entry into ``path`` between the + read and this touch, blindly applying ``st_before.st_mtime_ns`` + would roll the new entry's mtime back to the old value and confuse + the eviction stat-guard (which checks ``(ino, size, mtime_ns)``) + into deleting a freshly-committed file. + + Where ``os.utime`` supports file descriptors (Linux, macOS), the + fstat-then-utime pair runs against the same open fd: even if another + writer replaces the path between our ``os.open`` and the ``fstat``, + the fd still refers to the file we opened, so the comparison and the + utime both target the same inode. This closes the residual TOCTOU + window that a path-based stat + path-based utime would have. + + On Windows, ``os.utime`` is path-only; the fallback re-stats the + path and accepts a small TOCTOU window between the second stat and + the utime. That window is microseconds and the worst-case outcome + is the racing writer's mtime being rolled back by a few hundred + nanoseconds -- the eviction stat-guard would then refuse to evict + the slightly-stale entry, costing one cache miss (recompile) but + not a corrupt eviction. + + Best-effort: any ``OSError`` (read-only mount, restrictive ACLs, + ...) is swallowed -- size enforcement still bounds the cache, but + eviction degrades toward FIFO. + """ + new_atime_ns = time.time_ns() + if _UTIME_SUPPORTS_FD: + try: + fd = os.open(path, os.O_RDONLY) + except OSError: + return + try: + try: + st_now = os.fstat(fd) + except OSError: + return + if (st_now.st_ino, st_now.st_size, st_now.st_mtime_ns) != ( + st_before.st_ino, + st_before.st_size, + st_before.st_mtime_ns, + ): + return + with contextlib.suppress(OSError): + os.utime(fd, ns=(new_atime_ns, st_before.st_mtime_ns)) + finally: + os.close(fd) + return + + # Path-based fallback (Windows). Best-effort -- residual TOCTOU window + # documented above. + try: + st_now = path.stat() + except OSError: + return + if (st_now.st_ino, st_now.st_size, st_now.st_mtime_ns) != ( + st_before.st_ino, + st_before.st_size, + st_before.st_mtime_ns, + ): + return + with contextlib.suppress(OSError): + os.utime(path, ns=(new_atime_ns, st_before.st_mtime_ns)) + + +def _prune_if_stat_unchanged(path: Path, st_before: os.stat_result) -> None: + """Unlink ``path`` iff its stat still matches ``st_before``. + + Guards against a cross-process race: a reader that sees a corrupt + record can have it atomically replaced (via ``os.replace``) by a + writer before the reader decides to prune. Comparing + ``(ino, size, mtime_ns)`` before and after rules out that case -- + any mismatch means someone else wrote a new file and we must not + delete their work. The residual TOCTOU window between stat and + unlink is narrow; worst case, a very-recently-written entry is + removed and the next read recompiles. + """ + try: + st_now = path.stat() + except FileNotFoundError: + return + key_before = (st_before.st_ino, st_before.st_size, st_before.st_mtime_ns) + key_now = (st_now.st_ino, st_now.st_size, st_now.st_mtime_ns) + if key_before != key_now: + return + with contextlib.suppress(FileNotFoundError): + path.unlink() + + +class FileStreamProgramCache(ProgramCacheResource): + """Persistent program cache backed by a directory of atomic files. + + Designed for multi-process use: writes stage a temporary file and then + :func:`os.replace` it into place, so concurrent readers never observe a + partially-written entry. Each entry on disk is the raw compiled binary + -- cubin / PTX / LTO-IR -- with no header, framing, or pickle wrapper, + so the files are directly consumable by external NVIDIA tools + (``cuobjdump``, ``nvdisasm``, ``cuda-gdb``). + + Eviction is by least-recently-*read* time: every successful read bumps + the entry's ``atime``, and the size enforcer evicts oldest atime + first. + + .. note:: **Best-effort writes.** + + On Windows, ``os.replace`` raises ``PermissionError`` (winerror + 32 / 33) when another process holds the target file open. This + backend retries with bounded backoff (~185 ms) and, if still + failing, drops the cache write silently and returns success-shaped + control flow. The next call will see no entry and recompile. POSIX + and other ``PermissionError`` codes propagate. + + .. note:: **Atomic for readers, not crash-durable.** + + Each entry's temp file is ``fsync``-ed before ``os.replace``, but + the containing directory is **not** ``fsync``-ed. A host crash + between write and the next directory commit may lose recently + added entries; surviving entries remain consistent. + + .. note:: **Cross-version sharing.** + + ``_FILESTREAM_SCHEMA_VERSION`` encodes both the on-disk storage + format and the key-schema version, so a cache written by an + incompatible version is wiped on open (bumping either + ``_KEY_SCHEMA_VERSION`` or ``_FILESTREAM_BACKEND_SCHEMA`` forces + cleanup instead of leaving orphaned entries on disk). + + Within a single schema version the cache is safe to share across + ``cuda.core`` patch releases: every entry's key encodes the + relevant backend/compiler/runtime fingerprints for its compilation + path (NVRTC entries pin the NVRTC version, NVVM entries pin the + libNVVM library and IR versions, PTX/linker entries pin the chosen + linker backend and its version -- and, when the cuLink/driver + backend is selected, the driver version too; nvJitLink-backed PTX + entries are deliberately driver-version independent). Entries are + stored verbatim as the compiled binary, so cross-patch sharing + only requires that the compiler-pinning surface above stays stable + -- there is no Python-pickle compatibility involved. + + Parameters + ---------- + path: + Directory that owns the cache. Created if missing. If omitted, + the OS-conventional user cache directory is used: + ``$XDG_CACHE_HOME/cuda-python/program-cache`` (Linux, defaulting + to ``~/.cache/cuda-python/program-cache``), + ``~/Library/Caches/cuda-python/program-cache`` (macOS), or + ``%LOCALAPPDATA%\\cuda-python\\program-cache`` (Windows). + max_size_bytes: + Optional soft cap on total on-disk size. Enforced opportunistically + on writes; concurrent writers may briefly exceed it. Eviction is by + least-recently-read time (oldest ``st_atime`` first). + """ + + def __init__( + self, + path: str | os.PathLike | None = None, + *, + max_size_bytes: int | None = None, + ) -> None: + if max_size_bytes is not None and max_size_bytes < 0: + raise ValueError("max_size_bytes must be non-negative or None") + self._root = Path(path) if path is not None else _default_cache_dir() + self._entries = self._root / _ENTRIES_SUBDIR + self._tmp = self._root / _TMP_SUBDIR + self._schema_path = self._root / _SCHEMA_FILE + self._max_size_bytes = max_size_bytes + self._root.mkdir(parents=True, exist_ok=True) + self._entries.mkdir(exist_ok=True) + self._tmp.mkdir(exist_ok=True) + expected = str(_FILESTREAM_SCHEMA_VERSION) + if not self._schema_path.exists(): + self._schema_path.write_text(expected) + else: + existing = self._schema_path.read_text().strip() + if existing != expected: + # Schema mismatch: wipe incompatible entries. Losing cache + # contents is safe; returning bytes from an old format is not. + for entry in list(self._iter_entry_paths()): + with contextlib.suppress(FileNotFoundError): + entry.unlink() + self._schema_path.write_text(expected) + # Opportunistic startup sweep of orphaned temp files left by any + # crashed writers. Age-based so concurrent in-flight writes from + # other processes are preserved. + self._sweep_stale_tmp_files() + + # -- key-to-path helpers ------------------------------------------------- + + def _path_for_key(self, key: object) -> Path: + k = _as_key_bytes(key) + # Hash the key to a fixed-length identifier so arbitrary-length user + # keys never exceed per-component filename limits (typically 255 on + # ext4 / NTFS). With a 256-bit blake2b digest, the cache relies on + # cryptographic collision resistance for key uniqueness -- two + # distinct keys hashing to the same path is astronomically unlikely + # (~2^-128 with the 32-byte digest in use here). + digest = hashlib.blake2b(k, digest_size=32).hexdigest() if k else "empty" + if len(digest) < 3: + digest = digest.rjust(3, "0") + return self._entries / digest[:2] / digest[2:] + + # -- mapping API --------------------------------------------------------- + + def __contains__(self, key: object) -> bool: + # Route through __getitem__ so corrupt records / schema mismatches / + # stored-key mismatches are treated as absent (and pruned), matching + # the semantics of ``cache[key]``. + try: + self[key] + except KeyError: + return False + return True + + def __getitem__(self, key: object) -> bytes: + path = self._path_for_key(key) + try: + # The helper retries on Windows transient sharing-violation + # PermissionErrors so a racing rewriter doesn't turn a hit + # into a spurious propagated error. + st, data = _stat_and_read_with_sharing_retry(path) + except FileNotFoundError: + raise KeyError(key) from None + # Bump atime to "now" so eviction (which sorts by st_atime) treats + # this read as the entry's most recent use. Best-effort: filesystems + # mounted ``noatime`` or with restrictive ACLs may refuse, in which + # case the cap still bounds size but eviction degrades toward FIFO + # rather than true LRU. + _touch_atime(path, st) + return data + + def __setitem__(self, key: object, value: bytes | bytearray | memoryview | ObjectCode) -> None: + data = _extract_bytes(value) + target = self._path_for_key(key) + target.parent.mkdir(parents=True, exist_ok=True) + + fd, tmp_name = tempfile.mkstemp(prefix="entry-", dir=self._tmp) + tmp_path = Path(tmp_name) + try: + with os.fdopen(fd, "wb") as fh: + fh.write(data) + fh.flush() + os.fsync(fh.fileno()) + # Retry os.replace under Windows sharing/lock violations; only + # give up (and drop the cache write) after a bounded backoff, so + # transient contention is not turned into a silent miss. + # Non-sharing PermissionErrors and all POSIX PermissionErrors + # propagate immediately (real config problem). + if not _replace_with_sharing_retry(tmp_path, target): + with contextlib.suppress(FileNotFoundError): + tmp_path.unlink() + return + except BaseException: + with contextlib.suppress(FileNotFoundError): + tmp_path.unlink() + raise + self._enforce_size_cap() + + def __delitem__(self, key: object) -> None: + path = self._path_for_key(key) + try: + path.unlink() + except FileNotFoundError: + raise KeyError(key) from None + + def __len__(self) -> int: + # Count present entry files. There is no payload validation at + # the cache layer (entries are raw binary, not framed records), + # so anything that exists in ``entries/`` is a member. + count = 0 + for path in self._iter_entry_paths(): + if path.is_file(): + count += 1 + return count + + def clear(self) -> None: + # Snapshot stat alongside path so we can refuse to unlink an entry + # that was concurrently replaced by another process between the + # snapshot scan and the unlink. Same stat-guard contract as + # ``_prune_if_stat_unchanged`` and ``_enforce_size_cap``. + snapshot = [] + for path in self._iter_entry_paths(): + try: + snapshot.append((path, path.stat())) + except FileNotFoundError: + continue + for path, st_before in snapshot: + _prune_if_stat_unchanged(path, st_before) + # Sweep ONLY stale temp files. Deleting a young temp would race with + # another process between ``mkstemp`` and ``os.replace`` and turn its + # write into ``FileNotFoundError`` instead of a successful commit. + self._sweep_stale_tmp_files() + # Remove empty subdirs (best-effort; concurrent writers may re-create). + if self._entries.exists(): + for sub in sorted(self._entries.iterdir(), reverse=True): + if sub.is_dir(): + with contextlib.suppress(OSError): + sub.rmdir() + + # -- internals ----------------------------------------------------------- + + def _iter_entry_paths(self) -> Iterable[Path]: + if not self._entries.exists(): + return + for sub in self._entries.iterdir(): + if not sub.is_dir(): + continue + for entry in sub.iterdir(): + if entry.is_file(): + yield entry + + def _sweep_stale_tmp_files(self) -> None: + """Remove temp files left behind by crashed writers. + + Age threshold is conservative (``_TMP_STALE_AGE_SECONDS``) so an + in-flight write from another process is not interrupted. Best + effort: a missing file or a permission failure is ignored. + """ + if not self._tmp.exists(): + return + cutoff = time.time() - _TMP_STALE_AGE_SECONDS + for tmp in self._tmp.iterdir(): + if not tmp.is_file(): + continue + try: + if tmp.stat().st_mtime < cutoff: + tmp.unlink() + except (FileNotFoundError, PermissionError): + continue + + def _enforce_size_cap(self) -> None: + if self._max_size_bytes is None: + return + # Sweep stale temp files first so a long-dead writer's leftovers + # don't drag the apparent size up and force needless eviction. + self._sweep_stale_tmp_files() + entries = [] + total = 0 + # Count both committed entries AND surviving temp files: temp files + # occupy disk too, even if they're young. Without this the soft cap + # silently undercounts in-flight writes. + for path in self._iter_entry_paths(): + try: + st = path.stat() + except FileNotFoundError: + continue + # Carry the full stat so eviction can guard against a concurrent + # os.replace that swapped a fresh entry into this path between + # snapshot and unlink. Eviction below sorts by ``st_atime`` so + # entries that callers actually read recently survive + # write-only churn (true LRU instead of FIFO). + entries.append((st.st_atime, st.st_size, path, st)) + total += st.st_size + if self._tmp.exists(): + for tmp in self._tmp.iterdir(): + if not tmp.is_file(): + continue + try: + total += tmp.stat().st_size + except FileNotFoundError: + continue + if total <= self._max_size_bytes: + return + entries.sort(key=lambda e: e[0]) # oldest atime first + for _atime, size, path, st_before in entries: + if total <= self._max_size_bytes: + return + # _prune_if_stat_unchanged refuses if a writer replaced the file + # between snapshot and now, so eviction can't silently delete a + # freshly-committed entry from another process. + try: + stat_now = path.stat() + except FileNotFoundError: + total -= size + continue + if (stat_now.st_ino, stat_now.st_size, stat_now.st_mtime_ns) != ( + st_before.st_ino, + st_before.st_size, + st_before.st_mtime_ns, + ): + # File was replaced -- don't unlink, but update ``total`` to + # reflect the replacement's actual size or the cap check + # below could declare us done while still over the limit. + total += stat_now.st_size - size + continue + with contextlib.suppress(FileNotFoundError): + path.unlink() + total -= size diff --git a/cuda_core/docs/source/api.rst b/cuda_core/docs/source/api.rst index 88780732d54..caada4b2479 100644 --- a/cuda_core/docs/source/api.rst +++ b/cuda_core/docs/source/api.rst @@ -247,7 +247,23 @@ Utility functions :toctree: generated/ args_viewable_as_strided_memory + make_program_cache_key :template: autosummary/cyclass.rst StridedMemoryView + +Program caches +-------------- + +``Program.compile`` accepts a ``cache=`` keyword argument that integrates +with any :class:`~cuda.core.utils.ProgramCacheResource`, so callers can +avoid recompiling identical source + options + target without writing the +:func:`~cuda.core.utils.make_program_cache_key` lookup by hand. + +.. autosummary:: + :toctree: generated/ + + ProgramCacheResource + InMemoryProgramCache + FileStreamProgramCache diff --git a/cuda_core/pixi.lock b/cuda_core/pixi.lock index 4b7d2809cf1..1918ea7011d 100644 --- a/cuda_core/pixi.lock +++ b/cuda_core/pixi.lock @@ -167,6 +167,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/pango-1.56.4-hadf4263_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.47-haa7fec5_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pixman-0.46.4-h54a6638_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.9.6-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/psutil-7.2.2-py314h0f05182_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-hb9d3cd8_1002.conda @@ -372,6 +373,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pango-1.56.4-he55ef5b_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pcre2-10.47-hf841c20_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pixman-0.46.4-h7ac5ae9_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.9.6-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/psutil-7.2.2-py314h2e8dab5_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pthread-stubs-0.4-h86ecc28_1002.conda @@ -538,6 +540,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/pango-1.56.4-h03d888a_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/pcre2-10.47-hd2b5f0e_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/pixman-0.46.4-h5112557_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.9.6-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda - conda: https://conda.anaconda.org/conda-forge/win-64/psutil-7.2.2-py314hc5dbbe4_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/py-cpuinfo-9.0.0-pyhd8ed1ab_1.conda @@ -732,6 +735,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/pango-1.56.4-hadf4263_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.47-haa7fec5_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pixman-0.46.4-h54a6638_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.9.6-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/psutil-7.2.2-py314h0f05182_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-hb9d3cd8_1002.conda @@ -933,6 +937,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pango-1.56.4-he55ef5b_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pcre2-10.47-hf841c20_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pixman-0.46.4-h7ac5ae9_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.9.6-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/psutil-7.2.2-py314h2e8dab5_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pthread-stubs-0.4-h86ecc28_1002.conda @@ -1093,6 +1098,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/pango-1.56.4-h03d888a_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/pcre2-10.47-hd2b5f0e_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/pixman-0.46.4-h5112557_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.9.6-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda - conda: https://conda.anaconda.org/conda-forge/win-64/psutil-7.2.2-py314hc5dbbe4_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/py-cpuinfo-9.0.0-pyhd8ed1ab_1.conda @@ -1290,6 +1296,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/pango-1.56.4-hadf4263_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.47-haa7fec5_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pixman-0.46.4-h54a6638_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.9.6-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/psutil-7.2.2-py314h0f05182_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-hb9d3cd8_1002.conda @@ -1491,6 +1498,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pango-1.56.4-he55ef5b_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pcre2-10.47-hf841c20_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pixman-0.46.4-h7ac5ae9_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.9.6-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/psutil-7.2.2-py314h2e8dab5_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pthread-stubs-0.4-h86ecc28_1002.conda @@ -1651,6 +1659,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/pango-1.56.4-h03d888a_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/pcre2-10.47-hd2b5f0e_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/pixman-0.46.4-h5112557_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.9.6-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda - conda: https://conda.anaconda.org/conda-forge/win-64/psutil-7.2.2-py314hc5dbbe4_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/py-cpuinfo-9.0.0-pyhd8ed1ab_1.conda @@ -2412,6 +2421,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/pango-1.56.4-hadf4263_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.47-haa7fec5_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pixman-0.46.4-h54a6638_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.9.6-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-hb9d3cd8_1002.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pugixml-1.15-h3f63f65_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-client-17.0-h9a6aba3_3.conda @@ -2636,6 +2646,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pango-1.56.4-he55ef5b_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pcre2-10.47-hf841c20_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pixman-0.46.4-h7ac5ae9_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.9.6-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pthread-stubs-0.4-h86ecc28_1002.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pugixml-1.15-h6ef32b0_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-aarch64/pulseaudio-client-17.0-hcf98165_3.conda @@ -2764,6 +2775,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/pango-1.56.4-h03d888a_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/pcre2-10.47-hd2b5f0e_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/pixman-0.46.4-h5112557_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.9.6-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pycparser-2.22-pyh29332c3_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pyglet-2.1.13-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/python-3.14.3-h4b44e0e_101_cp314.conda @@ -3670,6 +3682,7 @@ packages: - numpy - cuda-bindings - cuda-pathfinder + - platformdirs >=3.0 - vc >=14.3,<15 - vc14_runtime >=14.44.35208 - ucrt >=10.0.20348.0 @@ -3693,6 +3706,7 @@ packages: - numpy - cuda-bindings - cuda-pathfinder + - platformdirs >=3.0 - vc >=14.3,<15 - vc14_runtime >=14.44.35208 - ucrt >=10.0.20348.0 @@ -3715,6 +3729,7 @@ packages: - numpy - cuda-bindings - cuda-pathfinder + - platformdirs >=3.0 - libgcc >=15 - libgcc >=15 - libstdcxx >=15 @@ -3737,6 +3752,7 @@ packages: - numpy - cuda-bindings - cuda-pathfinder + - platformdirs >=3.0 - libgcc >=15 - libgcc >=15 - libstdcxx >=15 @@ -3759,6 +3775,7 @@ packages: - numpy - cuda-bindings - cuda-pathfinder + - platformdirs >=3.0 - libgcc >=15 - libgcc >=15 - libstdcxx >=15 @@ -3781,6 +3798,7 @@ packages: - numpy - cuda-bindings - cuda-pathfinder + - platformdirs >=3.0 - libgcc >=15 - libgcc >=15 - libstdcxx >=15 @@ -11354,6 +11372,16 @@ packages: - pkg:pypi/platformdirs?source=compressed-mapping size: 25646 timestamp: 1773199142345 +- conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.9.6-pyhcf101f3_0.conda + sha256: 8f29915c172f1f7f4f7c9391cd5dac3ebf5d13745c8b7c8006032615246345a5 + md5: 89c0b6d1793601a2a3a3f7d2d3d8b937 + depends: + - python >=3.10 + - python + license: MIT + license_family: MIT + size: 25862 + timestamp: 1775741140609 - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda sha256: e14aafa63efa0528ca99ba568eaf506eb55a0371d12e6250aaaa61718d2eb62e md5: d7585b6550ad04c8c5e21097ada2888e diff --git a/cuda_core/pixi.toml b/cuda_core/pixi.toml index 1008fe9711f..87d112de612 100644 --- a/cuda_core/pixi.toml +++ b/cuda_core/pixi.toml @@ -185,6 +185,7 @@ cuda-version = "*" numpy = "*" cuda-bindings = "*" cuda-pathfinder = "*" +platformdirs = ">=3.0" [target.linux.tasks.build-cython-tests] cmd = ["$PIXI_PROJECT_ROOT/tests/cython/build_tests.sh"] diff --git a/cuda_core/pyproject.toml b/cuda_core/pyproject.toml index aa403409894..235e96d877c 100644 --- a/cuda_core/pyproject.toml +++ b/cuda_core/pyproject.toml @@ -50,6 +50,7 @@ classifiers = [ dependencies = [ "cuda-pathfinder >=1.4.2", "numpy", + "platformdirs >=3.0", ] [project.optional-dependencies] diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index 992ce336555..e7baffa9962 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -434,6 +434,23 @@ def test_nvvm_compile_invalid_target(nvvm_ir): program.close() +@nvvm_available +def test_nvvm_accepts_bytearray_input(nvvm_ir): + """Program(..., 'nvvm') must accept bytearray input. + + Regression for a bug where the NVVM init branch retained the coerced + ``self._code`` as bytes but still cast the original ``code`` object to + ```` for the C pointer -- tripping a runtime type error for + bytearray inputs before nvvmAddModuleToProgram was called. + """ + program = Program(bytearray(nvvm_ir, "utf-8"), "nvvm") + try: + assert program.backend == "NVVM" + assert program.handle is not None + finally: + program.close() + + @nvvm_available def test_nvvm_compile_invalid_ir(): """Compiling invalid NVVM IR exercises the HANDLE_RETURN_NVVM error path.""" diff --git a/cuda_core/tests/test_program_cache.py b/cuda_core/tests/test_program_cache.py new file mode 100644 index 00000000000..23dcb50256f --- /dev/null +++ b/cuda_core/tests/test_program_cache.py @@ -0,0 +1,1978 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +import abc +import time + +import pytest + + +def test_program_cache_resource_is_abstract(): + from cuda.core.utils import ProgramCacheResource + + assert issubclass(ProgramCacheResource, abc.ABC) + with pytest.raises(TypeError, match="abstract"): + ProgramCacheResource() + + +def test_program_cache_resource_requires_core_methods(): + from cuda.core.utils import ProgramCacheResource + + required = { + "__getitem__", + "__setitem__", + "__contains__", + "__delitem__", + "__len__", + "clear", + } + assert required <= ProgramCacheResource.__abstractmethods__ + + +def _build_empty_subclass(): + from cuda.core.utils import ProgramCacheResource + + class _Empty(ProgramCacheResource): + def __getitem__(self, key): + raise KeyError(key) + + def __setitem__(self, key, value): + pass + + def __contains__(self, key): + return False + + def __delitem__(self, key): + raise KeyError(key) + + def __len__(self): + return 0 + + def clear(self): + pass + + return _Empty + + +def test_program_cache_resource_default_get_returns_default_on_miss(): + sentinel = object() + cache = _build_empty_subclass()() + assert cache.get(b"missing", default=sentinel) is sentinel + + +def test_program_cache_resource_default_get_returns_none_without_default(): + cache = _build_empty_subclass()() + assert cache.get(b"missing") is None + + +def test_program_cache_resource_close_is_noop_by_default(): + cache = _build_empty_subclass()() + cache.close() # does not raise + + +def test_program_cache_resource_context_manager_closes(): + from cuda.core.utils import ProgramCacheResource + + closed = [] + + class _Tracked(ProgramCacheResource): + def __getitem__(self, key): + raise KeyError(key) + + def __setitem__(self, key, value): + pass + + def __contains__(self, key): + return False + + def __delitem__(self, key): + raise KeyError(key) + + def __len__(self): + return 0 + + def clear(self): + pass + + def close(self): + closed.append(True) + + with _Tracked(): + pass + assert closed == [True] + + +def test_cuda_core_utils_memoryview_import_is_lightweight(tmp_path): + """``from cuda.core.utils import StridedMemoryView`` must NOT transitively + import the program-cache backends; the cache modules pull in extra + driver/NVRTC machinery that memoryview-only consumers have no reason + to load.""" + import subprocess + import sys + import textwrap + + prog = textwrap.dedent(""" + import sys + # Touch the memoryview-only API. + from cuda.core.utils import StridedMemoryView, args_viewable_as_strided_memory # noqa: F401 + assert "cuda.core.utils._program_cache" not in sys.modules, ( + "importing the memoryview shim eagerly imported the cache backend: " + + str([m for m in sys.modules if m.startswith("cuda.core.utils")]) + ) + # And that the lazy attr still works on first access. + import cuda.core.utils as u + _ = u.make_program_cache_key + assert "cuda.core.utils._program_cache" in sys.modules + """) + # Run from a neutral cwd so Python's implicit ``sys.path[0]=''`` does not + # resolve to the unbuilt cuda_core source tree (which lacks the + # setuptools-scm-generated ``_version.py``). The subprocess must import + # the installed cuda.core from site-packages. + subprocess.run([sys.executable, "-c", prog], check=True, cwd=str(tmp_path)) # noqa: S603 + + +def test_cuda_core_utils_dir_includes_lazy_and_module_attrs(): + """``dir(cuda.core.utils)`` must surface BOTH the lazy public API AND + the regular module attributes (``__file__``, ``__spec__``, ...). The + package's ``__dir__`` is custom to coexist with the lazy ``__getattr__`` + shim and has regressed here before.""" + import cuda.core.utils as u + + names = dir(u) + assert "make_program_cache_key" in names + assert "StridedMemoryView" in names + assert "__file__" in names + assert "__spec__" in names + + +# --------------------------------------------------------------------------- +# make_program_cache_key +# --------------------------------------------------------------------------- + + +def _opts(**kw): + from cuda.core import ProgramOptions + + kw.setdefault("arch", "sm_80") + return ProgramOptions(**kw) + + +def _make_key(**overrides): + """Call ``make_program_cache_key`` with a sensible default baseline. + + Tests only need to state the field(s) they care about; everything + unspecified defaults to a valid cubin-from-c++ compile over "a".""" + from cuda.core.utils import make_program_cache_key + + base = dict(code="a", code_type="c++", options=_opts(), target_type="cubin") + return make_program_cache_key(**{**base, **overrides}) + + +def test_make_program_cache_key_returns_bytes(): + key = _make_key() + assert isinstance(key, bytes) + assert len(key) == 32 + + +def test_make_program_cache_key_propagates_as_bytes_typeerror(monkeypatch): + """A ``TypeError`` out of ``ProgramOptions.as_bytes`` must propagate -- + regressing this to a silent retry/fallback would mint cache keys for + inputs the real compile path rejects.""" + options = _opts() + + def _boom(*args, **kwargs): + raise TypeError("boom") + + monkeypatch.setattr(options, "as_bytes", _boom) + with pytest.raises(TypeError, match="boom"): + _make_key(options=options) + + +@pytest.mark.parametrize("code_type, code", [("c++", "void k(){}"), ("ptx", ".version 7.0")]) +def test_make_program_cache_key_is_deterministic(code_type, code): + assert _make_key(code=code, code_type=code_type) == _make_key(code=code, code_type=code_type) + + +def test_make_program_cache_key_accepts_bytes_code(): + # NVVM IR is bytes; accept both str and bytes equivalently (str is UTF-8). + k_str = _make_key(code="abc", code_type="nvvm", target_type="ptx") + k_bytes = _make_key(code=b"abc", code_type="nvvm", target_type="ptx") + assert k_str == k_bytes + + +@pytest.mark.parametrize( + "a, b", + [ + pytest.param({"code": "a"}, {"code": "b"}, id="code"), + pytest.param({"target_type": "ptx"}, {"target_type": "cubin"}, id="target_type"), + pytest.param({"options": _opts(arch="sm_80")}, {"options": _opts(arch="sm_90")}, id="arch"), + pytest.param( + {"options": _opts(use_fast_math=True)}, + {"options": _opts(use_fast_math=False)}, + id="option", + ), + pytest.param( + {"options": _opts(name="kernel-a")}, + {"options": _opts(name="kernel-b")}, + id="options.name", + ), + # no extra_digest vs some digest -- adding a digest must perturb the key. + pytest.param({}, {"extra_digest": b"\x01" * 32}, id="extra_digest_added"), + pytest.param( + {"extra_digest": b"\x01" * 32}, + {"extra_digest": b"\x02" * 32}, + id="extra_digest_value", + ), + ], +) +def test_make_program_cache_key_differs_on(a, b): + """Every invalidation axis: code, target, arch, option flag, options.name, + extra_digest presence and value.""" + assert _make_key(**a) != _make_key(**b) + + +@pytest.mark.parametrize( + "first, second", + [ + pytest.param(("driver", "13200"), ("nvJitLink", "12030"), id="backend_flip"), + pytest.param(("nvJitLink", "12030"), ("nvJitLink", "12040"), id="version_bump"), + ], +) +def test_make_program_cache_key_ptx_linker_probe_changes(first, second, monkeypatch): + """PTX keys must reflect both the linker backend choice (nvJitLink vs + driver) and its version.""" + from cuda.core.utils import _program_cache + + monkeypatch.setattr(_program_cache, "_linker_backend_and_version", lambda: first) + k1 = _make_key(code=".version 7.0", code_type="ptx") + monkeypatch.setattr(_program_cache, "_linker_backend_and_version", lambda: second) + k2 = _make_key(code=".version 7.0", code_type="ptx") + assert k1 != k2 + + +def test_make_program_cache_key_name_expressions_order_insensitive(): + assert _make_key(name_expressions=("f", "g")) == _make_key(name_expressions=("g", "f")) + + +@pytest.mark.parametrize("bad", [123, 1.5, object(), None]) +def test_make_program_cache_key_rejects_invalid_name_expressions_element(bad): + """For NVRTC, Program.compile only forwards str/bytes name_expressions; + persisting a key for an invalid input is just a foot-gun. Reject up front.""" + with pytest.raises(TypeError, match="name_expressions"): + _make_key(name_expressions=("ok", bad)) + + +@pytest.mark.parametrize( + "code_type, code, target_type", + [ + pytest.param("ptx", ".version 7.0", "cubin", id="ptx"), + pytest.param("nvvm", "abc", "ptx", id="nvvm"), + ], +) +def test_make_program_cache_key_ignores_invalid_name_expressions_for_non_nvrtc(code_type, code, target_type): + """Program.compile silently ignores name_expressions on PTX/NVVM, so + the cache helper must not reject invalid elements there either -- + otherwise legitimate non-NVRTC compiles fail the cache layer.""" + # Should not raise even though 123 isn't a valid NVRTC name. + _make_key(code=code, code_type=code_type, target_type=target_type, name_expressions=(123, object())) + + +@pytest.mark.parametrize("code_type", ["PTX", "C++", "NVVM", "Ptx", "c++"]) +def test_make_program_cache_key_normalises_code_type_case(code_type): + """Program() normalises code_type to lower; the cache helper must do + the same so callers using ``Program(code, "PTX")`` get the same routing + and the same key as the lowercase form.""" + # Pick a target valid for any of the lowered code types. + if code_type.lower() == "nvvm": + target = "ptx" + code = "abc" + elif code_type.lower() == "ptx": + target = "cubin" + code = ".version 7.0" + else: + target = "cubin" + code = "void k(){}" + upper_key = _make_key(code=code, code_type=code_type, target_type=target) + lower_key = _make_key(code=code, code_type=code_type.lower(), target_type=target) + assert upper_key == lower_key + + +def test_make_program_cache_key_name_expressions_str_bytes_distinct(): + """``Program.compile`` records the *original* Python object as the key in + ``ObjectCode.symbol_mapping``. Returning a cached ObjectCode whose + mapping-key type differs from the caller's later ``get_kernel`` lookup + would silently miss, so ``"foo"`` and ``b"foo"`` must produce distinct + cache keys.""" + assert _make_key(name_expressions=("foo",)) != _make_key(name_expressions=(b"foo",)) + + +def test_make_program_cache_key_rejects_bytearray_in_name_expressions(): + """``bytearray`` is unhashable, and ``Program.compile`` stores each + element of ``name_expressions`` as a dict key + (``symbol_mapping[n] = ...`` in ``_program.pyx``). Accepting it in the + cache helper would mean hits served for inputs the uncached compile + path crashes on -- so reject up front.""" + with pytest.raises(TypeError, match="bytearray"): + _make_key(name_expressions=("ok", bytearray(b"bad"))) + + +@pytest.mark.parametrize( + "code_type, target_type", + [ + pytest.param("c++", "cubin", id="nvrtc"), + pytest.param("ptx", "cubin", id="ptx"), + ], +) +def test_make_program_cache_key_rejects_bytes_code_outside_nvvm(code_type, target_type): + """``Program()`` only accepts bytes-like code for NVVM; c++ and PTX + require str. The cache helper must mirror that rejection.""" + with pytest.raises(TypeError, match="code must be str for code_type"): + _make_key(code=b"abc", code_type=code_type, target_type=target_type) + + +@pytest.mark.parametrize( + "code_type, code, target_type", + [ + pytest.param("c++", "void k(){}", "cubin", id="nvrtc"), + pytest.param("ptx", ".version 7.0", "cubin", id="ptx"), + ], +) +def test_make_program_cache_key_rejects_extra_sources_outside_nvvm(code_type, code, target_type): + """``Program(code, code_type)`` rejects ``extra_sources`` for non-NVVM + backends. The cache key path should mirror that and not silently + fingerprint a configuration the real compile would refuse.""" + with pytest.raises(ValueError, match="extra_sources"): + _make_key( + code=code, + code_type=code_type, + target_type=target_type, + options=_opts(extra_sources=[("foo.cu", "int x = 0;")]), + ) + + +@pytest.mark.parametrize( + "kwargs, exc_type, match", + [ + pytest.param({"code_type": "fortran"}, ValueError, "code_type", id="unknown_code_type"), + pytest.param({"target_type": "exe"}, ValueError, "target_type", id="unknown_target_type"), + pytest.param({"code": 12345}, TypeError, "code", id="non_str_bytes_code"), + # Backend-specific target matrix -- Program.compile rejects these + # combinations, so caching a key for them would be a lie. + pytest.param( + {"code_type": "ptx", "target_type": "ltoir"}, + ValueError, + "not valid for code_type", + id="ptx_cannot_ltoir", + ), + pytest.param( + {"code_type": "nvvm", "target_type": "cubin"}, + ValueError, + "not valid for code_type", + id="nvvm_cannot_cubin", + ), + ], +) +def test_make_program_cache_key_rejects(kwargs, exc_type, match): + with pytest.raises(exc_type, match=match): + _make_key(**kwargs) + + +def test_make_program_cache_key_supported_targets_matches_program_compile(): + """``_SUPPORTED_TARGETS_BY_CODE_TYPE`` duplicates the backend target + matrix in ``_program.pyx``. Guard against drift: parse the pyx source + with :mod:`tokenize` (which skips string literals and comments) to + extract ``SUPPORTED_TARGETS`` and assert the two views agree.""" + import ast + import io + import tokenize + from pathlib import Path + + from cuda.core.utils._program_cache import _SUPPORTED_TARGETS_BY_CODE_TYPE + + backend_to_code_type = {"NVRTC": "c++", "NVVM": "nvvm"} + linker_backends = ("nvJitLink", "driver") + + pyx = Path(__file__).parent.parent / "cuda" / "core" / "_program.pyx" + text = pyx.read_text() + marker_idx = text.index("cdef dict SUPPORTED_TARGETS") + tokens = tokenize.generate_tokens(io.StringIO(text[marker_idx:]).readline) + + depth = 0 + start_offset = None + end_offset = None + lines = text[marker_idx:].splitlines(keepends=True) + line_starts = [0] + for line in lines[:-1]: + line_starts.append(line_starts[-1] + len(line)) + + def _offset(row, col): + return line_starts[row - 1] + col + + for tok in tokens: + if tok.type != tokenize.OP: + continue + if tok.string == "{": + if depth == 0: + start_offset = _offset(tok.start[0], tok.start[1]) + depth += 1 + elif tok.string == "}": + depth -= 1 + if depth == 0: + end_offset = _offset(tok.end[0], tok.end[1]) + break + assert start_offset is not None and end_offset is not None, "could not locate SUPPORTED_TARGETS literal" + pyx_targets = ast.literal_eval(text[marker_idx + start_offset : marker_idx + end_offset]) + + for backend, code_type in backend_to_code_type.items(): + assert frozenset(pyx_targets[backend]) == _SUPPORTED_TARGETS_BY_CODE_TYPE[code_type], ( + backend, + code_type, + ) + linker_sets = [frozenset(pyx_targets[b]) for b in linker_backends] + assert all(s == linker_sets[0] for s in linker_sets) + assert linker_sets[0] == _SUPPORTED_TARGETS_BY_CODE_TYPE["ptx"] + + +@pytest.mark.parametrize( + "code_type, code, target_type", + [ + pytest.param("nvvm", "abc", "ptx", id="nvvm"), + pytest.param("ptx", ".version 7.0", "cubin", id="ptx"), + ], +) +def test_make_program_cache_key_ignores_name_expressions_for_non_nvrtc(code_type, code, target_type): + """Program.compile only forwards ``name_expressions`` on the NVRTC path + (_program.pyx). Folding them into the key for NVVM/PTX compiles would + cause identical compiles to miss the cache for no behavioural reason.""" + k_none = _make_key(code=code, code_type=code_type, target_type=target_type) + k_with = _make_key(code=code, code_type=code_type, target_type=target_type, name_expressions=("foo", "bar")) + assert k_none == k_with + + +@pytest.mark.parametrize( + "a, b", + [ + # ``debug`` / ``lineinfo`` / ``link_time_optimization`` are truthy-only + # gates in the linker; False and None produce identical output. + pytest.param({"debug": False}, {"debug": None}, id="debug_false_eq_none"), + pytest.param({"lineinfo": False}, {"lineinfo": None}, id="lineinfo_false_eq_none"), + pytest.param( + {"link_time_optimization": False}, + {"link_time_optimization": None}, + id="lto_false_eq_none", + ), + # ``time`` is a presence gate: the linker emits ``-time`` for any + # non-None value, so True / "path" produce the same flag. + pytest.param({"time": True}, {"time": "timing.csv"}, id="time_true_eq_path"), + # ``no_cache`` has an ``is True`` gate; False and None equivalent. + pytest.param({"no_cache": False}, {"no_cache": None}, id="no_cache_false_eq_none"), + ], +) +def test_make_program_cache_key_ptx_linker_equivalent_options_hash_same(a, b, monkeypatch): + """The linker folds several PTX-relevant fields through simple gates: + truthy-only (``debug``, ``lineinfo``, ``link_time_optimization``), + presence-only (``time``), ``is True`` (``no_cache``). Semantically + equivalent inputs under those gates must hash to the same key.""" + # Pin the linker probe so the only variable is the options gate. + from cuda.core.utils import _program_cache + + monkeypatch.setattr(_program_cache, "_linker_backend_and_version", lambda: ("nvJitLink", "12030")) + k_a = _make_key(code=".version 7.0", code_type="ptx", options=_opts(**a)) + k_b = _make_key(code=".version 7.0", code_type="ptx", options=_opts(**b)) + assert k_a == k_b + + +@pytest.mark.parametrize( + "field, a, b", + [ + pytest.param("ftz", True, False, id="ftz"), + pytest.param("prec_div", True, False, id="prec_div"), + pytest.param("prec_sqrt", True, False, id="prec_sqrt"), + pytest.param("fma", True, False, id="fma"), + ], +) +def test_make_program_cache_key_ptx_driver_ignored_fields_collapse(field, a, b, monkeypatch): + """The driver (cuLink) linker silently ignores ftz/prec_div/prec_sqrt/fma + (only emits a DeprecationWarning). Under the driver backend, those + fields must not perturb the PTX cache key -- two otherwise-equivalent + compiles differing only in these flags produce identical ObjectCode.""" + from cuda.core import _linker + + monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: True) # driver + k_a = _make_key(code=".version 7.0", code_type="ptx", options=_opts(**{field: a})) + k_b = _make_key(code=".version 7.0", code_type="ptx", options=_opts(**{field: b})) + assert k_a == k_b + + +@pytest.mark.parametrize( + "a, b", + [ + pytest.param("-v", ["-v"], id="str_vs_list"), + pytest.param("-v", ("-v",), id="str_vs_tuple"), + pytest.param(["-v"], ("-v",), id="list_vs_tuple"), + # Empty sequence emits no -Xptxas flags; must match None. + pytest.param(None, [], id="none_vs_empty_list"), + pytest.param(None, (), id="none_vs_empty_tuple"), + pytest.param([], (), id="empty_list_vs_empty_tuple"), + ], +) +def test_make_program_cache_key_ptx_ptxas_options_canonicalized(a, b, monkeypatch): + """_prepare_nvjitlink_options emits the same -Xptxas= flags for str, + list, and tuple shapes of ptxas_options. The cache key must treat them + as equivalent so equivalent compiles don't miss the cache.""" + from cuda.core import _linker + + monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: False) # nvJitLink + k_a = _make_key(code=".version 7.0", code_type="ptx", options=_opts(ptxas_options=a)) + k_b = _make_key(code=".version 7.0", code_type="ptx", options=_opts(ptxas_options=b)) + assert k_a == k_b + + +def test_make_program_cache_key_ptx_driver_ignored_fields_still_matter_under_nvjitlink(monkeypatch): + """nvJitLink does honour those fields; they must still differentiate keys there.""" + from cuda.core import _linker + + monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: False) # nvJitLink + k_a = _make_key(code=".version 7.0", code_type="ptx", options=_opts(ftz=True)) + k_b = _make_key(code=".version 7.0", code_type="ptx", options=_opts(ftz=False)) + assert k_a != k_b + + +@pytest.mark.parametrize( + "code_type, code, target_type", + [ + pytest.param("c++", "void k(){}", "cubin", id="nvrtc"), + pytest.param("ptx", ".version 7.0", "cubin", id="ptx"), + ], +) +def test_make_program_cache_key_use_libdevice_ignored_for_non_nvvm(code_type, code, target_type): + """``use_libdevice`` is only consumed on the NVVM path; NVRTC and PTX + ignore it, so toggling it must not perturb the cache key elsewhere.""" + k_off = _make_key(code=code, code_type=code_type, target_type=target_type, options=_opts(use_libdevice=False)) + k_on = _make_key(code=code, code_type=code_type, target_type=target_type, options=_opts(use_libdevice=True)) + k_none = _make_key(code=code, code_type=code_type, target_type=target_type, options=_opts(use_libdevice=None)) + assert k_off == k_on == k_none + + +def test_make_program_cache_key_nvvm_use_libdevice_requires_extra_digest(): + """NVVM with ``use_libdevice=True`` links an external libdevice bitcode + file whose contents the cache can't observe; require an extra_digest + or the cached ObjectCode can silently drift under a toolkit upgrade.""" + from cuda.core.utils import make_program_cache_key + + with pytest.raises(ValueError, match="libdevice"): + make_program_cache_key( + code="abc", + code_type="nvvm", + options=_opts(use_libdevice=True), + target_type="ptx", + ) + # With an extra_digest, it's accepted; different digests produce + # different keys so a caller can represent a libdevice change. + k_a = make_program_cache_key( + code="abc", + code_type="nvvm", + options=_opts(use_libdevice=True), + target_type="ptx", + extra_digest=b"libdev-a" * 4, + ) + k_b = make_program_cache_key( + code="abc", + code_type="nvvm", + options=_opts(use_libdevice=True), + target_type="ptx", + extra_digest=b"libdev-b" * 4, + ) + assert k_a != k_b + + +def test_make_program_cache_key_nvvm_use_libdevice_false_equals_none(): + """Program_init gates ``use_libdevice`` on truthiness, so False and None + compile identically and must hash the same way. (True without an + extra_digest is rejected; see test_...requires_extra_digest.)""" + k_none = _make_key(code="abc", code_type="nvvm", target_type="ptx", options=_opts(use_libdevice=None)) + k_false = _make_key(code="abc", code_type="nvvm", target_type="ptx", options=_opts(use_libdevice=False)) + assert k_none == k_false + # With an explicit extra_digest, True produces a different key. + k_true = _make_key( + code="abc", + code_type="nvvm", + target_type="ptx", + options=_opts(use_libdevice=True), + extra_digest=b"libdev" * 8, + ) + assert k_true != k_none + + +def test_make_program_cache_key_nvvm_library_version_changes_key(monkeypatch): + """Updating libNVVM (different ``module.version()``) must invalidate + NVVM cache entries even when the IR version stays constant; a patch + upgrade can change codegen without bumping the IR pair.""" + + class _FakeNVVM: + def __init__(self, lib_version): + self._lib_version = lib_version + + def version(self): + return self._lib_version + + def ir_version(self): + return (1, 8, 3, 0) # constant -- only the lib version varies + + fake_old = _FakeNVVM((12, 3)) + fake_new = _FakeNVVM((12, 4)) + from cuda.core import _program + + monkeypatch.setattr(_program, "_get_nvvm_module", lambda: fake_old) + k_old = _make_key(code="abc", code_type="nvvm", target_type="ptx") + monkeypatch.setattr(_program, "_get_nvvm_module", lambda: fake_new) + k_new = _make_key(code="abc", code_type="nvvm", target_type="ptx") + assert k_old != k_new + + +def test_make_program_cache_key_nvvm_fingerprint_uses_get_nvvm_module(monkeypatch): + """The fingerprint must call _get_nvvm_module() rather than importing + cuda.bindings.nvvm directly -- otherwise it bypasses the availability + /cuda-bindings-version gate and could disagree with the actual NVVM + compile path.""" + sentinel_called = {"n": 0} + + class _SentinelNVVM: + def version(self): + sentinel_called["n"] += 1 + return (12, 9) + + def ir_version(self): + return (1, 8, 3, 0) + + from cuda.core import _program + + monkeypatch.setattr(_program, "_get_nvvm_module", lambda: _SentinelNVVM()) + _make_key(code="abc", code_type="nvvm", target_type="ptx") + assert sentinel_called["n"] == 1 + + +def test_make_program_cache_key_nvvm_probe_changes_key(monkeypatch): + """NVVM keys must reflect the NVVM toolchain identity (IR version) + so an upgraded libNVVM does not silently reuse pre-upgrade entries.""" + from cuda.core.utils import _program_cache + + monkeypatch.setattr(_program_cache, "_nvvm_fingerprint", lambda: "ir=1.8.3.0") + k1 = _make_key(code="abc", code_type="nvvm", target_type="ptx") + monkeypatch.setattr(_program_cache, "_nvvm_fingerprint", lambda: "ir=2.0.3.0") + k2 = _make_key(code="abc", code_type="nvvm", target_type="ptx") + assert k1 != k2 + + +@pytest.mark.parametrize( + "option_kw", + [ + pytest.param({"time": True}, id="time_true"), + # ``_prepare_driver_options`` checks ``is not None``, so even the + # "falsy-but-set" cases must still be rejected at key time. + pytest.param({"time": False}, id="time_false"), + pytest.param({"ptxas_options": "-v"}, id="ptxas_options_str"), + pytest.param({"ptxas_options": ["-v", "-O2"]}, id="ptxas_options_list"), + pytest.param({"ptxas_options": []}, id="ptxas_options_empty_list"), + # ProgramOptions.ptxas_options also accepts tuples (and frozenset () + # literal is falsy). Lock in parity for all accepted shapes. + pytest.param({"ptxas_options": ("-v",)}, id="ptxas_options_tuple"), + pytest.param({"ptxas_options": ()}, id="ptxas_options_empty_tuple"), + pytest.param({"split_compile": 0}, id="split_compile_zero"), + pytest.param({"split_compile": 4}, id="split_compile_nonzero"), + # split_compile_extended is a LinkerOptions-only field; ProgramOptions + # does not expose it, so it cannot reach the driver linker via + # Program.compile and is not part of the cache-time guard. + ], +) +def test_make_program_cache_key_ptx_rejects_driver_linker_unsupported(option_kw, monkeypatch): + """When the driver (cuLink) linker backend is selected, options that + ``_prepare_driver_options`` rejects must also be rejected at key time + so we never cache a compilation that would fail. Uses ``is not None`` + to exactly mirror the driver-linker's own gate.""" + from cuda.core import _linker + + monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: True) # driver + with pytest.raises(ValueError, match="driver linker"): + _make_key(code=".version 7.0", code_type="ptx", options=_opts(**option_kw)) + + +def test_make_program_cache_key_ptx_accepts_driver_linker_unsupported_with_nvjitlink(monkeypatch): + """Under nvJitLink those same options are valid and must not be + rejected at key time.""" + from cuda.core import _linker + + monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: False) # nvJitLink + # Should not raise. + _make_key(code=".version 7.0", code_type="ptx", options=_opts(time=True)) + + +def test_filestream_cache_replace_retries_on_sharing_violation(tmp_path, monkeypatch): + """Under Windows sharing/lock violations, os.replace is retried with a + bounded backoff; a transient violation that clears within the budget + must still produce a successful cache write.""" + import os as _os + + from cuda.core.utils import FileStreamProgramCache, _program_cache + + monkeypatch.setattr(_program_cache, "_IS_WINDOWS", True) + + real_replace = _os.replace + calls = {"n": 0} + + def _flaky_replace(src, dst): + calls["n"] += 1 + if calls["n"] < 3: + exc = PermissionError("sharing violation") + exc.winerror = 32 + raise exc + return real_replace(src, dst) + + with FileStreamProgramCache(tmp_path / "fc") as cache: + monkeypatch.setattr(_os, "replace", _flaky_replace) + cache[b"k"] = _fake_object_code(b"v") # succeeds on third attempt + assert calls["n"] == 3 + assert cache[b"k"] == b"v" + + +@pytest.mark.parametrize( + "option_kw", + [ + # Populated path-like options + pytest.param({"include_path": "/usr/local/include"}, id="include_path"), + pytest.param({"pre_include": "stdint.h"}, id="pre_include"), + pytest.param({"pch": True}, id="pch"), + pytest.param({"pch_dir": "pch-cache"}, id="pch_dir"), + # Non-list/tuple Sequence: the compiler iterates it via ``is_sequence`` + # (``isinstance(v, Sequence)``), so the guard must too. + pytest.param({"include_path": range(1)}, id="include_path_nonempty_range"), + # Empty-string path-like options -- NVRTC still emits a flag + # (``--use-pch=``, ``--pch-dir=``, ``--pre-include=``) so the guard + # must fire for them too. + pytest.param({"use_pch": ""}, id="use_pch_empty_string"), + pytest.param({"pch_dir": ""}, id="pch_dir_empty_string"), + pytest.param({"pre_include": ""}, id="pre_include_empty_string"), + # For path-shaped fields (``use_pch``, ``pch_dir``), NVRTC's gate is + # ``is not None``, so even False emits a real flag and must be caught. + pytest.param({"use_pch": False}, id="use_pch_false"), + pytest.param({"pch_dir": False}, id="pch_dir_false"), + # ``include_path`` / ``pre_include`` are NOT in that group: the + # compiler only emits them for str or non-empty sequences, so + # ``False`` is silently ignored at compile time -- test the accept + # path below, not the reject path. + ], +) +def test_make_program_cache_key_rejects_external_content_without_extra_digest(option_kw): + """Options that pull in external file content must force an extra_digest: + the cache cannot observe header/PCH bytes, so silently omitting them + would yield stale cache hits after header edits.""" + with pytest.raises(ValueError, match="extra_digest"): + _make_key(options=_opts(**option_kw)) + + +@pytest.mark.parametrize( + "option_kw", + [ + pytest.param({"include_path": []}, id="include_path_empty_list"), + pytest.param({"include_path": ()}, id="include_path_empty_tuple"), + pytest.param({"pre_include": []}, id="pre_include_empty_list"), + # ``_prepare_nvrtc_options_impl`` only emits include_path / pre_include + # for str or non-empty sequence, so False (or any non-str non-sequence) + # is silently ignored at compile time and must not trip the guard. + pytest.param({"include_path": False}, id="include_path_false"), + pytest.param({"pre_include": False}, id="pre_include_false"), + # Empty non-list/tuple Sequence: ``_prepare_nvrtc_options_impl`` uses + # ``is_sequence`` (i.e. ``isinstance(v, Sequence)``); a zero-length + # sequence produces no emission regardless of type. + pytest.param({"include_path": range(0)}, id="include_path_empty_range"), + ], +) +def test_make_program_cache_key_accepts_empty_external_content(option_kw): + """Truly empty sequences mean 'no external inputs' -- they must not + force an extra_digest. (Empty *strings* are rejected separately because + NVRTC still emits a flag for them.)""" + _make_key(options=_opts(**option_kw)) # Should not raise. + + +def test_make_program_cache_key_ptx_ignores_nvrtc_only_options(): + """PTX compiles go through ``_translate_program_options`` which drops + NVRTC-only fields (include_path, pch_*, frandom_seed, ...). Those + fields must not perturb the PTX cache key; otherwise a shared + ProgramOptions that happens to set them causes spurious misses.""" + base = _make_key(code=".version 7.0", code_type="ptx", options=_opts()) + # Each of these only affects NVRTC, never Linker. + for kw in ( + {"define_macro": "FOO"}, + {"frandom_seed": "1234"}, + {"ofast_compile": "min"}, + {"std": "c++17"}, + {"disable_warnings": True}, + ): + assert _make_key(code=".version 7.0", code_type="ptx", options=_opts(**kw)) == base, kw + + +@pytest.mark.parametrize( + "option_kw", + [ + pytest.param({"include_path": "/usr/local/include"}, id="include_path"), + pytest.param({"pre_include": "stdint.h"}, id="pre_include"), + pytest.param({"pch": True}, id="pch"), + pytest.param({"use_pch": "pch.file"}, id="use_pch"), + pytest.param({"pch_dir": "pch-cache"}, id="pch_dir"), + ], +) +def test_make_program_cache_key_accepts_external_content_options_for_ptx(option_kw): + """The external-content guard is NVRTC-only: ``Program.compile`` for PTX + inputs translates options via ``_translate_program_options``, which + drops include_path/pre_include/PCH fields entirely. A PTX compile must + not be blocked just because a reused ProgramOptions object carries + irrelevant header settings.""" + _make_key(code=".version 7.0", code_type="ptx", options=_opts(**option_kw)) # no raise + + +def test_make_program_cache_key_accepts_external_content_with_extra_digest(): + """With an extra_digest, external-content options are accepted and + different digests produce different keys so callers can represent + header edits.""" + opts = _opts(include_path="/usr/local/include") + k_a = _make_key(options=opts, extra_digest=b"header-a" * 4) + k_b = _make_key(options=opts, extra_digest=b"header-b" * 4) + assert k_a != k_b + + +@pytest.mark.parametrize( + "option_kw, extra_digest", + [ + pytest.param({"create_pch": "out.pch"}, None, id="create_pch"), + # Even with extra_digest, create_pch is rejected: a cache hit skips + # compilation, so the side effect (writing the PCH) would not run. + pytest.param({"create_pch": "out.pch"}, b"x" * 32, id="create_pch_with_extra_digest"), + pytest.param({"create_pch": ""}, None, id="create_pch_empty_string"), + # NVRTC emits ``--create-pch=False`` for any non-None value, so False + # still triggers the side effect and must be rejected. + pytest.param({"create_pch": False}, None, id="create_pch_false"), + pytest.param({"time": "timing.csv"}, None, id="time"), + pytest.param({"time": False}, None, id="time_false"), + pytest.param({"fdevice_time_trace": "trace.json"}, None, id="fdevice_time_trace"), + pytest.param({"fdevice_time_trace": False}, None, id="fdevice_time_trace_false"), + ], +) +def test_make_program_cache_key_rejects_side_effect_options_nvrtc(option_kw, extra_digest): + """Options that write files as a compile-time side effect must refuse + key generation when the target backend is NVRTC; a cache hit would skip + compilation and the artifact would never be produced.""" + with pytest.raises(ValueError, match="side effect"): + _make_key(options=_opts(**option_kw), extra_digest=extra_digest) + + +@pytest.mark.parametrize( + "option_kw", + [ + # ``time`` goes through Linker's ``-time`` flag which only logs to the + # info log -- no filesystem side effect -- so PTX compiles with + # ``time=True`` must cache normally. + pytest.param({"time": True}, id="time_true"), + pytest.param({"time": "whatever.csv"}, id="time_path"), + ], +) +def test_make_program_cache_key_accepts_side_effect_options_for_ptx(option_kw): + """The side-effect guard is NVRTC-specific: PTX (linker) and NVVM must + not be blocked by options whose side effects only apply under NVRTC.""" + _make_key(code=".version 7.0", code_type="ptx", options=_opts(**option_kw)) # no raise + + +@pytest.mark.parametrize( + "code_type, code, target_type", + [ + pytest.param("c++", "a", "cubin", id="nvrtc"), + pytest.param("ptx", ".version 7.0", "cubin", id="linker"), + pytest.param("nvvm", "abc", "ptx", id="nvvm"), + ], +) +def test_make_program_cache_key_survives_cuda_core_version_change(code_type, code, target_type, monkeypatch): + """The docstring promises cross-patch sharing within a schema version, so + cuda.core's own ``__version__`` must NOT be mixed into the digest.""" + import cuda.core._version as _version_mod + + monkeypatch.setattr(_version_mod, "__version__", "0.0.0") + k_a = _make_key(code=code, code_type=code_type, target_type=target_type) + monkeypatch.setattr(_version_mod, "__version__", "999.999.999") + k_b = _make_key(code=code, code_type=code_type, target_type=target_type) + assert k_a == k_b + + +def test_make_program_cache_key_driver_version_does_not_perturb_ptx_under_nvjitlink(monkeypatch): + """nvJitLink does NOT route PTX compilation through cuLink, so a + changing driver version must not invalidate PTX cache keys when + nvJitLink is the active linker backend.""" + from cuda.core.utils import _program_cache + + monkeypatch.setattr(_program_cache, "_linker_backend_and_version", lambda: ("nvJitLink", "12030")) + monkeypatch.setattr(_program_cache, "_driver_version", lambda: 13200) + k_a = _make_key(code=".version 7.0", code_type="ptx") + monkeypatch.setattr(_program_cache, "_driver_version", lambda: 13300) + k_b = _make_key(code=".version 7.0", code_type="ptx") + assert k_a == k_b + + +@pytest.mark.parametrize( + "code_type, code, target_type", + [ + pytest.param("c++", "a", "cubin", id="nvrtc"), + pytest.param("nvvm", "abc", "ptx", id="nvvm"), + ], +) +def test_make_program_cache_key_driver_probe_failure_does_not_perturb_non_linker( + code_type, code, target_type, monkeypatch +): + """The driver version is only consumed on the linker (PTX) path because + cuLink runs through the driver. NVRTC and NVVM produce identical bytes + regardless of the driver version, so a failed driver probe must NOT + perturb their cache keys -- otherwise driver upgrades would invalidate + perfectly good caches.""" + from cuda.core.utils import _program_cache + + def _broken(): + raise RuntimeError("driver probe failed") + + k_ok = _make_key(code=code, code_type=code_type, target_type=target_type) + monkeypatch.setattr(_program_cache, "_driver_version", _broken) + k_broken = _make_key(code=code, code_type=code_type, target_type=target_type) + assert k_ok == k_broken + + +@pytest.mark.parametrize( + "probe_name, code_type, code", + [ + pytest.param("_nvrtc_version", "c++", "a", id="nvrtc"), + pytest.param("_linker_backend_and_version", "ptx", ".ptx", id="linker"), + ], +) +def test_make_program_cache_key_fails_closed_on_probe_failure(probe_name, code_type, code, monkeypatch): + """A failed probe (a) must produce a key that differs from a working + probe (so environments never silently share cache entries), and (b) + must produce a *stable* key across calls -- otherwise the persistent + cache could not be reused in broken environments. ``_driver_version`` + is exercised separately because it's only invoked transitively from + ``_linker_backend_and_version`` on the cuLink driver path.""" + from cuda.core.utils import _program_cache + + def _broken(): + raise RuntimeError("probe failed") + + k_ok = _make_key(code=code, code_type=code_type) + monkeypatch.setattr(_program_cache, probe_name, _broken) + k_broken1 = _make_key(code=code, code_type=code_type) + k_broken2 = _make_key(code=code, code_type=code_type) + assert k_ok != k_broken1 + assert k_broken1 == k_broken2 # stable: same failure -> same key + + +def test_make_program_cache_key_driver_probe_failure_taints_ptx_under_cuLink(monkeypatch): + """When the driver linker is active, _linker_backend_and_version + invokes _driver_version internally; a failing driver probe must (a) + perturb the PTX key away from the success key, AND (b) be stable + across repeated calls so the persistent cache stays usable in the + failed environment.""" + from cuda.core import _linker + from cuda.core.utils import _program_cache + + def _broken(): + raise RuntimeError("driver probe failed") + + monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: True) + k_ok = _make_key(code=".ptx", code_type="ptx") + monkeypatch.setattr(_program_cache, "_driver_version", _broken) + k_broken1 = _make_key(code=".ptx", code_type="ptx") + k_broken2 = _make_key(code=".ptx", code_type="ptx") + assert k_ok != k_broken1 + assert k_broken1 == k_broken2 # stable: same failure -> same key + + +def _fake_object_code(payload: bytes = b"fake-cubin", name: str = "unit"): + """Build an ObjectCode without touching the driver.""" + from cuda.core._module import ObjectCode + + return ObjectCode._init(payload, "cubin", name=name) + + +# --------------------------------------------------------------------------- +# FileStreamProgramCache -- single-process CRUD +# --------------------------------------------------------------------------- + + +def test_filestream_cache_empty_on_create(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache: + assert len(cache) == 0 + assert b"nope" not in cache + with pytest.raises(KeyError): + cache[b"nope"] + + +def test_filestream_cache_roundtrip(tmp_path): + """Cache returns the exact bytes that were written. ObjectCode metadata + (name, code_type, symbol_mapping) is NOT preserved -- the cache stores + just the binary.""" + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache[b"k1"] = _fake_object_code(b"v1", name="x") + assert b"k1" in cache + assert cache[b"k1"] == b"v1" + + +def test_filestream_cache_delete(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache[b"k"] = _fake_object_code() + del cache[b"k"] + assert b"k" not in cache + with pytest.raises(KeyError): + del cache[b"k"] + + +def test_filestream_cache_len_counts_all(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache[b"a"] = _fake_object_code(b"1") + cache[b"b"] = _fake_object_code(b"2") + cache[b"c"] = _fake_object_code(b"3") + assert len(cache) == 3 + + +def test_filestream_cache_clear(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + root = tmp_path / "fc" + with FileStreamProgramCache(root) as cache: + cache[b"a"] = _fake_object_code() + cache.clear() + assert len(cache) == 0 + + +def test_filestream_cache_persists_across_reopen(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + root = tmp_path / "fc" + with FileStreamProgramCache(root) as cache: + cache[b"k"] = _fake_object_code(b"persisted") + with FileStreamProgramCache(root) as cache: + assert cache[b"k"] == b"persisted" + + +def test_filestream_cache_permission_error_propagates_on_posix(tmp_path, monkeypatch): + """On non-Windows, PermissionError from os.replace is a real config error + and must not be silently swallowed.""" + import os as _os + + from cuda.core.utils import FileStreamProgramCache, _program_cache + + monkeypatch.setattr(_program_cache, "_IS_WINDOWS", False) + + with FileStreamProgramCache(tmp_path / "fc") as cache: + + def _denied(src, dst): + raise PermissionError("denied") + + monkeypatch.setattr(_os, "replace", _denied) + with pytest.raises(PermissionError, match="denied"): + cache[b"k"] = _fake_object_code(b"v") + + +def test_filestream_cache_write_phase_permission_error_propagates_on_windows(tmp_path, monkeypatch): + """Even on Windows, a PermissionError from the write phase (mkstemp / + fdopen / fsync) is a real config problem -- the Windows carve-out is + only for the os.replace race. A write-phase error must propagate.""" + from cuda.core.utils import FileStreamProgramCache, _program_cache + + monkeypatch.setattr(_program_cache, "_IS_WINDOWS", True) + + def _denied(*args, **kwargs): + raise PermissionError("mkstemp denied") + + monkeypatch.setattr(_program_cache.tempfile, "mkstemp", _denied) + + with FileStreamProgramCache(tmp_path / "fc") as cache, pytest.raises(PermissionError, match="mkstemp"): + cache[b"k"] = _fake_object_code(b"v") + + +@pytest.mark.parametrize( + "winerror, should_raise", + [ + pytest.param(5, False, id="access_denied_swallowed"), + pytest.param(32, False, id="sharing_violation_swallowed"), + pytest.param(33, False, id="lock_violation_swallowed"), + pytest.param(1, True, id="other_winerror_propagates"), + pytest.param(None, True, id="no_winerror_propagates"), + ], +) +def test_filestream_cache_permission_error_windows_is_narrowed(tmp_path, monkeypatch, winerror, should_raise): + """On Windows, ERROR_ACCESS_DENIED (5), ERROR_SHARING_VIOLATION (32) and + ERROR_LOCK_VIOLATION (33) are all transient "target held open by another + process / pending delete" cases worth swallowing after the bounded retry. + Any other PermissionError -- unrelated winerrors, missing winerror + attribute, etc. -- is a real problem and must propagate.""" + import os as _os + + from cuda.core.utils import FileStreamProgramCache, _program_cache + + monkeypatch.setattr(_program_cache, "_IS_WINDOWS", True) + + def _denied(src, dst): + exc = PermissionError("simulated") + exc.winerror = winerror + raise exc + + with FileStreamProgramCache(tmp_path / "fc") as cache: + monkeypatch.setattr(_os, "replace", _denied) + if should_raise: + with pytest.raises(PermissionError, match="simulated"): + cache[b"k"] = _fake_object_code(b"v") + else: + cache[b"k"] = _fake_object_code(b"v") # swallowed + assert b"k" not in cache + + +def test_filestream_cache_atomic_no_half_written_file(tmp_path, monkeypatch): + # Simulate a crash during write: patch os.replace to raise. + import os as _os + + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache: + + def _boom(src, dst): + raise RuntimeError("crash during replace") + + monkeypatch.setattr(_os, "replace", _boom) + with pytest.raises(RuntimeError, match="crash"): + cache[b"k"] = _fake_object_code(b"v") + monkeypatch.undo() + assert b"k" not in cache + + +def test_filestream_cache_prune_only_if_stat_unchanged(tmp_path): + """The reader-unlink-vs-writer-replace race: if a concurrent writer + atomically replaced a file between the reader's read and the reader's + prune, the pruner must NOT delete the replacement.""" + from cuda.core.utils import FileStreamProgramCache + from cuda.core.utils._program_cache import _prune_if_stat_unchanged + + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache[b"k"] = _fake_object_code(b"v1") + path = cache._path_for_key(b"k") + stale_stat = path.stat() + # Simulate a concurrent writer replacing the file. + time.sleep(0.02) + cache[b"k"] = _fake_object_code(b"v2") + + # Reader decides to prune using the stale stat; the guard refuses. + _prune_if_stat_unchanged(path, stale_stat) + assert path.exists() + + # With a fresh stat matching the current file, pruning proceeds. + _prune_if_stat_unchanged(path, path.stat()) + assert not path.exists() + + +def test_filestream_cache_touch_atime_only_if_stat_unchanged(tmp_path): + """The atime-touch is also stat-guarded so a racing rewriter's freshly + replaced file does NOT get its mtime rolled back to the previous + entry's value. Without the guard, the eviction stat-check (which keys + on (ino, size, mtime_ns)) would mistake the replacement for the old + entry and delete a just-committed file.""" + from cuda.core.utils import FileStreamProgramCache + from cuda.core.utils._program_cache import _touch_atime + + same_size_bytes = b"v" * 64 + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache[b"k"] = same_size_bytes + path = cache._path_for_key(b"k") + stale_stat = path.stat() + # Concurrent writer replaces with same-size payload (same st_size, + # different ino/mtime) -- this is the dangerous case: ino and + # mtime differ, only the stat-guard saves us. + time.sleep(0.02) + cache[b"k"] = same_size_bytes + new_mtime_ns = path.stat().st_mtime_ns + + _touch_atime(path, stale_stat) + # The new file's mtime must be untouched. + assert path.stat().st_mtime_ns == new_mtime_ns + + # With a stat that matches the current file, atime is updated and + # mtime is preserved. + fresh_stat = path.stat() + _touch_atime(path, fresh_stat) + after = path.stat() + assert after.st_mtime_ns == fresh_stat.st_mtime_ns + assert after.st_atime_ns >= fresh_stat.st_atime_ns + + +def test_filestream_cache_returns_bytes_verbatim(tmp_path): + """The cache stores raw binary and does no payload validation: whatever + bytes were written are returned exactly. External tools (cuobjdump, + nvdisasm) can read the entry file directly.""" + from cuda.core.utils import FileStreamProgramCache + + root = tmp_path / "fc" + payload = b"\x7fELF\x02\x01\x01\x00" + b"\xab" * 256 # plausible cubin header + with FileStreamProgramCache(root) as cache: + cache[b"k"] = payload + assert cache[b"k"] == payload + path = cache._path_for_key(b"k") + # On-disk content equals the input bytes verbatim -- no header, + # no pickle frame, no length prefix. + assert path.read_bytes() == payload + + +def test_filestream_cache_accepts_bytes_directly(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache[b"k"] = b"raw" + assert cache[b"k"] == b"raw" + + +def test_filestream_cache_accepts_bytearray_and_memoryview(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache[b"a"] = bytearray(b"ba") + cache[b"b"] = memoryview(b"mv") + assert cache[b"a"] == b"ba" + assert cache[b"b"] == b"mv" + + +def test_filestream_cache_rejects_non_bytes_non_object_code(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache, pytest.raises(TypeError, match="bytes-like or ObjectCode"): + cache[b"k"] = "a string" + + +def test_filestream_cache_accepts_path_backed_object_code(tmp_path): + """Path-backed ObjectCode is now read at write time so the cache stores + the binary content (not the path), keeping cache files self-contained + even if the source path is later moved or deleted.""" + from cuda.core._module import ObjectCode + from cuda.core.utils import FileStreamProgramCache + + src = tmp_path / "src.cubin" + src.write_bytes(b"hello-cubin-bytes") + path_backed = ObjectCode.from_cubin(str(src), name="x") + + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache[b"k"] = path_backed + assert cache[b"k"] == b"hello-cubin-bytes" + + # Mutating / removing the source must not affect the cached entry. + src.unlink() + with FileStreamProgramCache(tmp_path / "fc") as cache: + assert cache[b"k"] == b"hello-cubin-bytes" + + +def test_program_cache_resource_update_accepts_mapping_and_pairs(tmp_path): + """``update`` is a default ABC method; it must accept either a Mapping + or an iterable of (key, value) pairs and dispatch each item through + ``__setitem__`` so backend coercion (bytes extraction, size-cap + enforcement) still runs.""" + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc-mapping") as cache: + cache.update({b"a": b"v-a", b"b": b"v-b"}) + assert cache[b"a"] == b"v-a" + assert cache[b"b"] == b"v-b" + + with FileStreamProgramCache(tmp_path / "fc-pairs") as cache: + cache.update([(b"x", b"v-x"), (b"y", b"v-y")]) + assert cache[b"x"] == b"v-x" + assert cache[b"y"] == b"v-y" + + +def test_filestream_cache_input_forms_are_byte_equivalent(tmp_path): + """Whether the caller writes raw bytes, a bytearray, a memoryview, a + bytes-backed ObjectCode, or a path-backed ObjectCode pointing at a file + with the same bytes, the cache content is byte-identical and the on-disk + file has those exact bytes. Demonstrates the transparency contract: + callers don't have to normalise their input shape themselves.""" + from cuda.core._module import ObjectCode + from cuda.core.utils import FileStreamProgramCache + + payload = b"\x7fELF\x02\x01\x01\x00fake-cubin-bytes" + src = tmp_path / "src.cubin" + src.write_bytes(payload) + + inputs = { + b"raw-bytes": payload, + b"bytearray": bytearray(payload), + b"memoryview": memoryview(payload), + b"obj-bytes-backed": ObjectCode._init(payload, "cubin", name="x"), + b"obj-path-backed": ObjectCode.from_cubin(str(src), name="y"), + } + + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache.update(inputs) + for k in inputs: + assert cache[k] == payload, f"value for {k!r} round-tripped to a different byte string" + on_disk = cache._path_for_key(k).read_bytes() + assert on_disk == payload, f"on-disk file for {k!r} is not the raw payload" + + +def test_filestream_cache_rejects_negative_size_cap(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with pytest.raises(ValueError, match="non-negative"): + FileStreamProgramCache(tmp_path / "fc", max_size_bytes=-1) + + +def test_default_cache_dir_lives_under_user_cache_root(): + """We delegate per-platform path resolution to ``platformdirs``, so we + only verify the cuda-python-owned suffix; the OS-specific bits are + platformdirs' responsibility. + + ``opinion=False`` matters: with the default ``opinion=True`` on + Windows, platformdirs inserts an extra ``Cache`` component that + would diverge from the documented ``/cuda-python/program-cache`` + layout. Asserting against the same flag pins the layout invariant. + """ + import platformdirs + + from cuda.core.utils._program_cache import _default_cache_dir + + expected = platformdirs.user_cache_path("cuda-python", appauthor=False, opinion=False) / "program-cache" + assert _default_cache_dir() == expected + # Path must end with cuda-python/program-cache regardless of platform; + # no extra "Cache" component sneaks in. + assert _default_cache_dir().parts[-2:] == ("cuda-python", "program-cache") + + +def test_filestream_cache_uses_default_dir_when_path_omitted(tmp_path, monkeypatch): + from cuda.core.utils import FileStreamProgramCache, _program_cache + + monkeypatch.setattr(_program_cache, "_default_cache_dir", lambda: tmp_path / "default-fc") + + with FileStreamProgramCache() as cache: + cache[b"k"] = b"hello" + assert cache[b"k"] == b"hello" + assert (tmp_path / "default-fc" / "SCHEMA_VERSION").is_file() + + +def test_filestream_cache_sweeps_stale_tmp_files_on_open(tmp_path): + """A crashed writer can leave files in ``tmp/``; the next ``open`` must + sweep ones older than the staleness threshold so disk usage doesn't + grow without bound.""" + import os as _os + + from cuda.core.utils import FileStreamProgramCache, _program_cache + + root = tmp_path / "fc" + # Create the cache directory layout, then plant two temp files: one + # young (must be preserved as it could be an in-flight write) and one + # ancient (must be swept). + with FileStreamProgramCache(root): + pass + young = root / "tmp" / "entry-young" + young.write_bytes(b"in-flight") + ancient = root / "tmp" / "entry-ancient" + ancient.write_bytes(b"crashed-writer-leftover") + ancient_mtime = time.time() - _program_cache._TMP_STALE_AGE_SECONDS - 60 + _os.utime(ancient, (ancient_mtime, ancient_mtime)) + + with FileStreamProgramCache(root): + # Reopen triggers _sweep_stale_tmp_files. + assert young.exists(), "young temp file must not be swept" + assert not ancient.exists(), "ancient temp file should have been swept" + + +def test_filestream_cache_clear_preserves_young_tmp_files(tmp_path): + """clear() must not delete young temp files: another process could be + mid-write between ``mkstemp`` and ``os.replace``, and unlinking under + it turns the writer's harmless rename into ``FileNotFoundError``. + Stale temps (older than the threshold) are still swept.""" + import os as _os + + from cuda.core.utils import FileStreamProgramCache, _program_cache + + root = tmp_path / "fc" + with FileStreamProgramCache(root) as cache: + cache[b"k"] = _fake_object_code(b"v") + young_tmp = root / "tmp" / "entry-young" + young_tmp.write_bytes(b"in-flight") + ancient_tmp = root / "tmp" / "entry-ancient" + ancient_tmp.write_bytes(b"crashed") + ancient_mtime = time.time() - _program_cache._TMP_STALE_AGE_SECONDS - 60 + _os.utime(ancient_tmp, (ancient_mtime, ancient_mtime)) + + with FileStreamProgramCache(root) as cache: + cache.clear() + # Committed entry is gone, ancient orphan is gone, young temp survives. + # Filenames are hash-like (no extension), so use a file filter rather + # than a "*.*" glob. + remaining_entries = [p for p in (root / "entries").rglob("*") if p.is_file()] + assert not remaining_entries + assert young_tmp.exists() + assert not ancient_tmp.exists() + + +def test_filestream_cache_clear_does_not_unlink_replaced_file(tmp_path): + """``clear()``'s scan-then-unlink loop must use the stat-guard so a + concurrent writer's ``os.replace`` between snapshot and unlink doesn't + delete the fresh entry. Race injection: subclass the cache and have + ``_iter_entry_paths``'s post-yield cleanup os.replace path_a, then call + ``clear()`` and verify the fresh contents survive.""" + import os as _os + + from cuda.core.utils import FileStreamProgramCache + + root = tmp_path / "fc" + with FileStreamProgramCache(root) as cache: + cache[b"a"] = _fake_object_code(b"A" * 200, name="a") + cache[b"b"] = _fake_object_code(b"B" * 200, name="b") + path_a = cache._path_for_key(b"a") + + class _RaceCache(FileStreamProgramCache): + race_armed = True + + def _iter_entry_paths(self): + yield from super()._iter_entry_paths() + # Generator cleanup runs at StopIteration, between clear()'s + # scan and its unlink loop. + if _RaceCache.race_armed and path_a.exists(): + _RaceCache.race_armed = False + tmp = path_a.parent / "_inflight" + tmp.write_bytes(b"\x80\x05fresh-by-other-writer-" * 32) + _os.replace(tmp, path_a) + + with _RaceCache(root) as cache: + cache.clear() + + # The fresh file must survive: clear() saw a stat mismatch and skipped. + assert path_a.exists(), "stat guard failed -- clear() unlinked a concurrently-replaced file" + assert path_a.read_bytes().startswith(b"\x80\x05fresh-by-other-writer-") + + +def test_filestream_cache_clear_does_not_break_concurrent_writer(tmp_path): + """Simulate a writer that has already produced a temp file but has not + yet executed ``os.replace``; a concurrent ``clear()`` from another + cache instance must NOT unlink that temp, so the writer's + ``os.replace`` still succeeds.""" + import os as _os + + from cuda.core.utils import FileStreamProgramCache + + root = tmp_path / "fc" + with FileStreamProgramCache(root) as cache: + cache[b"seed"] = _fake_object_code(b"seed") + + # Stage a temp file that mimics an in-flight write. + inflight_tmp = root / "tmp" / "entry-inflight" + inflight_tmp.write_bytes(b"in-flight payload") # contents do not matter + + # Concurrent clear() from another cache handle. + with FileStreamProgramCache(root) as other: + other.clear() + + # The writer can now finish: rename the staged file into entries/. + target = root / "entries" / "ab" / "cdef" + target.parent.mkdir(parents=True, exist_ok=True) + _os.replace(inflight_tmp, target) + assert target.exists() + + +def test_filestream_cache_size_cap_does_not_unlink_replaced_file(tmp_path): + """The PRODUCTION ``_enforce_size_cap`` must compare the snapshot stat + to the current stat before unlinking; if the file was replaced under + us (a concurrent writer's ``os.replace``), the unlink is skipped. + + Race injection without reimplementing the method: subclass the cache + and override only ``_iter_entry_paths`` so that the cleanup code + *after* the generator's last yield runs an ``os.replace`` on path_a. + Python's for-loop calls ``next()`` until ``StopIteration``; the + generator code after its last yield runs at that ``StopIteration``, + which is exactly between ``_enforce_size_cap``'s scan loop and its + eviction loop. Eviction's per-entry re-stat then sees a different + stat for path_a and the production code's stat-guard must skip it. + """ + import os as _os + + from cuda.core.utils import FileStreamProgramCache + + # Cap fits two 2000-byte entries (raw payload only -- no per-entry + # framing) but not three. + cap = 5000 + root = tmp_path / "fc" + with FileStreamProgramCache(root, max_size_bytes=cap) as cache: + cache[b"a"] = _fake_object_code(b"A" * 2000, name="a") + time.sleep(0.02) + cache[b"b"] = _fake_object_code(b"B" * 2000, name="b") + path_a = cache._path_for_key(b"a") + assert path_a.exists(), "cap too small -- 'a' was evicted before the test ran" + + class _RaceCache(FileStreamProgramCache): + race_armed = True + + def _iter_entry_paths(self): + yield from super()._iter_entry_paths() + # Generator cleanup runs at StopIteration, between + # _enforce_size_cap's scan and its eviction loop. Fire the race + # here exactly once. + if _RaceCache.race_armed and path_a.exists(): + _RaceCache.race_armed = False + tmp = path_a.parent / "_inflight" + tmp.write_bytes(b"\x80\x05fresh-by-other-writer-" * 32) + _os.replace(tmp, path_a) + + with _RaceCache(root, max_size_bytes=cap) as cache: + # Trigger eviction by adding 'c'; eviction's scan exhausts our + # racing generator, the cleanup fires, then the eviction loop's + # re-stat sees the new stat and the production stat-guard MUST + # refuse to unlink path_a. + time.sleep(0.02) + cache[b"c"] = _fake_object_code(b"C" * 2000, name="c") + + # The race-injected fresh file must survive: production stat-guard worked. + assert path_a.exists(), "stat guard failed -- evicted a concurrently-replaced file" + assert path_a.read_bytes().startswith(b"\x80\x05fresh-by-other-writer-") + + +def test_filestream_cache_size_cap_counts_tmp_files(tmp_path): + """Surviving temp files occupy disk too; the soft cap must include + them, otherwise an attacker (or a flurry of crashed writers) could + inflate disk usage well past max_size_bytes.""" + from cuda.core.utils import FileStreamProgramCache + + cap = 4000 + root = tmp_path / "fc" + with FileStreamProgramCache(root, max_size_bytes=cap) as cache: + cache[b"a"] = _fake_object_code(b"A" * 1500, name="a") + time.sleep(0.02) + cache[b"b"] = _fake_object_code(b"B" * 1500, name="b") + # Plant a young temp file that pushes total over the cap. + young_tmp = root / "tmp" / "entry-leftover" + young_tmp.write_bytes(b"X" * 2500) + + with FileStreamProgramCache(root, max_size_bytes=cap) as cache: + # New write triggers _enforce_size_cap; 'a' must be evicted because + # the temp file's bytes count toward the cap now. + time.sleep(0.02) + cache[b"c"] = _fake_object_code(b"C" * 200, name="c") + assert b"a" not in cache + assert b"c" in cache + + +def test_filestream_cache_handles_long_keys(tmp_path): + """Arbitrary-length keys must not overflow per-component filename limits. + The filename is a fixed-length 256-bit blake2b digest; key uniqueness + relies on the digest's collision resistance.""" + from cuda.core.utils import FileStreamProgramCache + + long_bytes_key = b"x" * 4096 + long_str_key = "y" * 4096 + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache[long_bytes_key] = _fake_object_code(b"b", name="nb") + cache[long_str_key] = _fake_object_code(b"s", name="ns") + assert long_bytes_key in cache + assert long_str_key in cache + assert cache[long_bytes_key] == b"b" + assert cache[long_str_key] == b"s" + + +def test_filestream_cache_accepts_str_keys(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache["my-key"] = _fake_object_code(b"v") + assert "my-key" in cache + assert b"my-key" in cache + + +def test_filestream_cache_size_cap_evicts_oldest(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + # Big payloads, small cap; after the third entry, the cap is exceeded and + # the entry with the oldest atime (a) must be evicted. + cap = 3000 + root = tmp_path / "fc" + with FileStreamProgramCache(root, max_size_bytes=cap) as cache: + cache[b"a"] = b"A" * 2000 + time.sleep(0.02) + cache[b"b"] = b"B" * 2000 + time.sleep(0.02) + cache[b"c"] = b"C" * 2000 + + with FileStreamProgramCache(root, max_size_bytes=cap) as cache: + assert b"a" not in cache + assert b"c" in cache + + +def test_filestream_cache_atime_lru_promotes_recently_read(tmp_path): + """Eviction sorts by ``st_atime``: an entry that was recently READ + survives even if it was the first one WRITTEN. This is the practical + win over mtime-based FIFO eviction.""" + from cuda.core.utils import FileStreamProgramCache + + # Cap fits two payloads but not three; the third write triggers exactly + # one eviction, and we want it to evict 'b' (oldest atime), not 'a'. + cap = 5000 + root = tmp_path / "fc" + with FileStreamProgramCache(root, max_size_bytes=cap) as cache: + cache[b"a"] = b"A" * 2000 # oldest write + time.sleep(0.02) + cache[b"b"] = b"B" * 2000 # newer write + time.sleep(0.02) + # Bump 'a' to most-recently-used. + _ = cache[b"a"] + time.sleep(0.02) + cache[b"c"] = b"C" * 2000 # 6000 total -> one eviction + + with FileStreamProgramCache(root, max_size_bytes=cap) as cache: + # 'b' is now the oldest atime -- evicted. + assert b"b" not in cache + # 'a' was read after 'b' so its atime is newer -- survives. + assert b"a" in cache + assert b"c" in cache + + +def test_filestream_cache_unbounded_by_default(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache: + for i in range(20): + cache[f"k{i}".encode()] = _fake_object_code(b"X" * 1024, name=f"n{i}") + assert len(cache) == 20 + + +def test_filestream_cache_wipes_on_schema_mismatch(tmp_path): + """A cache written with an older schema must be wiped on open, not + silently mixed with a newer format.""" + from cuda.core.utils import FileStreamProgramCache + + root = tmp_path / "fc" + with FileStreamProgramCache(root) as cache: + cache[b"k"] = _fake_object_code(b"old-payload") + # Simulate an older schema by rewriting the version marker. + (root / "SCHEMA_VERSION").write_text("0") + + with FileStreamProgramCache(root) as cache: + assert len(cache) == 0 + assert b"k" not in cache + # Marker should be back at the current version. + assert (root / "SCHEMA_VERSION").read_text().strip() != "0" + + +def test_filestream_cache_schema_version_encodes_key_schema(tmp_path, monkeypatch): + """As with the SQLite backend, bumping ``_KEY_SCHEMA_VERSION`` alone + must invalidate the on-disk cache so orphaned entries from the old + key-hash format do not linger after an upgrade.""" + from cuda.core.utils import FileStreamProgramCache, _program_cache + + root = tmp_path / "fc" + with FileStreamProgramCache(root) as cache: + cache[b"k"] = _fake_object_code(b"old-payload") + path = cache._path_for_key(b"k") + assert path.exists() + + monkeypatch.setattr(_program_cache, "_KEY_SCHEMA_VERSION", _program_cache._KEY_SCHEMA_VERSION + 1) + monkeypatch.setattr( + _program_cache, + "_FILESTREAM_SCHEMA_VERSION", + f"{_program_cache._FILESTREAM_BACKEND_SCHEMA}.{_program_cache._KEY_SCHEMA_VERSION}", + ) + + with FileStreamProgramCache(root) as cache: + assert len(cache) == 0 + assert b"k" not in cache + assert not path.exists() + + +# --------------------------------------------------------------------------- +# End-to-end: real NVRTC compilation through persistent cache +# --------------------------------------------------------------------------- + + +def test_cache_roundtrip_with_real_compilation(tmp_path, init_cuda): + """Compile a real kernel, persist its bytes, reopen the cache, and + reconstruct an ``ObjectCode`` from the cached bytes. + + Exercises the full user workflow: NVRTC compile → persistent store → + fresh process (simulated by closing and reopening the cache handle) + → driver-side module load from an ObjectCode rebuilt from the + cached bytes. + """ + from cuda.core import Program, ProgramOptions + from cuda.core._module import Kernel, ObjectCode + from cuda.core.utils import FileStreamProgramCache, make_program_cache_key + + code = 'extern "C" __global__ void my_kernel() {}' + code_type = "c++" + target_type = "cubin" + options = ProgramOptions(name="cached_kernel") + + program = Program(code, code_type, options=options) + try: + compiled = program.compile(target_type) + finally: + program.close() + + key = make_program_cache_key( + code=code, + code_type=code_type, + options=options, + target_type=target_type, + ) + + # First "process": compile and store the binary. + with FileStreamProgramCache(tmp_path / "fc") as cache: + assert key not in cache + cache[key] = compiled # extracts bytes(compiled.code) + + # Second "process": reopen, retrieve bytes, rebuild ObjectCode. + with FileStreamProgramCache(tmp_path / "fc") as cache: + assert key in cache + cached_bytes = cache[key] + + assert cached_bytes == bytes(compiled.code) + rebuilt = ObjectCode._init(cached_bytes, target_type, name="cached_kernel") + # The reconstructed ObjectCode must still be usable against the driver. + assert isinstance(rebuilt.get_kernel("my_kernel"), Kernel) + + +# --------------------------------------------------------------------------- +# InMemoryProgramCache +# --------------------------------------------------------------------------- + + +def test_inmemory_cache_empty_on_create(): + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache() + assert len(cache) == 0 + assert b"nope" not in cache + with pytest.raises(KeyError): + cache[b"nope"] + + +def test_inmemory_cache_roundtrip_object_code(): + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache() + cache[b"k1"] = _fake_object_code(b"v1", name="x") + assert b"k1" in cache + assert cache[b"k1"] == b"v1" + + +def test_inmemory_cache_accepts_bytes_directly(): + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache() + cache[b"k"] = b"raw-payload" + assert cache[b"k"] == b"raw-payload" + + +def test_inmemory_cache_accepts_bytearray_and_memoryview(): + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache() + cache[b"a"] = bytearray(b"ba-payload") + cache[b"b"] = memoryview(b"mv-payload") + assert cache[b"a"] == b"ba-payload" + assert cache[b"b"] == b"mv-payload" + + +def test_inmemory_cache_accepts_path_backed_object_code(tmp_path): + """Path-backed ObjectCode should be read at write time so the cache + holds the bytes, not a path -- mirrors FileStreamProgramCache.""" + from cuda.core._module import ObjectCode + from cuda.core.utils import InMemoryProgramCache + + payload = b"\x7fELF\x02\x01fake-cubin-bytes" + src = tmp_path / "src.cubin" + src.write_bytes(payload) + obj = ObjectCode.from_cubin(str(src), name="x") + + cache = InMemoryProgramCache() + cache[b"k"] = obj + # If we mutate the source file, the cached entry must be unchanged. + src.write_bytes(b"changed") + assert cache[b"k"] == payload + + +def test_inmemory_cache_str_and_bytes_keys_alias(): + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache() + cache["k"] = b"v" + assert cache[b"k"] == b"v" + assert b"k" in cache + assert "k" in cache + + +def test_inmemory_cache_rejects_non_str_non_bytes_key(): + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache() + with pytest.raises(TypeError): + cache[123] = b"v" + with pytest.raises(TypeError): + cache[123] + with pytest.raises(TypeError): + 123 in cache # noqa: B015 + + +def test_inmemory_cache_rejects_non_bytes_non_object_code_value(): + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache() + with pytest.raises(TypeError): + cache[b"k"] = "a string" + + +def test_inmemory_cache_delete(): + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache() + cache[b"k"] = _fake_object_code() + del cache[b"k"] + assert b"k" not in cache + with pytest.raises(KeyError): + del cache[b"k"] + + +def test_inmemory_cache_len_counts_all(): + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache() + cache[b"a"] = _fake_object_code(b"1") + cache[b"b"] = _fake_object_code(b"2") + cache[b"c"] = _fake_object_code(b"3") + assert len(cache) == 3 + + +def test_inmemory_cache_clear(): + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache() + cache[b"a"] = _fake_object_code() + cache[b"b"] = _fake_object_code() + cache.clear() + assert len(cache) == 0 + assert b"a" not in cache + + +def test_inmemory_cache_get_returns_default_on_miss(): + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache() + assert cache.get(b"missing") is None + assert cache.get(b"missing", b"fallback") == b"fallback" + + +def test_inmemory_cache_get_returns_value_on_hit(): + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache() + cache[b"k"] = b"v" + assert cache.get(b"k") == b"v" + + +def test_inmemory_cache_update_accepts_mapping_and_pairs(): + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache() + cache.update({b"a": b"v-a", b"b": b"v-b"}) + assert cache[b"a"] == b"v-a" + assert cache[b"b"] == b"v-b" + + cache2 = InMemoryProgramCache() + cache2.update([(b"x", b"v-x"), (b"y", b"v-y")]) + assert cache2[b"x"] == b"v-x" + assert cache2[b"y"] == b"v-y" + + +def test_inmemory_cache_overwrite_replaces_value_and_updates_size(): + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache(max_size_bytes=1000) + cache[b"k"] = b"x" * 100 + assert cache[b"k"] == b"x" * 100 + cache[b"k"] = b"y" * 50 + assert cache[b"k"] == b"y" * 50 + assert len(cache) == 1 + # Internal accounting should track the replacement, not double-count. + assert cache._total_bytes == 50 + + +def test_inmemory_cache_rejects_negative_size_cap(): + from cuda.core.utils import InMemoryProgramCache + + with pytest.raises(ValueError): + InMemoryProgramCache(max_size_bytes=-1) + + +def test_inmemory_cache_size_cap_evicts_oldest(): + from cuda.core.utils import InMemoryProgramCache + + # Cap fits two 100-byte entries; the third write evicts the LRU one. + cache = InMemoryProgramCache(max_size_bytes=250) + cache[b"a"] = b"a" * 100 + cache[b"b"] = b"b" * 100 + cache[b"c"] = b"c" * 100 # forces eviction of b"a" + assert b"a" not in cache + assert b"b" in cache + assert b"c" in cache + + +def test_inmemory_cache_read_promotes_lru(): + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache(max_size_bytes=250) + cache[b"a"] = b"a" * 100 + cache[b"b"] = b"b" * 100 + # Read of a promotes it past b in LRU order. + _ = cache[b"a"] + cache[b"c"] = b"c" * 100 # b is now the oldest; b should evict + assert b"a" in cache + assert b"b" not in cache + assert b"c" in cache + + +def test_inmemory_cache_contains_does_not_promote_lru(): + """``__contains__`` is read-only and must not shift LRU order, mirroring + FileStream semantics.""" + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache(max_size_bytes=250) + cache[b"a"] = b"a" * 100 + cache[b"b"] = b"b" * 100 + # Membership check must NOT promote a in LRU. + assert b"a" in cache + cache[b"c"] = b"c" * 100 # a is still oldest -> a evicts + assert b"a" not in cache + assert b"b" in cache + assert b"c" in cache + + +def test_inmemory_cache_oversized_write_evicts_itself(): + """A single write larger than max_size_bytes does not survive its own + size-cap pass -- mirrors FileStreamProgramCache.""" + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache(max_size_bytes=10) + cache[b"big"] = b"x" * 100 + assert b"big" not in cache + assert len(cache) == 0 + + +def test_inmemory_cache_unbounded_when_max_size_none(): + from cuda.core.utils import InMemoryProgramCache + + cache = InMemoryProgramCache() + for i in range(50): + cache[f"k{i}".encode()] = b"x" * 1024 + assert len(cache) == 50 diff --git a/cuda_core/tests/test_program_cache_multiprocess.py b/cuda_core/tests/test_program_cache_multiprocess.py new file mode 100644 index 00000000000..a7a4038dd6e --- /dev/null +++ b/cuda_core/tests/test_program_cache_multiprocess.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +"""Multiprocess stress tests for FileStreamProgramCache. + +These run without a GPU. They exercise the atomic-rename write path from +multiple processes launched via ``multiprocessing.get_context("spawn")``. +""" + +from __future__ import annotations + +import multiprocessing as _mp + + +def _worker_write(root: str, key: bytes, payload: bytes) -> None: + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(root) as cache: + cache[key] = payload + + +def _worker_write_many(root: str, base: int, n: int) -> None: + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(root) as cache: + for i in range(n): + key = f"proc-{base}-key-{i}".encode() + cache[key] = f"payload-{base}-{i}".encode() + + +def _worker_reader(root: str, key: bytes, rounds: int, result_queue) -> None: + from cuda.core.utils import FileStreamProgramCache + + hits = 0 + for _ in range(rounds): + with FileStreamProgramCache(root) as cache: + got = cache.get(key) + if got is not None: + hits += 1 + result_queue.put(hits) + + +def test_concurrent_writers_same_key_no_corruption(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + root = str(tmp_path / "fc") + ctx = _mp.get_context("spawn") + procs = [ + ctx.Process( + target=_worker_write, + args=(root, b"shared", f"v{i}".encode() * 64), + ) + for i in range(6) + ] + for p in procs: + p.start() + for p in procs: + p.join(timeout=60) + assert p.exitcode == 0, f"worker exited with {p.exitcode}" + + with FileStreamProgramCache(root) as cache: + # At least one writer must have succeeded; on Windows some writes + # may silently fail due to PermissionError on os.replace. + got = cache.get(b"shared") + assert got is not None, "no writer succeeded" + assert got.startswith(b"v") + + +def test_concurrent_writers_distinct_keys_all_survive(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + root = str(tmp_path / "fc") + n_procs = 4 + per_proc = 25 + ctx = _mp.get_context("spawn") + procs = [ctx.Process(target=_worker_write_many, args=(root, base, per_proc)) for base in range(n_procs)] + for p in procs: + p.start() + for p in procs: + p.join(timeout=60) + assert p.exitcode == 0 + + with FileStreamProgramCache(root) as cache: + for base in range(n_procs): + for i in range(per_proc): + key = f"proc-{base}-key-{i}".encode() + assert key in cache + + +def test_concurrent_reader_never_sees_torn_file(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + root = str(tmp_path / "fc") + # Seed 'k' so the reader can hit; the writer writes unrelated keys so 'k' + # is never overwritten while the reader is active. + with FileStreamProgramCache(root) as cache: + cache[b"k"] = b"seed" * 256 + + ctx = _mp.get_context("spawn") + queue = ctx.Queue() + writer = ctx.Process(target=_worker_write_many, args=(root, 99, 50)) + reader = ctx.Process(target=_worker_reader, args=(root, b"k", 200, queue)) + reader.start() + writer.start() + writer.join(timeout=60) + reader.join(timeout=60) + assert writer.exitcode == 0 + assert reader.exitcode == 0 + hits = queue.get(timeout=5) + # 'k' was never overwritten, so every read must hit. + assert hits == 200 + + +def _worker_size_cap_writer(root: str, prefix: bytes, payload: bytes, count: int, max_size_bytes: int) -> None: + """Hammer a small-cap cache with churning writes so eviction fires often.""" + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(root, max_size_bytes=max_size_bytes) as cache: + for i in range(count): + cache[prefix + str(i).encode()] = payload + + +def _worker_size_cap_rewriter(root: str, key: bytes, payload: bytes, max_size_bytes: int, done_event) -> None: + """Repeatedly rewrite ``key`` with a fresh value until ``done_event`` fires; + afterwards land one final uncontested write so the test's end-state assertion + isn't sensitive to scheduler-dependent interleaving.""" + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(root, max_size_bytes=max_size_bytes) as cache: + i = 0 + while not done_event.is_set(): + cache[key] = payload + str(i).encode() + i += 1 + cache[key] = payload + b"final" + + +def test_concurrent_eviction_does_not_delete_replaced_file(tmp_path): + """Eviction is stat-guarded: while one process is evicting an entry to + bring the cache under its size cap, another process may have already + ``os.replace``-d a fresh value into the same path. The evictor must + refuse to unlink in that case, otherwise the racing rewriter's + just-committed entry vanishes.""" + from cuda.core.utils import FileStreamProgramCache + + root = str(tmp_path / "fc") + payload = b"X" * 2000 + cap = 5000 # fits 2 raw 2000B entries; the third write triggers eviction + + ctx = _mp.get_context("spawn") + done_event = ctx.Event() + rewriter = ctx.Process( + target=_worker_size_cap_rewriter, + args=(root, b"survivor", payload, cap, done_event), + ) + # Churning writer creates new keys faster than eviction can drain them, + # forcing _enforce_size_cap to consider 'survivor' for eviction many times. + churner = ctx.Process( + target=_worker_size_cap_writer, + args=(root, b"churn-", payload, 80, cap), + ) + rewriter.start() + churner.start() + churner.join(timeout=60) + done_event.set() + rewriter.join(timeout=60) + assert rewriter.exitcode == 0 + assert churner.exitcode == 0 + + # The rewriter's final uncontested write must survive: if eviction + # blindly unlinked replaced files, this entry would be gone. + with FileStreamProgramCache(root, max_size_bytes=cap) as cache: + got = cache.get(b"survivor") + assert got is not None, "rewriter's entry was evicted by racing churner" + assert got.endswith(b"final") diff --git a/cuda_core/tests/test_program_compile_cache.py b/cuda_core/tests/test_program_compile_cache.py new file mode 100644 index 00000000000..69f239a6f23 --- /dev/null +++ b/cuda_core/tests/test_program_compile_cache.py @@ -0,0 +1,277 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the ``Program.compile(cache=...)`` convenience integration.""" + +from __future__ import annotations + +import pytest + +from cuda.core import Program, ProgramOptions +from cuda.core import _program as _program_module +from cuda.core._module import ObjectCode +from cuda.core.utils import ( + FileStreamProgramCache, + make_program_cache_key, +) + + +class _RecordingCache: + """Minimal recording stub for the bytes-in / bytes-out cache protocol. + + Mirrors :class:`FileStreamProgramCache`'s contract: ``__setitem__`` + accepts bytes-like or :class:`ObjectCode` (extracts bytes), and + ``get`` returns the stored bytes (or ``None``). + + Intentionally does NOT subclass ``ProgramCacheResource`` -- the wrapper + should be duck-typed, so we test the duck-typed surface directly. + """ + + def __init__(self, preseed=None): + self._store: dict[bytes, bytes] = {} + for k, v in (preseed or {}).items(): + self._store[k] = self._extract(v) + self.get_calls: list[bytes] = [] + self.set_calls: list[tuple[bytes, bytes]] = [] + self.get_side_effect: BaseException | None = None + self.set_side_effect: BaseException | None = None + + @staticmethod + def _extract(value) -> bytes: + if isinstance(value, ObjectCode): + return bytes(value.code) + if isinstance(value, (bytes, bytearray, memoryview)): + return bytes(value) + raise TypeError(f"unexpected value type: {type(value).__name__}") + + def get(self, key, default=None): + self.get_calls.append(key) + if self.get_side_effect is not None: + raise self.get_side_effect + return self._store.get(key, default) + + def __setitem__(self, key, value): + data = self._extract(value) + self.set_calls.append((key, data)) + if self.set_side_effect is not None: + raise self.set_side_effect + self._store[key] = data + + +_KERNEL = 'extern "C" __global__ void k() {}' +_SENTINEL_BYTES = b"sentinel-cubin-bytes" + + +def _make_sentinel_object_code(): + """Construct a cache-safe ``ObjectCode`` that doesn't require compilation.""" + return ObjectCode._init(_SENTINEL_BYTES, "cubin", name="sentinel") + + +def test_cache_miss_runs_compile_then_stores(monkeypatch): + """On cache miss: get(key) once, _program_compile_uncached once, __setitem__ once.""" + sentinel = _make_sentinel_object_code() + program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80")) + + def _return_sentinel(_program, *_args, **_kwargs): + return sentinel + + monkeypatch.setattr(_program_module, "_program_compile_uncached", _return_sentinel) + cache = _RecordingCache() + + result = program.compile("cubin", cache=cache) + + # On miss the wrapper returns the freshly-compiled ObjectCode unchanged. + assert result is sentinel + assert len(cache.get_calls) == 1 + assert len(cache.set_calls) == 1 + # The cache stored the binary bytes extracted from the ObjectCode. + assert cache.set_calls[0][1] == _SENTINEL_BYTES + + +def test_cache_hit_returns_object_code_reconstructed_from_bytes(monkeypatch): + """On hit: get(key) returns bytes, the wrapper rebuilds an ObjectCode with + the same code_type and ProgramOptions.name. _program_compile_uncached is + NOT called and there is no __setitem__.""" + options = ProgramOptions(arch="sm_80", name="my_program") + program = Program(_KERNEL, "c++", options) + key = make_program_cache_key( + code=_KERNEL, + code_type="c++", + options=options, + target_type="cubin", + ) + + def _explode(_program, *_args, **_kwargs): + raise AssertionError("_program_compile_uncached must not be called on cache hit") + + monkeypatch.setattr(_program_module, "_program_compile_uncached", _explode) + cache = _RecordingCache(preseed={key: _SENTINEL_BYTES}) + + result = program.compile("cubin", cache=cache) + + assert isinstance(result, ObjectCode) + assert bytes(result.code) == _SENTINEL_BYTES + assert result.code_type == "cubin" + assert result.name == "my_program" + assert cache.get_calls == [key] + assert cache.set_calls == [] + + +def test_name_expressions_affects_cache_key(monkeypatch): + """Different ``name_expressions`` must produce different cache keys.""" + sentinel = _make_sentinel_object_code() + program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80")) + monkeypatch.setattr( + _program_module, "_program_compile_uncached", lambda _self, *_args, **_kwargs: sentinel + ) + cache = _RecordingCache() + + program.compile("cubin", name_expressions=("foo",), cache=cache) + program.compile("cubin", name_expressions=("foo", "bar"), cache=cache) + + assert len(cache.get_calls) == 2 + assert cache.get_calls[0] != cache.get_calls[1] + + +def test_cache_raises_for_extra_digest_required_option(): + """Options that require an ``extra_digest`` propagate a ValueError.""" + program = Program( + _KERNEL, + "c++", + ProgramOptions(arch="sm_80", include_path=["/some/dir"]), + ) + cache = _RecordingCache() + + with pytest.raises(ValueError, match="extra_digest"): + program.compile("cubin", cache=cache) + + assert cache.get_calls == [] + assert cache.set_calls == [] + + +def test_cache_raises_for_side_effect_option(tmp_path): + """Options with compile-time side effects can't be cached.""" + program = Program( + _KERNEL, + "c++", + ProgramOptions(arch="sm_80", create_pch=str(tmp_path / "k.pch")), + ) + cache = _RecordingCache() + + with pytest.raises(ValueError): + program.compile("cubin", cache=cache) + + assert cache.get_calls == [] + assert cache.set_calls == [] + + +def test_cache_miss_compile_failure_does_not_store(monkeypatch): + """If _program_compile_uncached raises after a miss, the cache is not written.""" + program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80")) + + def _boom(_program, *_args, **_kwargs): + raise RuntimeError("compile failed") + + monkeypatch.setattr(_program_module, "_program_compile_uncached", _boom) + cache = _RecordingCache() + + with pytest.raises(RuntimeError, match="compile failed"): + program.compile("cubin", cache=cache) + + assert len(cache.get_calls) == 1 + assert cache.set_calls == [] + + +def test_cache_read_exception_propagates(monkeypatch): + """Exceptions from cache.get propagate and compile is not invoked.""" + program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80")) + + def _explode(_program, *_args, **_kwargs): + raise AssertionError("_program_compile_uncached must not be called when get raises") + + monkeypatch.setattr(_program_module, "_program_compile_uncached", _explode) + cache = _RecordingCache() + cache.get_side_effect = RuntimeError("broken") + + with pytest.raises(RuntimeError, match="broken"): + program.compile("cubin", cache=cache) + + +def test_cache_write_exception_propagates(monkeypatch): + """Exceptions from cache.__setitem__ propagate after compile runs.""" + sentinel = _make_sentinel_object_code() + program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80")) + monkeypatch.setattr( + _program_module, "_program_compile_uncached", lambda _self, *_args, **_kwargs: sentinel + ) + cache = _RecordingCache() + cache.set_side_effect = RuntimeError("disk full") + + with pytest.raises(RuntimeError, match="disk full"): + program.compile("cubin", cache=cache) + + assert len(cache.get_calls) == 1 + assert len(cache.set_calls) == 1 + + +def test_no_cache_kwarg_does_not_derive_key(monkeypatch): + """Without cache=, no cache-module functions run; compile goes straight through.""" + sentinel = _make_sentinel_object_code() + program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80")) + monkeypatch.setattr( + _program_module, "_program_compile_uncached", lambda _self, *_args, **_kwargs: sentinel + ) + + # If the implementation accidentally derived a key, it would call + # make_program_cache_key. Replace it with a raising stub to catch that. + from cuda.core.utils import _program_cache as _pc + + def _cache_path_must_not_run(*_args, **_kwargs): + raise AssertionError("cache path must not run when cache= is omitted") + + monkeypatch.setattr(_pc, "make_program_cache_key", _cache_path_must_not_run) + + result = program.compile("cubin") + + assert result is sentinel + + +def test_filestream_hit_returns_byte_equal_object_code(init_cuda, tmp_path): + """End-to-end: real compile, FileStreamProgramCache roundtrip, second + compile returns an ObjectCode whose bytes match the first compile.""" + program = Program(_KERNEL, "c++", ProgramOptions(arch="sm_80")) + cache_dir = tmp_path / "fc" + + with FileStreamProgramCache(cache_dir) as cache: + first = program.compile("cubin", cache=cache) + + with FileStreamProgramCache(cache_dir) as cache: + second = program.compile("cubin", cache=cache) + + assert bytes(second.code) == bytes(first.code) + assert second.code_type == first.code_type + + +def test_cache_kwarg_roundtrip_across_reopen(init_cuda, tmp_path, monkeypatch): + """Compile with cache= in one 'session', reopen and fetch via cache= again.""" + init_args = (_KERNEL, "c++", ProgramOptions(arch="sm_80", name="cached_kernel")) + cache_path = tmp_path / "fc" + + with FileStreamProgramCache(cache_path) as cache: + program = Program(*init_args) + first = program.compile("cubin", cache=cache) + + # Fresh process / fresh Program and cache-handle -- same cache path. + with FileStreamProgramCache(cache_path) as cache: + program = Program(*init_args) + + # If the reopened cache misses, the wrapper would fall back to + # _program_compile_uncached -- replace it with a raising stub so + # the test can only succeed via a hit. + def _must_not_recompile(_program, *_args, **_kwargs): + raise AssertionError("cache miss: reopened cache didn't serve entry") + + monkeypatch.setattr(_program_module, "_program_compile_uncached", _must_not_recompile) + second = program.compile("cubin", cache=cache) + + assert bytes(second.code) == bytes(first.code) + assert second.name == "cached_kernel"