Skip to content

Commit f50620c

Browse files
wlhjasonJason
authored andcommitted
Set Substrait output type for binary expressions
1 parent 1ab146a commit f50620c

2 files changed

Lines changed: 101 additions & 112 deletions

File tree

datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs

Lines changed: 68 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,16 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::logical_plan::producer::{SubstraitProducer, to_substrait_literal_expr};
18+
use crate::logical_plan::producer::{
19+
SubstraitProducer, to_substrait_literal_expr, to_substrait_type,
20+
};
1921
use datafusion::common::{DFSchemaRef, ScalarValue, not_impl_err};
20-
use datafusion::logical_expr::{Between, BinaryExpr, Expr, Like, Operator, expr};
22+
use datafusion::logical_expr::{
23+
Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, expr,
24+
};
2125
use substrait::proto::expression::{RexType, ScalarFunction};
2226
use substrait::proto::function_argument::ArgType;
23-
use substrait::proto::{Expression, FunctionArgument};
27+
use substrait::proto::{Expression, FunctionArgument, Type};
2428

2529
pub fn from_scalar_function(
2630
producer: &mut impl SubstraitProducer,
@@ -114,7 +118,19 @@ pub fn from_binary_expr(
114118
let BinaryExpr { left, op, right } = expr;
115119
let l = producer.handle_expr(left, schema)?;
116120
let r = producer.handle_expr(right, schema)?;
117-
Ok(make_binary_op_scalar_func(producer, &l, &r, *op))
121+
let (_, output_field) = Expr::BinaryExpr(expr.clone()).to_field(schema)?;
122+
let output_type = to_substrait_type(
123+
producer,
124+
output_field.data_type(),
125+
output_field.is_nullable(),
126+
)?;
127+
Ok(make_binary_op_scalar_func(
128+
producer,
129+
&l,
130+
&r,
131+
*op,
132+
&output_type,
133+
))
118134
}
119135

120136
pub fn from_like(
@@ -232,6 +248,7 @@ pub fn make_binary_op_scalar_func(
232248
lhs: &Expression,
233249
rhs: &Expression,
234250
op: Operator,
251+
output_type: &Type,
235252
) -> Expression {
236253
let function_anchor = producer.register_function(operator_to_name(op).to_string());
237254
#[expect(deprecated)]
@@ -246,7 +263,7 @@ pub fn make_binary_op_scalar_func(
246263
arg_type: Some(ArgType::Value(rhs.clone())),
247264
},
248265
],
249-
output_type: None,
266+
output_type: Some(output_type.clone()),
250267
args: vec![],
251268
options: vec![],
252269
})),
@@ -264,57 +281,21 @@ pub fn from_between(
264281
low,
265282
high,
266283
} = between;
267-
if *negated {
268-
// `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr)
269-
let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?;
270-
let substrait_low = producer.handle_expr(low.as_ref(), schema)?;
271-
let substrait_high = producer.handle_expr(high.as_ref(), schema)?;
272284

273-
let l_expr = make_binary_op_scalar_func(
274-
producer,
275-
&substrait_expr,
276-
&substrait_low,
277-
Operator::Lt,
278-
);
279-
let r_expr = make_binary_op_scalar_func(
280-
producer,
281-
&substrait_high,
282-
&substrait_expr,
283-
Operator::Lt,
284-
);
285-
286-
Ok(make_binary_op_scalar_func(
287-
producer,
288-
&l_expr,
289-
&r_expr,
290-
Operator::Or,
291-
))
285+
let expr = if *negated {
286+
// `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr)
287+
Expr::or(
288+
Expr::lt(*expr.clone(), *low.clone()),
289+
Expr::lt(*high.clone(), *expr.clone()),
290+
)
292291
} else {
293292
// `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high)
294-
let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?;
295-
let substrait_low = producer.handle_expr(low.as_ref(), schema)?;
296-
let substrait_high = producer.handle_expr(high.as_ref(), schema)?;
297-
298-
let l_expr = make_binary_op_scalar_func(
299-
producer,
300-
&substrait_low,
301-
&substrait_expr,
302-
Operator::LtEq,
303-
);
304-
let r_expr = make_binary_op_scalar_func(
305-
producer,
306-
&substrait_expr,
307-
&substrait_high,
308-
Operator::LtEq,
309-
);
310-
311-
Ok(make_binary_op_scalar_func(
312-
producer,
313-
&l_expr,
314-
&r_expr,
315-
Operator::And,
316-
))
317-
}
293+
Expr::and(
294+
Expr::lt_eq(*low.clone(), *expr.clone()),
295+
Expr::lt_eq(*expr.clone(), *high.clone()),
296+
)
297+
};
298+
producer.handle_expr(&expr, schema)
318299
}
319300

320301
pub fn operator_to_name(op: Operator) -> &'static str {
@@ -364,3 +345,37 @@ pub fn operator_to_name(op: Operator) -> &'static str {
364345
Operator::Colon => "colon",
365346
}
366347
}
348+
349+
#[cfg(test)]
350+
mod tests {
351+
use crate::logical_plan::producer::{
352+
DefaultSubstraitProducer, SubstraitProducer, to_substrait_type,
353+
};
354+
use datafusion::arrow::datatypes::DataType;
355+
use datafusion::common::{DFSchema, DFSchemaRef};
356+
use datafusion::execution::SessionStateBuilder;
357+
use datafusion::prelude::lit;
358+
use substrait::proto::Expression;
359+
use substrait::proto::expression::{RexType, ScalarFunction};
360+
361+
#[tokio::test]
362+
async fn binary_expr_output_type() -> datafusion::common::Result<()> {
363+
let state = SessionStateBuilder::default().build();
364+
let empty_schema = DFSchemaRef::new(DFSchema::empty());
365+
let mut producer = DefaultSubstraitProducer::new(&state);
366+
367+
let expr = lit(1i64) + lit(2i64);
368+
let substrait_expr = producer.handle_expr(&expr, &empty_schema)?;
369+
if let Expression {
370+
rex_type: Some(RexType::ScalarFunction(ScalarFunction { output_type, .. })),
371+
} = substrait_expr
372+
{
373+
let expected_type =
374+
to_substrait_type(&mut producer, &DataType::Int64, false)?;
375+
assert_eq!(output_type, Some(expected_type));
376+
Ok(())
377+
} else {
378+
panic!("Substrait ScalarFunction expected")
379+
}
380+
}
381+
}

datafusion/substrait/src/logical_plan/producer/rel/join.rs

Lines changed: 33 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -15,59 +15,38 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::logical_plan::producer::{SubstraitProducer, make_binary_op_scalar_func};
19-
use datafusion::common::{
20-
DFSchemaRef, JoinConstraint, JoinType, NullEquality, not_impl_err,
21-
};
18+
use crate::logical_plan::producer::SubstraitProducer;
19+
use datafusion::common::{JoinConstraint, JoinType, NullEquality, not_impl_err};
20+
use datafusion::logical_expr::utils::conjunction;
2221
use datafusion::logical_expr::{Expr, Join, Operator};
22+
use datafusion::prelude::binary_expr;
2323
use std::sync::Arc;
2424
use substrait::proto::rel::RelType;
25-
use substrait::proto::{Expression, JoinRel, Rel, join_rel};
25+
use substrait::proto::{JoinRel, Rel, join_rel};
2626

2727
pub fn from_join(
2828
producer: &mut impl SubstraitProducer,
2929
join: &Join,
3030
) -> datafusion::common::Result<Box<Rel>> {
31-
let left = producer.handle_plan(join.left.as_ref())?;
32-
let right = producer.handle_plan(join.right.as_ref())?;
33-
let join_type = to_substrait_jointype(join.join_type);
34-
// we only support basic joins so return an error for anything not yet supported
31+
// only ON constraints are supported right now
3532
match join.join_constraint {
3633
JoinConstraint::On => {}
3734
JoinConstraint::Using => return not_impl_err!("join constraint: `using`"),
3835
}
39-
let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?);
40-
41-
// convert filter if present
42-
let join_filter = match &join.filter {
43-
Some(filter) => Some(producer.handle_expr(filter, &in_join_schema)?),
44-
None => None,
45-
};
4636

47-
// map the left and right columns to binary expressions in the form `l = r`
48-
// build a single expression for the ON condition, such as `l.a = r.a AND l.b = r.b`
49-
let eq_op = match join.null_equality {
50-
NullEquality::NullEqualsNothing => Operator::Eq,
51-
NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom,
52-
};
53-
let join_on = to_substrait_join_expr(producer, &join.on, eq_op, &in_join_schema)?;
37+
let left = producer.handle_plan(join.left.as_ref())?;
38+
let right = producer.handle_plan(join.right.as_ref())?;
39+
let join_type = to_substrait_jointype(join.join_type);
5440

55-
// create conjunction between `join_on` and `join_filter` to embed all join conditions,
56-
// whether equal or non-equal in a single expression
57-
let join_expr = match &join_on {
58-
Some(on_expr) => match &join_filter {
59-
Some(filter) => Some(Box::new(make_binary_op_scalar_func(
60-
producer,
61-
on_expr,
62-
filter,
63-
Operator::And,
64-
))),
65-
None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist
66-
},
67-
None => match &join_filter {
68-
Some(_) => join_filter.map(Box::new), // the join expression will only contain `join_filter` if the `on` condition doesn't exist
69-
None => None,
70-
},
41+
let join_expr =
42+
to_substrait_join_expr(join.on.clone(), join.null_equality, join.filter.clone());
43+
let join_expression = match join_expr {
44+
Some(expr) => {
45+
let in_join_schema = Arc::new(join.left.schema().join(join.right.schema())?);
46+
let expression = producer.handle_expr(&expr, &in_join_schema)?;
47+
Some(Box::new(expression))
48+
}
49+
None => None,
7150
};
7251

7352
Ok(Box::new(Rel {
@@ -76,33 +55,28 @@ pub fn from_join(
7655
left: Some(left),
7756
right: Some(right),
7857
r#type: join_type as i32,
79-
expression: join_expr,
58+
expression: join_expression,
8059
post_join_filter: None,
8160
advanced_extension: None,
8261
}))),
8362
}))
8463
}
8564

8665
fn to_substrait_join_expr(
87-
producer: &mut impl SubstraitProducer,
88-
join_conditions: &Vec<(Expr, Expr)>,
89-
eq_op: Operator,
90-
join_schema: &DFSchemaRef,
91-
) -> datafusion::common::Result<Option<Expression>> {
92-
// Only support AND conjunction for each binary expression in join conditions
93-
let mut exprs: Vec<Expression> = vec![];
94-
for (left, right) in join_conditions {
95-
let l = producer.handle_expr(left, join_schema)?;
96-
let r = producer.handle_expr(right, join_schema)?;
97-
// AND with existing expression
98-
exprs.push(make_binary_op_scalar_func(producer, &l, &r, eq_op));
99-
}
100-
101-
let join_expr: Option<Expression> =
102-
exprs.into_iter().reduce(|acc: Expression, e: Expression| {
103-
make_binary_op_scalar_func(producer, &acc, &e, Operator::And)
104-
});
105-
Ok(join_expr)
66+
join_on: Vec<(Expr, Expr)>,
67+
null_equality: NullEquality,
68+
join_filter: Option<Expr>,
69+
) -> Option<Expr> {
70+
// Combine join on and filter conditions into a single Boolean expression (#7611)
71+
let eq_op = match null_equality {
72+
NullEquality::NullEqualsNothing => Operator::Eq,
73+
NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom,
74+
};
75+
let all_conditions = join_on
76+
.into_iter()
77+
.map(|(left, right)| binary_expr(left, eq_op, right))
78+
.chain(join_filter);
79+
conjunction(all_conditions)
10680
}
10781

10882
fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType {

0 commit comments

Comments
 (0)