@@ -19,19 +19,21 @@ use vortex_cuda::CudaSession;
1919use vortex_cuda:: arrow:: ArrowDeviceArray ;
2020use vortex_cuda:: arrow:: ArrowDeviceArrayStream ;
2121use vortex_cuda:: arrow:: DeviceArrayExt ;
22+ use vortex_cuda:: arrow:: DeviceArrayStreamExt ;
2223use vortex_ffi:: try_or;
2324use vortex_ffi:: vx_array;
2425use vortex_ffi:: vx_array_ref;
2526use vortex_ffi:: vx_error;
2627use vortex_ffi:: vx_partition;
27- use vortex_ffi:: vx_partition_into_array_iter ;
28+ use vortex_ffi:: vx_partition_into_array_stream ;
2829use vortex_ffi:: vx_session;
2930use vortex_ffi:: vx_session_new_with;
3031use vortex_ffi:: vx_session_ref;
3132
3233const VX_CUDA_OK : c_int = 0 ;
3334const VX_CUDA_ERR : c_int = 1 ;
3435
36+ /// Return a session with CUDA state, adding default CUDA support when needed.
3537fn session_with_cuda ( session : & VortexSession ) -> VortexResult < VortexSession > {
3638 if session. get_opt :: < CudaSession > ( ) . is_some ( ) {
3739 return Ok ( session. clone ( ) ) ;
@@ -132,11 +134,9 @@ pub unsafe extern "C-unwind" fn vx_cuda_partition_scan_arrow_device_stream(
132134 vortex_ensure ! ( !partition. is_null( ) , "null vx_partition" ) ;
133135 vortex_ensure ! ( !out_stream. is_null( ) , "null ArrowDeviceArrayStream output" ) ;
134136
135- let ( dtype , array_iter ) = unsafe { vx_partition_into_array_iter ( partition) } ?;
137+ let array_stream = unsafe { vx_partition_into_array_stream ( partition) } ?;
136138 let session = session_with_cuda ( unsafe { vx_session_ref ( session) } ?) ?;
137- let ctx = CudaSession :: create_execution_ctx ( & session) ?;
138- let device_stream =
139- vortex_cuda:: export_device_array_stream_from_iter_with_ctx ( array_iter, dtype, ctx) ;
139+ let device_stream = array_stream. export_device_array_stream ( & session) ?;
140140
141141 unsafe { ptr:: write ( out_stream, device_stream) } ;
142142 Ok ( VX_CUDA_OK )
0 commit comments