@@ -21,6 +21,7 @@ use std::sync::Arc;
2121use abi_stable:: StableAbi ;
2222use abi_stable:: std_types:: { RResult , RVec } ;
2323use datafusion_catalog:: { TableFunctionArgs , TableFunctionImpl , TableProvider } ;
24+ use datafusion_common:: DataFusionError ;
2425use datafusion_common:: error:: Result ;
2526use datafusion_execution:: TaskContext ;
2627use datafusion_proto:: logical_plan:: from_proto:: parse_exprs;
@@ -29,11 +30,13 @@ use datafusion_proto::logical_plan::{
2930 DefaultLogicalExtensionCodec , LogicalExtensionCodec ,
3031} ;
3132use datafusion_proto:: protobuf:: LogicalExprList ;
33+ use datafusion_session:: Session ;
3234use prost:: Message ;
3335use tokio:: runtime:: Handle ;
3436
3537use crate :: execution:: FFI_TaskContextProvider ;
3638use crate :: proto:: logical_extension_codec:: FFI_LogicalExtensionCodec ;
39+ use crate :: session:: { FFI_SessionRef , ForeignSession } ;
3740use crate :: table_provider:: FFI_TableProvider ;
3841use crate :: util:: FFIResult ;
3942use crate :: { df_result, rresult_return} ;
@@ -42,11 +45,18 @@ use crate::{df_result, rresult_return};
4245#[ repr( C ) ]
4346#[ derive( Debug , StableAbi ) ]
4447pub struct FFI_TableFunction {
45- /// Equivalent to the ` call` function of the TableFunctionImpl .
48+ /// Equivalent to the [`TableFunctionImpl:: call`] .
4649 /// The arguments are Expr passed as protobuf encoded bytes.
4750 pub call :
4851 unsafe extern "C" fn ( udtf : & Self , args : RVec < u8 > ) -> FFIResult < FFI_TableProvider > ,
4952
53+ /// Equivalent to the [`TableFunctionImpl::call_with_args`].
54+ call_with_args : unsafe extern "C" fn (
55+ udtf : & Self ,
56+ args : RVec < u8 > ,
57+ session : FFI_SessionRef ,
58+ ) -> FFIResult < FFI_TableProvider > ,
59+
5060 pub logical_codec : FFI_LogicalExtensionCodec ,
5161
5262 /// Used to create a clone on the provider of the udtf. This should
@@ -115,6 +125,48 @@ unsafe extern "C" fn call_fn_wrapper(
115125 ) )
116126}
117127
128+ unsafe extern "C" fn call_with_args_wrapper (
129+ udtf : & FFI_TableFunction ,
130+ args : RVec < u8 > ,
131+ session : FFI_SessionRef ,
132+ ) -> FFIResult < FFI_TableProvider > {
133+ let runtime = udtf. runtime ( ) ;
134+ let udtf_inner = udtf. inner ( ) ;
135+
136+ let ctx: Arc < TaskContext > =
137+ rresult_return ! ( ( & udtf. logical_codec. task_ctx_provider) . try_into( ) ) ;
138+ let codec: Arc < dyn LogicalExtensionCodec > = ( & udtf. logical_codec ) . into ( ) ;
139+
140+ let proto_filters = rresult_return ! ( LogicalExprList :: decode( args. as_ref( ) ) ) ;
141+
142+ let args = rresult_return ! ( parse_exprs(
143+ proto_filters. expr. iter( ) ,
144+ ctx. as_ref( ) ,
145+ codec. as_ref( )
146+ ) ) ;
147+
148+ let mut foreign_session = None ;
149+ let session = rresult_return ! (
150+ session
151+ . as_local( )
152+ . map( Ok :: <& ( dyn Session + Send + Sync ) , DataFusionError >)
153+ . unwrap_or_else( || {
154+ foreign_session = Some ( ForeignSession :: try_from( & session) ?) ;
155+ Ok ( foreign_session. as_ref( ) . unwrap( ) )
156+ } )
157+ ) ;
158+ let table_provider = rresult_return ! ( udtf_inner. call_with_args( TableFunctionArgs {
159+ args: & args,
160+ session
161+ } ) ) ;
162+ RResult :: ROk ( FFI_TableProvider :: new_with_ffi_codec (
163+ table_provider,
164+ false ,
165+ runtime,
166+ udtf. logical_codec . clone ( ) ,
167+ ) )
168+ }
169+
118170unsafe extern "C" fn release_fn_wrapper ( udtf : & mut FFI_TableFunction ) {
119171 unsafe {
120172 debug_assert ! ( !udtf. private_data. is_null( ) ) ;
@@ -170,6 +222,7 @@ impl FFI_TableFunction {
170222
171223 Self {
172224 call : call_fn_wrapper,
225+ call_with_args : call_with_args_wrapper,
173226 logical_codec,
174227 clone : clone_fn_wrapper,
175228 release : release_fn_wrapper,
@@ -209,12 +262,30 @@ impl From<FFI_TableFunction> for Arc<dyn TableFunctionImpl> {
209262
210263impl TableFunctionImpl for ForeignTableFunction {
211264 fn call_with_args ( & self , args : TableFunctionArgs ) -> Result < Arc < dyn TableProvider > > {
265+ let session =
266+ FFI_SessionRef :: new ( args. session , None , self . 0 . logical_codec . clone ( ) ) ;
212267 let codec: Arc < dyn LogicalExtensionCodec > = ( & self . 0 . logical_codec ) . into ( ) ;
213268 let expr_list = LogicalExprList {
214269 expr : serialize_exprs ( args. args , codec. as_ref ( ) ) ?,
215270 } ;
216271 let filters_serialized = expr_list. encode_to_vec ( ) . into ( ) ;
217272
273+ let table_provider =
274+ unsafe { ( self . 0 . call_with_args ) ( & self . 0 , filters_serialized, session) } ;
275+
276+ let table_provider = df_result ! ( table_provider) ?;
277+ let table_provider: Arc < dyn TableProvider > = ( & table_provider) . into ( ) ;
278+
279+ Ok ( table_provider)
280+ }
281+
282+ fn call ( & self , args : & [ datafusion_expr:: Expr ] ) -> Result < Arc < dyn TableProvider > > {
283+ let codec: Arc < dyn LogicalExtensionCodec > = ( & self . 0 . logical_codec ) . into ( ) ;
284+ let expr_list = LogicalExprList {
285+ expr : serialize_exprs ( args, codec. as_ref ( ) ) ?,
286+ } ;
287+ let filters_serialized = expr_list. encode_to_vec ( ) . into ( ) ;
288+
218289 let table_provider = unsafe { ( self . 0 . call ) ( & self . 0 , filters_serialized) } ;
219290
220291 let table_provider = df_result ! ( table_provider) ?;
@@ -340,7 +411,10 @@ mod tests {
340411
341412 let foreign_udf: Arc < dyn TableFunctionImpl > = local_udtf. into ( ) ;
342413
343- let table = foreign_udf. call ( & [ lit ( 6_u64 ) , lit ( "one" ) , lit ( 2.0 ) , lit ( 3_u64 ) ] ) ?;
414+ let table = foreign_udf. call_with_args ( TableFunctionArgs {
415+ args : & [ lit ( 6_u64 ) , lit ( "one" ) , lit ( 2.0 ) , lit ( 3_u64 ) ] ,
416+ session : & ctx. state ( ) ,
417+ } ) ?;
344418
345419 let _ = ctx. register_table ( "test-table" , table) ?;
346420
0 commit comments