1717
1818use crate :: dataset:: Dataset ;
1919use crate :: errors:: { py_datafusion_err, to_datafusion_err, PyDataFusionError , PyDataFusionResult } ;
20- use crate :: table:: pyany_to_table_provider ;
20+ use crate :: table:: PyTableProvider ;
2121use crate :: utils:: { validate_pycapsule, wait_for_future} ;
2222use async_trait:: async_trait;
2323use datafusion:: catalog:: { MemoryCatalogProvider , MemorySchemaProvider } ;
@@ -28,6 +28,7 @@ use datafusion::{
2828 datasource:: { TableProvider , TableType } ,
2929} ;
3030use datafusion_ffi:: schema_provider:: { FFI_SchemaProvider , ForeignSchemaProvider } ;
31+ use datafusion_ffi:: table_provider:: { FFI_TableProvider , ForeignTableProvider } ;
3132use pyo3:: exceptions:: PyKeyError ;
3233use pyo3:: prelude:: * ;
3334use pyo3:: types:: PyCapsule ;
@@ -196,7 +197,29 @@ impl PySchema {
196197 }
197198
198199 fn register_table ( & self , name : & str , table_provider : Bound < ' _ , PyAny > ) -> PyResult < ( ) > {
199- let provider = pyany_to_table_provider ( & table_provider) ?;
200+ let provider = if table_provider. hasattr ( "__datafusion_table_provider__" ) ? {
201+ let capsule = table_provider
202+ . getattr ( "__datafusion_table_provider__" ) ?
203+ . call0 ( ) ?;
204+ let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
205+ validate_pycapsule ( capsule, "datafusion_table_provider" ) ?;
206+
207+ let provider = unsafe { capsule. reference :: < FFI_TableProvider > ( ) } ;
208+ let provider: ForeignTableProvider = provider. into ( ) ;
209+ Arc :: new ( provider) as Arc < dyn TableProvider + Send >
210+ } else {
211+ match table_provider. extract :: < PyTable > ( ) {
212+ Ok ( py_table) => py_table. table ,
213+ Err ( _) => match table_provider. extract :: < PyTableProvider > ( ) {
214+ Ok ( py_provider) => py_provider. into_inner ( ) ,
215+ Err ( _) => {
216+ let py = table_provider. py ( ) ;
217+ let provider = Dataset :: new ( & table_provider, py) ?;
218+ Arc :: new ( provider) as Arc < dyn TableProvider + Send >
219+ }
220+ } ,
221+ }
222+ } ;
200223
201224 let _ = self
202225 . schema
@@ -285,7 +308,34 @@ impl RustWrappedPySchemaProvider {
285308 return Ok ( None ) ;
286309 }
287310
288- pyany_to_table_provider ( & py_table) . map ( Some )
311+ if py_table. hasattr ( "__datafusion_table_provider__" ) ? {
312+ let capsule = py_table. getattr ( "__datafusion_table_provider__" ) ?. call0 ( ) ?;
313+ let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
314+ validate_pycapsule ( capsule, "datafusion_table_provider" ) ?;
315+
316+ let provider = unsafe { capsule. reference :: < FFI_TableProvider > ( ) } ;
317+ let provider: ForeignTableProvider = provider. into ( ) ;
318+
319+ Ok ( Some ( Arc :: new ( provider) as Arc < dyn TableProvider + Send > ) )
320+ } else {
321+ if let Ok ( inner_table) = py_table. getattr ( "table" ) {
322+ if let Ok ( inner_table) = inner_table. extract :: < PyTable > ( ) {
323+ return Ok ( Some ( inner_table. table ) ) ;
324+ }
325+ }
326+
327+ if let Ok ( py_provider) = py_table. extract :: < PyTableProvider > ( ) {
328+ return Ok ( Some ( py_provider. into_inner ( ) ) ) ;
329+ }
330+
331+ match py_table. extract :: < PyTable > ( ) {
332+ Ok ( py_table) => Ok ( Some ( py_table. table ) ) ,
333+ Err ( _) => {
334+ let ds = Dataset :: new ( & py_table, py) . map_err ( py_datafusion_err) ?;
335+ Ok ( Some ( Arc :: new ( ds) as Arc < dyn TableProvider + Send > ) )
336+ }
337+ }
338+ }
289339 } )
290340 }
291341}
0 commit comments