@@ -483,11 +483,12 @@ def make_program_cache_key(
483483 One of ``"ptx"``, ``"cubin"``, ``"ltoir"``.
484484 name_expressions:
485485 Optional iterable of mangled-name lookups. Order is not significant.
486- Elements may be ``str``, ``bytes``, or ``bytearray``; ``"foo"`` and
487- ``b"foo"`` produce distinct keys because ``Program.compile`` records
488- the original Python object as the ``ObjectCode.symbol_mapping`` key,
489- and ``get_kernel`` lookups must use the same type the cache key
490- recorded.
486+ Elements may be ``str`` or ``bytes``; ``"foo"`` and ``b"foo"`` produce
487+ distinct keys because ``Program.compile`` records the original Python
488+ object as the ``ObjectCode.symbol_mapping`` key, and ``get_kernel``
489+ lookups must use the same type the cache key recorded. ``bytearray``
490+ is rejected because ``Program.compile`` stores each element as a
491+ dict key and ``bytearray`` is unhashable.
491492 extra_digest:
492493 Caller-supplied bytes mixed into the key. Required whenever
493494 :class:`cuda.core.ProgramOptions` sets any option that pulls in
@@ -681,15 +682,25 @@ def make_program_cache_key(
681682 # because Program.compile records the original Python object as the
682683 # ObjectCode.symbol_mapping key (_program.pyx:759), so a cached
683684 # ObjectCode whose mapping-key type differs from what the caller's
684- # later ``get_kernel`` passes would silently miss.
685+ # later ``get_kernel`` passes would silently miss. Reject
686+ # ``bytearray`` because Program.compile also uses the raw element as a
687+ # dict key -- bytearray is unhashable, so a cache miss would compile
688+ # then crash in ``symbol_mapping[n] = ...``. Accepting it here would
689+ # let the cache serve hits for inputs the uncached path can't handle.
685690 if backend == "nvrtc" :
686691
687692 def _tag_name (n ):
688- if isinstance (n , ( bytes , bytearray ) ):
689- return b"b:" + bytes ( n )
693+ if isinstance (n , bytes ):
694+ return b"b:" + n
690695 if isinstance (n , str ):
691696 return b"s:" + n .encode ("utf-8" )
692- raise TypeError (f"name_expressions elements must be str, bytes, or bytearray; got { type (n ).__name__ } " )
697+ if isinstance (n , bytearray ):
698+ raise TypeError (
699+ "name_expressions elements must be str or bytes; "
700+ "bytearray is not accepted because Program.compile uses "
701+ "each element as a dict key and bytearray is unhashable."
702+ )
703+ raise TypeError (f"name_expressions elements must be str or bytes; got { type (n ).__name__ } " )
693704
694705 names = tuple (sorted (_tag_name (n ) for n in name_expressions ))
695706 else :
0 commit comments