Skip to content

Commit 75c7da5

Browse files
authored
Pass ConfigOptions to scalar UDFs via FFI (#20454)
## Which issue does this PR close? - Closes #17035 ## Rationale for this change Now that we have proper `FFI_ConfigOptions` we can pass these to scalar UDFs via FFI. ## What changes are included in this PR? Instead of passing default options, pass in converted config options from the input. Also did a drive by cleanup of switching to using FFI_ColumnarValue since it is now available. ## Are these changes tested? Unit test added. ## Are there any user-facing changes? This is a breaking API change, but not one that users will interact with directly. It breaks the ABI for FFI libraries, which is currently unstable.
1 parent fd97799 commit 75c7da5

File tree

4 files changed

+110
-24
lines changed

4 files changed

+110
-24
lines changed

datafusion/ffi/src/tests/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ pub struct ForeignLibraryModule {
8484

8585
pub create_nullary_udf: extern "C" fn() -> FFI_ScalarUDF,
8686

87+
pub create_timezone_udf: extern "C" fn() -> FFI_ScalarUDF,
88+
8789
pub create_table_function:
8890
extern "C" fn(FFI_LogicalExtensionCodec) -> FFI_TableFunction,
8991

@@ -157,6 +159,7 @@ pub fn get_foreign_library_module() -> ForeignLibraryModuleRef {
157159
create_table_factory: construct_table_provider_factory,
158160
create_scalar_udf: create_ffi_abs_func,
159161
create_nullary_udf: create_ffi_random_func,
162+
create_timezone_udf: udf_udaf_udwf::create_timezone_func,
160163
create_table_function: create_ffi_table_func,
161164
create_sum_udaf: create_ffi_sum_func,
162165
create_stddev_udaf: create_ffi_stddev_func,

datafusion/ffi/src/tests/udf_udaf_udwf.rs

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ use std::sync::Arc;
2020

2121
use arrow_schema::DataType;
2222
use datafusion_catalog::TableFunctionImpl;
23+
use datafusion_common::ScalarValue;
2324
use datafusion_expr::{
2425
AggregateUDF, ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
25-
WindowUDF,
26+
Volatility, WindowUDF,
2627
};
2728
use datafusion_functions::math::abs::AbsFunc;
2829
use datafusion_functions::math::random::RandomFunc;
@@ -78,6 +79,47 @@ pub(crate) extern "C" fn create_ffi_random_func() -> FFI_ScalarUDF {
7879
udf.into()
7980
}
8081

82+
#[derive(Debug, PartialEq, Eq, Hash)]
83+
struct TimeZoneUDF {
84+
signature: Signature,
85+
}
86+
87+
impl ScalarUDFImpl for TimeZoneUDF {
88+
fn as_any(&self) -> &dyn Any {
89+
self
90+
}
91+
fn name(&self) -> &str {
92+
"TimeZoneUDF"
93+
}
94+
95+
fn signature(&self) -> &Signature {
96+
&self.signature
97+
}
98+
99+
fn return_type(
100+
&self,
101+
_arg_types: &[DataType],
102+
) -> datafusion_common::Result<DataType> {
103+
Ok(DataType::Utf8)
104+
}
105+
106+
fn invoke_with_args(
107+
&self,
108+
args: ScalarFunctionArgs,
109+
) -> datafusion_common::Result<ColumnarValue> {
110+
let tz = args.config_options.execution.time_zone.clone();
111+
Ok(ColumnarValue::Scalar(ScalarValue::from(tz)))
112+
}
113+
}
114+
115+
pub(crate) extern "C" fn create_timezone_func() -> FFI_ScalarUDF {
116+
let udf: Arc<ScalarUDF> = Arc::new(ScalarUDF::from(TimeZoneUDF {
117+
signature: Signature::uniform(1, vec![DataType::Utf8], Volatility::Stable),
118+
}));
119+
120+
udf.into()
121+
}
122+
81123
pub(crate) extern "C" fn create_ffi_table_func(
82124
codec: FFI_LogicalExtensionCodec,
83125
) -> FFI_TableFunction {

datafusion/ffi/src/udf/mod.rs

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ use std::hash::{Hash, Hasher};
2020
use std::sync::Arc;
2121

2222
use abi_stable::StableAbi;
23-
use abi_stable::std_types::{RResult, RString, RVec};
24-
use arrow::array::ArrayRef;
23+
use abi_stable::std_types::{RString, RVec};
24+
use arrow::array::Array;
2525
use arrow::datatypes::{DataType, Field};
2626
use arrow::error::ArrowError;
2727
use arrow::ffi::{FFI_ArrowSchema, from_ffi, to_ffi};
@@ -38,6 +38,8 @@ use return_type_args::{
3838
};
3939

4040
use crate::arrow_wrappers::{WrappedArray, WrappedSchema};
41+
use crate::config::FFI_ConfigOptions;
42+
use crate::expr::columnar_value::FFI_ColumnarValue;
4143
use crate::util::{
4244
FFIResult, rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped,
4345
};
@@ -73,7 +75,8 @@ pub struct FFI_ScalarUDF {
7375
arg_fields: RVec<WrappedSchema>,
7476
num_rows: usize,
7577
return_field: WrappedSchema,
76-
) -> FFIResult<WrappedArray>,
78+
config_options: FFI_ConfigOptions,
79+
) -> FFIResult<FFI_ColumnarValue>,
7780

7881
/// See [`ScalarUDFImpl`] for details on short_circuits
7982
pub short_circuits: bool,
@@ -159,7 +162,8 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper(
159162
arg_fields: RVec<WrappedSchema>,
160163
number_rows: usize,
161164
return_field: WrappedSchema,
162-
) -> FFIResult<WrappedArray> {
165+
config_options: FFI_ConfigOptions,
166+
) -> FFIResult<FFI_ColumnarValue> {
163167
unsafe {
164168
let args = args
165169
.into_iter()
@@ -181,28 +185,22 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper(
181185
})
182186
.collect::<Result<Vec<FieldRef>>>();
183187
let arg_fields = rresult_return!(arg_fields);
188+
let config_options = rresult_return!(ConfigOptions::try_from(config_options));
189+
let config_options = Arc::new(config_options);
184190

185191
let args = ScalarFunctionArgs {
186192
args,
187193
arg_fields,
188194
number_rows,
189195
return_field,
190-
// TODO: pass config options: https://github.com/apache/datafusion/issues/17035
191-
config_options: Arc::new(ConfigOptions::default()),
196+
config_options,
192197
};
193198

194-
let result = rresult_return!(
199+
rresult!(
195200
udf.inner()
196201
.invoke_with_args(args)
197-
.and_then(|r| r.to_array(number_rows))
198-
);
199-
200-
let (result_array, result_schema) = rresult_return!(to_ffi(&result.to_data()));
201-
202-
RResult::ROk(WrappedArray {
203-
array: result_array,
204-
schema: WrappedSchema(result_schema),
205-
})
202+
.and_then(FFI_ColumnarValue::try_from)
203+
)
206204
}
207205
}
208206

@@ -366,8 +364,7 @@ impl ScalarUDFImpl for ForeignScalarUDF {
366364
arg_fields,
367365
number_rows,
368366
return_field,
369-
// TODO: pass config options: https://github.com/apache/datafusion/issues/17035
370-
config_options: _config_options,
367+
config_options,
371368
} = invoke_args;
372369

373370
let args = args
@@ -396,6 +393,7 @@ impl ScalarUDFImpl for ForeignScalarUDF {
396393

397394
let return_field = return_field.as_ref().clone();
398395
let return_field = WrappedSchema(FFI_ArrowSchema::try_from(return_field)?);
396+
let config_options = config_options.as_ref().into();
399397

400398
let result = unsafe {
401399
(self.udf.invoke_with_args)(
@@ -404,13 +402,12 @@ impl ScalarUDFImpl for ForeignScalarUDF {
404402
arg_fields,
405403
number_rows,
406404
return_field,
405+
config_options,
407406
)
408407
};
409408

410409
let result = df_result!(result)?;
411-
let result_array: ArrayRef = result.try_into()?;
412-
413-
Ok(ColumnarValue::Array(result_array))
410+
result.try_into()
414411
}
415412

416413
fn aliases(&self) -> &[String] {

datafusion/ffi/tests/ffi_udf.rs

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,17 @@
1919
/// when the feature integration-tests is built
2020
#[cfg(feature = "integration-tests")]
2121
mod tests {
22-
use std::sync::Arc;
23-
22+
use arrow::array::{Array, AsArray};
2423
use arrow::datatypes::DataType;
2524
use datafusion::common::record_batch;
2625
use datafusion::error::{DataFusionError, Result};
2726
use datafusion::logical_expr::{ScalarUDF, ScalarUDFImpl};
2827
use datafusion::prelude::{SessionContext, col};
28+
use datafusion_execution::config::SessionConfig;
29+
use datafusion_expr::lit;
2930
use datafusion_ffi::tests::create_record_batch;
3031
use datafusion_ffi::tests::utils::get_module;
32+
use std::sync::Arc;
3133

3234
/// This test validates that we can load an external module and use a scalar
3335
/// udf defined in it via the foreign function interface. In this case we are
@@ -100,4 +102,46 @@ mod tests {
100102

101103
Ok(())
102104
}
105+
106+
#[tokio::test]
107+
async fn test_config_on_scalar_udf() -> Result<()> {
108+
let module = get_module()?;
109+
110+
let ffi_udf =
111+
module
112+
.create_timezone_udf()
113+
.ok_or(DataFusionError::NotImplemented(
114+
"External module failed to implement create_timezone_udf".to_string(),
115+
))?();
116+
let foreign_udf: Arc<dyn ScalarUDFImpl> = (&ffi_udf).into();
117+
118+
let udf = ScalarUDF::new_from_shared_impl(foreign_udf);
119+
120+
let ctx = SessionContext::default();
121+
122+
let df = ctx
123+
.read_empty()?
124+
.select(vec![udf.call(vec![lit("a")]).alias("a")])?;
125+
126+
let result = df.collect().await?;
127+
assert!(result[0].column(0).as_string::<i32>().is_null(0));
128+
129+
let mut config = SessionConfig::new();
130+
config.options_mut().execution.time_zone = Some("AEST".into());
131+
132+
let ctx = SessionContext::new_with_config(config);
133+
134+
let df = ctx
135+
.read_empty()?
136+
.select(vec![udf.call(vec![lit("a")]).alias("a")])?;
137+
138+
let result = df.collect().await?;
139+
140+
assert!(result.len() == 1);
141+
assert!(!result[0].column(0).as_string::<i32>().is_null(0));
142+
let result = result[0].column(0).as_string::<i32>().value(0);
143+
assert_eq!(result, "AEST");
144+
145+
Ok(())
146+
}
103147
}

0 commit comments

Comments
 (0)