Skip to content

Commit 93c8619

Browse files
committed
chore: test CUDA Arrow Device capsule exports
Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 79c6b0f commit 93c8619

5 files changed

Lines changed: 211 additions & 0 deletions

File tree

vortex-python-cuda/python/vortex_cuda/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
_debug_array_metadata_dtype = _lib._debug_array_metadata_dtype
88
_debug_array_metadata_display_values = _lib._debug_array_metadata_display_values
9+
_debug_arrow_device_array_capsule_summary = _lib._debug_arrow_device_array_capsule_summary
10+
_debug_consume_arrow_device_array_capsules = _lib._debug_consume_arrow_device_array_capsules
911
cuda_available = _lib.cuda_available
1012
export_device_array = _lib.export_device_array
1113

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
def _debug_array_metadata_dtype(array: object) -> str: ...
5+
def _debug_array_metadata_display_values(array: object) -> str: ...
6+
def _debug_arrow_device_array_capsule_summary(schema: object, device_array: object) -> dict[str, object]: ...
7+
def _debug_consume_arrow_device_array_capsules(
8+
schema: object, device_array: object
9+
) -> tuple[bool, bool, bool, bool, bool, bool]: ...
10+
def cuda_available() -> bool: ...
11+
def export_device_array(
12+
array: object, requested_schema: object | None = None, **kwargs: object
13+
) -> tuple[object, object]: ...
14+
15+
__all__: list[str]

vortex-python-cuda/python/vortex_cuda/_lib.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
def _debug_array_metadata_dtype(array: object) -> str: ...
55
def _debug_array_metadata_display_values(array: object) -> str: ...
6+
def _debug_arrow_device_array_capsule_summary(schema: object, device_array: object) -> dict[str, object]: ...
7+
def _debug_consume_arrow_device_array_capsules(
8+
schema: object, device_array: object
9+
) -> tuple[bool, bool, bool, bool, bool, bool]: ...
610
def cuda_available() -> bool: ...
711
def export_device_array(
812
array: object, requested_schema: object | None = None, **kwargs: object

vortex-python-cuda/src/lib.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,88 @@ fn release_exported(exported: &mut ArrowDeviceArrayWithSchema) {
461461
release_device_array(&mut exported.array);
462462
}
463463

464+
/// Return non-owning details from Arrow Device capsules for Python-side smoke consumers.
465+
#[pyfunction]
466+
fn _debug_arrow_device_array_capsule_summary<'py>(
467+
py: Python<'py>,
468+
schema: Bound<'py, PyCapsule>,
469+
device_array: Bound<'py, PyCapsule>,
470+
) -> PyResult<Bound<'py, PyDict>> {
471+
let schema = unsafe {
472+
schema
473+
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
474+
.cast::<FFI_ArrowSchema>()
475+
.as_ref()
476+
};
477+
let device_array = unsafe {
478+
device_array
479+
.pointer_checked(Some(ARROW_DEVICE_ARRAY_CAPSULE_NAME))?
480+
.cast::<ArrowDeviceArray>()
481+
.as_ref()
482+
};
483+
484+
let summary = PyDict::new(py);
485+
summary.set_item("schema_live", schema.release.is_some())?;
486+
summary.set_item("array_live", device_array.array.release.is_some())?;
487+
summary.set_item("is_cuda", device_array.device_type == ARROW_DEVICE_CUDA)?;
488+
summary.set_item("device_type", device_array.device_type)?;
489+
summary.set_item("device_id", device_array.device_id)?;
490+
summary.set_item("length", device_array.array.length)?;
491+
summary.set_item("null_count", device_array.array.null_count)?;
492+
summary.set_item("n_buffers", device_array.array.n_buffers)?;
493+
summary.set_item("n_children", device_array.array.n_children)?;
494+
Ok(summary)
495+
}
496+
497+
/// Simulate a Python Arrow Device consumer taking ownership from the returned capsules.
498+
#[pyfunction]
499+
fn _debug_consume_arrow_device_array_capsules(
500+
schema: Bound<'_, PyCapsule>,
501+
device_array: Bound<'_, PyCapsule>,
502+
) -> PyResult<(bool, bool, bool, bool, bool, bool)> {
503+
let mut schema_ptr = schema
504+
.pointer_checked(Some(ARROW_SCHEMA_CAPSULE_NAME))?
505+
.cast::<FFI_ArrowSchema>();
506+
let mut device_array_ptr = device_array
507+
.pointer_checked(Some(ARROW_DEVICE_ARRAY_CAPSULE_NAME))?
508+
.cast::<ArrowDeviceArray>();
509+
510+
let schema_ref = unsafe { schema_ptr.as_mut() };
511+
let device_array_ref = unsafe { device_array_ptr.as_mut() };
512+
let schema_had_release = schema_ref.release.is_some();
513+
let array_had_release = device_array_ref.array.release.is_some();
514+
515+
release_schema(schema_ref);
516+
release_device_array(device_array_ref);
517+
518+
let schema_release_cleared = schema_ref.release.is_none();
519+
let array_release_cleared = device_array_ref.array.release.is_none();
520+
521+
set_capsule_name(&schema, USED_ARROW_SCHEMA_CAPSULE_NAME)?;
522+
set_capsule_name(&device_array, USED_ARROW_DEVICE_ARRAY_CAPSULE_NAME)?;
523+
524+
Ok((
525+
schema_had_release,
526+
array_had_release,
527+
schema_release_cleared,
528+
array_release_cleared,
529+
capsule_is_valid(&schema, USED_ARROW_SCHEMA_CAPSULE_NAME),
530+
capsule_is_valid(&device_array, USED_ARROW_DEVICE_ARRAY_CAPSULE_NAME),
531+
))
532+
}
533+
534+
fn set_capsule_name(capsule: &Bound<'_, PyCapsule>, name: &CStr) -> PyResult<()> {
535+
let result = unsafe { ffi::PyCapsule_SetName(capsule.as_ptr(), name.as_ptr()) };
536+
if result != 0 {
537+
return Err(PyErr::fetch(capsule.py()));
538+
}
539+
Ok(())
540+
}
541+
542+
fn capsule_is_valid(capsule: &Bound<'_, PyCapsule>, name: &CStr) -> bool {
543+
unsafe { ffi::PyCapsule_IsValid(capsule.as_ptr(), name.as_ptr()) == 1 }
544+
}
545+
464546
fn schema_capsule<'py>(
465547
py: Python<'py>,
466548
schema: FFI_ArrowSchema,
@@ -573,6 +655,14 @@ fn _lib(m: &Bound<PyModule>) -> PyResult<()> {
573655
m.add_function(wrap_pyfunction!(cuda_available, m)?)?;
574656
m.add_function(wrap_pyfunction!(_debug_array_metadata_dtype, m)?)?;
575657
m.add_function(wrap_pyfunction!(_debug_array_metadata_display_values, m)?)?;
658+
m.add_function(wrap_pyfunction!(
659+
_debug_arrow_device_array_capsule_summary,
660+
m
661+
)?)?;
662+
m.add_function(wrap_pyfunction!(
663+
_debug_consume_arrow_device_array_capsules,
664+
m
665+
)?)?;
576666
m.add_function(wrap_pyfunction!(export_device_array, m)?)?;
577667
Ok(())
578668
}

vortex-python-cuda/test/test_native_bridge.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,44 @@
22
# SPDX-FileCopyrightText: Copyright the Vortex contributors
33
# pyright: reportPrivateUsage=false
44

5+
import gc
6+
from typing import cast
7+
58
import pytest
69
import vortex_cuda
710

811
import vortex
912

1013

14+
def _require_cuda() -> None:
15+
if not vortex_cuda.cuda_available():
16+
pytest.skip("CUDA device is not available")
17+
18+
19+
def _assert_exported_device_array(
20+
array: object, *, length: int, null_count: int, n_children: int
21+
) -> tuple[object, object]:
22+
schema, device_array = vortex_cuda.export_device_array(array)
23+
summary = cast(
24+
dict[str, object],
25+
vortex_cuda._debug_arrow_device_array_capsule_summary( # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType]
26+
schema, device_array
27+
),
28+
)
29+
30+
assert summary["schema_live"] is True
31+
assert summary["array_live"] is True
32+
assert summary["is_cuda"] is True
33+
assert summary["length"] == length
34+
assert summary["null_count"] == null_count
35+
assert summary["n_children"] == n_children
36+
n_buffers = summary["n_buffers"]
37+
assert isinstance(n_buffers, int)
38+
assert n_buffers >= 0
39+
40+
return schema, device_array
41+
42+
1143
def test_debug_array_metadata_dtype_reads_base_vortex_array():
1244
array = vortex.Array.from_range(range(0, 3))
1345

@@ -64,3 +96,71 @@ def test_export_device_array_returns_capsules_or_clean_cuda_error():
6496
schema, device_array = vortex_cuda.export_device_array(array)
6597
assert type(schema).__name__ == "PyCapsule"
6698
assert type(device_array).__name__ == "PyCapsule"
99+
100+
101+
def test_arrow_device_export_primitive_array():
102+
_require_cuda()
103+
104+
_ = _assert_exported_device_array(vortex.array([1, 2, 3]), length=3, null_count=0, n_children=0)
105+
106+
107+
def test_arrow_device_export_nullable_primitive_array():
108+
_require_cuda()
109+
110+
_ = _assert_exported_device_array(vortex.array([1, None, 3]), length=3, null_count=1, n_children=0)
111+
112+
113+
def test_arrow_device_export_nullable_bool_array():
114+
_require_cuda()
115+
116+
_ = _assert_exported_device_array(vortex.array([True, None, False]), length=3, null_count=1, n_children=0)
117+
118+
119+
def test_arrow_device_export_string_array():
120+
_require_cuda()
121+
122+
_ = _assert_exported_device_array(
123+
vortex.array(["alpha", "beta", "a longer string that should use the varbin data buffer"]),
124+
length=3,
125+
null_count=0,
126+
n_children=0,
127+
)
128+
129+
130+
def test_arrow_device_export_struct_array():
131+
import pyarrow as pa
132+
133+
_require_cuda()
134+
135+
arrow_table = pa.table({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0]})
136+
struct_array = vortex.Array.from_arrow(
137+
pa.StructArray.from_arrays( # pyright: ignore[reportUnknownMemberType]
138+
[arrow_table.column("a").combine_chunks(), arrow_table.column("b").combine_chunks()],
139+
names=["a", "b"],
140+
)
141+
)
142+
143+
_ = _assert_exported_device_array(struct_array, length=3, null_count=0, n_children=2)
144+
145+
146+
def test_arrow_device_capsules_drop_unconsumed():
147+
_require_cuda()
148+
149+
schema, device_array = _assert_exported_device_array(vortex.array([1, 2, 3]), length=3, null_count=0, n_children=0)
150+
del schema, device_array
151+
_ = gc.collect()
152+
153+
154+
def test_arrow_device_capsules_consumer_release_and_used_names():
155+
_require_cuda()
156+
157+
schema, device_array = _assert_exported_device_array(vortex.array([1, 2, 3]), length=3, null_count=0, n_children=0)
158+
consume_result = cast(
159+
tuple[bool, bool, bool, bool, bool, bool],
160+
vortex_cuda._debug_consume_arrow_device_array_capsules( # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType]
161+
schema, device_array
162+
),
163+
)
164+
assert consume_result == (True, True, True, True, True, True)
165+
del schema, device_array
166+
_ = gc.collect()

0 commit comments

Comments
 (0)