Skip to content

Commit e4c3fbc

Browse files
committed
Add Python CUDA bridge CI and buffer handoff ABI
Add explicit GPU-runner CI coverage for the Python CUDA bridge through the vortex-data[cuda] optional-extra path. Extend the private metadata bridge to carry host buffer-export capsules instead of only a buffer count. The base Python package exports repr(C) VortexBufferExport descriptors, and vortex-python-cuda imports them into local BufferHandles before deserializing arrays through its own VortexSession. Tests now cover primitive, nullable, bool, and struct arrays across the bridge, plus the existing CUDA Arrow Device smoke path. Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 49f300d commit e4c3fbc

9 files changed

Lines changed: 339 additions & 27 deletions

File tree

.github/workflows/ci.yml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,30 @@ jobs:
132132
uv run --all-packages make html
133133
working-directory: docs/
134134

135+
python-cuda-test:
136+
name: "Python CUDA (test)"
137+
if: github.repository == 'vortex-data/vortex'
138+
runs-on: >-
139+
${{ format('runs-on={0}/runner=gpu/tag=python-cuda-test', github.run_id) }}
140+
timeout-minutes: 30
141+
env:
142+
RUST_LOG: "info,maturin=off,uv=debug"
143+
MATURIN_PEP517_ARGS: "--profile ci"
144+
steps:
145+
- uses: runs-on/action@v2
146+
with:
147+
sccache: s3
148+
- uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7
149+
- uses: ./.github/actions/setup-prebuild
150+
151+
- name: Pytest - PyVortex CUDA bridge
152+
run: |
153+
uv run --extra cuda \
154+
--reinstall-package vortex-data \
155+
--reinstall-package vortex-data-cuda \
156+
pytest --benchmark-disable ../vortex-python-cuda/test/test_native_bridge.py
157+
working-directory: vortex-python/
158+
135159
rust-docs:
136160
name: "Rust (docs)"
137161
needs: duckdb-ready

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-python-cuda/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ extension-module = []
2828

2929
[dependencies]
3030
arrow-schema = { workspace = true }
31+
bytes = { workspace = true }
3132
pyo3 = { workspace = true, features = ["abi3", "abi3-py311"] }
3233
vortex = { workspace = true }
3334
vortex-cuda = { workspace = true }
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
# pyright: reportMissingModuleSource=false, reportPrivateUsage=false
34

4-
from ._lib import ( # pyright: ignore[reportMissingModuleSource]
5-
_debug_array_metadata_dtype as _debug_array_metadata_dtype, # pyright: ignore[reportPrivateUsage]
6-
)
7-
from ._lib import ( # pyright: ignore[reportMissingModuleSource]
8-
cuda_available,
9-
export_device_array,
10-
)
5+
from . import _lib
6+
7+
_debug_array_metadata_dtype = _lib._debug_array_metadata_dtype
8+
_debug_array_metadata_scalar_values = _lib._debug_array_metadata_scalar_values
9+
cuda_available = _lib.cuda_available
10+
export_device_array = _lib.export_device_array
1111

1212
__all__ = ["cuda_available", "export_device_array"]

vortex-python-cuda/python/vortex_cuda/_lib.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
def _debug_array_metadata_dtype(array: object) -> str: ...
5+
def _debug_array_metadata_scalar_values(array: object) -> list[str]: ...
56
def cuda_available() -> bool: ...
67
def export_device_array(
78
array: object, requested_schema: object | None = None, **kwargs: object

vortex-python-cuda/src/lib.rs

Lines changed: 139 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ use vortex::buffer::ByteBuffer;
3535
use vortex::dtype::DType;
3636
use vortex::error::VortexError;
3737
use vortex::error::VortexResult;
38-
use vortex::error::vortex_bail;
3938
use vortex::error::vortex_ensure;
4039
use vortex::error::vortex_err;
4140
use vortex::flatbuffers::FlatBuffer;
@@ -54,6 +53,117 @@ const ARROW_SCHEMA_CAPSULE_NAME: &CStr = c_str!("arrow_schema");
5453
const USED_ARROW_SCHEMA_CAPSULE_NAME: &CStr = c_str!("used_arrow_schema");
5554
const ARROW_DEVICE_ARRAY_CAPSULE_NAME: &CStr = c_str!("arrow_device_array");
5655
const USED_ARROW_DEVICE_ARRAY_CAPSULE_NAME: &CStr = c_str!("used_arrow_device_array");
56+
const BUFFER_EXPORT_CAPSULE_NAME: &CStr = c_str!("vortex_buffer_export");
57+
58+
/// Private C-ABI struct matching the layout in `vortex-python/src/arrays/mod.rs`.
59+
/// Both crates define this identically; the exact-version pin ensures layout agreement.
60+
#[repr(C)]
61+
struct VortexBufferExport {
62+
version: u32,
63+
kind: u32,
64+
ptr: *const u8,
65+
len: usize,
66+
alignment: usize,
67+
device_id: i32,
68+
sync_event: *mut c_void,
69+
private_data: *mut c_void,
70+
release: Option<unsafe extern "C" fn(*mut VortexBufferExport)>,
71+
}
72+
73+
const VORTEX_BUFFER_HOST: u32 = 0;
74+
75+
struct BufferExportGuard {
76+
export: NonNull<VortexBufferExport>,
77+
}
78+
79+
impl BufferExportGuard {
80+
fn export(&self) -> &VortexBufferExport {
81+
unsafe { self.export.as_ref() }
82+
}
83+
}
84+
85+
impl AsRef<[u8]> for BufferExportGuard {
86+
fn as_ref(&self) -> &[u8] {
87+
let export = self.export();
88+
if export.len == 0 {
89+
&[]
90+
} else {
91+
unsafe { std::slice::from_raw_parts(export.ptr, export.len) }
92+
}
93+
}
94+
}
95+
96+
impl Drop for BufferExportGuard {
97+
fn drop(&mut self) {
98+
let mut export = unsafe { Box::from_raw(self.export.as_ptr()) };
99+
if let Some(release) = export.release.take() {
100+
unsafe { release(&raw mut *export) };
101+
}
102+
}
103+
}
104+
105+
// The guard is moved into `Bytes::from_owner`, which requires `Send + Sync`. After import we disable
106+
// the source capsule destructor and own the boxed C export until this guard is dropped.
107+
unsafe impl Send for BufferExportGuard {}
108+
unsafe impl Sync for BufferExportGuard {}
109+
110+
fn import_buffer_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyResult<BufferHandle> {
111+
let export_ptr = capsule
112+
.pointer_checked(Some(BUFFER_EXPORT_CAPSULE_NAME))?
113+
.cast::<VortexBufferExport>();
114+
let export = unsafe { export_ptr.as_ref() };
115+
116+
if export.version != 1 {
117+
return Err(PyValueError::new_err(format!(
118+
"unsupported VortexBufferExport version {}",
119+
export.version
120+
)));
121+
}
122+
if export.kind != VORTEX_BUFFER_HOST {
123+
return Err(PyValueError::new_err(format!(
124+
"unsupported buffer kind {} (only host buffers are supported in metadata bridge)",
125+
export.kind
126+
)));
127+
}
128+
129+
if export.len != 0 && export.ptr.is_null() {
130+
return Err(PyValueError::new_err(
131+
"non-empty VortexBufferExport has null data pointer",
132+
));
133+
}
134+
135+
let len = export.len;
136+
let alignment = vortex::buffer::Alignment::try_from(
137+
u32::try_from(export.alignment)
138+
.map_err(|_| PyValueError::new_err("buffer alignment exceeds u32"))?,
139+
)
140+
.map_err(|e| PyValueError::new_err(e.to_string()))?;
141+
142+
if len != 0 && !alignment.is_ptr_aligned(export.ptr) {
143+
return Err(PyValueError::new_err(format!(
144+
"buffer pointer is not aligned to requested alignment {alignment}"
145+
)));
146+
}
147+
148+
// Transfer ownership of the boxed VortexBufferExport from the producer capsule into the Bytes
149+
// owner below. Otherwise the producer capsule could be dropped before the reconstructed
150+
// BufferHandle, leaving the Bytes owner with a dangling export pointer.
151+
unsafe { ffi::PyCapsule_SetDestructor(capsule.as_ptr(), None) };
152+
if PyErr::occurred(capsule.py()) {
153+
return Err(PyErr::fetch(capsule.py()));
154+
}
155+
156+
let guard = BufferExportGuard { export: export_ptr };
157+
158+
let byte_buffer = if len == 0 {
159+
drop(guard);
160+
ByteBuffer::empty_aligned(alignment)
161+
} else {
162+
ByteBuffer::from(bytes::Bytes::from_owner(guard)).aligned(alignment)
163+
};
164+
165+
Ok(BufferHandle::new_host(byte_buffer))
166+
}
57167

58168
struct ExportedDeviceArray(ArrowDeviceArrayWithSchema);
59169

@@ -101,7 +211,7 @@ struct ArrayMetadata {
101211
dtype: Vec<u8>,
102212
len: usize,
103213
metadata: Vec<u8>,
104-
buffer_count: usize,
214+
buffers: Vec<BufferHandle>,
105215
children: Vec<ArrayMetadata>,
106216
}
107217

@@ -147,6 +257,16 @@ fn parse_array_metadata(value: &Bound<'_, PyAny>) -> PyResult<ArrayMetadata> {
147257
)));
148258
}
149259

260+
let buffers = tuple
261+
.get_item(4)?
262+
.cast::<PyList>()?
263+
.iter()
264+
.map(|item| {
265+
let capsule: Bound<'_, PyCapsule> = item.extract()?;
266+
import_buffer_from_capsule(&capsule)
267+
})
268+
.collect::<PyResult<Vec<_>>>()?;
269+
150270
let children = tuple
151271
.get_item(5)?
152272
.cast::<PyList>()?
@@ -159,7 +279,7 @@ fn parse_array_metadata(value: &Bound<'_, PyAny>) -> PyResult<ArrayMetadata> {
159279
dtype: tuple.get_item(1)?.extract()?,
160280
len: tuple.get_item(2)?.extract()?,
161281
metadata: tuple.get_item(3)?.extract()?,
162-
buffer_count: tuple.get_item(4)?.extract()?,
282+
buffers,
163283
children,
164284
})
165285
}
@@ -173,14 +293,6 @@ fn deserialize_metadata_tree(
173293
metadata: &ArrayMetadata,
174294
session: &VortexSession,
175295
) -> VortexResult<ArrayRef> {
176-
if metadata.buffer_count != 0 {
177-
vortex_bail!(
178-
"metadata-only bridge cannot deserialize array {} with {} buffers yet",
179-
metadata.encoding_id,
180-
metadata.buffer_count
181-
);
182-
}
183-
184296
let dtype = dtype_from_metadata(metadata, session)?;
185297
let children = metadata
186298
.children
@@ -194,12 +306,11 @@ fn deserialize_metadata_tree(
194306
.registry()
195307
.find(&encoding_id)
196308
.ok_or_else(|| vortex_err!("Unknown array encoding: {}", metadata.encoding_id))?;
197-
let buffers: &[BufferHandle] = &[];
198309
let decoded = plugin.deserialize(
199310
&dtype,
200311
metadata.len,
201312
&metadata.metadata,
202-
buffers,
313+
&metadata.buffers,
203314
&children,
204315
session,
205316
)?;
@@ -246,6 +357,20 @@ fn _debug_array_metadata_dtype(array: Bound<'_, PyAny>) -> PyResult<String> {
246357
Ok(array.dtype().to_string())
247358
}
248359

360+
/// Return scalar display strings after crossing the private vtable-metadata bridge.
361+
#[pyfunction]
362+
fn _debug_array_metadata_scalar_values(array: Bound<'_, PyAny>) -> PyResult<Vec<String>> {
363+
let metadata = extract_array_metadata(&array)?;
364+
let array = deserialize_metadata_tree(&metadata, &METADATA_SESSION).map_err(to_py_err)?;
365+
(0..array.len())
366+
.map(|index| {
367+
#[expect(deprecated)]
368+
array.scalar_at(index).map(|scalar| scalar.to_string())
369+
})
370+
.collect::<VortexResult<Vec<_>>>()
371+
.map_err(to_py_err)
372+
}
373+
249374
/// Export a PyVortex array as Arrow C Device schema and array PyCapsules.
250375
#[pyfunction]
251376
#[pyo3(signature = (array, requested_schema = None, **kwargs))]
@@ -461,6 +586,7 @@ unsafe extern "C" fn release_device_array_capsule(capsule: *mut ffi::PyObject) {
461586
fn _lib(m: &Bound<PyModule>) -> PyResult<()> {
462587
m.add_function(wrap_pyfunction!(cuda_available, m)?)?;
463588
m.add_function(wrap_pyfunction!(_debug_array_metadata_dtype, m)?)?;
589+
m.add_function(wrap_pyfunction!(_debug_array_metadata_scalar_values, m)?)?;
464590
m.add_function(wrap_pyfunction!(export_device_array, m)?)?;
465591
Ok(())
466592
}

vortex-python-cuda/test/test_native_bridge.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,58 @@ def test_debug_array_metadata_dtype_reads_base_vortex_array():
1313
assert vortex_cuda._debug_array_metadata_dtype(array) == str(array.dtype) # pyright: ignore[reportPrivateUsage]
1414

1515

16-
def test_metadata_bridge_reports_arrays_that_need_buffer_handoff():
16+
def test_metadata_bridge_primitive_array():
1717
array = vortex.array([1, 2, 3])
1818

19-
with pytest.raises(RuntimeError, match="metadata-only bridge.*buffers"):
20-
_ = vortex_cuda._debug_array_metadata_dtype(array) # pyright: ignore[reportPrivateUsage]
19+
assert vortex_cuda._debug_array_metadata_dtype(array) == str(array.dtype) # pyright: ignore[reportPrivateUsage]
20+
assert vortex_cuda._debug_array_metadata_scalar_values(array) == [ # pyright: ignore[reportPrivateUsage]
21+
"1i64",
22+
"2i64",
23+
"3i64",
24+
]
25+
26+
27+
def test_metadata_bridge_nullable_array():
28+
array = vortex.array([1, None, 3])
29+
30+
assert vortex_cuda._debug_array_metadata_dtype(array) == str(array.dtype) # pyright: ignore[reportPrivateUsage]
31+
assert vortex_cuda._debug_array_metadata_scalar_values(array) == [ # pyright: ignore[reportPrivateUsage]
32+
"1i64",
33+
"null",
34+
"3i64",
35+
]
36+
37+
38+
def test_metadata_bridge_bool_array():
39+
array = vortex.array([True, False, True])
40+
41+
assert vortex_cuda._debug_array_metadata_dtype(array) == str(array.dtype) # pyright: ignore[reportPrivateUsage]
42+
assert vortex_cuda._debug_array_metadata_scalar_values(array) == [ # pyright: ignore[reportPrivateUsage]
43+
"true",
44+
"false",
45+
"true",
46+
]
47+
48+
49+
def test_metadata_bridge_struct_with_children():
50+
import pyarrow as pa
51+
52+
arrow_table = pa.table({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0]})
53+
struct_array = vortex.Array.from_arrow(
54+
pa.StructArray.from_arrays( # pyright: ignore[reportUnknownMemberType]
55+
[arrow_table.column("a").combine_chunks(), arrow_table.column("b").combine_chunks()],
56+
names=["a", "b"],
57+
)
58+
)
59+
60+
assert vortex_cuda._debug_array_metadata_dtype(struct_array) == str( # pyright: ignore[reportPrivateUsage]
61+
struct_array.dtype
62+
)
63+
assert vortex_cuda._debug_array_metadata_scalar_values(struct_array) == [ # pyright: ignore[reportPrivateUsage]
64+
"{a: 1i64, b: 4f64}",
65+
"{a: 2i64, b: 5f64}",
66+
"{a: 3i64, b: 6f64}",
67+
]
2168

2269

2370
def test_export_device_array_returns_capsules_or_clean_cuda_error():

vortex-python/python/vortex/_lib/arrays.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Array:
2626
@staticmethod
2727
def from_range(obj: range, *, dtype: DType | None = None) -> Array: ...
2828
def to_arrow_array(self) -> pa.Array[pa.Scalar[pa.DataType]]: ...
29-
def __vortex_array_metadata__(self) -> tuple[str, bytes, int, bytes, int, list[object]]: ...
29+
def __vortex_array_metadata__(self) -> tuple[str, bytes, int, bytes, list[object], list[object]]: ...
3030
@property
3131
def id(self) -> str: ...
3232
@property

0 commit comments

Comments
 (0)