Skip to content

Commit 5ff80e4

Browse files
authored
Provide session to the udtf call (#20222)
## Rationale for this change In our project, we have several UDTFs that depend on the session from which they are called -- for example, functions to list views, tables and functions that wrap scan of another table (like composition). To implement them, we need a way to provide the current session state to the UDTF. It would be nice to add the session as an argument to the UDTF call. ## What changes are included in this PR? 1. Introduce `TableFunctionImpl::call_with_args`: it takes struct `TableFunctionArgs`, doing arguments extendable without breaking backward compatibility. 2. Deprecate `TableFunctionImpl::call`. 3. Switch to `call_with_args` for each UDTF in DF. 4. Add an example of the UDTF `table_list()` that depends on the session state.
1 parent f830ee3 commit 5ff80e4

File tree

11 files changed

+330
-42
lines changed

11 files changed

+330
-42
lines changed

datafusion-cli/src/functions.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer};
3131
use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef, TimeUnit};
3232
use arrow::record_batch::RecordBatch;
3333
use arrow::util::pretty::pretty_format_batches;
34-
use datafusion::catalog::{Session, TableFunctionImpl};
34+
use datafusion::catalog::{Session, TableFunctionArgs, TableFunctionImpl};
3535
use datafusion::common::{Column, plan_err};
3636
use datafusion::datasource::TableProvider;
3737
use datafusion::datasource::memory::MemorySourceConfig;
@@ -326,7 +326,8 @@ fn fixed_len_byte_array_to_string(val: Option<&FixedLenByteArray>) -> Option<Str
326326
pub struct ParquetMetadataFunc {}
327327

328328
impl TableFunctionImpl for ParquetMetadataFunc {
329-
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
329+
fn call_with_args(&self, args: TableFunctionArgs) -> Result<Arc<dyn TableProvider>> {
330+
let exprs = args.exprs();
330331
let filename = match exprs.first() {
331332
Some(Expr::Literal(ScalarValue::Utf8(Some(s)), _)) => s, // single quote: parquet_metadata('x.parquet')
332333
Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet")
@@ -517,7 +518,8 @@ impl MetadataCacheFunc {
517518
}
518519

519520
impl TableFunctionImpl for MetadataCacheFunc {
520-
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
521+
fn call_with_args(&self, args: TableFunctionArgs) -> Result<Arc<dyn TableProvider>> {
522+
let exprs = args.exprs();
521523
if !exprs.is_empty() {
522524
return plan_err!("metadata_cache should have no arguments");
523525
}
@@ -635,7 +637,8 @@ impl StatisticsCacheFunc {
635637
}
636638

637639
impl TableFunctionImpl for StatisticsCacheFunc {
638-
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
640+
fn call_with_args(&self, args: TableFunctionArgs) -> Result<Arc<dyn TableProvider>> {
641+
let exprs = args.exprs();
639642
if !exprs.is_empty() {
640643
return plan_err!("statistics_cache should have no arguments");
641644
}
@@ -770,7 +773,8 @@ impl ListFilesCacheFunc {
770773
}
771774

772775
impl TableFunctionImpl for ListFilesCacheFunc {
773-
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
776+
fn call_with_args(&self, args: TableFunctionArgs) -> Result<Arc<dyn TableProvider>> {
777+
let exprs = args.exprs();
774778
if !exprs.is_empty() {
775779
return plan_err!("list_files_cache should have no arguments");
776780
}

datafusion-examples/README.md

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,14 @@ cargo run --example dataframe -- dataframe
218218

219219
#### Category: Single Process
220220

221-
| Subcommand | File Path | Description |
222-
| ---------- | ------------------------------------------------------- | ----------------------------------------------- |
223-
| adv_udaf | [`udf/advanced_udaf.rs`](examples/udf/advanced_udaf.rs) | Advanced User Defined Aggregate Function (UDAF) |
224-
| adv_udf | [`udf/advanced_udf.rs`](examples/udf/advanced_udf.rs) | Advanced User Defined Scalar Function (UDF) |
225-
| adv_udwf | [`udf/advanced_udwf.rs`](examples/udf/advanced_udwf.rs) | Advanced User Defined Window Function (UDWF) |
226-
| async_udf | [`udf/async_udf.rs`](examples/udf/async_udf.rs) | Asynchronous User Defined Scalar Function |
227-
| udaf | [`udf/simple_udaf.rs`](examples/udf/simple_udaf.rs) | Simple UDAF example |
228-
| udf | [`udf/simple_udf.rs`](examples/udf/simple_udf.rs) | Simple UDF example |
229-
| udtf | [`udf/simple_udtf.rs`](examples/udf/simple_udtf.rs) | Simple UDTF example |
230-
| udwf | [`udf/simple_udwf.rs`](examples/udf/simple_udwf.rs) | Simple UDWF example |
221+
| Subcommand | File Path | Description |
222+
| --------------- | ----------------------------------------------------------- | ----------------------------------------------- |
223+
| adv_udaf | [`udf/advanced_udaf.rs`](examples/udf/advanced_udaf.rs) | Advanced User Defined Aggregate Function (UDAF) |
224+
| adv_udf | [`udf/advanced_udf.rs`](examples/udf/advanced_udf.rs) | Advanced User Defined Scalar Function (UDF) |
225+
| adv_udwf | [`udf/advanced_udwf.rs`](examples/udf/advanced_udwf.rs) | Advanced User Defined Window Function (UDWF) |
226+
| async_udf | [`udf/async_udf.rs`](examples/udf/async_udf.rs) | Asynchronous User Defined Scalar Function |
227+
| udaf | [`udf/simple_udaf.rs`](examples/udf/simple_udaf.rs) | Simple UDAF example |
228+
| udf | [`udf/simple_udf.rs`](examples/udf/simple_udf.rs) | Simple UDF example |
229+
| udtf | [`udf/simple_udtf.rs`](examples/udf/simple_udtf.rs) | Simple UDTF example |
230+
| udwf | [`udf/simple_udwf.rs`](examples/udf/simple_udwf.rs) | Simple UDWF example |
231+
| table_list_udtf | [`udf/table_list_udtf.rs`](examples/udf/table_list_udtf.rs) | Session-aware UDTF table list example |

datafusion-examples/examples/udf/main.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
//!
2222
//! ## Usage
2323
//! ```bash
24-
//! cargo run --example udf -- [all|adv_udaf|adv_udf|adv_udwf|async_udf|udaf|udf|udtf|udwf]
24+
//! cargo run --example udf -- [all|adv_udaf|adv_udf|adv_udwf|async_udf|udaf|udf|udtf|udwf|table_list_udtf]
2525
//! ```
2626
//!
2727
//! Each subcommand runs a corresponding example:
@@ -50,6 +50,9 @@
5050
//!
5151
//! - `udwf`
5252
//! (file: simple_udwf.rs, desc: Simple UDWF example)
53+
//!
54+
//! - `table_list_udtf`
55+
//! (file: table_list_udtf.rs, desc: Session-aware UDTF table list example)
5356
5457
mod advanced_udaf;
5558
mod advanced_udf;
@@ -59,6 +62,7 @@ mod simple_udaf;
5962
mod simple_udf;
6063
mod simple_udtf;
6164
mod simple_udwf;
65+
mod table_list_udtf;
6266

6367
use datafusion::error::{DataFusionError, Result};
6468
use strum::{IntoEnumIterator, VariantNames};
@@ -76,6 +80,7 @@ enum ExampleKind {
7680
Udaf,
7781
Udwf,
7882
Udtf,
83+
TableListUdtf,
7984
}
8085

8186
impl ExampleKind {
@@ -101,6 +106,7 @@ impl ExampleKind {
101106
ExampleKind::Udf => simple_udf::simple_udf().await?,
102107
ExampleKind::Udtf => simple_udtf::simple_udtf().await?,
103108
ExampleKind::Udwf => simple_udwf::simple_udwf().await?,
109+
ExampleKind::TableListUdtf => table_list_udtf::table_list_udtf().await?,
104110
}
105111

106112
Ok(())

datafusion-examples/examples/udf/simple_udtf.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use arrow::csv::reader::Format;
2727
use async_trait::async_trait;
2828
use datafusion::arrow::datatypes::SchemaRef;
2929
use datafusion::arrow::record_batch::RecordBatch;
30-
use datafusion::catalog::{Session, TableFunctionImpl};
30+
use datafusion::catalog::{Session, TableFunctionArgs, TableFunctionImpl};
3131
use datafusion::common::{ScalarValue, plan_err};
3232
use datafusion::datasource::TableProvider;
3333
use datafusion::datasource::memory::MemorySourceConfig;
@@ -135,7 +135,8 @@ impl TableProvider for LocalCsvTable {
135135
struct LocalCsvTableFunc {}
136136

137137
impl TableFunctionImpl for LocalCsvTableFunc {
138-
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
138+
fn call_with_args(&self, args: TableFunctionArgs) -> Result<Arc<dyn TableProvider>> {
139+
let exprs = args.exprs();
139140
let Some(Expr::Literal(ScalarValue::Utf8(Some(path)), _)) = exprs.first() else {
140141
return plan_err!("read_csv requires at least one string argument");
141142
};
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! See `main.rs` for how to run it.
19+
20+
use std::sync::{Arc, LazyLock};
21+
22+
use arrow::array::{RecordBatch, StringBuilder};
23+
use arrow_schema::{DataType, Field, Schema, SchemaRef};
24+
use datafusion::{
25+
catalog::{MemTable, TableFunctionArgs, TableFunctionImpl, TableProvider},
26+
common::Result,
27+
execution::SessionState,
28+
prelude::SessionContext,
29+
};
30+
use datafusion_common::{DataFusionError, plan_err};
31+
use tokio::{runtime::Handle, task::block_in_place};
32+
33+
const FUNCTION_NAME: &str = "table_list";
34+
35+
// The example shows, how to create UDTF that depends on the session state.
36+
// Defines a `table_list` UDTF that returns a list of tables within the provided session.
37+
38+
pub async fn table_list_udtf() -> Result<()> {
39+
let ctx = SessionContext::new();
40+
ctx.register_udtf(FUNCTION_NAME, Arc::new(TableListUdtf));
41+
42+
// Register different kinds of tables.
43+
ctx.sql("create view v as select 1")
44+
.await?
45+
.collect()
46+
.await?;
47+
ctx.sql("create table t(a int)").await?.collect().await?;
48+
49+
// Print results.
50+
ctx.sql("select * from table_list()").await?.show().await?;
51+
52+
Ok(())
53+
}
54+
55+
#[derive(Debug, Default)]
56+
struct TableListUdtf;
57+
58+
static SCHEMA: LazyLock<SchemaRef> = LazyLock::new(|| {
59+
SchemaRef::new(Schema::new(vec![
60+
Field::new("catalog", DataType::Utf8, false),
61+
Field::new("schema", DataType::Utf8, false),
62+
Field::new("table", DataType::Utf8, false),
63+
Field::new("type", DataType::Utf8, false),
64+
]))
65+
});
66+
67+
impl TableFunctionImpl for TableListUdtf {
68+
fn call_with_args(&self, args: TableFunctionArgs) -> Result<Arc<dyn TableProvider>> {
69+
if !args.exprs().is_empty() {
70+
return plan_err!(
71+
"{}: unexpected number of arguments: {}, expected: 0",
72+
FUNCTION_NAME,
73+
args.exprs().len()
74+
);
75+
}
76+
let state = args
77+
.session()
78+
.as_any()
79+
.downcast_ref::<SessionState>()
80+
.ok_or_else(|| {
81+
DataFusionError::Internal("failed to downcast state".into())
82+
})?;
83+
84+
let mut catalogs = StringBuilder::new();
85+
let mut schemas = StringBuilder::new();
86+
let mut tables = StringBuilder::new();
87+
let mut types = StringBuilder::new();
88+
89+
let catalog_list = state.catalog_list();
90+
for catalog_name in catalog_list.catalog_names() {
91+
let Some(catalog) = catalog_list.catalog(&catalog_name) else {
92+
continue;
93+
};
94+
for schema_name in catalog.schema_names() {
95+
let Some(schema) = catalog.schema(&schema_name) else {
96+
continue;
97+
};
98+
for table_name in schema.table_names() {
99+
let Some(provider) = block_in_place(|| {
100+
Handle::current().block_on(schema.table(&table_name))
101+
})?
102+
else {
103+
continue;
104+
};
105+
catalogs.append_value(catalog_name.clone());
106+
schemas.append_value(schema_name.clone());
107+
tables.append_value(table_name.clone());
108+
types.append_value(provider.table_type().to_string())
109+
}
110+
}
111+
}
112+
113+
let batch = RecordBatch::try_new(
114+
Arc::clone(&SCHEMA),
115+
vec![
116+
Arc::new(catalogs.finish()),
117+
Arc::new(schemas.finish()),
118+
Arc::new(tables.finish()),
119+
Arc::new(types.finish()),
120+
],
121+
)?;
122+
123+
Ok(Arc::new(MemTable::try_new(
124+
batch.schema(),
125+
vec![vec![batch]],
126+
)?))
127+
}
128+
}

datafusion/catalog/src/table.rs

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ use std::sync::Arc;
2323
use crate::session::Session;
2424
use arrow::datatypes::SchemaRef;
2525
use async_trait::async_trait;
26-
use datafusion_common::Result;
2726
use datafusion_common::{Constraints, Statistics, not_impl_err};
27+
use datafusion_common::{Result, internal_err};
2828
use datafusion_expr::Expr;
2929

3030
use datafusion_expr::dml::InsertOp;
@@ -507,10 +507,49 @@ pub trait TableProviderFactory: Debug + Sync + Send {
507507
) -> Result<Arc<dyn TableProvider>>;
508508
}
509509

510+
/// Describes arguments provided to the table function call.
511+
pub struct TableFunctionArgs<'e, 's> {
512+
/// Call arguments.
513+
exprs: &'e [Expr],
514+
/// Session within which the function is called.
515+
session: &'s dyn Session,
516+
}
517+
518+
impl<'e, 's> TableFunctionArgs<'e, 's> {
519+
/// Make a new [`TableFunctionArgs`].
520+
pub fn new(exprs: &'e [Expr], session: &'s dyn Session) -> Self {
521+
Self { exprs, session }
522+
}
523+
524+
/// Get expressions passed as the called function arguments.
525+
pub fn exprs(&self) -> &'e [Expr] {
526+
self.exprs
527+
}
528+
529+
/// Get a session where the table function is called.
530+
pub fn session(&self) -> &'s dyn Session {
531+
self.session
532+
}
533+
}
534+
510535
/// A trait for table function implementations
511536
pub trait TableFunctionImpl: Debug + Sync + Send + Any {
512537
/// Create a table provider
513-
fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>>;
538+
#[deprecated(
539+
since = "53.0.0",
540+
note = "Implement `TableFunctionImpl::call_with_args` instead"
541+
)]
542+
fn call(&self, _exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
543+
internal_err!(
544+
"TableFunctionImpl::call is not implemented. Implement TableFunctionImpl::call_with_args instead."
545+
)
546+
}
547+
548+
/// Create a table provider
549+
fn call_with_args(&self, args: TableFunctionArgs) -> Result<Arc<dyn TableProvider>> {
550+
#[expect(deprecated)]
551+
self.call(args.exprs)
552+
}
514553
}
515554

516555
/// A table that uses a function to generate data
@@ -539,7 +578,20 @@ impl TableFunction {
539578
}
540579

541580
/// Get the function implementation and generate a table
581+
#[deprecated(
582+
since = "53.0.0",
583+
note = "Use `TableFunction::create_table_provider_with_args` instead"
584+
)]
542585
pub fn create_table_provider(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
586+
#[expect(deprecated)]
543587
self.fun.call(args)
544588
}
589+
590+
/// Get the function implementation and generate a table
591+
pub fn create_table_provider_with_args(
592+
&self,
593+
args: TableFunctionArgs,
594+
) -> Result<Arc<dyn TableProvider>> {
595+
self.fun.call_with_args(args)
596+
}
545597
}

datafusion/core/src/execution/session_state.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1862,6 +1862,8 @@ impl ContextProvider for SessionContextProvider<'_> {
18621862
name: &str,
18631863
args: Vec<Expr>,
18641864
) -> datafusion_common::Result<Arc<dyn TableSource>> {
1865+
use datafusion_catalog::TableFunctionArgs;
1866+
18651867
let tbl_func = self
18661868
.state
18671869
.table_functions
@@ -1884,7 +1886,8 @@ impl ContextProvider for SessionContextProvider<'_> {
18841886
.and_then(|e| simplifier.simplify(e))
18851887
})
18861888
.collect::<datafusion_common::Result<Vec<_>>>()?;
1887-
let provider = tbl_func.create_table_provider(&args)?;
1889+
let provider = tbl_func
1890+
.create_table_provider_with_args(TableFunctionArgs::new(&args, self.state))?;
18881891

18891892
Ok(provider_as_source(provider))
18901893
}

datafusion/core/tests/user_defined/user_defined_table_functions.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ use datafusion::error::Result;
3333
use datafusion::execution::TaskContext;
3434
use datafusion::physical_plan::{ExecutionPlan, collect};
3535
use datafusion::prelude::SessionContext;
36-
use datafusion_catalog::Session;
3736
use datafusion_catalog::TableFunctionImpl;
37+
use datafusion_catalog::{Session, TableFunctionArgs};
3838
use datafusion_common::{DFSchema, ScalarValue};
3939
use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType};
4040

@@ -200,7 +200,8 @@ impl SimpleCsvTable {
200200
struct SimpleCsvTableFunc {}
201201

202202
impl TableFunctionImpl for SimpleCsvTableFunc {
203-
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
203+
fn call_with_args(&self, args: TableFunctionArgs) -> Result<Arc<dyn TableProvider>> {
204+
let exprs = args.exprs();
204205
let mut new_exprs = vec![];
205206
let mut filepath = String::new();
206207
for expr in exprs {
@@ -231,7 +232,7 @@ async fn test_udtf_type_coercion() -> Result<()> {
231232
struct NoOpTableFunc;
232233

233234
impl TableFunctionImpl for NoOpTableFunc {
234-
fn call(&self, _: &[Expr]) -> Result<Arc<dyn TableProvider>> {
235+
fn call_with_args(&self, _: TableFunctionArgs) -> Result<Arc<dyn TableProvider>> {
235236
let schema = Arc::new(arrow::datatypes::Schema::empty());
236237
Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?))
237238
}

0 commit comments

Comments
 (0)