Skip to content

Commit b4e3d1a

Browse files
feat: improve structured dtype array support in StridedMemoryView (NVIDIA#1425) (NVIDIA#1472)
Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com>
1 parent 7e3574b commit b4e3d1a

2 files changed

Lines changed: 27 additions & 22 deletions

File tree

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,7 @@ cdef class StridedMemoryView:
365365
if self.dl_tensor != NULL:
366366
self._dtype = dtype_dlpack_to_numpy(&self.dl_tensor.dtype)
367367
elif self.metadata is not None:
368-
# TODO: this only works for built-in numeric types
369-
self._dtype = _typestr2dtype[self.metadata["typestr"]]
368+
self._dtype = _typestr2dtype(self.metadata["typestr"])
370369
return self._dtype
371370

372371

@@ -486,25 +485,14 @@ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None):
486485
return buf
487486

488487

489-
_builtin_numeric_dtypes = [
490-
numpy.dtype("uint8"),
491-
numpy.dtype("uint16"),
492-
numpy.dtype("uint32"),
493-
numpy.dtype("uint64"),
494-
numpy.dtype("int8"),
495-
numpy.dtype("int16"),
496-
numpy.dtype("int32"),
497-
numpy.dtype("int64"),
498-
numpy.dtype("float16"),
499-
numpy.dtype("float32"),
500-
numpy.dtype("float64"),
501-
numpy.dtype("complex64"),
502-
numpy.dtype("complex128"),
503-
numpy.dtype("bool"),
504-
]
505-
# Doing it once to avoid repeated overhead
506-
_typestr2dtype = {dtype.str: dtype for dtype in _builtin_numeric_dtypes}
507-
_typestr2itemsize = {dtype.str: dtype.itemsize for dtype in _builtin_numeric_dtypes}
488+
@functools.lru_cache
489+
def _typestr2dtype(str typestr):
490+
return numpy.dtype(typestr)
491+
492+
493+
@functools.lru_cache
494+
def _typestr2itemsize(str typestr):
495+
return _typestr2dtype(typestr).itemsize
508496

509497

510498
cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
@@ -664,7 +652,7 @@ cdef _StridedLayout layout_from_cai(object metadata):
664652
cdef _StridedLayout layout = _StridedLayout.__new__(_StridedLayout)
665653
cdef object shape = metadata["shape"]
666654
cdef object strides = metadata.get("strides")
667-
cdef int itemsize = _typestr2itemsize[metadata["typestr"]]
655+
cdef int itemsize = _typestr2itemsize(metadata["typestr"])
668656
layout.init_from_tuple(shape, strides, itemsize, True)
669657
return layout
670658

cuda_core/tests/test_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,20 @@ def test_from_buffer_with_non_power_of_two_itemsize():
446446
buffer = dev.memory_resource.allocate(required_size)
447447
view = StridedMemoryView.from_buffer(buffer, shape=shape, strides=layout.strides, dtype=dtype, is_readonly=True)
448448
assert view.dtype == dtype
449+
450+
451+
def test_struct_array():
452+
cp = pytest.importorskip("cupy")
453+
454+
x = np.array([(1.0, 2), (2.0, 3)], dtype=[("array1", np.float64), ("array2", np.int64)])
455+
456+
y = cp.empty(2, dtype=x.dtype)
457+
y.set(x)
458+
459+
smv = StridedMemoryView.from_cuda_array_interface(y, stream_ptr=0)
460+
assert smv.size * smv.dtype.itemsize == x.nbytes
461+
assert smv.size == x.size
462+
assert smv.shape == x.shape
463+
# full dtype information doesn't seem to be preserved due to use of type strings,
464+
# which are lossy, e.g., dtype([("a", "int")]).str == "V8"
465+
assert smv.dtype == np.dtype(f"V{x.itemsize}")

0 commit comments

Comments
 (0)