Skip to content

Commit 07926a0

Browse files
committed
simplify
Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 581d260 commit 07926a0

2 files changed

Lines changed: 90 additions & 104 deletions

File tree

vortex-cuda/src/arrow/mod.rs

Lines changed: 90 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,14 @@ fn device_stream_runtime() -> &'static CurrentThreadRuntime {
275275
&DEVICE_STREAM_RUNTIME
276276
}
277277

278-
#[derive(Clone, Debug, PartialEq)]
278+
#[derive(Debug, PartialEq)]
279279
enum ArrowDeviceStreamSchema {
280280
Schema(Schema),
281281
Field(Field),
282282
}
283283

284284
impl ArrowDeviceStreamSchema {
285-
/// Interpret an Arrow C schema as the stream schema shape for `dtype`.
285+
/// Convert an Arrow C schema into the stream schema shape for `dtype`.
286286
fn from_ffi(schema: &FFI_ArrowSchema, dtype: &DType) -> VortexResult<Self> {
287287
if matches!(dtype, DType::Struct(..)) {
288288
Ok(Self::Schema(Schema::try_from(schema)?))
@@ -291,7 +291,12 @@ impl ArrowDeviceStreamSchema {
291291
}
292292
}
293293

294-
/// Build the Arrow stream schema for an empty stream with the given dtype.
294+
/// Convert a Vortex dtype into a stream schema when no batch is available.
295+
///
296+
/// This uses only the logical dtype, so it can differ from a non-empty stream's first-batch
297+
/// schema for encodings the dtype does not capture: a dictionary column reports a plain field
298+
/// here but `DataType::Dictionary` once a concrete batch is seen. This is harmless because an
299+
/// empty stream carries no data.
295300
fn from_dtype(dtype: &DType, ctx: &mut CudaExecutionCtx) -> VortexResult<Self> {
296301
let dtype = arrow_device_export_dtype(dtype);
297302
if let DType::Struct(struct_dtype, _) = &dtype {
@@ -331,25 +336,34 @@ impl DeviceArrayStreamPrivateData {
331336
}
332337

333338
/// Store the last stream error and return the Arrow callback error code.
339+
///
340+
/// Interior NUL bytes are replaced so `get_last_error` is never null while a non-zero status
341+
/// is reported.
334342
fn set_error(&mut self, error: impl ToString) -> c_int {
335-
self.last_error = CString::new(error.to_string()).ok();
343+
let message = error.to_string().replace('\0', " ");
344+
self.last_error = Some(CString::new(message).unwrap_or_default());
336345
ARROW_STREAM_EIO
337346
}
338347

339-
/// Initialize and return the stream schema, exporting the first batch if needed.
348+
/// Set the schema from the dtype alone, so an empty stream still has a schema to report.
349+
fn set_empty_schema(&mut self) -> VortexResult<()> {
350+
if self.schema.is_none() {
351+
self.schema = Some(ArrowDeviceStreamSchema::from_dtype(
352+
&self.dtype,
353+
&mut self.ctx,
354+
)?);
355+
}
356+
Ok(())
357+
}
358+
359+
/// Return the stream schema, exporting the first batch to derive it if needed.
360+
///
361+
/// A first batch is held in `pending_array` so the following `get_next` returns it.
340362
fn ensure_schema(&mut self) -> VortexResult<&ArrowDeviceStreamSchema> {
341363
if self.schema.is_none() {
342364
match self.array_iter.next() {
343-
Some(array) => {
344-
let array = self.export_batch(array?)?;
345-
self.pending_array = Some(array);
346-
}
347-
None => {
348-
self.schema = Some(ArrowDeviceStreamSchema::from_dtype(
349-
&self.dtype,
350-
&mut self.ctx,
351-
)?);
352-
}
365+
Some(array) => self.pending_array = Some(self.export_batch(array?)?),
366+
None => self.set_empty_schema()?,
353367
}
354368
}
355369

@@ -364,20 +378,16 @@ impl DeviceArrayStreamPrivateData {
364378
return Ok(Some(array));
365379
}
366380

367-
let Some(array) = self.array_iter.next() else {
368-
if self.schema.is_none() {
369-
self.schema = Some(ArrowDeviceStreamSchema::from_dtype(
370-
&self.dtype,
371-
&mut self.ctx,
372-
)?);
381+
match self.array_iter.next() {
382+
Some(array) => self.export_batch(array?).map(Some),
383+
None => {
384+
self.set_empty_schema()?;
385+
Ok(None)
373386
}
374-
return Ok(None);
375-
};
376-
377-
self.export_batch(array?).map(Some)
387+
}
378388
}
379389

380-
/// Export one Vortex stream batch and validate it against the stream schema and device.
390+
/// Export one Vortex batch as a device array, validating it against the stream.
381391
fn export_batch(&mut self, array: ArrayRef) -> VortexResult<ArrowDeviceArray> {
382392
vortex_ensure!(
383393
array.dtype() == &self.dtype,
@@ -386,54 +396,55 @@ impl DeviceArrayStreamPrivateData {
386396
array.dtype()
387397
);
388398

389-
let exported = device_stream_runtime()
390-
.block_on(array.export_device_array_with_schema(&mut self.ctx))?;
391399
let ArrowDeviceArrayWithSchema {
392-
schema: mut ffi_schema,
400+
mut schema,
393401
mut array,
394-
} = exported;
395-
let batch_schema = ArrowDeviceStreamSchema::from_ffi(&ffi_schema, &self.dtype);
396-
release_schema(&mut ffi_schema);
397-
let batch_schema = match batch_schema {
398-
Ok(batch_schema) => batch_schema,
399-
Err(error) => {
400-
release_device_array(&mut array);
401-
return Err(error);
402-
}
403-
};
402+
} = device_stream_runtime()
403+
.block_on(array.export_device_array_with_schema(&mut self.ctx))?;
404404

405-
let validation = (|| -> VortexResult<()> {
406-
if let Some(stream_schema) = &self.schema {
405+
// Release the schema we no longer need, and on any failure the array we will not return.
406+
let checked = self.check_batch(&schema, &array);
407+
release_schema(&mut schema);
408+
if let Err(error) = checked {
409+
release_device_array(&mut array);
410+
return Err(error);
411+
}
412+
Ok(array)
413+
}
414+
415+
/// Check that a freshly exported batch matches the stream schema and CUDA device.
416+
fn check_batch(
417+
&mut self,
418+
schema: &FFI_ArrowSchema,
419+
array: &ArrowDeviceArray,
420+
) -> VortexResult<()> {
421+
vortex_ensure!(
422+
array.device_type == ARROW_DEVICE_CUDA,
423+
"stream batch exported on non-CUDA device type {}",
424+
array.device_type
425+
);
426+
vortex_ensure!(
427+
array.device_id == self.device_id,
428+
"stream batch moved from CUDA device {} to {}",
429+
self.device_id,
430+
array.device_id
431+
);
432+
433+
// Commit the schema only after the batch is fully accepted, so a rejected first batch
434+
// never becomes the schema later reported by `get_schema`.
435+
let batch_schema = ArrowDeviceStreamSchema::from_ffi(schema, &self.dtype)?;
436+
match &self.schema {
437+
Some(stream_schema) => {
407438
vortex_ensure!(
408439
stream_schema == &batch_schema,
409440
"stream batch Arrow schema changed from {:?} to {:?}",
410441
stream_schema,
411442
batch_schema
412443
);
413-
} else {
414-
self.schema = Some(batch_schema);
415444
}
416-
417-
vortex_ensure!(
418-
array.device_type == ARROW_DEVICE_CUDA,
419-
"stream batch exported on non-CUDA device type {}",
420-
array.device_type
421-
);
422-
vortex_ensure!(
423-
array.device_id == self.device_id,
424-
"stream batch moved from CUDA device {} to {}",
425-
self.device_id,
426-
array.device_id
427-
);
428-
Ok(())
429-
})();
430-
431-
if let Err(error) = validation {
432-
release_device_array(&mut array);
433-
return Err(error);
445+
None => self.schema = Some(batch_schema),
434446
}
435-
436-
Ok(array)
447+
Ok(())
437448
}
438449
}
439450

@@ -454,64 +465,41 @@ pub trait DeviceArrayStreamExt {
454465
/// context's CUDA device at construction time, and each `get_next` verifies that the produced
455466
/// [`ArrowDeviceArray`] is CUDA-resident on that same device. The returned C stream owns the
456467
/// Vortex stream and must be released through its embedded `release` callback.
468+
///
469+
/// Per the Arrow C stream contract, drive the returned stream from a single thread; its
470+
/// callbacks must not be invoked concurrently.
457471
fn export_device_array_stream(
458472
self,
459473
session: &VortexSession,
460474
) -> VortexResult<ArrowDeviceArrayStream>;
461475
}
462476

463477
impl DeviceArrayStreamExt for SendableArrayStream {
464-
/// Export this stream by driving it on the shared Arrow Device stream runtime.
478+
/// Drive this stream on the shared Arrow Device stream runtime and export it.
465479
fn export_device_array_stream(
466480
self,
467481
session: &VortexSession,
468482
) -> VortexResult<ArrowDeviceArrayStream> {
469483
let dtype = self.dtype().clone();
470-
export_device_array_stream_from_iter(
471-
device_stream_runtime().block_on_stream(self),
472-
dtype,
473-
session,
474-
)
475-
}
476-
}
477-
478-
/// Export a blocking Vortex array iterator as an [`ArrowDeviceArrayStream`].
479-
///
480-
/// The iterator is advanced by the Arrow stream callbacks. Use this helper when the stream must be
481-
/// driven by a specific runtime or executor before crossing the Arrow C Device stream boundary.
482-
/// Each yielded array must have `dtype`; every exported batch is validated to stay on the CUDA
483-
/// device selected by the session's CUDA execution context.
484-
pub fn export_device_array_stream_from_iter(
485-
array_iter: impl Iterator<Item = VortexResult<ArrayRef>> + 'static,
486-
dtype: DType,
487-
session: &VortexSession,
488-
) -> VortexResult<ArrowDeviceArrayStream> {
489-
let ctx = crate::CudaSession::create_execution_ctx(session)?;
490-
Ok(export_device_array_stream_from_iter_with_ctx(
491-
array_iter, dtype, ctx,
492-
))
493-
}
494-
495-
/// Export a blocking Vortex array iterator as an [`ArrowDeviceArrayStream`] using an existing CUDA
496-
/// execution context.
497-
///
498-
/// Use this helper when the caller has already selected the CUDA execution context that must drive
499-
/// the exported stream. Each yielded array must have `dtype`; every exported batch is validated to
500-
/// stay on the CUDA device selected by `ctx`.
501-
pub fn export_device_array_stream_from_iter_with_ctx(
502-
array_iter: impl Iterator<Item = VortexResult<ArrayRef>> + 'static,
484+
let ctx = crate::CudaSession::create_execution_ctx(session)?;
485+
let array_iter = Box::new(device_stream_runtime().block_on_stream(self));
486+
Ok(device_array_stream(array_iter, dtype, ctx))
487+
}
488+
}
489+
490+
/// Build the Arrow C Device stream that owns `array_iter` and exports its batches through `ctx`.
491+
fn device_array_stream(
492+
array_iter: ArrayStreamIterator,
503493
dtype: DType,
504494
ctx: CudaExecutionCtx,
505495
) -> ArrowDeviceArrayStream {
506-
let device_id = ctx.stream().context().ordinal() as i64;
507-
508496
let private_data = Box::new(DeviceArrayStreamPrivateData {
509-
array_iter: Box::new(array_iter),
497+
device_id: ctx.stream().context().ordinal() as i64,
498+
array_iter,
510499
ctx,
511500
dtype,
512501
schema: None,
513502
pending_array: None,
514-
device_id,
515503
last_error: None,
516504
});
517505

@@ -650,7 +638,7 @@ unsafe extern "C" fn device_stream_get_last_error(
650638
.map_or(ptr::null(), |error| error.as_ptr())
651639
}
652640

653-
/// Release the stream state and clear callbacks so release is idempotent.
641+
/// Free the stream state and null its callbacks. The null `release` makes a second call a no-op.
654642
unsafe extern "C" fn device_stream_release(stream: *mut ArrowDeviceArrayStream) {
655643
let Some(stream_ref) = (unsafe { stream.as_mut() }) else {
656644
return;

vortex-cuda/src/lib.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ pub use arrow::ArrowDeviceArrayWithSchema;
2727
pub use arrow::DeviceArrayExt;
2828
pub use arrow::DeviceArrayStreamExt;
2929
pub use arrow::ExportDeviceArray;
30-
pub use arrow::export_device_array_stream_from_iter;
31-
pub use arrow::export_device_array_stream_from_iter_with_ctx;
3230
pub use canonical::CanonicalCudaExt;
3331
pub use device_buffer::CudaBufferExt;
3432
pub use device_buffer::CudaDeviceBuffer;

0 commit comments

Comments
 (0)