Skip to content

Commit 571a9d9

Browse files
committed
allow specifying positional info on exact signature
1 parent e3649df commit 571a9d9

5 files changed

Lines changed: 99 additions & 49 deletions

File tree

datafusion/expr/src/higher_order_function.rs

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ pub enum HigherOrderTypeSignature {
7373
VariadicAny,
7474
/// The specified number of lambdas or arguments with arbitrary types.
7575
Any(usize),
76-
/// Exactly the specified number of value arguments and lambda arguments, in any order,
77-
/// with arbitrary types. DataFusion will call [`HigherOrderUDF::coerce_value_types`]
78-
/// to prepare the value argument types.
79-
Exact { values: usize, lambdas: usize },
76+
/// Exactly the specified arguments in the given order, with arbitrary types.
77+
/// DataFusion will call [`HigherOrderUDF::coerce_value_types`] to prepare the value
78+
/// argument types.
79+
Exact(Vec<ValueOrLambda<(), ()>>),
8080
}
8181

8282
/// Provides information necessary for calling a higher order function.
@@ -142,12 +142,22 @@ impl HigherOrderSignature {
142142
}
143143
}
144144

145-
/// Exactly the specified number of value arguments and lambda arguments, with
146-
/// arbitrary types. DataFusion will call [`HigherOrderUDF::coerce_value_types`]
147-
/// to prepare the value argument types.
148-
pub fn exact(values: usize, lambdas: usize, volatility: Volatility) -> Self {
145+
/// Exactly the specified arguments in the given order, with arbitrary types.
146+
/// DataFusion will call [`HigherOrderUDF::coerce_value_types`] to prepare the value
147+
/// argument types.
148+
///
149+
/// # Example
150+
/// A function that takes one value argument followed by one lambda:
151+
/// ```
152+
/// # use datafusion_expr::{HigherOrderSignature, ValueOrLambda, Volatility};
153+
/// let sig = HigherOrderSignature::exact(
154+
/// vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
155+
/// Volatility::Immutable,
156+
/// );
157+
/// ```
158+
pub fn exact(args: Vec<ValueOrLambda<(), ()>>, volatility: Volatility) -> Self {
149159
Self {
150-
type_signature: HigherOrderTypeSignature::Exact { values, lambdas },
160+
type_signature: HigherOrderTypeSignature::Exact(args),
151161
volatility,
152162
coerce_values_for_lambdas: false,
153163
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
@@ -422,7 +432,7 @@ pub struct HigherOrderReturnFieldArgs<'a> {
422432
}
423433

424434
/// An argument to a higher order function
425-
#[derive(Clone, Debug, PartialEq, Eq)]
435+
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Hash)]
426436
pub enum ValueOrLambda<V, L> {
427437
/// A value with associated data
428438
Value(V),

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -230,23 +230,37 @@ pub fn value_fields_with_higher_order_udf<L: Clone>(
230230

231231
Ok(current_fields.to_vec())
232232
}
233-
HigherOrderTypeSignature::Exact { values, lambdas } => {
234-
let actual_values = current_fields
235-
.iter()
236-
.filter(|f| matches!(f, ValueOrLambda::Value(_)))
237-
.count();
238-
let actual_lambdas = current_fields
239-
.iter()
240-
.filter(|f| matches!(f, ValueOrLambda::Lambda(_)))
241-
.count();
242-
243-
if actual_values != values || actual_lambdas != lambdas {
233+
HigherOrderTypeSignature::Exact(ref expected) => {
234+
if current_fields.len() != expected.len() {
244235
let name = func.name();
236+
let expected_len = expected.len();
237+
let actual_len = current_fields.len();
245238
return plan_err!(
246-
"The function '{name}' expected {values} value argument(s) and {lambdas} lambda(s) but received {actual_values} value argument(s) and {actual_lambdas} lambda(s)"
239+
"The function '{name}' expected {expected_len} argument(s) but received {actual_len}"
247240
);
248241
}
249242

243+
for (i, (actual, expected)) in
244+
current_fields.iter().zip(expected.iter()).enumerate()
245+
{
246+
match (actual, expected) {
247+
(ValueOrLambda::Value(_), ValueOrLambda::Value(_)) => {}
248+
(ValueOrLambda::Lambda(_), ValueOrLambda::Lambda(_)) => {}
249+
(ValueOrLambda::Value(_), ValueOrLambda::Lambda(_)) => {
250+
let name = func.name();
251+
return plan_err!(
252+
"The function '{name}' expected a lambda at position {i} but received a value"
253+
);
254+
}
255+
(ValueOrLambda::Lambda(_), ValueOrLambda::Value(_)) => {
256+
let name = func.name();
257+
return plan_err!(
258+
"The function '{name}' expected a value at position {i} but received a lambda"
259+
);
260+
}
261+
}
262+
}
263+
250264
let arg_types = current_fields
251265
.iter()
252266
.filter_map(|p| match p {
@@ -2088,7 +2102,10 @@ mod tests {
20882102
#[test]
20892103
fn test_higher_order_function_exact_signature() {
20902104
let fun = MockHigherOrderUDF {
2091-
signature: HigherOrderSignature::exact(1, 1, Volatility::Immutable),
2105+
signature: HigherOrderSignature::exact(
2106+
vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
2107+
Volatility::Immutable,
2108+
),
20922109
coerced_value_types: vec![DataType::new_large_list(DataType::Int32, false)],
20932110
};
20942111

@@ -2122,7 +2139,10 @@ mod tests {
21222139
#[test]
21232140
fn test_higher_order_function_exact_signature_wrong_value_count() {
21242141
let fun = MockHigherOrderUDF {
2125-
signature: HigherOrderSignature::exact(1, 1, Volatility::Immutable),
2142+
signature: HigherOrderSignature::exact(
2143+
vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
2144+
Volatility::Immutable,
2145+
),
21262146
coerced_value_types: vec![],
21272147
};
21282148

@@ -2134,14 +2154,17 @@ mod tests {
21342154

21352155
assert_contains!(
21362156
err.to_string(),
2137-
"expected 1 value argument(s) and 1 lambda(s) but received 0 value argument(s) and 2 lambda(s)"
2157+
"expected a value at position 0 but received a lambda"
21382158
);
21392159
}
21402160

21412161
#[test]
21422162
fn test_higher_order_function_exact_signature_wrong_lambda_count() {
21432163
let fun = MockHigherOrderUDF {
2144-
signature: HigherOrderSignature::exact(1, 1, Volatility::Immutable),
2164+
signature: HigherOrderSignature::exact(
2165+
vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
2166+
Volatility::Immutable,
2167+
),
21452168
coerced_value_types: vec![],
21462169
};
21472170

@@ -2156,7 +2179,7 @@ mod tests {
21562179

21572180
assert_contains!(
21582181
err.to_string(),
2159-
"expected 1 value argument(s) and 1 lambda(s) but received 2 value argument(s) and 0 lambda(s)"
2182+
"expected a lambda at position 1 but received a value"
21602183
);
21612184
}
21622185
}

datafusion/functions-nested/src/array_any_match.rs

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ impl Default for ArrayAnyMatch {
8181
impl ArrayAnyMatch {
8282
pub fn new() -> Self {
8383
Self {
84-
signature: HigherOrderSignature::exact(1, 1, Volatility::Immutable),
84+
signature: HigherOrderSignature::exact(
85+
vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
86+
Volatility::Immutable,
87+
),
8588
aliases: vec![String::from("any_match"), String::from("list_any_match")],
8689
}
8790
}
@@ -118,7 +121,11 @@ impl HigherOrderUDF for ArrayAnyMatch {
118121

119122
fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
120123
let [list] = arg_types else {
121-
unreachable!("arity enforced by Exact signature")
124+
return plan_err!(
125+
"{} function requires 1 value argument, got {}",
126+
self.name(),
127+
arg_types.len()
128+
);
122129
};
123130

124131
let coerced = match list {
@@ -144,15 +151,15 @@ impl HigherOrderUDF for ArrayAnyMatch {
144151
_step: usize,
145152
fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
146153
) -> Result<LambdaParametersProgress> {
147-
let [list, _lambda] = take_function_args(self.name(), fields)?;
148-
149-
let field = match list {
150-
ValueOrLambda::Value(f) => match f.data_type() {
151-
DataType::List(field) => field,
152-
DataType::LargeList(field) => field,
153-
other => return plan_err!("expected list, got {other}"),
154-
},
155-
_ => return plan_err!("{} expected a value as first argument", self.name()),
154+
let [list, _] = take_function_args(self.name(), fields)?;
155+
let ValueOrLambda::Value(list) = list else {
156+
return plan_err!("{} expects a value as first argument", self.name());
157+
};
158+
159+
let field = match list.data_type() {
160+
DataType::List(field) => field,
161+
DataType::LargeList(field) => field,
162+
other => return plan_err!("expected list, got {other}"),
156163
};
157164

158165
Ok(LambdaParametersProgress::Complete(vec![vec![Arc::clone(
@@ -164,15 +171,18 @@ impl HigherOrderUDF for ArrayAnyMatch {
164171
&self,
165172
args: HigherOrderReturnFieldArgs,
166173
) -> Result<Arc<Field>> {
167-
let [list, _lambda] = take_function_args(self.name(), args.arg_fields)?;
168-
let nullable = matches!(list, ValueOrLambda::Value(f) if f.is_nullable());
174+
let [ValueOrLambda::Value(list), _] =
175+
take_function_args(self.name(), args.arg_fields)?
176+
else {
177+
return plan_err!("{} expects a value as first argument", self.name());
178+
};
179+
let nullable = list.is_nullable();
169180
Ok(Arc::new(Field::new("", DataType::Boolean, nullable)))
170181
}
171182

172183
fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue> {
173-
let [list, lambda] = take_function_args(self.name(), &args.args)?;
174-
175-
let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)) = (list, lambda)
184+
let [ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)] =
185+
take_function_args(self.name(), &args.args)?
176186
else {
177187
return exec_err!("{} expects a value followed by a lambda", self.name());
178188
};

datafusion/functions-nested/src/array_transform.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ impl Default for ArrayTransform {
7777
impl ArrayTransform {
7878
pub fn new() -> Self {
7979
Self {
80-
signature: HigherOrderSignature::exact(1, 1, Volatility::Immutable),
80+
signature: HigherOrderSignature::exact(
81+
vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
82+
Volatility::Immutable,
83+
),
8184
aliases: vec![String::from("list_transform")],
8285
}
8386
}
@@ -98,7 +101,11 @@ impl HigherOrderUDF for ArrayTransform {
98101

99102
fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
100103
let [list] = arg_types else {
101-
unreachable!("arity enforced by Exact signature")
104+
return plan_err!(
105+
"{} function requires 1 value argument, got {}",
106+
self.name(),
107+
arg_types.len()
108+
);
102109
};
103110

104111
let coerced = match list {
@@ -146,8 +153,8 @@ impl HigherOrderUDF for ArrayTransform {
146153
&self,
147154
args: HigherOrderReturnFieldArgs,
148155
) -> Result<Arc<Field>> {
149-
let [list, lambda] = take_function_args(self.name(), args.arg_fields)?;
150-
let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)) = (list, lambda)
156+
let [ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)] =
157+
take_function_args(self.name(), args.arg_fields)?
151158
else {
152159
return plan_err!("{} expects a value followed by a lambda", self.name());
153160
};

datafusion/sqllogictest/test_files/array/array_transform.slt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,13 +396,13 @@ physical_plan
396396
query error
397397
select array_transform();
398398
----
399-
DataFusion error: Error during planning: The function 'array_transform' expected 1 value argument(s) and 1 lambda(s) but received 0 value argument(s) and 0 lambda(s)
399+
DataFusion error: Error during planning: The function 'array_transform' expected 2 argument(s) but received 0
400400

401401

402402
query error DataFusion error: Error during planning: array_transform expected a list as first argument, got Int64
403403
select array_transform(1, v -> v*2);
404404

405-
query error DataFusion error: Error during planning: array_transform expects a value as first argument
405+
query error DataFusion error: Error during planning: The function 'array_transform' expected a value at position 0 but received a lambda
406406
select array_transform(v -> v*2, [1, 2]);
407407

408408
query error DataFusion error: Error during planning: lambda defined 3 params but UDF support only 1

0 commit comments

Comments
 (0)