Skip to content

Commit ca31ede

Browse files
committed
fix(core.utils): reject bytearray in NVRTC name_expressions cache keys
1 parent d2643af commit ca31ede

2 files changed

Lines changed: 30 additions & 9 deletions

File tree

cuda_core/cuda/core/utils/_program_cache.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -483,11 +483,12 @@ def make_program_cache_key(
483483
One of ``"ptx"``, ``"cubin"``, ``"ltoir"``.
484484
name_expressions:
485485
Optional iterable of mangled-name lookups. Order is not significant.
486-
Elements may be ``str``, ``bytes``, or ``bytearray``; ``"foo"`` and
487-
``b"foo"`` produce distinct keys because ``Program.compile`` records
488-
the original Python object as the ``ObjectCode.symbol_mapping`` key,
489-
and ``get_kernel`` lookups must use the same type the cache key
490-
recorded.
486+
Elements may be ``str`` or ``bytes``; ``"foo"`` and ``b"foo"`` produce
487+
distinct keys because ``Program.compile`` records the original Python
488+
object as the ``ObjectCode.symbol_mapping`` key, and ``get_kernel``
489+
lookups must use the same type the cache key recorded. ``bytearray``
490+
is rejected because ``Program.compile`` stores each element as a
491+
dict key and ``bytearray`` is unhashable.
491492
extra_digest:
492493
Caller-supplied bytes mixed into the key. Required whenever
493494
:class:`cuda.core.ProgramOptions` sets any option that pulls in
@@ -681,15 +682,25 @@ def make_program_cache_key(
681682
# because Program.compile records the original Python object as the
682683
# ObjectCode.symbol_mapping key (_program.pyx:759), so a cached
683684
# ObjectCode whose mapping-key type differs from what the caller's
684-
# later ``get_kernel`` passes would silently miss.
685+
# later ``get_kernel`` passes would silently miss. Reject
686+
# ``bytearray`` because Program.compile also uses the raw element as a
687+
# dict key -- bytearray is unhashable, so a cache miss would compile
688+
# then crash in ``symbol_mapping[n] = ...``. Accepting it here would
689+
# let the cache serve hits for inputs the uncached path can't handle.
685690
if backend == "nvrtc":
686691

687692
def _tag_name(n):
688-
if isinstance(n, (bytes, bytearray)):
689-
return b"b:" + bytes(n)
693+
if isinstance(n, bytes):
694+
return b"b:" + n
690695
if isinstance(n, str):
691696
return b"s:" + n.encode("utf-8")
692-
raise TypeError(f"name_expressions elements must be str, bytes, or bytearray; got {type(n).__name__}")
697+
if isinstance(n, bytearray):
698+
raise TypeError(
699+
"name_expressions elements must be str or bytes; "
700+
"bytearray is not accepted because Program.compile uses "
701+
"each element as a dict key and bytearray is unhashable."
702+
)
703+
raise TypeError(f"name_expressions elements must be str or bytes; got {type(n).__name__}")
693704

694705
names = tuple(sorted(_tag_name(n) for n in name_expressions))
695706
else:

cuda_core/tests/test_program_cache.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,16 @@ def test_make_program_cache_key_name_expressions_str_bytes_distinct():
306306
assert _make_key(name_expressions=("foo",)) != _make_key(name_expressions=(b"foo",))
307307

308308

309+
def test_make_program_cache_key_rejects_bytearray_in_name_expressions():
310+
"""``bytearray`` is unhashable, and ``Program.compile`` stores each
311+
element of ``name_expressions`` as a dict key
312+
(``symbol_mapping[n] = ...`` in ``_program.pyx``). Accepting it in the
313+
cache helper would mean hits served for inputs the uncached compile
314+
path crashes on -- so reject up front."""
315+
with pytest.raises(TypeError, match="bytearray"):
316+
_make_key(name_expressions=("ok", bytearray(b"bad")))
317+
318+
309319
@pytest.mark.parametrize(
310320
"code_type, target_type",
311321
[

0 commit comments

Comments
 (0)