Skip to content

Commit ed923e5

Browse files
committed
Provide session to the udtf call (review fixes)
This patch adds the passing of the current session to the UDTF call. This helps implement session-dependent functions, for example, a function that returns the list of registered tables.
1 parent 3c2d718 commit ed923e5

5 files changed

Lines changed: 101 additions & 19 deletions

File tree

datafusion-examples/README.md

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -207,13 +207,14 @@ cargo run --example dataframe -- dataframe
207207

208208
#### Category: Single Process
209209

210-
| Subcommand | File Path | Description |
211-
| ---------- | ------------------------------------------------------- | ----------------------------------------------- |
212-
| adv_udaf | [`udf/advanced_udaf.rs`](examples/udf/advanced_udaf.rs) | Advanced User Defined Aggregate Function (UDAF) |
213-
| adv_udf | [`udf/advanced_udf.rs`](examples/udf/advanced_udf.rs) | Advanced User Defined Scalar Function (UDF) |
214-
| adv_udwf | [`udf/advanced_udwf.rs`](examples/udf/advanced_udwf.rs) | Advanced User Defined Window Function (UDWF) |
215-
| async_udf | [`udf/async_udf.rs`](examples/udf/async_udf.rs) | Asynchronous User Defined Scalar Function |
216-
| udaf | [`udf/simple_udaf.rs`](examples/udf/simple_udaf.rs) | Simple UDAF example |
217-
| udf | [`udf/simple_udf.rs`](examples/udf/simple_udf.rs) | Simple UDF example |
218-
| udtf | [`udf/simple_udtf.rs`](examples/udf/simple_udtf.rs) | Simple UDTF example |
219-
| udwf | [`udf/simple_udwf.rs`](examples/udf/simple_udwf.rs) | Simple UDWF example |
210+
| Subcommand | File Path | Description |
211+
| --------------- | ----------------------------------------------------------- | ----------------------------------------------- |
212+
| adv_udaf | [`udf/advanced_udaf.rs`](examples/udf/advanced_udaf.rs) | Advanced User Defined Aggregate Function (UDAF) |
213+
| adv_udf | [`udf/advanced_udf.rs`](examples/udf/advanced_udf.rs) | Advanced User Defined Scalar Function (UDF) |
214+
| adv_udwf | [`udf/advanced_udwf.rs`](examples/udf/advanced_udwf.rs) | Advanced User Defined Window Function (UDWF) |
215+
| async_udf | [`udf/async_udf.rs`](examples/udf/async_udf.rs) | Asynchronous User Defined Scalar Function |
216+
| udaf | [`udf/simple_udaf.rs`](examples/udf/simple_udaf.rs) | Simple UDAF example |
217+
| udf | [`udf/simple_udf.rs`](examples/udf/simple_udf.rs) | Simple UDF example |
218+
| udtf | [`udf/simple_udtf.rs`](examples/udf/simple_udtf.rs) | Simple UDTF example |
219+
| udwf | [`udf/simple_udwf.rs`](examples/udf/simple_udwf.rs) | Simple UDWF example |
220+
| table_list_udtf | [`udf/table_list_udtf.rs`](examples/udf/table_list_udtf.rs) | Session-aware UDTF table list example |

datafusion-examples/examples/udf/main.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@
5050
//!
5151
//! - `udwf`
5252
//! (file: simple_udwf.rs, desc: Simple UDWF example)
53+
//!
54+
//! - `table_list_udtf`
55+
//! (file: table_list_udtf.rs, desc: Session-aware UDTF table list example)
5356
5457
mod advanced_udaf;
5558
mod advanced_udf;

datafusion-examples/examples/udf/table_list_udtf.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ use datafusion::{
2828
prelude::SessionContext,
2929
};
3030
use datafusion_common::{DataFusionError, plan_err};
31-
use futures::executor::block_on;
31+
use tokio::{runtime::Handle, task::block_in_place};
3232

3333
const FUNCTION_NAME: &str = "table_list";
3434

3535
// The example shows, how to create UDTF that depends on the session state.
36-
// There is `table_list` UDTF is defined which returns list of tables within session.
36+
// Defines a `table_list` UDTF that returns a list of tables within the provided session.
3737

3838
pub async fn table_list_udtf() -> Result<()> {
3939
let ctx = SessionContext::new();
@@ -96,7 +96,10 @@ impl TableFunctionImpl for TableListUdtf {
9696
continue;
9797
};
9898
for table_name in schema.table_names() {
99-
let Some(provider) = block_on(schema.table(&table_name))? else {
99+
let Some(provider) = block_in_place(|| {
100+
Handle::current().block_on(schema.table(&table_name))
101+
})?
102+
else {
100103
continue;
101104
};
102105
catalogs.push(catalog_name.clone());

datafusion/ffi/src/udtf.rs

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use std::sync::Arc;
2121
use abi_stable::StableAbi;
2222
use abi_stable::std_types::{RResult, RVec};
2323
use datafusion_catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider};
24+
use datafusion_common::DataFusionError;
2425
use datafusion_common::error::Result;
2526
use datafusion_execution::TaskContext;
2627
use datafusion_proto::logical_plan::from_proto::parse_exprs;
@@ -29,11 +30,13 @@ use datafusion_proto::logical_plan::{
2930
DefaultLogicalExtensionCodec, LogicalExtensionCodec,
3031
};
3132
use datafusion_proto::protobuf::LogicalExprList;
33+
use datafusion_session::Session;
3234
use prost::Message;
3335
use tokio::runtime::Handle;
3436

3537
use crate::execution::FFI_TaskContextProvider;
3638
use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
39+
use crate::session::{FFI_SessionRef, ForeignSession};
3740
use crate::table_provider::FFI_TableProvider;
3841
use crate::util::FFIResult;
3942
use crate::{df_result, rresult_return};
@@ -42,11 +45,18 @@ use crate::{df_result, rresult_return};
4245
#[repr(C)]
4346
#[derive(Debug, StableAbi)]
4447
pub struct FFI_TableFunction {
45-
/// Equivalent to the `call` function of the TableFunctionImpl.
48+
/// Equivalent to the [`TableFunctionImpl::call`].
4649
/// The arguments are Expr passed as protobuf encoded bytes.
4750
pub call:
4851
unsafe extern "C" fn(udtf: &Self, args: RVec<u8>) -> FFIResult<FFI_TableProvider>,
4952

53+
/// Equivalent to the [`TableFunctionImpl::call_with_args`].
54+
call_with_args: unsafe extern "C" fn(
55+
udtf: &Self,
56+
args: RVec<u8>,
57+
session: FFI_SessionRef,
58+
) -> FFIResult<FFI_TableProvider>,
59+
5060
pub logical_codec: FFI_LogicalExtensionCodec,
5161

5262
/// Used to create a clone on the provider of the udtf. This should
@@ -115,6 +125,48 @@ unsafe extern "C" fn call_fn_wrapper(
115125
))
116126
}
117127

128+
unsafe extern "C" fn call_with_args_wrapper(
129+
udtf: &FFI_TableFunction,
130+
args: RVec<u8>,
131+
session: FFI_SessionRef,
132+
) -> FFIResult<FFI_TableProvider> {
133+
let runtime = udtf.runtime();
134+
let udtf_inner = udtf.inner();
135+
136+
let ctx: Arc<TaskContext> =
137+
rresult_return!((&udtf.logical_codec.task_ctx_provider).try_into());
138+
let codec: Arc<dyn LogicalExtensionCodec> = (&udtf.logical_codec).into();
139+
140+
let proto_filters = rresult_return!(LogicalExprList::decode(args.as_ref()));
141+
142+
let args = rresult_return!(parse_exprs(
143+
proto_filters.expr.iter(),
144+
ctx.as_ref(),
145+
codec.as_ref()
146+
));
147+
148+
let mut foreign_session = None;
149+
let session = rresult_return!(
150+
session
151+
.as_local()
152+
.map(Ok::<&(dyn Session + Send + Sync), DataFusionError>)
153+
.unwrap_or_else(|| {
154+
foreign_session = Some(ForeignSession::try_from(&session)?);
155+
Ok(foreign_session.as_ref().unwrap())
156+
})
157+
);
158+
let table_provider = rresult_return!(udtf_inner.call_with_args(TableFunctionArgs {
159+
args: &args,
160+
session
161+
}));
162+
RResult::ROk(FFI_TableProvider::new_with_ffi_codec(
163+
table_provider,
164+
false,
165+
runtime,
166+
udtf.logical_codec.clone(),
167+
))
168+
}
169+
118170
unsafe extern "C" fn release_fn_wrapper(udtf: &mut FFI_TableFunction) {
119171
unsafe {
120172
debug_assert!(!udtf.private_data.is_null());
@@ -170,6 +222,7 @@ impl FFI_TableFunction {
170222

171223
Self {
172224
call: call_fn_wrapper,
225+
call_with_args: call_with_args_wrapper,
173226
logical_codec,
174227
clone: clone_fn_wrapper,
175228
release: release_fn_wrapper,
@@ -209,12 +262,30 @@ impl From<FFI_TableFunction> for Arc<dyn TableFunctionImpl> {
209262

210263
impl TableFunctionImpl for ForeignTableFunction {
211264
fn call_with_args(&self, args: TableFunctionArgs) -> Result<Arc<dyn TableProvider>> {
265+
let session =
266+
FFI_SessionRef::new(args.session, None, self.0.logical_codec.clone());
212267
let codec: Arc<dyn LogicalExtensionCodec> = (&self.0.logical_codec).into();
213268
let expr_list = LogicalExprList {
214269
expr: serialize_exprs(args.args, codec.as_ref())?,
215270
};
216271
let filters_serialized = expr_list.encode_to_vec().into();
217272

273+
let table_provider =
274+
unsafe { (self.0.call_with_args)(&self.0, filters_serialized, session) };
275+
276+
let table_provider = df_result!(table_provider)?;
277+
let table_provider: Arc<dyn TableProvider> = (&table_provider).into();
278+
279+
Ok(table_provider)
280+
}
281+
282+
fn call(&self, args: &[datafusion_expr::Expr]) -> Result<Arc<dyn TableProvider>> {
283+
let codec: Arc<dyn LogicalExtensionCodec> = (&self.0.logical_codec).into();
284+
let expr_list = LogicalExprList {
285+
expr: serialize_exprs(args, codec.as_ref())?,
286+
};
287+
let filters_serialized = expr_list.encode_to_vec().into();
288+
218289
let table_provider = unsafe { (self.0.call)(&self.0, filters_serialized) };
219290

220291
let table_provider = df_result!(table_provider)?;
@@ -340,7 +411,10 @@ mod tests {
340411

341412
let foreign_udf: Arc<dyn TableFunctionImpl> = local_udtf.into();
342413

343-
let table = foreign_udf.call(&[lit(6_u64), lit("one"), lit(2.0), lit(3_u64)])?;
414+
let table = foreign_udf.call_with_args(TableFunctionArgs {
415+
args: &[lit(6_u64), lit("one"), lit(2.0), lit(3_u64)],
416+
session: &ctx.state(),
417+
})?;
344418

345419
let _ = ctx.register_table("test-table", table)?;
346420

docs/source/library-user-guide/functions/adding-udfs.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,15 +1388,16 @@ in the CLI to read the metadata from a Parquet file.
13881388

13891389
The simple UDTF used here takes a single `Int64` argument and returns a table with a single column with the value of the
13901390
argument. To create a function in DataFusion, you need to implement the `TableFunctionImpl` trait. This trait has a
1391-
single method, `call`, that takes a slice of `Expr`s and returns a `Result<Arc<dyn TableProvider>>`.
1391+
single method, `call_with_args`, that takes a `TableFunctionArgs` struct and returns a `Result<Arc<dyn TableProvider>>`.
1392+
Passed struct includes function arguments as a slice of `Expr`s.
13921393

1393-
In the `call` method, you parse the input `Expr`s and return a `TableProvider`. You might also want to do some
1394+
In the `call_with_args` method, you parse the input `Expr`s and return a `TableProvider`. You might also want to do some
13941395
validation of the input `Expr`s, e.g. checking that the number of arguments is correct.
13951396

13961397
```rust
13971398
use std::sync::Arc;
13981399
use datafusion::common::{plan_err, ScalarValue, Result};
1399-
use datafusion::catalog::{TableFunctionImpl, TableProvider};
1400+
use datafusion::catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider};
14001401
use datafusion::arrow::array::{ArrayRef, Int64Array};
14011402
use datafusion::datasource::memory::MemTable;
14021403
use arrow::record_batch::RecordBatch;
@@ -1438,7 +1439,7 @@ With the UDTF implemented, you can register it with the `SessionContext`:
14381439
```rust
14391440
# use std::sync::Arc;
14401441
# use datafusion::common::{plan_err, ScalarValue, Result};
1441-
# use datafusion::catalog::{TableFunctionImpl, TableProvider};
1442+
# use datafusion::catalog::{TableFunctionArgs, TableFunctionImpl, TableProvider};
14421443
# use datafusion::arrow::array::{ArrayRef, Int64Array};
14431444
# use datafusion::datasource::memory::MemTable;
14441445
# use arrow::record_batch::RecordBatch;

0 commit comments

Comments
 (0)