Skip to content

Commit 1b72df0

Browse files
authored
feat: add from_* style constructor classmethods to StridedMemoryView and deprecate __init__ constructor (NVIDIA#1250)
* feat: add from_* style constructor classmethods to StridedMemoryView and deprecate `__init__` constructor * chore: type the `buf` variable * chore: use DeprecationWarning * chore: match the pytest assertion with the new DeprecationWarning type
1 parent 07ba4ca commit 1b72df0

3 files changed

Lines changed: 94 additions & 30 deletions

File tree

cuda_core/cuda/core/experimental/_memoryview.pyx

Lines changed: 72 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ._dlpack cimport *
66

77
import functools
8+
import warnings
89
from typing import Optional
910

1011
import numpy
@@ -78,30 +79,78 @@ cdef class StridedMemoryView:
7879
bint readonly
7980
object exporting_obj
8081

81-
# If using dlpack, this is a strong reference to the result of
82-
# obj.__dlpack__() so we can lazily create shape and strides from
83-
# it later. If using CAI, this is a reference to the source
84-
# `__cuda_array_interface__` object.
85-
cdef object metadata
86-
87-
# The tensor object if has obj has __dlpack__, otherwise must be NULL
88-
cdef DLTensor *dl_tensor
89-
90-
# Memoized properties
91-
cdef tuple _shape
92-
cdef tuple _strides
93-
cdef bint _strides_init # Has the strides tuple been init'ed?
94-
cdef object _dtype
95-
96-
def __init__(self, obj=None, stream_ptr=None):
82+
cdef:
83+
# If using dlpack, this is a strong reference to the result of
84+
# obj.__dlpack__() so we can lazily create shape and strides from
85+
# it later. If using CAI, this is a reference to the source
86+
# `__cuda_array_interface__` object.
87+
object metadata
88+
89+
# The tensor object if has obj has __dlpack__, otherwise must be NULL
90+
DLTensor *dl_tensor
91+
92+
# Memoized properties
93+
tuple _shape
94+
tuple _strides
95+
# a `None` value for _strides has defined meaning in dlpack and
96+
# the cuda array interface, meaning C order, contiguous.
97+
#
98+
# this flag helps prevent unnecessary recompuation of _strides
99+
bint _strides_init
100+
object _dtype
101+
102+
def __init__(self, obj: object = None, stream_ptr: int | None = None) -> None:
103+
cdef str clsname = self.__class__.__name__
97104
if obj is not None:
98105
# populate self's attributes
99106
if check_has_dlpack(obj):
107+
warnings.warn(
108+
f"Constructing a {clsname} directly from a DLPack-supporting object is deprecated; "
109+
"Use `StridedMemoryView.from_dlpack` or `StridedMemoryView.from_any_interface` instead.",
110+
DeprecationWarning,
111+
stacklevel=2,
112+
)
100113
view_as_dlpack(obj, stream_ptr, self)
101114
else:
115+
warnings.warn(
116+
f"Constructing a {clsname} directly from a CUDA-array-interface-supporting object is deprecated; "
117+
"Use `StridedMemoryView.from_cuda_array_interface` or `StridedMemoryView.from_any_interface` instead.",
118+
DeprecationWarning,
119+
stacklevel=2,
120+
)
102121
view_as_cai(obj, stream_ptr, self)
103122
else:
104-
pass
123+
warnings.warn(
124+
f"Constructing an empty {clsname} is deprecated; "
125+
"use one of the classmethods `from_dlpack`, `from_cuda_array_interface` or `from_any_interface` "
126+
"to construct a StridedMemoryView from an object",
127+
DeprecationWarning,
128+
stacklevel=2,
129+
)
130+
131+
@classmethod
132+
def from_dlpack(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView:
133+
cdef StridedMemoryView buf
134+
with warnings.catch_warnings():
135+
warnings.simplefilter("ignore")
136+
buf = cls()
137+
view_as_dlpack(obj, stream_ptr, buf)
138+
return buf
139+
140+
@classmethod
141+
def from_cuda_array_interface(cls, obj: object, stream_ptr: int | None=None) -> StridedMemoryView:
142+
cdef StridedMemoryView buf
143+
with warnings.catch_warnings():
144+
warnings.simplefilter("ignore")
145+
buf = cls()
146+
view_as_cai(obj, stream_ptr, buf)
147+
return buf
148+
149+
@classmethod
150+
def from_any_interface(cls, obj: object, stream_ptr: int | None = None) -> StridedMemoryView:
151+
if check_has_dlpack(obj):
152+
return cls.from_dlpack(obj, stream_ptr)
153+
return cls.from_cuda_array_interface(obj, stream_ptr)
105154

106155
def __dealloc__(self):
107156
if self.dl_tensor == NULL:
@@ -121,7 +170,7 @@ cdef class StridedMemoryView:
121170
dlm_tensor.deleter(dlm_tensor)
122171

123172
@property
124-
def shape(self) -> tuple[int]:
173+
def shape(self) -> tuple[int, ...]:
125174
if self._shape is None:
126175
if self.exporting_obj is not None:
127176
if self.dl_tensor != NULL:
@@ -136,7 +185,7 @@ cdef class StridedMemoryView:
136185
return self._shape
137186

138187
@property
139-
def strides(self) -> Optional[tuple[int]]:
188+
def strides(self) -> Optional[tuple[int, ...]]:
140189
cdef int itemsize
141190
if self._strides_init is False:
142191
if self.exporting_obj is not None:
@@ -193,6 +242,7 @@ cdef str get_simple_repr(obj):
193242
return obj_repr
194243

195244

245+
196246
cdef bint check_has_dlpack(obj) except*:
197247
cdef bint has_dlpack
198248
if hasattr(obj, "__dlpack__") and hasattr(obj, "__dlpack_device__"):
@@ -206,8 +256,7 @@ cdef bint check_has_dlpack(obj) except*:
206256

207257

208258
cdef class _StridedMemoryViewProxy:
209-
210-
cdef:
259+
cdef readonly:
211260
object obj
212261
bint has_dlpack
213262

@@ -217,9 +266,9 @@ cdef class _StridedMemoryViewProxy:
217266

218267
cpdef StridedMemoryView view(self, stream_ptr=None):
219268
if self.has_dlpack:
220-
return view_as_dlpack(self.obj, stream_ptr)
269+
return StridedMemoryView.from_dlpack(self.obj, stream_ptr)
221270
else:
222-
return view_as_cai(self.obj, stream_ptr)
271+
return StridedMemoryView.from_cuda_array_interface(self.obj, stream_ptr)
223272

224273

225274
cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None):
@@ -354,7 +403,6 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
354403
return numpy.dtype(np_dtype)
355404

356405

357-
# Also generate for Python so we can test this code path
358406
cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
359407
cdef dict cai_data = obj.__cuda_array_interface__
360408
if cai_data["version"] < 3:

cuda_core/tests/test_memory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -696,15 +696,15 @@ def test_strided_memory_view_leak():
696696
arr = np.zeros(1048576, dtype=np.uint8)
697697
before = sys.getrefcount(arr)
698698
for idx in range(10):
699-
StridedMemoryView(arr, stream_ptr=-1)
699+
StridedMemoryView.from_any_interface(arr, stream_ptr=-1)
700700
after = sys.getrefcount(arr)
701701
assert before == after
702702

703703

704704
def test_strided_memory_view_refcnt():
705705
# Use Fortran ordering so strides is used
706706
a = np.zeros((64, 4), dtype=np.uint8, order="F")
707-
av = StridedMemoryView(a, stream_ptr=-1)
707+
av = StridedMemoryView.from_any_interface(a, stream_ptr=-1)
708708
# segfaults if refcnt is wrong
709709
assert av.shape[0] == 64
710710
assert sys.getrefcount(av.shape) >= 2

cuda_core/tests/test_utils.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import numpy as np
1515
import pytest
1616
from cuda.core.experimental import Device
17-
from cuda.core.experimental._memoryview import view_as_cai
1817
from cuda.core.experimental.utils import StridedMemoryView, args_viewable_as_strided_memory
1918

2019

@@ -78,7 +77,13 @@ def my_func(arr):
7877

7978
def test_strided_memory_view_cpu(self, in_arr):
8079
# stream_ptr=-1 means "the consumer does not care"
81-
view = StridedMemoryView(in_arr, stream_ptr=-1)
80+
view = StridedMemoryView.from_any_interface(in_arr, stream_ptr=-1)
81+
self._check_view(view, in_arr)
82+
83+
def test_strided_memory_view_cpu_init(self, in_arr):
84+
# stream_ptr=-1 means "the consumer does not care"
85+
with pytest.deprecated_call(match="deprecated"):
86+
view = StridedMemoryView(in_arr, stream_ptr=-1)
8287
self._check_view(view, in_arr)
8388

8489
def _check_view(self, view, in_arr):
@@ -147,7 +152,18 @@ def test_strided_memory_view_cpu(self, in_arr, use_stream):
147152
# This is the consumer stream
148153
s = dev.create_stream() if use_stream else None
149154

150-
view = StridedMemoryView(in_arr, stream_ptr=s.handle if s else -1)
155+
view = StridedMemoryView.from_any_interface(in_arr, stream_ptr=s.handle if s else -1)
156+
self._check_view(view, in_arr, dev)
157+
158+
def test_strided_memory_view_init(self, in_arr, use_stream):
159+
# TODO: use the device fixture?
160+
dev = Device()
161+
dev.set_current()
162+
# This is the consumer stream
163+
s = dev.create_stream() if use_stream else None
164+
165+
with pytest.deprecated_call(match="deprecated"):
166+
view = StridedMemoryView(in_arr, stream_ptr=s.handle if s else -1)
151167
self._check_view(view, in_arr, dev)
152168

153169
def _check_view(self, view, in_arr, dev):
@@ -179,7 +195,7 @@ def test_cuda_array_interface_gpu(self, in_arr, use_stream):
179195
# The usual path in `StridedMemoryView` prefers the DLPack interface
180196
# over __cuda_array_interface__, so we call `view_as_cai` directly
181197
# here so we can test the CAI code path.
182-
view = view_as_cai(in_arr, stream_ptr=s.handle if s else -1)
198+
view = StridedMemoryView.from_cuda_array_interface(in_arr, stream_ptr=s.handle if s else -1)
183199
self._check_view(view, in_arr, dev)
184200

185201
def _check_view(self, view, in_arr, dev):

0 commit comments

Comments
 (0)