Skip to content

Commit b76cde5

Browse files
committed
test: add garbage collection test for arrow C stream capsule
1 parent 61d0739 commit b76cde5

2 files changed

Lines changed: 43 additions & 4 deletions

File tree

python/tests/test_dataframe.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,6 +1833,20 @@ def test_arrow_c_stream_capsule_manual_destructor_noop(ctx):
18331833
gc.collect()
18341834

18351835

1836+
def test_arrow_c_stream_capsule_gc_without_consuming(ctx):
1837+
df = ctx.from_pydict({"a": [1]})
1838+
1839+
capsule = df.__arrow_c_stream__()
1840+
reader = pa.RecordBatchReader._import_from_c_capsule(capsule)
1841+
1842+
del capsule
1843+
gc.collect()
1844+
1845+
table = reader.read_all()
1846+
expected = pa.table({"a": [1]})
1847+
assert table.equals(expected)
1848+
1849+
18361850
def test_arrow_c_stream_context_drop_no_segfault():
18371851
"""Repeatedly create/drop SessionContext after __arrow_c_stream__."""
18381852
for _ in range(5):

src/dataframe.rs

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ use datafusion::prelude::*;
4040
use datafusion_ffi::table_provider::FFI_TableProvider;
4141
use futures::{StreamExt, TryStreamExt};
4242
use pyo3::exceptions::PyValueError;
43+
use pyo3::ffi;
4344
use pyo3::prelude::*;
4445
use pyo3::pybacked::PyBackedStr;
4546
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
47+
use std::os::raw::c_void;
4648

4749
use crate::catalog::PyTable;
4850
use crate::errors::{py_datafusion_err, PyDataFusionError};
@@ -376,7 +378,7 @@ struct PartitionedDataFrameStreamReader {
376378
/// interface.
377379
struct StreamWithContext {
378380
reader: PartitionedDataFrameStreamReader,
379-
ctx: Arc<SessionContext>,
381+
_ctx: Arc<SessionContext>,
380382
}
381383

382384
impl Iterator for PartitionedDataFrameStreamReader {
@@ -452,6 +454,18 @@ impl Drop for StreamWithContext {
452454
}
453455
}
454456

457+
unsafe extern "C" fn stream_capsule_destructor(capsule: *mut ffi::PyObject) {
458+
let name = pyo3::ffi::c_str!("arrow_array_stream");
459+
unsafe {
460+
let ptr = ffi::PyCapsule_GetPointer(capsule, name.as_ptr());
461+
if ptr.is_null() {
462+
ffi::PyCapsule_SetDestructor(capsule, None);
463+
return;
464+
}
465+
drop(Box::from_raw(ptr.cast::<FFI_ArrowArrayStream>()));
466+
}
467+
}
468+
455469
#[pymethods]
456470
impl PyDataFrame {
457471
/// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]`
@@ -1022,14 +1036,25 @@ impl PyDataFrame {
10221036
projection,
10231037
current: 0,
10241038
};
1025-
let reader = StreamWithContext { reader, ctx };
1039+
let reader = StreamWithContext { reader, _ctx: ctx };
10261040
let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
10271041

10281042
let stream = FFI_ArrowArrayStream::new(reader);
10291043
let name = pyo3::ffi::c_str!("arrow_array_stream");
10301044

1031-
let capsule =
1032-
PyCapsule::new(py, stream, Some(name.to_owned())).map_err(py_datafusion_err)?;
1045+
let capsule_ptr = unsafe {
1046+
ffi::PyCapsule_New(
1047+
Box::into_raw(Box::new(stream)) as *mut c_void,
1048+
name.as_ptr(),
1049+
Some(stream_capsule_destructor),
1050+
)
1051+
};
1052+
if capsule_ptr.is_null() {
1053+
return Err(PyErr::fetch(py).into());
1054+
}
1055+
let capsule = unsafe {
1056+
Bound::from_owned_ptr(py, capsule_ptr).downcast_into_unchecked::<PyCapsule>()
1057+
};
10331058
Ok(capsule)
10341059
}
10351060

0 commit comments

Comments
 (0)