@@ -42,6 +42,29 @@ if TYPE_CHECKING:
4242# MemoryResource both inherit from it
4343
4444
45+ cdef extern from * nogil:
46+ """
47+ #if defined(__GNUC__)
48+ #pragma GCC diagnostic push
49+ #pragma GCC diagnostic ignored "-Wint-to-pointer-cast"
50+ #elif defined(_MSC_VER)
51+ #pragma warning(push)
52+ #pragma warning(disable : 4312)
53+ #endif
54+
55+ void* unsafe_cast_from_int(int x) {
56+ return (void*)x;
57+ }
58+
59+ #if defined(__GNUC__)
60+ #pragma GCC diagnostic pop
61+ #elif defined(_MSC_VER)
62+ #pragma warning(pop)
63+ #endif
64+ """
65+ void * unsafe_cast_from_int(int x)
66+
67+
4568PyCapsule = TypeVar(" PyCapsule" )
4669""" Represent the capsule type."""
4770
@@ -64,23 +87,33 @@ cdef class Buffer:
6487 size_t _size
6588 object _mr
6689 object _ptr_obj
90+ cyStream _alloc_stream
91+
92+ def __cinit__ (self ):
93+ self ._ptr = 0
94+ self ._size = 0
95+ self ._mr = None
96+ self ._ptr_obj = None
97+ self ._alloc_stream = None
6798
6899 def __init__ (self , *args , **kwargs ):
69100 raise RuntimeError (" Buffer objects cannot be instantiated directly. Please use MemoryResource APIs." )
70101
71102 @classmethod
72- def _init (cls , ptr: DevicePointerT , size_t size , mr: MemoryResource | None = None ):
103+ def _init (cls , ptr: DevicePointerT , size_t size , mr: MemoryResource | None = None , stream: Stream | None = None ):
73104 cdef Buffer self = Buffer.__new__ (cls )
74105 self ._ptr = < intptr_t> (int (ptr))
75106 self ._ptr_obj = ptr
76107 self ._size = size
77108 self ._mr = mr
109+ self ._alloc_stream = < cyStream> (stream) if stream is not None else None
78110 return self
79111
80112 def __dealloc__ (self ):
81- self .close()
113+ self .close(self ._alloc_stream )
82114
83115 def __reduce__ (self ):
116+ # Must not serialize the parent's stream!
84117 return Buffer.from_ipc_descriptor, (self .memory_resource, self .get_ipc_descriptor())
85118
86119 cpdef close(self , stream: Stream = None ):
@@ -95,18 +128,23 @@ cdef class Buffer:
95128 The stream object to use for asynchronous deallocation. If None,
96129 the behavior depends on the underlying memory resource.
97130 """
131+ cdef cyStream s
98132 if self ._ptr and self ._mr is not None :
99133 if isinstance (self ._mr, _cyMemoryResource):
100- # FIXME
101- if stream is None :
102- stream = Stream.__new__ (Stream)
103- (< cyStream> (stream))._handle = < cydriver.CUstream> (0 )
104- (< _cyMemoryResource> (self ._mr))._deallocate(self ._ptr, self ._size, < cyStream> stream)
134+ s = self ._alloc_stream if stream is None else < cyStream> stream
135+ (< _cyMemoryResource> (self ._mr))._deallocate(self ._ptr, self ._size, s)
105136 else :
137+ if stream is None :
138+ if self ._alloc_stream is not None :
139+ stream = self ._alloc_stream
140+ else :
141+ # TODO: remove this branch when from_handle takes a stream
142+ stream = default_stream()
106143 self ._mr.deallocate(self ._ptr, self ._size, stream)
107144 self ._ptr = 0
108145 self ._mr = None
109146 self ._ptr_obj = None
147+ self ._alloc_stream = None
110148
111149 @property
112150 def handle (self ) -> DevicePointerT:
@@ -167,16 +205,19 @@ cdef class Buffer:
167205 return IPCBufferDescriptor._init(data_b , self.size )
168206
169207 @classmethod
170- def from_ipc_descriptor(cls , mr: DeviceMemoryResource , ipc_buffer: IPCBufferDescriptor ) -> Buffer:
208+ def from_ipc_descriptor(cls , mr: DeviceMemoryResource , ipc_buffer: IPCBufferDescriptor , stream: Stream = None ) -> Buffer:
171209 """Import a buffer that was exported from another process."""
172210 if not mr.is_ipc_enabled:
173211 raise RuntimeError("Memory resource is not IPC-enabled")
212+ if stream is None:
213+ # Note: match this behavior to DeviceMemoryResource.allocate()
214+ stream = default_stream()
174215 cdef cydriver.CUmemPoolPtrExportData share_data
175216 memcpy(share_data.reserved , <const void*><const char*>(ipc_buffer._reserved ), sizeof(share_data.reserved ))
176217 cdef cydriver.CUdeviceptr ptr
177218 with nogil:
178219 HANDLE_RETURN(cydriver.cuMemPoolImportPointer(&ptr , mr._mempool_handle , &share_data ))
179- return Buffer.from_handle (<intptr_t>ptr , ipc_buffer.size , mr )
220+ return Buffer._init (<intptr_t>ptr , ipc_buffer.size , mr , stream )
180221
181222 def copy_to(self , dst: Buffer = None , *, stream: Stream ) -> Buffer:
182223 """Copy from this buffer to the dst buffer asynchronously on the given stream.
@@ -297,6 +338,7 @@ cdef class Buffer:
297338 mr : :obj:`~_memory.MemoryResource`, optional
298339 Memory resource associated with the buffer
299340 """
341+ # TODO: It is better to take a stream for latter deallocation
300342 return Buffer._init(ptr , size , mr = mr)
301343
302344
@@ -839,7 +881,7 @@ cdef class DeviceMemoryResource(_cyMemoryResource, MemoryResource):
839881 cdef int handle = int (alloc_handle)
840882 with nogil:
841883 HANDLE_RETURN(cydriver.cuMemPoolImportFromShareableHandle(
842- &(self._mempool_handle ), <void*> handle , _IPC_HANDLE_TYPE , 0)
884+ &(self._mempool_handle ), unsafe_cast_from_int( handle ) , _IPC_HANDLE_TYPE , 0)
843885 )
844886 if uuid is not None:
845887 registered = self .register(uuid)
@@ -889,6 +931,7 @@ cdef class DeviceMemoryResource(_cyMemoryResource, MemoryResource):
889931 buf._ptr_obj = None
890932 buf._size = size
891933 buf._mr = self
934+ buf._alloc_stream = stream
892935 return buf
893936
894937 def allocate (self , size_t size , stream: Stream = None ) -> Buffer:
@@ -921,7 +964,7 @@ cdef class DeviceMemoryResource(_cyMemoryResource, MemoryResource):
921964 HANDLE_RETURN(cydriver.cuMemFreeAsync(devptr , s ))
922965 return 0
923966
924- def deallocate(self , ptr: DevicePointerT , size_t size , stream: Stream = None ):
967+ def deallocate(self , ptr: DevicePointerT , size_t size , stream: Stream ):
925968 """ Deallocate a buffer previously allocated by this resource.
926969
927970 Parameters
@@ -932,10 +975,9 @@ cdef class DeviceMemoryResource(_cyMemoryResource, MemoryResource):
932975 The size of the buffer to deallocate, in bytes.
933976 stream : Stream, optional
934977 The stream on which to perform the deallocation asynchronously.
935- If None, an internal stream is used.
978+ If the buffer is deallocated without an explicit stream, the allocation stream
979+ is used.
936980 """
937- if stream is None :
938- stream = default_stream()
939981 self ._deallocate(< intptr_t> ptr, size, < cyStream> stream)
940982
941983 @property
@@ -1018,11 +1060,13 @@ class LegacyPinnedMemoryResource(MemoryResource):
10181060 Buffer
10191061 The allocated buffer object , which is accessible on both host and device.
10201062 """
1063+ if stream is None:
1064+ stream = default_stream()
10211065 err , ptr = driver.cuMemAllocHost(size)
10221066 raise_if_driver_error(err )
1023- return Buffer._init(ptr , size , self )
1067+ return Buffer._init(ptr , size , self , stream )
10241068
1025- def deallocate(self , ptr: DevicePointerT , size_t size , stream: Stream = None ):
1069+ def deallocate(self , ptr: DevicePointerT , size_t size , stream: Stream ):
10261070 """ Deallocate a buffer previously allocated by this resource.
10271071
10281072 Parameters
@@ -1031,12 +1075,10 @@ class LegacyPinnedMemoryResource(MemoryResource):
10311075 The pointer or handle to the buffer to deallocate.
10321076 size : int
10331077 The size of the buffer to deallocate, in bytes.
1034- stream : Stream, optional
1035- The stream on which to perform the deallocation asynchronously.
1036- If None, no synchronization would happen.
1078+ stream : Stream
1079+ The stream on which to perform the deallocation synchronously.
10371080 """
1038- if stream:
1039- stream.sync()
1081+ stream.sync()
10401082 err, = driver.cuMemFreeHost(ptr)
10411083 raise_if_driver_error(err)
10421084
@@ -1064,13 +1106,13 @@ class _SynchronousMemoryResource(MemoryResource):
10641106 self ._dev_id = getattr (device_id, ' device_id' , device_id)
10651107
10661108 def allocate (self , size , stream = None ) -> Buffer:
1109+ if stream is None:
1110+ stream = default_stream()
10671111 err , ptr = driver.cuMemAlloc(size)
10681112 raise_if_driver_error(err )
10691113 return Buffer._init(ptr , size , self )
10701114
1071- def deallocate(self , ptr , size , stream = None ):
1072- if stream is None :
1073- stream = default_stream()
1115+ def deallocate(self , ptr , size , stream ):
10741116 stream.sync()
10751117 err, = driver.cuMemFree(ptr)
10761118 raise_if_driver_error(err)
0 commit comments