Skip to content

Commit 9b301ba

Browse files
authored
feat: allow constructing SMV from numpy arrays (#1428)
* feat: add `StridedMemoryView.from_array_interface` * test: add test demonstrating the lazy failure with invalid strides * fix: ignore specific warnings in `StridedMemoryView.from_*` APIs * docs: add comment about `DeprecationWarning`
1 parent 1dc914a commit 9b301ba

3 files changed

Lines changed: 95 additions & 3 deletions

File tree

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,9 @@ cdef class StridedMemoryView:
139139
def from_dlpack(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView:
140140
cdef StridedMemoryView buf
141141
with warnings.catch_warnings():
142-
warnings.simplefilter("ignore")
142+
# ignore the warning triggered by calling the constructor
143+
# inside the library we're allowed to do this
144+
warnings.simplefilter("ignore", DeprecationWarning)
143145
buf = cls()
144146
view_as_dlpack(obj, stream_ptr, buf)
145147
return buf
@@ -148,11 +150,20 @@ cdef class StridedMemoryView:
148150
def from_cuda_array_interface(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView:
149151
cdef StridedMemoryView buf
150152
with warnings.catch_warnings():
151-
warnings.simplefilter("ignore")
153+
warnings.simplefilter("ignore", DeprecationWarning)
152154
buf = cls()
153155
view_as_cai(obj, stream_ptr, buf)
154156
return buf
155157

158+
@classmethod
159+
def from_array_interface(cls, obj: object) -> StridedMemoryView:
160+
cdef StridedMemoryView buf
161+
with warnings.catch_warnings():
162+
warnings.simplefilter("ignore", DeprecationWarning)
163+
buf = cls()
164+
view_as_array_interface(obj, buf)
165+
return buf
166+
156167
@classmethod
157168
def from_any_interface(cls, obj: object, stream_ptr: int | None = None) -> StridedMemoryView:
158169
if check_has_dlpack(obj):
@@ -597,6 +608,23 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
597608
return buf
598609

599610

611+
cpdef StridedMemoryView view_as_array_interface(obj, view=None):
612+
cdef dict data = obj.__array_interface__
613+
if data["version"] < 3:
614+
raise BufferError("only NumPy Array Interface v3 or above is supported")
615+
if data.get("mask") is not None:
616+
raise BufferError("mask is not supported")
617+
618+
cdef StridedMemoryView buf = StridedMemoryView() if view is None else view
619+
buf.exporting_obj = obj
620+
buf.metadata = data
621+
buf.dl_tensor = NULL
622+
buf.ptr, buf.readonly = data["data"]
623+
buf.is_device_accessible = False
624+
buf.device_id = handle_return(driver.cuCtxGetDevice())
625+
return buf
626+
627+
600628
def args_viewable_as_strided_memory(tuple arg_indices):
601629
"""
602630
Decorator to create proxy objects to :obj:`StridedMemoryView` for the

cuda_core/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def init_cuda():
7272
driver.cuDevicePrimaryCtxSetFlags(device.device_id, driver.CUctx_flags.CU_CTX_SCHED_BLOCKING_SYNC)
7373
)
7474

75-
yield
75+
yield device
7676
_ = _device_unset_current()
7777

7878

cuda_core/tests/test_utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from cuda.core import Device
1919
from cuda.core._layout import _StridedLayout
2020
from cuda.core.utils import StridedMemoryView, args_viewable_as_strided_memory
21+
from pytest import param
2122

2223

2324
def test_cast_to_3_tuple_success():
@@ -460,3 +461,66 @@ def test_struct_array():
460461
# full dtype information doesn't seem to be preserved due to use of type strings,
461462
# which are lossy, e.g., dtype([("a", "int")]).str == "V8"
462463
assert smv.dtype == np.dtype(f"V{x.itemsize}")
464+
465+
466+
@pytest.mark.parametrize(
467+
("x", "expected_dtype"),
468+
[
469+
# 1D arrays with different dtypes
470+
param(np.array([1, 2, 3], dtype=np.int32), "int32", id="1d-int32"),
471+
param(np.array([1.0, 2.0, 3.0], dtype=np.float64), "float64", id="1d-float64"),
472+
param(np.array([1 + 2j, 3 + 4j], dtype=np.complex128), "complex128", id="1d-complex128"),
473+
param(np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64), "complex64", id="1d-complex64"),
474+
param(np.array([1, 2, 3, 4, 5], dtype=np.uint8), "uint8", id="1d-uint8"),
475+
param(np.array([1, 2], dtype=np.int64), "int64", id="1d-int64"),
476+
param(np.array([100, 200, 300], dtype=np.int16), "int16", id="1d-int16"),
477+
param(np.array([1000, 2000, 3000], dtype=np.uint16), "uint16", id="1d-uint16"),
478+
param(np.array([10000, 20000, 30000], dtype=np.uint64), "uint64", id="1d-uint64"),
479+
# 2D arrays - C-contiguous
480+
param(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32), "int32", id="2d-c-int32"),
481+
param(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), "float32", id="2d-c-float32"),
482+
# 2D arrays - Fortran-contiguous
483+
param(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32, order="F"), "int32", id="2d-f-int32"),
484+
param(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float64, order="F"), "float64", id="2d-f-float64"),
485+
# 3D arrays
486+
param(np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int32), "int32", id="3d-int32"),
487+
param(np.ones((2, 3, 4), dtype=np.float64), "float64", id="3d-float64"),
488+
# Sliced/strided arrays
489+
param(np.array([1, 2, 3, 4, 5, 6], dtype=np.int32)[::2], "int32", id="1d-strided-int32"),
490+
param(np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.float64)[:, ::2], "float64", id="2d-strided-float64"),
491+
param(np.arange(20, dtype=np.int32).reshape(4, 5)[::2, ::2], "int32", id="2d-strided-2x2-int32"),
492+
# Scalar (0-D array)
493+
param(np.array(42, dtype=np.int32), "int32", id="scalar-int32"),
494+
param(np.array(3.14, dtype=np.float64), "float64", id="scalar-float64"),
495+
# Empty arrays
496+
param(np.array([], dtype=np.int32), "int32", id="empty-1d-int32"),
497+
param(np.empty((0, 3), dtype=np.float64), "float64", id="empty-2d-float64"),
498+
# Single element
499+
param(np.array([1], dtype=np.int32), "int32", id="single-element"),
500+
# Structured dtype
501+
param(np.array([(1, 2.0), (3, 4.0)], dtype=[("a", "i4"), ("b", "f8")]), "V12", id="structured-dtype"),
502+
],
503+
)
504+
def test_from_array_interface(x, init_cuda, expected_dtype):
505+
smv = StridedMemoryView.from_array_interface(x)
506+
assert smv.size == x.size
507+
assert smv.dtype == np.dtype(expected_dtype)
508+
assert smv.shape == x.shape
509+
assert smv.ptr == x.ctypes.data
510+
assert smv.device_id == init_cuda.device_id
511+
assert smv.is_device_accessible is False
512+
assert smv.exporting_obj is x
513+
assert smv.readonly is not x.flags.writeable
514+
# Check strides
515+
strides_in_counts = convert_strides_to_counts(x.strides, x.dtype.itemsize)
516+
assert (x.flags.c_contiguous and smv.strides is None) or smv.strides == strides_in_counts
517+
518+
519+
def test_from_array_interface_unsupported_strides(init_cuda):
520+
# Create an array with strides that aren't a multiple of itemsize
521+
x = np.array([(1, 2.0), (3, 4.0)], dtype=[("a", "i4"), ("b", "f8")])
522+
b = x["b"]
523+
smv = StridedMemoryView.from_array_interface(b)
524+
with pytest.raises(ValueError, match="strides must be divisible by itemsize"):
525+
# TODO: ideally this would raise on construction
526+
smv.strides # noqa: B018

0 commit comments

Comments
 (0)