diff --git a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs index e7dd2af13f9ca..75720395aae7c 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs @@ -15,22 +15,35 @@ // specific language governing permissions and limitations // under the License. -use crate::logical_plan::producer::{SubstraitProducer, to_substrait_literal_expr}; +use crate::logical_plan::producer::{ + SubstraitProducer, to_substrait_literal_expr, to_substrait_type, +}; +use datafusion::arrow::datatypes::DataType; use datafusion::common::datatype::FieldExt; use datafusion::common::{ DFSchemaRef, ScalarValue, internal_datafusion_err, not_impl_err, substrait_err, }; -use datafusion::logical_expr::{Between, BinaryExpr, Expr, Like, Operator, expr}; +use datafusion::logical_expr::{ + Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, expr, +}; use substrait::proto::expression::{RexType, ScalarFunction}; use substrait::proto::function_argument::ArgType; -use substrait::proto::{Expression, FunctionArgument}; +use substrait::proto::{Expression, FunctionArgument, Type}; pub fn from_scalar_function( producer: &mut impl SubstraitProducer, fun: &expr::ScalarFunction, schema: &DFSchemaRef, ) -> datafusion::common::Result { - from_function(producer, fun.name(), &fun.args, schema) + let (_, output_field) = Expr::ScalarFunction(fun.clone()).to_field(schema)?; + from_function( + producer, + fun.name(), + &fun.args, + output_field.data_type(), + output_field.is_nullable(), + schema, + ) } pub fn from_higher_order_function( @@ -100,12 +113,20 @@ pub fn from_higher_order_function( .collect::>()?; let function_anchor = producer.register_function(fun.name().to_string()); + + let (_, output_field) = Expr::HigherOrderFunction(fun.clone()).to_field(schema)?; + let output_type = to_substrait_type( + producer, + output_field.data_type(), + output_field.is_nullable(), + )?; + #[expect(deprecated)] Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, arguments, - output_type: None, + output_type: Some(output_type), options: vec![], args: vec![], })), @@ -116,6 +137,8 @@ fn from_function( producer: &mut impl SubstraitProducer, name: &str, args: &[Expr], + output_type: &DataType, + output_nullability: bool, schema: &DFSchemaRef, ) -> datafusion::common::Result { let mut arguments: Vec = vec![]; @@ -126,6 +149,7 @@ fn from_function( } let arguments = custom_argument_handler(name, arguments); + let output_type = to_substrait_type(producer, output_type, output_nullability)?; let function_anchor = producer.register_function(name.to_string()); #[expect(deprecated)] @@ -133,7 +157,7 @@ fn from_function( rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, arguments, - output_type: None, + output_type: Some(output_type), options: vec![], args: vec![], })), @@ -177,7 +201,13 @@ pub fn from_unary_expr( Expr::Negative(arg) => ("negate", arg), expr => not_impl_err!("Unsupported expression: {expr:?}")?, }; - to_substrait_unary_scalar_fn(producer, fn_name, arg, schema) + let (_, output_field) = expr.to_field(schema)?; + let output_type = to_substrait_type( + producer, + output_field.data_type(), + output_field.is_nullable(), + )?; + to_substrait_unary_scalar_fn(producer, fn_name, arg, schema, &output_type) } pub fn from_binary_expr( @@ -188,7 +218,19 @@ pub fn from_binary_expr( let BinaryExpr { left, op, right } = expr; let l = producer.handle_expr(left, schema)?; let r = producer.handle_expr(right, schema)?; - Ok(make_binary_op_scalar_func(producer, &l, &r, *op)) + let (_, output_field) = Expr::BinaryExpr(expr.clone()).to_field(schema)?; + let output_type = to_substrait_type( + producer, + output_field.data_type(), + output_field.is_nullable(), + )?; + Ok(make_binary_op_scalar_func( + producer, + &l, + &r, + *op, + &output_type, + )) } pub fn from_like( @@ -283,6 +325,7 @@ fn to_substrait_unary_scalar_fn( fn_name: &str, arg: &Expr, schema: &DFSchemaRef, + output_type: &Type, ) -> datafusion::common::Result { let function_anchor = producer.register_function(fn_name.to_string()); let substrait_expr = producer.handle_expr(arg, schema)?; @@ -293,7 +336,7 @@ fn to_substrait_unary_scalar_fn( arguments: vec![FunctionArgument { arg_type: Some(ArgType::Value(substrait_expr)), }], - output_type: None, + output_type: Some(output_type.clone()), options: vec![], ..Default::default() })), @@ -306,6 +349,7 @@ pub fn make_binary_op_scalar_func( lhs: &Expression, rhs: &Expression, op: Operator, + output_type: &Type, ) -> Expression { let function_anchor = producer.register_function(operator_to_name(op).to_string()); #[expect(deprecated)] @@ -320,7 +364,7 @@ pub fn make_binary_op_scalar_func( arg_type: Some(ArgType::Value(rhs.clone())), }, ], - output_type: None, + output_type: Some(output_type.clone()), args: vec![], options: vec![], })), @@ -338,57 +382,21 @@ pub fn from_between( low, high, } = between; - if *negated { - // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) - let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; - let substrait_low = producer.handle_expr(low.as_ref(), schema)?; - let substrait_high = producer.handle_expr(high.as_ref(), schema)?; - - let l_expr = make_binary_op_scalar_func( - producer, - &substrait_expr, - &substrait_low, - Operator::Lt, - ); - let r_expr = make_binary_op_scalar_func( - producer, - &substrait_high, - &substrait_expr, - Operator::Lt, - ); - Ok(make_binary_op_scalar_func( - producer, - &l_expr, - &r_expr, - Operator::Or, - )) + let expr = if *negated { + // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) + Expr::or( + Expr::lt(*expr.clone(), *low.clone()), + Expr::lt(*high.clone(), *expr.clone()), + ) } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) - let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?; - let substrait_low = producer.handle_expr(low.as_ref(), schema)?; - let substrait_high = producer.handle_expr(high.as_ref(), schema)?; - - let l_expr = make_binary_op_scalar_func( - producer, - &substrait_low, - &substrait_expr, - Operator::LtEq, - ); - let r_expr = make_binary_op_scalar_func( - producer, - &substrait_expr, - &substrait_high, - Operator::LtEq, - ); - - Ok(make_binary_op_scalar_func( - producer, - &l_expr, - &r_expr, - Operator::And, - )) - } + Expr::and( + Expr::lt_eq(*low.clone(), *expr.clone()), + Expr::lt_eq(*expr.clone(), *high.clone()), + ) + }; + producer.handle_expr(&expr, schema) } pub fn operator_to_name(op: Operator) -> &'static str { @@ -438,3 +446,37 @@ pub fn operator_to_name(op: Operator) -> &'static str { Operator::Colon => "colon", } } + +#[cfg(test)] +mod tests { + use crate::logical_plan::producer::{ + DefaultSubstraitProducer, SubstraitProducer, to_substrait_type, + }; + use datafusion::arrow::datatypes::DataType; + use datafusion::common::{DFSchema, DFSchemaRef}; + use datafusion::execution::SessionStateBuilder; + use datafusion::prelude::lit; + use substrait::proto::Expression; + use substrait::proto::expression::{RexType, ScalarFunction}; + + #[tokio::test] + async fn binary_expr_output_type() -> datafusion::common::Result<()> { + let state = SessionStateBuilder::default().build(); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + let mut producer = DefaultSubstraitProducer::new(&state); + + let expr = lit(1i64) + lit(2i64); + let substrait_expr = producer.handle_expr(&expr, &empty_schema)?; + if let Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { output_type, .. })), + } = substrait_expr + { + let expected_type = + to_substrait_type(&mut producer, &DataType::Int64, false)?; + assert_eq!(output_type, Some(expected_type)); + Ok(()) + } else { + panic!("Substrait ScalarFunction expected") + } + } +} diff --git a/datafusion/substrait/src/logical_plan/producer/rel/join.rs b/datafusion/substrait/src/logical_plan/producer/rel/join.rs index cbf5593ffc86c..9094774780e10 100644 --- a/datafusion/substrait/src/logical_plan/producer/rel/join.rs +++ b/datafusion/substrait/src/logical_plan/producer/rel/join.rs @@ -15,59 +15,38 @@ // specific language governing permissions and limitations // under the License. -use crate::logical_plan::producer::{SubstraitProducer, make_binary_op_scalar_func}; -use datafusion::common::{ - DFSchemaRef, JoinConstraint, JoinType, NullEquality, not_impl_err, -}; +use crate::logical_plan::producer::SubstraitProducer; +use datafusion::common::{JoinConstraint, JoinType, NullEquality, not_impl_err}; +use datafusion::logical_expr::utils::conjunction; use datafusion::logical_expr::{Expr, Join, Operator}; +use datafusion::prelude::binary_expr; use std::sync::Arc; use substrait::proto::rel::RelType; -use substrait::proto::{Expression, JoinRel, Rel, join_rel}; +use substrait::proto::{JoinRel, Rel, join_rel}; pub fn from_join( producer: &mut impl SubstraitProducer, join: &Join, ) -> datafusion::common::Result> { - let left = producer.handle_plan(join.left.as_ref())?; - let right = producer.handle_plan(join.right.as_ref())?; - let join_type = to_substrait_jointype(join.join_type); - // we only support basic joins so return an error for anything not yet supported + // only ON constraints are supported right now match join.join_constraint { JoinConstraint::On => {} JoinConstraint::Using => return not_impl_err!("join constraint: `using`"), } - let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?); - - // convert filter if present - let join_filter = match &join.filter { - Some(filter) => Some(producer.handle_expr(filter, &in_join_schema)?), - None => None, - }; - // map the left and right columns to binary expressions in the form `l = r` - // build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b` - let eq_op = match join.null_equality { - NullEquality::NullEqualsNothing => Operator::Eq, - NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom, - }; - let join_on = to_substrait_join_expr(producer, &join.on, eq_op, &in_join_schema)?; + let left = producer.handle_plan(join.left.as_ref())?; + let right = producer.handle_plan(join.right.as_ref())?; + let join_type = to_substrait_jointype(join.join_type); - // create conjunction between `join_on` and `join_filter` to embed all join conditions, - // whether equal or non-equal in a single expression - let join_expr = match &join_on { - Some(on_expr) => match &join_filter { - Some(filter) => Some(Box::new(make_binary_op_scalar_func( - producer, - on_expr, - filter, - Operator::And, - ))), - None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist - }, - None => match &join_filter { - Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist - None => None, - }, + let join_expr = + to_substrait_join_expr(join.on.clone(), join.null_equality, join.filter.clone()); + let join_expression = match join_expr { + Some(expr) => { + let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?); + let expression = producer.handle_expr(&expr, &in_join_schema)?; + Some(Box::new(expression)) + } + None => None, }; Ok(Box::new(Rel { @@ -76,7 +55,7 @@ pub fn from_join( left: Some(left), right: Some(right), r#type: join_type as i32, - expression: join_expr, + expression: join_expression, post_join_filter: None, advanced_extension: None, }))), @@ -84,25 +63,20 @@ pub fn from_join( } fn to_substrait_join_expr( - producer: &mut impl SubstraitProducer, - join_conditions: &Vec<(Expr, Expr)>, - eq_op: Operator, - join_schema: &DFSchemaRef, -) -> datafusion::common::Result> { - // Only support AND conjunction for each binary expression in join conditions - let mut exprs: Vec = vec![]; - for (left, right) in join_conditions { - let l = producer.handle_expr(left, join_schema)?; - let r = producer.handle_expr(right, join_schema)?; - // AND with existing expression - exprs.push(make_binary_op_scalar_func(producer, &l, &r, eq_op)); - } - - let join_expr: Option = - exprs.into_iter().reduce(|acc: Expression, e: Expression| { - make_binary_op_scalar_func(producer, &acc, &e, Operator::And) - }); - Ok(join_expr) + join_on: Vec<(Expr, Expr)>, + null_equality: NullEquality, + join_filter: Option, +) -> Option { + // Combine join on and filter conditions into a single Boolean expression (#7611) + let eq_op = match null_equality { + NullEquality::NullEqualsNothing => Operator::Eq, + NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom, + }; + let all_conditions = join_on + .into_iter() + .map(|(left, right)| binary_expr(left, eq_op, right)) + .chain(join_filter); + conjunction(all_conditions) } fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { @@ -119,3 +93,85 @@ fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { JoinType::RightSemi => join_rel::JoinType::RightSemi, } } + +#[cfg(test)] +mod tests { + use crate::logical_plan::producer::{ + DefaultSubstraitProducer, SubstraitProducer, to_substrait_type, + }; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::common::{JoinConstraint, JoinType, NullEquality}; + use datafusion::execution::SessionStateBuilder; + use datafusion::logical_expr::utils::conjunction; + use datafusion::logical_expr::{Join, col, table_scan}; + use std::sync::Arc; + use substrait::proto::expression::{RexType, ScalarFunction}; + use substrait::proto::rel::RelType; + use substrait::proto::{Expression, JoinRel, Rel, join_rel}; + + #[test] + fn test_from_join() -> datafusion::common::Result<()> { + let state = SessionStateBuilder::default().build(); + let mut producer = DefaultSubstraitProducer::new(&state); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + let left_scan = table_scan(Some("t1"), &schema, None)?.build()?; + let right_scan = table_scan(Some("t2"), &schema, None)?.build()?; + let join = Join::try_new( + Arc::new(left_scan.clone()), + Arc::new(right_scan.clone()), + vec![(col("t1.a"), col("t2.a")), (col("t1.b"), col("t2.b"))], + Some(col("t1.c").gt(col("t2.c"))), + JoinType::Inner, + JoinConstraint::On, + NullEquality::NullEqualsNothing, + false, + )?; + let join_expr = producer.handle_join(&join)?; + + let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?); + let expected_join_expr = conjunction(vec![ + // Join on + col("t1.a").eq(col("t2.a")), + col("t1.b").eq(col("t2.b")), + // Join filter + col("t1.c").gt(col("t2.c")), + ]) + .unwrap(); + let expected_join_expression = + producer.handle_expr(&expected_join_expr, &in_join_schema)?; + + assert_eq!( + join_expr, + Box::new(Rel { + rel_type: Some(RelType::Join(Box::new(JoinRel { + common: None, + left: Some(producer.handle_plan(&left_scan)?), + right: Some(producer.handle_plan(&right_scan)?), + r#type: join_rel::JoinType::Inner as i32, + expression: Some(Box::new(expected_join_expression.clone())), + post_join_filter: None, + advanced_extension: None, + }))) + }) + ); + + // Check that the join_expression has the expected output_type + if let Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { output_type, .. })), + } = expected_join_expression + { + let expected_type = + to_substrait_type(&mut producer, &DataType::Boolean, false)?; + assert_eq!(output_type, Some(expected_type)); + } else { + panic!("Substrait ScalarFunction expected") + } + + Ok(()) + } +}