Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 100 additions & 58 deletions datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,35 @@
// specific language governing permissions and limitations
// under the License.

use crate::logical_plan::producer::{SubstraitProducer, to_substrait_literal_expr};
use crate::logical_plan::producer::{
SubstraitProducer, to_substrait_literal_expr, to_substrait_type,
};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::datatype::FieldExt;
use datafusion::common::{
DFSchemaRef, ScalarValue, internal_datafusion_err, not_impl_err, substrait_err,
};
use datafusion::logical_expr::{Between, BinaryExpr, Expr, Like, Operator, expr};
use datafusion::logical_expr::{
Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, expr,
};
use substrait::proto::expression::{RexType, ScalarFunction};
use substrait::proto::function_argument::ArgType;
use substrait::proto::{Expression, FunctionArgument};
use substrait::proto::{Expression, FunctionArgument, Type};

pub fn from_scalar_function(
producer: &mut impl SubstraitProducer,
fun: &expr::ScalarFunction,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
from_function(producer, fun.name(), &fun.args, schema)
let (_, output_field) = Expr::ScalarFunction(fun.clone()).to_field(schema)?;
from_function(
producer,
fun.name(),
&fun.args,
output_field.data_type(),
output_field.is_nullable(),
schema,
)
}

pub fn from_higher_order_function(
Expand Down Expand Up @@ -100,12 +113,20 @@ pub fn from_higher_order_function(
.collect::<datafusion::common::Result<_>>()?;

let function_anchor = producer.register_function(fun.name().to_string());

let (_, output_field) = Expr::HigherOrderFunction(fun.clone()).to_field(schema)?;
let output_type = to_substrait_type(
producer,
output_field.data_type(),
output_field.is_nullable(),
)?;

#[expect(deprecated)]
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments,
output_type: None,
output_type: Some(output_type),
options: vec![],
args: vec![],
})),
Expand All @@ -116,6 +137,8 @@ fn from_function(
producer: &mut impl SubstraitProducer,
name: &str,
args: &[Expr],
output_type: &DataType,
output_nullability: bool,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
let mut arguments: Vec<FunctionArgument> = vec![];
Expand All @@ -126,14 +149,15 @@ fn from_function(
}

let arguments = custom_argument_handler(name, arguments);
let output_type = to_substrait_type(producer, output_type, output_nullability)?;

let function_anchor = producer.register_function(name.to_string());
#[expect(deprecated)]
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
arguments,
output_type: None,
output_type: Some(output_type),
options: vec![],
args: vec![],
})),
Expand Down Expand Up @@ -177,7 +201,13 @@ pub fn from_unary_expr(
Expr::Negative(arg) => ("negate", arg),
expr => not_impl_err!("Unsupported expression: {expr:?}")?,
};
to_substrait_unary_scalar_fn(producer, fn_name, arg, schema)
let (_, output_field) = expr.to_field(schema)?;
let output_type = to_substrait_type(
producer,
output_field.data_type(),
output_field.is_nullable(),
)?;
to_substrait_unary_scalar_fn(producer, fn_name, arg, schema, &output_type)
}

pub fn from_binary_expr(
Expand All @@ -188,7 +218,19 @@ pub fn from_binary_expr(
let BinaryExpr { left, op, right } = expr;
let l = producer.handle_expr(left, schema)?;
let r = producer.handle_expr(right, schema)?;
Ok(make_binary_op_scalar_func(producer, &l, &r, *op))
let (_, output_field) = Expr::BinaryExpr(expr.clone()).to_field(schema)?;
let output_type = to_substrait_type(
producer,
output_field.data_type(),
output_field.is_nullable(),
)?;
Ok(make_binary_op_scalar_func(
producer,
&l,
&r,
*op,
&output_type,
))
}

pub fn from_like(
Expand Down Expand Up @@ -283,6 +325,7 @@ fn to_substrait_unary_scalar_fn(
fn_name: &str,
arg: &Expr,
schema: &DFSchemaRef,
output_type: &Type,
) -> datafusion::common::Result<Expression> {
let function_anchor = producer.register_function(fn_name.to_string());
let substrait_expr = producer.handle_expr(arg, schema)?;
Expand All @@ -293,7 +336,7 @@ fn to_substrait_unary_scalar_fn(
arguments: vec![FunctionArgument {
arg_type: Some(ArgType::Value(substrait_expr)),
}],
output_type: None,
output_type: Some(output_type.clone()),
options: vec![],
..Default::default()
})),
Expand All @@ -306,6 +349,7 @@ pub fn make_binary_op_scalar_func(
lhs: &Expression,
rhs: &Expression,
op: Operator,
output_type: &Type,
) -> Expression {
let function_anchor = producer.register_function(operator_to_name(op).to_string());
#[expect(deprecated)]
Expand All @@ -320,7 +364,7 @@ pub fn make_binary_op_scalar_func(
arg_type: Some(ArgType::Value(rhs.clone())),
},
],
output_type: None,
output_type: Some(output_type.clone()),
args: vec![],
options: vec![],
})),
Expand All @@ -338,57 +382,21 @@ pub fn from_between(
low,
high,
} = between;
if *negated {
// `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr)
let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?;
let substrait_low = producer.handle_expr(low.as_ref(), schema)?;
let substrait_high = producer.handle_expr(high.as_ref(), schema)?;

let l_expr = make_binary_op_scalar_func(
producer,
&substrait_expr,
&substrait_low,
Operator::Lt,
);
let r_expr = make_binary_op_scalar_func(
producer,
&substrait_high,
&substrait_expr,
Operator::Lt,
);

Ok(make_binary_op_scalar_func(
producer,
&l_expr,
&r_expr,
Operator::Or,
))
let expr = if *negated {
// `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr)
Expr::or(
Expr::lt(*expr.clone(), *low.clone()),
Expr::lt(*high.clone(), *expr.clone()),
)
Comment thread
wlhjason marked this conversation as resolved.
} else {
// `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high)
let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?;
let substrait_low = producer.handle_expr(low.as_ref(), schema)?;
let substrait_high = producer.handle_expr(high.as_ref(), schema)?;

let l_expr = make_binary_op_scalar_func(
producer,
&substrait_low,
&substrait_expr,
Operator::LtEq,
);
let r_expr = make_binary_op_scalar_func(
producer,
&substrait_expr,
&substrait_high,
Operator::LtEq,
);

Ok(make_binary_op_scalar_func(
producer,
&l_expr,
&r_expr,
Operator::And,
))
}
Expr::and(
Expr::lt_eq(*low.clone(), *expr.clone()),
Expr::lt_eq(*expr.clone(), *high.clone()),
)
};
producer.handle_expr(&expr, schema)
}

pub fn operator_to_name(op: Operator) -> &'static str {
Expand Down Expand Up @@ -438,3 +446,37 @@ pub fn operator_to_name(op: Operator) -> &'static str {
Operator::Colon => "colon",
}
}

#[cfg(test)]
mod tests {
use crate::logical_plan::producer::{
DefaultSubstraitProducer, SubstraitProducer, to_substrait_type,
};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::{DFSchema, DFSchemaRef};
use datafusion::execution::SessionStateBuilder;
use datafusion::prelude::lit;
use substrait::proto::Expression;
use substrait::proto::expression::{RexType, ScalarFunction};

#[tokio::test]
async fn binary_expr_output_type() -> datafusion::common::Result<()> {
let state = SessionStateBuilder::default().build();
let empty_schema = DFSchemaRef::new(DFSchema::empty());
let mut producer = DefaultSubstraitProducer::new(&state);

let expr = lit(1i64) + lit(2i64);
let substrait_expr = producer.handle_expr(&expr, &empty_schema)?;
if let Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction { output_type, .. })),
} = substrait_expr
{
let expected_type =
to_substrait_type(&mut producer, &DataType::Int64, false)?;
assert_eq!(output_type, Some(expected_type));
Ok(())
} else {
panic!("Substrait ScalarFunction expected")
}
}
}
Loading
Loading