@@ -365,8 +365,7 @@ cdef class StridedMemoryView:
365365 if self .dl_tensor != NULL :
366366 self ._dtype = dtype_dlpack_to_numpy(& self .dl_tensor.dtype)
367367 elif self .metadata is not None :
368- # TODO: this only works for built-in numeric types
369- self ._dtype = _typestr2dtype[self .metadata[" typestr" ]]
368+ self ._dtype = _typestr2dtype(self .metadata[" typestr" ])
370369 return self ._dtype
371370
372371
@@ -486,25 +485,14 @@ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None):
486485 return buf
487486
488487
489- _builtin_numeric_dtypes = [
490- numpy.dtype(" uint8" ),
491- numpy.dtype(" uint16" ),
492- numpy.dtype(" uint32" ),
493- numpy.dtype(" uint64" ),
494- numpy.dtype(" int8" ),
495- numpy.dtype(" int16" ),
496- numpy.dtype(" int32" ),
497- numpy.dtype(" int64" ),
498- numpy.dtype(" float16" ),
499- numpy.dtype(" float32" ),
500- numpy.dtype(" float64" ),
501- numpy.dtype(" complex64" ),
502- numpy.dtype(" complex128" ),
503- numpy.dtype(" bool" ),
504- ]
505- # Doing it once to avoid repeated overhead
506- _typestr2dtype = {dtype.str: dtype for dtype in _builtin_numeric_dtypes}
507- _typestr2itemsize = {dtype.str: dtype.itemsize for dtype in _builtin_numeric_dtypes}
488+ @functools.lru_cache
489+ def _typestr2dtype (str typestr ):
490+ return numpy.dtype(typestr)
491+
492+
493+ @functools.lru_cache
494+ def _typestr2itemsize (str typestr ):
495+ return _typestr2dtype(typestr).itemsize
508496
509497
510498cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
@@ -664,7 +652,7 @@ cdef _StridedLayout layout_from_cai(object metadata):
664652 cdef _StridedLayout layout = _StridedLayout.__new__ (_StridedLayout)
665653 cdef object shape = metadata[" shape" ]
666654 cdef object strides = metadata.get(" strides" )
667- cdef int itemsize = _typestr2itemsize[ metadata[" typestr" ]]
655+ cdef int itemsize = _typestr2itemsize( metadata[" typestr" ])
668656 layout.init_from_tuple(shape, strides, itemsize, True )
669657 return layout
670658
0 commit comments