1616// under the License.
1717
1818use std:: sync:: Arc ;
19+ use std:: { ffi:: c_void, ffi:: CStr } ;
1920
2021use crate :: errors:: PyDataFusionError ;
2122use crate :: utils:: wait_for_future;
23+ use arrow:: ffi:: { FFI_ArrowArray , FFI_ArrowSchema } ;
24+ use datafusion:: arrow:: array:: { Array , StructArray } ;
2225use datafusion:: arrow:: pyarrow:: ToPyArrow ;
2326use datafusion:: arrow:: record_batch:: RecordBatch ;
2427use datafusion:: physical_plan:: SendableRecordBatchStream ;
2528use futures:: StreamExt ;
2629use pyo3:: exceptions:: { PyStopAsyncIteration , PyStopIteration } ;
30+ use pyo3:: ffi;
2731use pyo3:: prelude:: * ;
32+ use pyo3:: types:: PyCapsule ;
2833use pyo3:: { pyclass, pymethods, PyObject , PyResult , Python } ;
2934use tokio:: sync:: Mutex ;
3035
36+ #[ allow( clippy:: manual_c_str_literals) ]
37+ static ARROW_ARRAY_NAME : & CStr = unsafe { CStr :: from_bytes_with_nul_unchecked ( b"arrow_array\0 " ) } ;
38+ #[ allow( clippy:: manual_c_str_literals) ]
39+ static ARROW_SCHEMA_NAME : & CStr = unsafe { CStr :: from_bytes_with_nul_unchecked ( b"arrow_schema\0 " ) } ;
40+
41+ unsafe extern "C" fn drop_array ( capsule : * mut ffi:: PyObject ) {
42+ if capsule. is_null ( ) {
43+ return ;
44+ }
45+
46+ if ffi:: PyCapsule_IsValid ( capsule, ARROW_ARRAY_NAME . as_ptr ( ) ) == 1 {
47+ let array_ptr =
48+ ffi:: PyCapsule_GetPointer ( capsule, ARROW_ARRAY_NAME . as_ptr ( ) ) as * mut FFI_ArrowArray ;
49+ if !array_ptr. is_null ( ) {
50+ drop ( Box :: from_raw ( array_ptr) ) ;
51+ }
52+ }
53+ ffi:: PyErr_Clear ( ) ;
54+ }
55+
56+ unsafe extern "C" fn drop_schema ( capsule : * mut ffi:: PyObject ) {
57+ if capsule. is_null ( ) {
58+ return ;
59+ }
60+
61+ if ffi:: PyCapsule_IsValid ( capsule, ARROW_SCHEMA_NAME . as_ptr ( ) ) == 1 {
62+ let schema_ptr =
63+ ffi:: PyCapsule_GetPointer ( capsule, ARROW_SCHEMA_NAME . as_ptr ( ) ) as * mut FFI_ArrowSchema ;
64+ if !schema_ptr. is_null ( ) {
65+ drop ( Box :: from_raw ( schema_ptr) ) ;
66+ }
67+ }
68+ ffi:: PyErr_Clear ( ) ;
69+ }
70+
3171#[ pyclass( name = "RecordBatch" , module = "datafusion" , subclass) ]
3272pub struct PyRecordBatch {
3373 batch : RecordBatch ,
@@ -38,6 +78,69 @@ impl PyRecordBatch {
3878 fn to_pyarrow ( & self , py : Python ) -> PyResult < PyObject > {
3979 self . batch . to_pyarrow ( py)
4080 }
81+
82+ #[ pyo3( signature = ( requested_schema=None ) ) ]
83+ fn __arrow_c_array__ < ' py > (
84+ & self ,
85+ py : Python < ' py > ,
86+ requested_schema : Option < Bound < ' py , PyCapsule > > ,
87+ ) -> PyResult < ( Bound < ' py , PyCapsule > , Bound < ' py , PyCapsule > ) > {
88+ // For now ignore requested_schema; future work could apply projection
89+ if let Some ( schema_capsule) = requested_schema {
90+ crate :: utils:: validate_pycapsule ( & schema_capsule, "arrow_schema" ) ?;
91+ }
92+
93+ let struct_array = StructArray :: from ( self . batch . clone ( ) ) ;
94+ let data = struct_array. to_data ( ) ;
95+ let array = FFI_ArrowArray :: new ( & data) ;
96+ let schema =
97+ FFI_ArrowSchema :: try_from ( data. data_type ( ) ) . map_err ( PyDataFusionError :: from) ?;
98+
99+ let array_ptr = Box :: into_raw ( Box :: new ( array) ) ;
100+ let schema_ptr = Box :: into_raw ( Box :: new ( schema) ) ;
101+
102+ unsafe {
103+ let schema_capsule = ffi:: PyCapsule_New (
104+ schema_ptr as * mut c_void ,
105+ ARROW_SCHEMA_NAME . as_ptr ( ) ,
106+ Some ( drop_schema) ,
107+ ) ;
108+ if schema_capsule. is_null ( ) {
109+ drop ( Box :: from_raw ( schema_ptr) ) ;
110+ drop ( Box :: from_raw ( array_ptr) ) ;
111+ return Err ( PyErr :: fetch ( py) ) ;
112+ }
113+
114+ let array_capsule = ffi:: PyCapsule_New (
115+ array_ptr as * mut c_void ,
116+ ARROW_ARRAY_NAME . as_ptr ( ) ,
117+ Some ( drop_array) ,
118+ ) ;
119+ if array_capsule. is_null ( ) {
120+ drop ( Box :: from_raw ( array_ptr) ) ;
121+ if ffi:: PyCapsule_IsValid ( schema_capsule, ARROW_SCHEMA_NAME . as_ptr ( ) ) == 1 {
122+ let schema_ptr =
123+ ffi:: PyCapsule_GetPointer ( schema_capsule, ARROW_SCHEMA_NAME . as_ptr ( ) )
124+ as * mut FFI_ArrowSchema ;
125+ if !schema_ptr. is_null ( ) {
126+ drop ( Box :: from_raw ( schema_ptr) ) ;
127+ }
128+ }
129+ ffi:: PyErr_Clear ( ) ;
130+ ffi:: Py_DECREF ( schema_capsule) ;
131+ return Err ( PyErr :: fetch ( py) ) ;
132+ }
133+
134+ let schema_capsule = Bound :: from_owned_ptr ( py, schema_capsule)
135+ . downcast_into :: < PyCapsule > ( )
136+ . unwrap ( ) ;
137+ let array_capsule = Bound :: from_owned_ptr ( py, array_capsule)
138+ . downcast_into :: < PyCapsule > ( )
139+ . unwrap ( ) ;
140+
141+ Ok ( ( schema_capsule, array_capsule) )
142+ }
143+ }
41144}
42145
43146impl From < RecordBatch > for PyRecordBatch {
0 commit comments