Skip to content

Commit fb26fd9

Browse files
authored
Add exact HigherOrderSignature (#22326)
## Which issue does this PR close? No issue — this is a follow-up to #21679. ## Rationale for this change In `ScalarUDF`, arity is enforced by the framework via `TypeSignature`. In `HigherOrderUDF`, functions with a fixed number of value and lambda arguments had to use `UserDefined` and manually validate arity inside `coerce_value_types`, which is boilerplate that every implementor has to repeat. ## What changes are included in this PR? Adds `HigherOrderTypeSignature::Exact { values: usize, lambdas: usize } `variant that enforces a fixed count of value and lambda arguments, calling coerce_value_types only for type coercion as well ## Are these changes tested? Yes I added some planning tests for exact siganture in `datafusion/expr/src/type_coercion/functions.rs` ## Are there any user-facing changes? Yes a new signature for HigherOrderSignature was added.
1 parent dc80bd7 commit fb26fd9

5 files changed

Lines changed: 227 additions & 45 deletions

File tree

datafusion/expr/src/higher_order_function.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ pub enum HigherOrderTypeSignature {
7373
VariadicAny,
7474
/// The specified number of lambdas or arguments with arbitrary types.
7575
Any(usize),
76+
/// Exactly the specified arguments in the given order, with arbitrary types.
77+
/// DataFusion will call [`HigherOrderUDF::coerce_value_types`] to prepare the value
78+
/// argument types.
79+
Exact(Vec<ValueOrLambda<(), ()>>),
7680
}
7781

7882
/// Provides information necessary for calling a higher order function.
@@ -138,6 +142,28 @@ impl HigherOrderSignature {
138142
}
139143
}
140144

145+
/// Exactly the specified arguments in the given order, with arbitrary types.
146+
/// DataFusion will call [`HigherOrderUDF::coerce_value_types`] to prepare the value
147+
/// argument types.
148+
///
149+
/// # Example
150+
/// A function that takes one value argument followed by one lambda:
151+
/// ```
152+
/// # use datafusion_expr::{HigherOrderSignature, ValueOrLambda, Volatility};
153+
/// let sig = HigherOrderSignature::exact(
154+
/// vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
155+
/// Volatility::Immutable,
156+
/// );
157+
/// ```
158+
pub fn exact(args: Vec<ValueOrLambda<(), ()>>, volatility: Volatility) -> Self {
159+
Self {
160+
type_signature: HigherOrderTypeSignature::Exact(args),
161+
volatility,
162+
coerce_values_for_lambdas: false,
163+
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
164+
}
165+
}
166+
141167
/// Set [Self::coerce_values_for_lambdas] to true to indicate that [HigherOrderUDF::coerce_values_for_lambdas]
142168
/// should be called
143169
pub fn with_coerce_values_for_lambdas(mut self) -> Self {
@@ -406,7 +432,7 @@ pub struct HigherOrderReturnFieldArgs<'a> {
406432
}
407433

408434
/// An argument to a higher order function
409-
#[derive(Clone, Debug, PartialEq, Eq)]
435+
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Hash)]
410436
pub enum ValueOrLambda<V, L> {
411437
/// A value with associated data
412438
Value(V),

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,78 @@ pub fn value_fields_with_higher_order_udf<L: Clone>(
230230

231231
Ok(current_fields.to_vec())
232232
}
233+
HigherOrderTypeSignature::Exact(ref expected) => {
234+
if current_fields.len() != expected.len() {
235+
let name = func.name();
236+
let expected_len = expected.len();
237+
let actual_len = current_fields.len();
238+
return plan_err!(
239+
"The function '{name}' expected {expected_len} argument(s) but received {actual_len}"
240+
);
241+
}
242+
243+
for (i, (actual, expected)) in
244+
current_fields.iter().zip(expected.iter()).enumerate()
245+
{
246+
match (actual, expected) {
247+
(ValueOrLambda::Value(_), ValueOrLambda::Value(_)) => {}
248+
(ValueOrLambda::Lambda(_), ValueOrLambda::Lambda(_)) => {}
249+
(ValueOrLambda::Value(_), ValueOrLambda::Lambda(_)) => {
250+
let name = func.name();
251+
return plan_err!(
252+
"The function '{name}' expected a lambda at position {i} but received a value"
253+
);
254+
}
255+
(ValueOrLambda::Lambda(_), ValueOrLambda::Value(_)) => {
256+
let name = func.name();
257+
return plan_err!(
258+
"The function '{name}' expected a value at position {i} but received a lambda"
259+
);
260+
}
261+
}
262+
}
263+
264+
let arg_types = current_fields
265+
.iter()
266+
.filter_map(|p| match p {
267+
ValueOrLambda::Value(field) => Some(field.data_type().clone()),
268+
ValueOrLambda::Lambda(_) => None,
269+
})
270+
.collect::<Vec<_>>();
271+
272+
let coerced_types = func.coerce_value_types(&arg_types)?;
273+
274+
if coerced_types.len() != arg_types.len() {
275+
return plan_err!(
276+
"{} coerce_value_types should have returned {} items but returned {}",
277+
func.name(),
278+
arg_types.len(),
279+
coerced_types.len()
280+
);
281+
}
282+
283+
let mut coerced_types = coerced_types.into_iter();
284+
285+
current_fields
286+
.iter()
287+
.map(|current_field| match current_field {
288+
ValueOrLambda::Value(field) => {
289+
let data_type = coerced_types.next().ok_or_else(|| {
290+
internal_datafusion_err!(
291+
"coerced_types len should have been checked above"
292+
)
293+
})?;
294+
295+
Ok(ValueOrLambda::Value(Arc::new(
296+
field.as_ref().clone().with_data_type(data_type),
297+
)))
298+
}
299+
ValueOrLambda::Lambda(lambda) => {
300+
Ok(ValueOrLambda::Lambda(lambda.clone()))
301+
}
302+
})
303+
.collect()
304+
}
233305
}
234306
}
235307

@@ -2026,4 +2098,88 @@ mod tests {
20262098
"The function 'mock_higher_order_function' expected 1 arguments but received 0"
20272099
);
20282100
}
2101+
2102+
#[test]
2103+
fn test_higher_order_function_exact_signature() {
2104+
let fun = MockHigherOrderUDF {
2105+
signature: HigherOrderSignature::exact(
2106+
vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
2107+
Volatility::Immutable,
2108+
),
2109+
coerced_value_types: vec![DataType::new_large_list(DataType::Int32, false)],
2110+
};
2111+
2112+
let new_fields = value_fields_with_higher_order_udf(
2113+
&[
2114+
ValueOrLambda::Value(Arc::new(Field::new_list(
2115+
"",
2116+
Field::new_list_field(DataType::Int32, false),
2117+
false,
2118+
))),
2119+
ValueOrLambda::Lambda(()),
2120+
],
2121+
&fun,
2122+
)
2123+
.unwrap();
2124+
2125+
// type coercion applied: List(Int32) -> LargeList(Int32)
2126+
assert_eq!(
2127+
new_fields,
2128+
vec![
2129+
ValueOrLambda::Value(Arc::new(Field::new_large_list(
2130+
"",
2131+
Field::new_list_field(DataType::Int32, false),
2132+
false
2133+
))),
2134+
ValueOrLambda::Lambda(()),
2135+
]
2136+
)
2137+
}
2138+
2139+
#[test]
2140+
fn test_higher_order_function_exact_signature_wrong_value_count() {
2141+
let fun = MockHigherOrderUDF {
2142+
signature: HigherOrderSignature::exact(
2143+
vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
2144+
Volatility::Immutable,
2145+
),
2146+
coerced_value_types: vec![],
2147+
};
2148+
2149+
let err = value_fields_with_higher_order_udf::<()>(
2150+
&[ValueOrLambda::Lambda(()), ValueOrLambda::Lambda(())],
2151+
&fun,
2152+
)
2153+
.unwrap_err();
2154+
2155+
assert_contains!(
2156+
err.to_string(),
2157+
"expected a value at position 0 but received a lambda"
2158+
);
2159+
}
2160+
2161+
#[test]
2162+
fn test_higher_order_function_exact_signature_wrong_lambda_count() {
2163+
let fun = MockHigherOrderUDF {
2164+
signature: HigherOrderSignature::exact(
2165+
vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
2166+
Volatility::Immutable,
2167+
),
2168+
coerced_value_types: vec![],
2169+
};
2170+
2171+
let err = value_fields_with_higher_order_udf::<()>(
2172+
&[
2173+
ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, false))),
2174+
ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, false))),
2175+
],
2176+
&fun,
2177+
)
2178+
.unwrap_err();
2179+
2180+
assert_contains!(
2181+
err.to_string(),
2182+
"expected a lambda at position 1 but received a value"
2183+
);
2184+
}
20292185
}

datafusion/functions-nested/src/array_any_match.rs

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ impl Default for ArrayAnyMatch {
8181
impl ArrayAnyMatch {
8282
pub fn new() -> Self {
8383
Self {
84-
signature: HigherOrderSignature::user_defined(Volatility::Immutable),
84+
signature: HigherOrderSignature::exact(
85+
vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
86+
Volatility::Immutable,
87+
),
8588
aliases: vec![String::from("any_match"), String::from("list_any_match")],
8689
}
8790
}
@@ -117,9 +120,7 @@ impl HigherOrderUDF for ArrayAnyMatch {
117120
}
118121

119122
fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
120-
let list = if arg_types.len() == 1 {
121-
&arg_types[0]
122-
} else {
123+
let [list] = arg_types else {
123124
return plan_err!(
124125
"{} function requires 1 value argument, got {}",
125126
self.name(),
@@ -150,15 +151,15 @@ impl HigherOrderUDF for ArrayAnyMatch {
150151
_step: usize,
151152
fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
152153
) -> Result<LambdaParametersProgress> {
153-
let [list, _lambda] = take_function_args(self.name(), fields)?;
154-
155-
let field = match list {
156-
ValueOrLambda::Value(f) => match f.data_type() {
157-
DataType::List(field) => field,
158-
DataType::LargeList(field) => field,
159-
other => return plan_err!("expected list, got {other}"),
160-
},
161-
_ => return plan_err!("{} expected a value as first argument", self.name()),
154+
let [list, _] = take_function_args(self.name(), fields)?;
155+
let ValueOrLambda::Value(list) = list else {
156+
return plan_err!("{} expects a value as first argument", self.name());
157+
};
158+
159+
let field = match list.data_type() {
160+
DataType::List(field) => field,
161+
DataType::LargeList(field) => field,
162+
other => return plan_err!("expected list, got {other}"),
162163
};
163164

164165
Ok(LambdaParametersProgress::Complete(vec![vec![Arc::clone(
@@ -170,15 +171,18 @@ impl HigherOrderUDF for ArrayAnyMatch {
170171
&self,
171172
args: HigherOrderReturnFieldArgs,
172173
) -> Result<Arc<Field>> {
173-
let [list, _lambda] = take_function_args(self.name(), args.arg_fields)?;
174-
let nullable = matches!(list, ValueOrLambda::Value(f) if f.is_nullable());
174+
let [ValueOrLambda::Value(list), _] =
175+
take_function_args(self.name(), args.arg_fields)?
176+
else {
177+
return plan_err!("{} expects a value as first argument", self.name());
178+
};
179+
let nullable = list.is_nullable();
175180
Ok(Arc::new(Field::new("", DataType::Boolean, nullable)))
176181
}
177182

178183
fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue> {
179-
let [list, lambda] = take_function_args(self.name(), &args.args)?;
180-
181-
let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)) = (list, lambda)
184+
let [ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)] =
185+
take_function_args(self.name(), &args.args)?
182186
else {
183187
return exec_err!("{} expects a value followed by a lambda", self.name());
184188
};

datafusion/functions-nested/src/array_transform.rs

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ impl Default for ArrayTransform {
7777
impl ArrayTransform {
7878
pub fn new() -> Self {
7979
Self {
80-
signature: HigherOrderSignature::user_defined(Volatility::Immutable),
80+
signature: HigherOrderSignature::exact(
81+
vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
82+
Volatility::Immutable,
83+
),
8184
aliases: vec![String::from("list_transform")],
8285
}
8386
}
@@ -97,11 +100,9 @@ impl HigherOrderUDF for ArrayTransform {
97100
}
98101

99102
fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
100-
let list = if arg_types.len() == 1 {
101-
&arg_types[0]
102-
} else {
103+
let [list] = arg_types else {
103104
return plan_err!(
104-
"{} function requires 1 value arguments, got {}",
105+
"{} function requires 1 value argument, got {}",
105106
self.name(),
106107
arg_types.len()
107108
);
@@ -130,7 +131,10 @@ impl HigherOrderUDF for ArrayTransform {
130131
_step: usize,
131132
fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
132133
) -> Result<LambdaParametersProgress> {
133-
let (list, _lambda) = value_lambda_pair(self.name(), fields)?;
134+
let [list, _] = take_function_args(self.name(), fields)?;
135+
let ValueOrLambda::Value(list) = list else {
136+
return plan_err!("{} expects a value as first argument", self.name());
137+
};
134138

135139
let field = match list.data_type() {
136140
DataType::List(field) => field,
@@ -149,7 +153,11 @@ impl HigherOrderUDF for ArrayTransform {
149153
&self,
150154
args: HigherOrderReturnFieldArgs,
151155
) -> Result<Arc<Field>> {
152-
let (list, lambda) = value_lambda_pair(self.name(), args.arg_fields)?;
156+
let [ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)] =
157+
take_function_args(self.name(), args.arg_fields)?
158+
else {
159+
return plan_err!("{} expects a value followed by a lambda", self.name());
160+
};
153161

154162
//TODO: should metadata be copied into the transformed array?
155163

@@ -171,7 +179,11 @@ impl HigherOrderUDF for ArrayTransform {
171179
}
172180

173181
fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue> {
174-
let (list, lambda) = value_lambda_pair(self.name(), &args.args)?;
182+
let [list, lambda] = take_function_args(self.name(), &args.args)?;
183+
let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)) = (list, lambda)
184+
else {
185+
return plan_err!("{} expects a value followed by a lambda", self.name());
186+
};
175187

176188
let list_array = list.to_array(args.number_rows)?;
177189

@@ -265,22 +277,6 @@ impl HigherOrderUDF for ArrayTransform {
265277
}
266278
}
267279

268-
fn value_lambda_pair<'a, V: Debug, L: Debug>(
269-
name: &str,
270-
args: &'a [ValueOrLambda<V, L>],
271-
) -> Result<(&'a V, &'a L)> {
272-
let [value, lambda] = take_function_args(name, args)?;
273-
274-
let (ValueOrLambda::Value(value), ValueOrLambda::Lambda(lambda)) = (value, lambda)
275-
else {
276-
return plan_err!(
277-
"{name} expects a value followed by a lambda, got {value:?} and {lambda:?}"
278-
);
279-
};
280-
281-
Ok((value, lambda))
282-
}
283-
284280
#[cfg(test)]
285281
mod tests {
286282
use std::{collections::HashMap, sync::Arc};

datafusion/sqllogictest/test_files/array/array_transform.slt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,13 +396,13 @@ physical_plan
396396
query error
397397
select array_transform();
398398
----
399-
DataFusion error: Error during planning: array_transform function requires 1 value arguments, got 0
399+
DataFusion error: Error during planning: The function 'array_transform' expected 2 argument(s) but received 0
400400

401401

402402
query error DataFusion error: Error during planning: array_transform expected a list as first argument, got Int64
403403
select array_transform(1, v -> v*2);
404404

405-
query error DataFusion error: Error during planning: array_transform expects a value followed by a lambda, got Lambda\(None\) and Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ data_type: Int64, nullable: true \}\), nullable: true \}\)
405+
query error DataFusion error: Error during planning: The function 'array_transform' expected a value at position 0 but received a lambda
406406
select array_transform(v -> v*2, [1, 2]);
407407

408408
query error DataFusion error: Error during planning: lambda defined 3 params but UDF support only 1

0 commit comments

Comments
 (0)