Skip to content

Commit 605241e

Browse files
committed
ensure allocation stream can be used for deallocation; suppress casting warning
1 parent bf712bf commit 605241e

File tree

1 file changed

+66
-24
lines changed

1 file changed

+66
-24
lines changed

cuda_core/cuda/core/experimental/_memory.pyx

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,29 @@ if TYPE_CHECKING:
4242
# MemoryResource both inherit from it
4343

4444

45+
cdef extern from * nogil:
46+
"""
47+
#if defined(__GNUC__)
48+
#pragma GCC diagnostic push
49+
#pragma GCC diagnostic ignored "-Wint-to-pointer-cast"
50+
#elif defined(_MSC_VER)
51+
#pragma warning(push)
52+
#pragma warning(disable : 4312)
53+
#endif
54+
55+
void* unsafe_cast_from_int(int x) {
56+
return (void*)x;
57+
}
58+
59+
#if defined(__GNUC__)
60+
#pragma GCC diagnostic pop
61+
#elif defined(_MSC_VER)
62+
#pragma warning(pop)
63+
#endif
64+
"""
65+
void* unsafe_cast_from_int(int x)
66+
67+
4568
PyCapsule = TypeVar("PyCapsule")
4669
"""Represent the capsule type."""
4770

@@ -64,23 +87,33 @@ cdef class Buffer:
6487
size_t _size
6588
object _mr
6689
object _ptr_obj
90+
cyStream _alloc_stream
91+
92+
def __cinit__(self):
93+
self._ptr = 0
94+
self._size = 0
95+
self._mr = None
96+
self._ptr_obj = None
97+
self._alloc_stream = None
6798

6899
def __init__(self, *args, **kwargs):
69100
raise RuntimeError("Buffer objects cannot be instantiated directly. Please use MemoryResource APIs.")
70101

71102
@classmethod
72-
def _init(cls, ptr: DevicePointerT, size_t size, mr: MemoryResource | None = None):
103+
def _init(cls, ptr: DevicePointerT, size_t size, mr: MemoryResource | None = None, stream: Stream | None = None):
73104
cdef Buffer self = Buffer.__new__(cls)
74105
self._ptr = <intptr_t>(int(ptr))
75106
self._ptr_obj = ptr
76107
self._size = size
77108
self._mr = mr
109+
self._alloc_stream = <cyStream>(stream) if stream is not None else None
78110
return self
79111

80112
def __dealloc__(self):
81-
self.close()
113+
self.close(self._alloc_stream)
82114

83115
def __reduce__(self):
116+
# Must not serialize the parent's stream!
84117
return Buffer.from_ipc_descriptor, (self.memory_resource, self.get_ipc_descriptor())
85118

86119
cpdef close(self, stream: Stream = None):
@@ -95,18 +128,23 @@ cdef class Buffer:
95128
The stream object to use for asynchronous deallocation. If None,
96129
the behavior depends on the underlying memory resource.
97130
"""
131+
cdef cyStream s
98132
if self._ptr and self._mr is not None:
99133
if isinstance(self._mr, _cyMemoryResource):
100-
# FIXME
101-
if stream is None:
102-
stream = Stream.__new__(Stream)
103-
(<cyStream>(stream))._handle = <cydriver.CUstream>(0)
104-
(<_cyMemoryResource>(self._mr))._deallocate(self._ptr, self._size, <cyStream>stream)
134+
s = self._alloc_stream if stream is None else <cyStream>stream
135+
(<_cyMemoryResource>(self._mr))._deallocate(self._ptr, self._size, s)
105136
else:
137+
if stream is None:
138+
if self._alloc_stream is not None:
139+
stream = self._alloc_stream
140+
else:
141+
# TODO: remove this branch when from_handle takes a stream
142+
stream = default_stream()
106143
self._mr.deallocate(self._ptr, self._size, stream)
107144
self._ptr = 0
108145
self._mr = None
109146
self._ptr_obj = None
147+
self._alloc_stream = None
110148

111149
@property
112150
def handle(self) -> DevicePointerT:
@@ -167,16 +205,19 @@ cdef class Buffer:
167205
return IPCBufferDescriptor._init(data_b, self.size)
168206

169207
@classmethod
170-
def from_ipc_descriptor(cls, mr: DeviceMemoryResource, ipc_buffer: IPCBufferDescriptor) -> Buffer:
208+
def from_ipc_descriptor(cls, mr: DeviceMemoryResource, ipc_buffer: IPCBufferDescriptor, stream: Stream = None) -> Buffer:
171209
"""Import a buffer that was exported from another process."""
172210
if not mr.is_ipc_enabled:
173211
raise RuntimeError("Memory resource is not IPC-enabled")
212+
if stream is None:
213+
# Note: match this behavior to DeviceMemoryResource.allocate()
214+
stream = default_stream()
174215
cdef cydriver.CUmemPoolPtrExportData share_data
175216
memcpy(share_data.reserved, <const void*><const char*>(ipc_buffer._reserved), sizeof(share_data.reserved))
176217
cdef cydriver.CUdeviceptr ptr
177218
with nogil:
178219
HANDLE_RETURN(cydriver.cuMemPoolImportPointer(&ptr, mr._mempool_handle, &share_data))
179-
return Buffer.from_handle(<intptr_t>ptr, ipc_buffer.size, mr)
220+
return Buffer._init(<intptr_t>ptr, ipc_buffer.size, mr, stream)
180221

181222
def copy_to(self, dst: Buffer = None, *, stream: Stream) -> Buffer:
182223
"""Copy from this buffer to the dst buffer asynchronously on the given stream.
@@ -297,6 +338,7 @@ cdef class Buffer:
297338
mr : :obj:`~_memory.MemoryResource`, optional
298339
Memory resource associated with the buffer
299340
"""
341+
# TODO: It is better to take a stream for latter deallocation
300342
return Buffer._init(ptr, size, mr=mr)
301343

302344

@@ -839,7 +881,7 @@ cdef class DeviceMemoryResource(_cyMemoryResource, MemoryResource):
839881
cdef int handle = int(alloc_handle)
840882
with nogil:
841883
HANDLE_RETURN(cydriver.cuMemPoolImportFromShareableHandle(
842-
&(self._mempool_handle), <void*>handle, _IPC_HANDLE_TYPE, 0)
884+
&(self._mempool_handle), unsafe_cast_from_int(handle), _IPC_HANDLE_TYPE, 0)
843885
)
844886
if uuid is not None:
845887
registered = self.register(uuid)
@@ -889,6 +931,7 @@ cdef class DeviceMemoryResource(_cyMemoryResource, MemoryResource):
889931
buf._ptr_obj = None
890932
buf._size = size
891933
buf._mr = self
934+
buf._alloc_stream = stream
892935
return buf
893936

894937
def allocate(self, size_t size, stream: Stream = None) -> Buffer:
@@ -921,7 +964,7 @@ cdef class DeviceMemoryResource(_cyMemoryResource, MemoryResource):
921964
HANDLE_RETURN(cydriver.cuMemFreeAsync(devptr, s))
922965
return 0
923966

924-
def deallocate(self, ptr: DevicePointerT, size_t size, stream: Stream = None):
967+
def deallocate(self, ptr: DevicePointerT, size_t size, stream: Stream):
925968
"""Deallocate a buffer previously allocated by this resource.
926969
927970
Parameters
@@ -932,10 +975,9 @@ cdef class DeviceMemoryResource(_cyMemoryResource, MemoryResource):
932975
The size of the buffer to deallocate, in bytes.
933976
stream : Stream, optional
934977
The stream on which to perform the deallocation asynchronously.
935-
If None, an internal stream is used.
978+
If the buffer is deallocated without an explicit stream, the allocation stream
979+
is used.
936980
"""
937-
if stream is None:
938-
stream = default_stream()
939981
self._deallocate(<intptr_t>ptr, size, <cyStream>stream)
940982

941983
@property
@@ -1018,11 +1060,13 @@ class LegacyPinnedMemoryResource(MemoryResource):
10181060
Buffer
10191061
The allocated buffer object, which is accessible on both host and device.
10201062
"""
1063+
if stream is None:
1064+
stream = default_stream()
10211065
err, ptr = driver.cuMemAllocHost(size)
10221066
raise_if_driver_error(err)
1023-
return Buffer._init(ptr, size, self)
1067+
return Buffer._init(ptr, size, self, stream)
10241068

1025-
def deallocate(self, ptr: DevicePointerT, size_t size, stream: Stream = None):
1069+
def deallocate(self, ptr: DevicePointerT, size_t size, stream: Stream):
10261070
"""Deallocate a buffer previously allocated by this resource.
10271071
10281072
Parameters
@@ -1031,12 +1075,10 @@ class LegacyPinnedMemoryResource(MemoryResource):
10311075
The pointer or handle to the buffer to deallocate.
10321076
size : int
10331077
The size of the buffer to deallocate, in bytes.
1034-
stream : Stream, optional
1035-
The stream on which to perform the deallocation asynchronously.
1036-
If None, no synchronization would happen.
1078+
stream : Stream
1079+
The stream on which to perform the deallocation synchronously.
10371080
"""
1038-
if stream:
1039-
stream.sync()
1081+
stream.sync()
10401082
err, = driver.cuMemFreeHost(ptr)
10411083
raise_if_driver_error(err)
10421084

@@ -1064,13 +1106,13 @@ class _SynchronousMemoryResource(MemoryResource):
10641106
self._dev_id = getattr(device_id, 'device_id', device_id)
10651107

10661108
def allocate(self, size, stream=None) -> Buffer:
1109+
if stream is None:
1110+
stream = default_stream()
10671111
err, ptr = driver.cuMemAlloc(size)
10681112
raise_if_driver_error(err)
10691113
return Buffer._init(ptr, size, self)
10701114

1071-
def deallocate(self, ptr, size, stream=None):
1072-
if stream is None:
1073-
stream = default_stream()
1115+
def deallocate(self, ptr, size, stream):
10741116
stream.sync()
10751117
err, = driver.cuMemFree(ptr)
10761118
raise_if_driver_error(err)

0 commit comments

Comments
 (0)