Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions datafusion/expr/src/higher_order_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ pub enum HigherOrderTypeSignature {
VariadicAny,
/// The specified number of lambdas or arguments with arbitrary types.
Any(usize),
/// Exactly the specified number of value arguments and lambda arguments, in any order,
/// with arbitrary types. DataFusion will call [`HigherOrderUDF::coerce_value_types`]
/// to prepare the value argument types.
Exact { values: usize, lambdas: usize },
}

/// Provides information necessary for calling a higher order function.
Expand Down Expand Up @@ -138,6 +142,18 @@ impl HigherOrderSignature {
}
}

/// Exactly the specified number of value arguments and lambda arguments, with
/// arbitrary types. DataFusion will call [`HigherOrderUDF::coerce_value_types`]
/// to prepare the value argument types.
pub fn exact(values: usize, lambdas: usize, volatility: Volatility) -> Self {
Self {
type_signature: HigherOrderTypeSignature::Exact { values, lambdas },
volatility,
coerce_values_for_lambdas: false,
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
}
}

/// Set [Self::coerce_values_for_lambdas] to true to indicate that [HigherOrderUDF::coerce_values_for_lambdas]
/// should be called
pub fn with_coerce_values_for_lambdas(mut self) -> Self {
Expand Down
133 changes: 133 additions & 0 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,64 @@ pub fn value_fields_with_higher_order_udf<L: Clone>(

Ok(current_fields.to_vec())
}
HigherOrderTypeSignature::Exact { values, lambdas } => {
Comment thread
gabotechs marked this conversation as resolved.
Outdated
let actual_values = current_fields
.iter()
.filter(|f| matches!(f, ValueOrLambda::Value(_)))
.count();
let actual_lambdas = current_fields
.iter()
.filter(|f| matches!(f, ValueOrLambda::Lambda(_)))
.count();

if actual_values != values || actual_lambdas != lambdas {
let name = func.name();
return plan_err!(
"The function '{name}' expected {values} value argument(s) and {lambdas} lambda(s) but received {actual_values} value argument(s) and {actual_lambdas} lambda(s)"
);
}

let arg_types = current_fields
.iter()
.filter_map(|p| match p {
ValueOrLambda::Value(field) => Some(field.data_type().clone()),
ValueOrLambda::Lambda(_) => None,
})
.collect::<Vec<_>>();

let coerced_types = func.coerce_value_types(&arg_types)?;

if coerced_types.len() != arg_types.len() {
return plan_err!(
"{} coerce_value_types should have returned {} items but returned {}",
func.name(),
arg_types.len(),
coerced_types.len()
);
}

let mut coerced_types = coerced_types.into_iter();

current_fields
.iter()
.map(|current_field| match current_field {
ValueOrLambda::Value(field) => {
let data_type = coerced_types.next().ok_or_else(|| {
internal_datafusion_err!(
"coerced_types len should have been checked above"
)
})?;

Ok(ValueOrLambda::Value(Arc::new(
field.as_ref().clone().with_data_type(data_type),
)))
}
ValueOrLambda::Lambda(lambda) => {
Ok(ValueOrLambda::Lambda(lambda.clone()))
}
})
.collect()
}
}
}

Expand Down Expand Up @@ -2026,4 +2084,79 @@ mod tests {
"The function 'mock_higher_order_function' expected 1 arguments but received 0"
);
}

#[test]
fn test_higher_order_function_exact_signature() {
let fun = MockHigherOrderUDF {
signature: HigherOrderSignature::exact(1, 1, Volatility::Immutable),
coerced_value_types: vec![DataType::new_large_list(DataType::Int32, false)],
};

let new_fields = value_fields_with_higher_order_udf(
&[
ValueOrLambda::Value(Arc::new(Field::new_list(
"",
Field::new_list_field(DataType::Int32, false),
false,
))),
ValueOrLambda::Lambda(()),
],
&fun,
)
.unwrap();

// type coercion applied: List(Int32) -> LargeList(Int32)
assert_eq!(
new_fields,
vec![
ValueOrLambda::Value(Arc::new(Field::new_large_list(
"",
Field::new_list_field(DataType::Int32, false),
false
))),
ValueOrLambda::Lambda(()),
]
)
}

#[test]
fn test_higher_order_function_exact_signature_wrong_value_count() {
let fun = MockHigherOrderUDF {
signature: HigherOrderSignature::exact(1, 1, Volatility::Immutable),
coerced_value_types: vec![],
};

let err = value_fields_with_higher_order_udf::<()>(
&[ValueOrLambda::Lambda(()), ValueOrLambda::Lambda(())],
&fun,
)
.unwrap_err();

assert_contains!(
err.to_string(),
"expected 1 value argument(s) and 1 lambda(s) but received 0 value argument(s) and 2 lambda(s)"
);
}

#[test]
fn test_higher_order_function_exact_signature_wrong_lambda_count() {
let fun = MockHigherOrderUDF {
signature: HigherOrderSignature::exact(1, 1, Volatility::Immutable),
coerced_value_types: vec![],
};

let err = value_fields_with_higher_order_udf::<()>(
&[
ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, false))),
ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, false))),
],
&fun,
)
.unwrap_err();

assert_contains!(
err.to_string(),
"expected 1 value argument(s) and 1 lambda(s) but received 2 value argument(s) and 0 lambda(s)"
);
}
}
12 changes: 3 additions & 9 deletions datafusion/functions-nested/src/array_any_match.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl Default for ArrayAnyMatch {
impl ArrayAnyMatch {
pub fn new() -> Self {
Self {
signature: HigherOrderSignature::user_defined(Volatility::Immutable),
signature: HigherOrderSignature::exact(1, 1, Volatility::Immutable),
aliases: vec![String::from("any_match"), String::from("list_any_match")],
}
}
Expand Down Expand Up @@ -117,14 +117,8 @@ impl HigherOrderUDF for ArrayAnyMatch {
}

fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let list = if arg_types.len() == 1 {
&arg_types[0]
} else {
return plan_err!(
"{} function requires 1 value argument, got {}",
self.name(),
arg_types.len()
);
let [list] = arg_types else {
unreachable!("arity enforced by Exact signature")
};

Comment thread
LiaCastaneda marked this conversation as resolved.
let coerced = match list {
Expand Down
45 changes: 17 additions & 28 deletions datafusion/functions-nested/src/array_transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl Default for ArrayTransform {
impl ArrayTransform {
pub fn new() -> Self {
Self {
signature: HigherOrderSignature::user_defined(Volatility::Immutable),
signature: HigherOrderSignature::exact(1, 1, Volatility::Immutable),
aliases: vec![String::from("list_transform")],
}
}
Expand All @@ -97,14 +97,8 @@ impl HigherOrderUDF for ArrayTransform {
}

fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let list = if arg_types.len() == 1 {
&arg_types[0]
} else {
return plan_err!(
"{} function requires 1 value arguments, got {}",
self.name(),
arg_types.len()
);
let [list] = arg_types else {
unreachable!("arity enforced by Exact signature")
};

Comment thread
LiaCastaneda marked this conversation as resolved.
let coerced = match list {
Expand All @@ -130,7 +124,10 @@ impl HigherOrderUDF for ArrayTransform {
_step: usize,
fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
) -> Result<LambdaParametersProgress> {
let (list, _lambda) = value_lambda_pair(self.name(), fields)?;
let [list, _] = take_function_args(self.name(), fields)?;
let ValueOrLambda::Value(list) = list else {
return plan_err!("{} expects a value as first argument", self.name());
};

let field = match list.data_type() {
DataType::List(field) => field,
Expand All @@ -149,7 +146,11 @@ impl HigherOrderUDF for ArrayTransform {
&self,
args: HigherOrderReturnFieldArgs,
) -> Result<Arc<Field>> {
let (list, lambda) = value_lambda_pair(self.name(), args.arg_fields)?;
let [list, lambda] = take_function_args(self.name(), args.arg_fields)?;
let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)) = (list, lambda)
else {
return plan_err!("{} expects a value followed by a lambda", self.name());
};

//TODO: should metadata be copied into the transformed array?

Expand All @@ -171,7 +172,11 @@ impl HigherOrderUDF for ArrayTransform {
}

fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue> {
let (list, lambda) = value_lambda_pair(self.name(), &args.args)?;
let [list, lambda] = take_function_args(self.name(), &args.args)?;
let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)) = (list, lambda)
else {
return plan_err!("{} expects a value followed by a lambda", self.name());
};

let list_array = list.to_array(args.number_rows)?;

Expand Down Expand Up @@ -265,22 +270,6 @@ impl HigherOrderUDF for ArrayTransform {
}
}

fn value_lambda_pair<'a, V: Debug, L: Debug>(
name: &str,
args: &'a [ValueOrLambda<V, L>],
) -> Result<(&'a V, &'a L)> {
let [value, lambda] = take_function_args(name, args)?;

let (ValueOrLambda::Value(value), ValueOrLambda::Lambda(lambda)) = (value, lambda)
else {
return plan_err!(
"{name} expects a value followed by a lambda, got {value:?} and {lambda:?}"
);
};

Ok((value, lambda))
}

#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::Arc};
Expand Down
Loading