@@ -11,8 +11,8 @@ use cudarc::driver::CudaSlice;
1111use cudarc:: driver:: CudaStream ;
1212use cudarc:: driver:: DeviceRepr ;
1313use cudarc:: driver:: result:: stream;
14+ use futures:: channel:: oneshot;
1415use futures:: future:: BoxFuture ;
15- use kanal:: Sender ;
1616use tracing:: warn;
1717use vortex:: array:: buffer:: BufferHandle ;
1818use vortex:: error:: VortexResult ;
@@ -132,9 +132,9 @@ impl VortexCudaStream {
132132pub ( crate ) async fn await_stream_callback ( stream : & CudaStream ) -> VortexResult < ( ) > {
133133 let rx = register_stream_callback ( stream) ?;
134134
135- rx. recv ( )
136- . await
137- . map_err ( |e| vortex_err ! ( "CUDA stream callback channel closed unexpectedly: {}" , e ) )
135+ rx. await . map_err ( |oneshot :: Canceled | {
136+ vortex_err ! ( "CUDA stream callback channel closed unexpectedly: channel canceled" )
137+ } )
138138}
139139
140140/// Registers a host function callback on the stream.
@@ -147,8 +147,8 @@ pub(crate) async fn await_stream_callback(stream: &CudaStream) -> VortexResult<(
147147/// # Errors
148148///
149149/// Returns an error if registering the host callback function fails.
150- fn register_stream_callback ( stream : & CudaStream ) -> VortexResult < kanal :: AsyncReceiver < ( ) > > {
151- let ( tx, rx) = kanal :: bounded :: < ( ) > ( 1 ) ;
150+ fn register_stream_callback ( stream : & CudaStream ) -> VortexResult < oneshot :: Receiver < ( ) > > {
151+ let ( tx, rx) = oneshot :: channel :: < ( ) > ( ) ;
152152
153153 let tx_ptr = Box :: into_raw ( Box :: new ( tx) ) ;
154154
@@ -161,7 +161,7 @@ fn register_stream_callback(stream: &CudaStream) -> VortexResult<kanal::AsyncRec
161161 unsafe extern "C" fn callback ( user_data : * mut std:: ffi:: c_void ) {
162162 // SAFETY: The memory of `tx` is manually managed has not been freed
163163 // before. We have unique ownership and can therefore free it.
164- let tx = unsafe { Box :: from_raw ( user_data as * mut Sender < ( ) > ) } ;
164+ let tx = unsafe { Box :: from_raw ( user_data as * mut oneshot :: Sender < ( ) > ) } ;
165165
166166 // Blocking send as we're in a callback invoked by the CUDA driver.
167167 // NOTE: send can fail if the CudaEvent is dropped by the caller, in which case the receiver
@@ -189,5 +189,5 @@ fn register_stream_callback(stream: &CudaStream) -> VortexResult<kanal::AsyncRec
189189 } ) ?;
190190 }
191191
192- Ok ( rx. to_async ( ) )
192+ Ok ( rx)
193193}
0 commit comments