Skip to content

Commit 581d260

Browse files
committed
Document CUDA stream helpers
Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 0b5fea3 commit 581d260

5 files changed

Lines changed: 114 additions & 73 deletions

File tree

vortex-cuda/ffi/src/lib.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,21 @@ use vortex_cuda::CudaSession;
1919
use vortex_cuda::arrow::ArrowDeviceArray;
2020
use vortex_cuda::arrow::ArrowDeviceArrayStream;
2121
use vortex_cuda::arrow::DeviceArrayExt;
22+
use vortex_cuda::arrow::DeviceArrayStreamExt;
2223
use vortex_ffi::try_or;
2324
use vortex_ffi::vx_array;
2425
use vortex_ffi::vx_array_ref;
2526
use vortex_ffi::vx_error;
2627
use vortex_ffi::vx_partition;
27-
use vortex_ffi::vx_partition_into_array_iter;
28+
use vortex_ffi::vx_partition_into_array_stream;
2829
use vortex_ffi::vx_session;
2930
use vortex_ffi::vx_session_new_with;
3031
use vortex_ffi::vx_session_ref;
3132

3233
const VX_CUDA_OK: c_int = 0;
3334
const VX_CUDA_ERR: c_int = 1;
3435

36+
/// Return a session with CUDA state, adding default CUDA support when needed.
3537
fn session_with_cuda(session: &VortexSession) -> VortexResult<VortexSession> {
3638
if session.get_opt::<CudaSession>().is_some() {
3739
return Ok(session.clone());
@@ -132,11 +134,9 @@ pub unsafe extern "C-unwind" fn vx_cuda_partition_scan_arrow_device_stream(
132134
vortex_ensure!(!partition.is_null(), "null vx_partition");
133135
vortex_ensure!(!out_stream.is_null(), "null ArrowDeviceArrayStream output");
134136

135-
let (dtype, array_iter) = unsafe { vx_partition_into_array_iter(partition) }?;
137+
let array_stream = unsafe { vx_partition_into_array_stream(partition) }?;
136138
let session = session_with_cuda(unsafe { vx_session_ref(session) }?)?;
137-
let ctx = CudaSession::create_execution_ctx(&session)?;
138-
let device_stream =
139-
vortex_cuda::export_device_array_stream_from_iter_with_ctx(array_iter, dtype, ctx);
139+
let device_stream = array_stream.export_device_array_stream(&session)?;
140140

141141
unsafe { ptr::write(out_stream, device_stream) };
142142
Ok(VX_CUDA_OK)

0 commit comments

Comments
 (0)