@@ -41,13 +41,14 @@ use crate::record_batch::PyRecordBatchStream;
4141use crate :: sql:: exceptions:: py_value_err;
4242use crate :: sql:: logical:: PyLogicalPlan ;
4343use crate :: store:: StorageContexts ;
44+ use crate :: table:: PyTableProvider ;
4445use crate :: udaf:: PyAggregateUDF ;
4546use crate :: udf:: PyScalarUDF ;
4647use crate :: udtf:: PyTableFunction ;
4748use crate :: udwf:: PyWindowUDF ;
4849use crate :: utils:: {
49- extract_table_provider , get_global_ctx, get_tokio_runtime, validate_pycapsule , wait_for_future ,
50- TABLE_OR_PROVIDER_EXPECTED_MESSAGE ,
50+ get_global_ctx, get_tokio_runtime, table_provider_from_pycapsule , validate_pycapsule ,
51+ wait_for_future , EXPECTED_PROVIDER_MSG ,
5152} ;
5253use datafusion:: arrow:: datatypes:: { DataType , Schema , SchemaRef } ;
5354use datafusion:: arrow:: pyarrow:: PyArrowType ;
@@ -609,13 +610,17 @@ impl PySessionContext {
609610 name : & str ,
610611 table_provider : Bound < ' _ , PyAny > ,
611612 ) -> PyDataFusionResult < ( ) > {
612- let provider = extract_table_provider ( & table_provider)
613- . map_err ( crate :: errors:: PyDataFusionError :: from) ?
614- . ok_or_else ( || {
615- crate :: errors:: PyDataFusionError :: Common (
616- TABLE_OR_PROVIDER_EXPECTED_MESSAGE . to_string ( ) ,
617- )
618- } ) ?;
613+ let provider = if let Ok ( py_table) = table_provider. extract :: < PyTable > ( ) {
614+ py_table. table ( )
615+ } else if let Ok ( py_provider) = table_provider. extract :: < PyTableProvider > ( ) {
616+ py_provider. into_inner ( )
617+ } else if let Some ( provider) = table_provider_from_pycapsule ( & table_provider) ? {
618+ provider
619+ } else {
620+ return Err ( crate :: errors:: PyDataFusionError :: Common (
621+ EXPECTED_PROVIDER_MSG . to_string ( ) ,
622+ ) ) ;
623+ } ;
619624
620625 self . ctx . register_table ( name, provider) ?;
621626 Ok ( ( ) )
0 commit comments