Skip to content

Commit 80bc96a

Browse files
committed
refactor: enhance stream capsule management in PyDataFrame
1 parent 07641a7 commit 80bc96a

1 file changed

Lines changed: 48 additions & 4 deletions

File tree

src/dataframe.rs

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use std::collections::HashMap;
19-
use std::ffi::CString;
19+
use std::ffi::{c_void, CString};
2020
use std::sync::Arc;
2121

2222
use arrow::array::{new_null_array, RecordBatch, RecordBatchReader};
@@ -42,6 +42,7 @@ use pyo3::exceptions::PyValueError;
4242
use pyo3::prelude::*;
4343
use pyo3::pybacked::PyBackedStr;
4444
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
45+
use pyo3::PyErr;
4546

4647
use crate::catalog::PyTable;
4748
use crate::errors::{py_datafusion_err, PyDataFusionError};
@@ -962,10 +963,53 @@ impl PyDataFrame {
962963
};
963964
let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
964965

965-
let stream = Box::new(FFI_ArrowArrayStream::new(reader));
966+
// Create a stream and transfer it to a raw pointer. The capsule takes
967+
// ownership and is responsible for freeing the stream unless PyArrow
968+
// steals it. PyArrow will set the capsule's pointer to NULL when it
969+
// imports the stream, signaling that it now owns the resources.
970+
let raw_stream = Box::into_raw(Box::new(FFI_ArrowArrayStream::new(reader)));
971+
972+
// Name used both for capsule creation and lookup in the destructor.
973+
const STREAM_NAME: &[u8] = b"arrow_array_stream\0";
974+
975+
unsafe extern "C" fn drop_stream_capsule(capsule: *mut pyo3::ffi::PyObject) {
976+
// Attempt to recover the raw stream pointer. If PyArrow imported the
977+
// stream it will have set the capsule pointer to NULL, in which case
978+
// `PyCapsule_GetPointer` returns NULL and we simply clear the error.
979+
let ptr = pyo3::ffi::PyCapsule_GetPointer(capsule, STREAM_NAME.as_ptr() as *const _)
980+
as *mut FFI_ArrowArrayStream;
981+
982+
if ptr.is_null() {
983+
// Ignore any exception raised by `PyCapsule_GetPointer` when the
984+
// pointer is already NULL.
985+
pyo3::ffi::PyErr_Clear();
986+
} else {
987+
// Reconstruct the Box and drop it so resources are released.
988+
drop(Box::from_raw(ptr));
989+
}
990+
}
991+
992+
let capsule_ptr = unsafe {
993+
pyo3::ffi::PyCapsule_New(
994+
raw_stream as *mut c_void,
995+
STREAM_NAME.as_ptr() as *const std::ffi::c_char,
996+
Some(drop_stream_capsule),
997+
)
998+
};
966999

967-
let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
968-
Ok(PyCapsule::new(py, stream, Some(stream_capsule_name))?)
1000+
if capsule_ptr.is_null() {
1001+
// Reclaim ownership to avoid leaking on failure
1002+
unsafe {
1003+
drop(Box::from_raw(raw_stream));
1004+
}
1005+
return Err(PyErr::fetch(py).into());
1006+
}
1007+
1008+
// Safety: `capsule_ptr` is a new reference from `PyCapsule_New`
1009+
let capsule = unsafe {
1010+
Bound::from_owned_ptr(py, capsule_ptr).downcast_into_unchecked::<PyCapsule>()
1011+
};
1012+
Ok(capsule)
9691013
}
9701014

9711015
fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {

0 commit comments

Comments
 (0)