Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 9 additions & 26 deletions datafusion/expr/src/higher_order_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ pub struct HigherOrderSignature {
pub type_signature: HigherOrderTypeSignature,
/// The volatility of the function. See [Volatility] for more information.
pub volatility: Volatility,
/// Whether [HigherOrderUDF::coerce_values_for_lambdas] should be called
pub coerce_values_for_lambdas: bool,
/// The max number of times to call [HigherOrderUDF::lambda_parameters] before raising an error.
/// Used to guard against implementations that causes an infinite loop by endlessly returning
/// [LambdaParametersProgress::Partial]. Defaults to 256
Expand All @@ -90,7 +88,6 @@ impl HigherOrderSignature {
HigherOrderSignature {
type_signature,
volatility,
coerce_values_for_lambdas: false,
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
}
}
Expand All @@ -100,7 +97,6 @@ impl HigherOrderSignature {
Self {
type_signature: HigherOrderTypeSignature::UserDefined,
volatility,
coerce_values_for_lambdas: false,
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
}
}
Expand All @@ -110,7 +106,6 @@ impl HigherOrderSignature {
Self {
type_signature: HigherOrderTypeSignature::VariadicAny,
volatility,
coerce_values_for_lambdas: false,
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
}
}
Expand All @@ -120,18 +115,9 @@ impl HigherOrderSignature {
Self {
type_signature: HigherOrderTypeSignature::Any(arg_count),
volatility,
coerce_values_for_lambdas: false,
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
}
}

/// Set [Self::coerce_values_for_lambdas] to true to indicate that [HigherOrderUDF::coerce_values_for_lambdas]
/// should be called
pub fn with_coerce_values_for_lambdas(mut self) -> Self {
self.coerce_values_for_lambdas = true;

self
}
}

impl PartialEq for dyn HigherOrderUDF {
Expand Down Expand Up @@ -518,7 +504,7 @@ pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any {
///
/// The implementation can assume that some other part of the code has coerced
/// the actual argument types to match [`Self::signature`], except the coercion defined by
/// [Self::coerce_values_for_lambdas], if applicable.
/// [Self::coerce_values_for_lambdas].
///
/// [`HigherOrderFunction`]: crate::expr::HigherOrderFunction
/// [`HigherOrderFunction::lambda_parameters`]: crate::expr::HigherOrderFunction::lambda_parameters
Expand All @@ -531,8 +517,7 @@ pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any {
/// Coerce value arguments of a function call to types that the function can evaluate also taking into
/// account the *output type of it's lambdas*. This differs from [HigherOrderUDF::coerce_value_types]
/// that only has access to the type of it's value arguments because it's called before the output type
/// of lambdas are known. So that this method is called, the function must have it's
/// [HigherOrderSignature::coerce_values_for_lambdas] set to true
/// of lambdas are known.
///
/// See the [type coercion module](crate::type_coercion)
/// documentation for more details on type coercion
Expand All @@ -541,29 +526,27 @@ pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any {
/// * `fields`: The argument types of the value arguments of this function, or the output type of lambdas
///
/// # Return value
/// A Vec with the same number of [ValueOrLambda::Value] in `fields`. DataFusion will `CAST` the
/// function call arguments to these specific types.
/// If `Some`, contains a Vec with the same number of [ValueOrLambda::Value] in `fields`.
/// DataFusion will `CAST` the function call arguments to these specific types. If `None`, no
/// coercion will be applied beyond the one defined by the function signature.
///
Comment on lines 671 to 675
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be updated now?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thank you 030c2fb

/// For example, a flexible array_reduce implementation (see [Self::lambda_parameters] docs), when working
/// with the expression below, may want to coerce it's initial value argument, the *integer* `0`,
/// to match the output it's merge function, which is a *float*:
/// to match the output of it's merge function, which is a *float*:
///
/// `array_reduce([1.2, 2.1], 0, (acc, v) -> acc + v + 1.5, v -> v > 2.0)`
fn coerce_values_for_lambdas(
&self,
_fields: &[ValueOrLambda<DataType, DataType>],
) -> Result<Vec<DataType>> {
not_impl_err!(
"{} coerce_values_for_lambdas is not implemented",
self.name()
)
) -> Result<Option<Vec<DataType>>> {
Ok(None)
}

/// What type will be returned by this function, given the arguments?
///
/// The implementation can assume that some other part of the code has coerced
/// the actual argument types to match [`Self::signature`], including the coercion
/// defined by [Self::coerce_values_for_lambdas], if applicable.
/// defined by [Self::coerce_values_for_lambdas].
///
/// # Example creating `Field`
///
Expand Down
25 changes: 12 additions & 13 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,16 +251,16 @@ pub fn value_fields_with_higher_order_udf_and_lambdas(
) -> Result<Vec<ValueOrLambda<FieldRef, FieldRef>>> {
let mut new_fields = value_fields_with_higher_order_udf(current_fields, func)?;

if func.signature().coerce_values_for_lambdas {
let new_types = new_fields
.iter()
.map(|f| match f {
ValueOrLambda::Value(f) => ValueOrLambda::Value(f.data_type().clone()),
ValueOrLambda::Lambda(f) => ValueOrLambda::Lambda(f.data_type().clone()),
})
.collect::<Vec<_>>();
let new_types = new_fields
.iter()
.map(|f| match f {
ValueOrLambda::Value(f) => ValueOrLambda::Value(f.data_type().clone()),
ValueOrLambda::Lambda(f) => ValueOrLambda::Lambda(f.data_type().clone()),
})
.collect::<Vec<_>>();

let mut new_value_types = func.coerce_values_for_lambdas(&new_types)?.into_iter();
if let Some(new_value_types) = func.coerce_values_for_lambdas(&new_types)? {
let mut new_value_types = new_value_types.into_iter();

let value_types_count = new_types
.iter()
Expand Down Expand Up @@ -1851,7 +1851,7 @@ mod tests {
fn coerce_values_for_lambdas(
&self,
fields: &[ValueOrLambda<DataType, DataType>],
) -> Result<Vec<DataType>> {
) -> Result<Option<Vec<DataType>>> {
// thoerical impl of array_reduce without finish
let [
ValueOrLambda::Value(list),
Expand All @@ -1862,7 +1862,7 @@ mod tests {
unreachable!()
};

Ok(vec![list.clone(), merge.clone()])
Ok(Some(vec![list.clone(), merge.clone()]))
}

fn lambda_parameters(
Expand Down Expand Up @@ -1925,8 +1925,7 @@ mod tests {
#[test]
fn test_higher_order_function_coerce_values_for_lambdas() {
let fun = MockHigherOrderUDF {
signature: HigherOrderSignature::variadic_any(Volatility::Immutable)
.with_coerce_values_for_lambdas(),
signature: HigherOrderSignature::variadic_any(Volatility::Immutable),
coerced_value_types: vec![],
};

Expand Down
Loading