Skip to content

Commit 5310988

Browse files
committed
Refactor ArrowArrayStream handling to use PyCapsule with destructor for improved memory management
1 parent a146330 commit 5310988

2 files changed

Lines changed: 26 additions & 15 deletions

File tree

python/datafusion/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,7 @@ def __iter__(self) -> Iterator[pa.RecordBatch]:
11271127
"""
11281128
import pyarrow as pa
11291129

1130-
reader = pa.RecordBatchReader._import_from_c(self.__arrow_c_stream__())
1130+
reader = pa.RecordBatchReader._import_from_c_capsule(self.__arrow_c_stream__())
11311131
yield from reader
11321132

11331133
def transform(self, func: Callable[..., DataFrame], *args: Any) -> DataFrame:

src/dataframe.rs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use std::collections::HashMap;
1919
use std::ffi::{c_void, CString};
20-
use std::sync::Arc;
20+
use std::sync::{Arc, OnceLock};
2121

2222
use arrow::array::{new_null_array, RecordBatch, RecordBatchReader};
2323
use arrow::compute::can_cast_types;
@@ -39,6 +39,7 @@ use datafusion::prelude::*;
3939
use datafusion_ffi::table_provider::FFI_TableProvider;
4040
use futures::{StreamExt, TryStreamExt};
4141
use pyo3::exceptions::PyValueError;
42+
use pyo3::ffi;
4243
use pyo3::prelude::*;
4344
use pyo3::pybacked::PyBackedStr;
4445
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
@@ -58,6 +59,19 @@ use crate::{
5859
expr::{sort_expr::PySortExpr, PyExpr},
5960
};
6061

62+
static ARROW_STREAM_NAME: OnceLock<CString> = OnceLock::new();
63+
64+
unsafe extern "C" fn drop_stream(capsule: *mut ffi::PyObject) {
65+
if capsule.is_null() {
66+
return;
67+
}
68+
let name = ARROW_STREAM_NAME.get_or_init(|| CString::new("arrow_array_stream").unwrap());
69+
let stream_ptr = ffi::PyCapsule_GetPointer(capsule, name.as_ptr()) as *mut FFI_ArrowArrayStream;
70+
if !stream_ptr.is_null() {
71+
drop(Box::from_raw(stream_ptr));
72+
}
73+
}
74+
6175
// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
6276
// - we have not decided on the table_provider approach yet
6377
// this is an interim implementation
@@ -958,20 +972,17 @@ impl PyDataFrame {
958972
!stream_ptr.is_null(),
959973
"ArrowArrayStream pointer should never be null"
960974
);
961-
let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
962-
unsafe {
963-
PyCapsule::new_bound_with_destructor(
964-
py,
965-
stream_ptr,
966-
Some(stream_capsule_name),
967-
|ptr: *mut FFI_ArrowArrayStream, _| {
968-
if !ptr.is_null() {
969-
unsafe { Box::from_raw(ptr) };
970-
}
971-
},
972-
)
975+
let name = ARROW_STREAM_NAME.get_or_init(|| CString::new("arrow_array_stream").unwrap());
976+
let capsule = unsafe {
977+
ffi::PyCapsule_New(stream_ptr as *mut c_void, name.as_ptr(), Some(drop_stream))
978+
};
979+
if capsule.is_null() {
980+
unsafe { drop(Box::from_raw(stream_ptr)) };
981+
Err(PyErr::fetch(py).into())
982+
} else {
983+
let any = unsafe { Bound::from_owned_ptr(py, capsule) };
984+
Ok(any.downcast_into::<PyCapsule>().unwrap())
973985
}
974-
.map_err(PyDataFusionError::from)
975986
}
976987

977988
fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {

0 commit comments

Comments
 (0)