Skip to content

Commit 77240f9

Browse files
authored
fix: Set Substrait output types for expressions (#20597)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #15831. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> The Substrait producer did not set the ScalarFunction `output_type` when converting binary expressions, which broke consumers relying on the `output_type`. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> * Refactor `from_join` and `from_between` to eliminate direct calls to `make_binary_op_scalar_func` * Set the Substrait ScalarFunction `output_type` when converting several types of DataFusion expressions: * Binary expressions (`Expr::BinaryExpr`) * Unary expressions (like `Expr::Not`) * Scalar functions (`Expr::ScalarFunction`) There are a few more places where the `output_type` has not been set, such as `from_like` and `from_in_list`, as mentioned in #15831. I've left these out of scope here as fixing them would require more substantial code changes. ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Yes, via a new unit test. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. --> No, beyond the Substrait output fix.
1 parent d3983d3 commit 77240f9

2 files changed

Lines changed: 215 additions & 117 deletions

File tree

datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs

Lines changed: 100 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,35 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::logical_plan::producer::{SubstraitProducer, to_substrait_literal_expr};
18+
use crate::logical_plan::producer::{
19+
SubstraitProducer, to_substrait_literal_expr, to_substrait_type,
20+
};
21+
use datafusion::arrow::datatypes::DataType;
1922
use datafusion::common::datatype::FieldExt;
2023
use datafusion::common::{
2124
DFSchemaRef, ScalarValue, internal_datafusion_err, not_impl_err, substrait_err,
2225
};
23-
use datafusion::logical_expr::{Between, BinaryExpr, Expr, Like, Operator, expr};
26+
use datafusion::logical_expr::{
27+
Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, expr,
28+
};
2429
use substrait::proto::expression::{RexType, ScalarFunction};
2530
use substrait::proto::function_argument::ArgType;
26-
use substrait::proto::{Expression, FunctionArgument};
31+
use substrait::proto::{Expression, FunctionArgument, Type};
2732

2833
pub fn from_scalar_function(
2934
producer: &mut impl SubstraitProducer,
3035
fun: &expr::ScalarFunction,
3136
schema: &DFSchemaRef,
3237
) -> datafusion::common::Result<Expression> {
33-
from_function(producer, fun.name(), &fun.args, schema)
38+
let (_, output_field) = Expr::ScalarFunction(fun.clone()).to_field(schema)?;
39+
from_function(
40+
producer,
41+
fun.name(),
42+
&fun.args,
43+
output_field.data_type(),
44+
output_field.is_nullable(),
45+
schema,
46+
)
3447
}
3548

3649
pub fn from_higher_order_function(
@@ -100,12 +113,20 @@ pub fn from_higher_order_function(
100113
.collect::<datafusion::common::Result<_>>()?;
101114

102115
let function_anchor = producer.register_function(fun.name().to_string());
116+
117+
let (_, output_field) = Expr::HigherOrderFunction(fun.clone()).to_field(schema)?;
118+
let output_type = to_substrait_type(
119+
producer,
120+
output_field.data_type(),
121+
output_field.is_nullable(),
122+
)?;
123+
103124
#[expect(deprecated)]
104125
Ok(Expression {
105126
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
106127
function_reference: function_anchor,
107128
arguments,
108-
output_type: None,
129+
output_type: Some(output_type),
109130
options: vec![],
110131
args: vec![],
111132
})),
@@ -116,6 +137,8 @@ fn from_function(
116137
producer: &mut impl SubstraitProducer,
117138
name: &str,
118139
args: &[Expr],
140+
output_type: &DataType,
141+
output_nullability: bool,
119142
schema: &DFSchemaRef,
120143
) -> datafusion::common::Result<Expression> {
121144
let mut arguments: Vec<FunctionArgument> = vec![];
@@ -126,14 +149,15 @@ fn from_function(
126149
}
127150

128151
let arguments = custom_argument_handler(name, arguments);
152+
let output_type = to_substrait_type(producer, output_type, output_nullability)?;
129153

130154
let function_anchor = producer.register_function(name.to_string());
131155
#[expect(deprecated)]
132156
Ok(Expression {
133157
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
134158
function_reference: function_anchor,
135159
arguments,
136-
output_type: None,
160+
output_type: Some(output_type),
137161
options: vec![],
138162
args: vec![],
139163
})),
@@ -177,7 +201,13 @@ pub fn from_unary_expr(
177201
Expr::Negative(arg) => ("negate", arg),
178202
expr => not_impl_err!("Unsupported expression: {expr:?}")?,
179203
};
180-
to_substrait_unary_scalar_fn(producer, fn_name, arg, schema)
204+
let (_, output_field) = expr.to_field(schema)?;
205+
let output_type = to_substrait_type(
206+
producer,
207+
output_field.data_type(),
208+
output_field.is_nullable(),
209+
)?;
210+
to_substrait_unary_scalar_fn(producer, fn_name, arg, schema, &output_type)
181211
}
182212

183213
pub fn from_binary_expr(
@@ -188,7 +218,19 @@ pub fn from_binary_expr(
188218
let BinaryExpr { left, op, right } = expr;
189219
let l = producer.handle_expr(left, schema)?;
190220
let r = producer.handle_expr(right, schema)?;
191-
Ok(make_binary_op_scalar_func(producer, &l, &r, *op))
221+
let (_, output_field) = Expr::BinaryExpr(expr.clone()).to_field(schema)?;
222+
let output_type = to_substrait_type(
223+
producer,
224+
output_field.data_type(),
225+
output_field.is_nullable(),
226+
)?;
227+
Ok(make_binary_op_scalar_func(
228+
producer,
229+
&l,
230+
&r,
231+
*op,
232+
&output_type,
233+
))
192234
}
193235

194236
pub fn from_like(
@@ -283,6 +325,7 @@ fn to_substrait_unary_scalar_fn(
283325
fn_name: &str,
284326
arg: &Expr,
285327
schema: &DFSchemaRef,
328+
output_type: &Type,
286329
) -> datafusion::common::Result<Expression> {
287330
let function_anchor = producer.register_function(fn_name.to_string());
288331
let substrait_expr = producer.handle_expr(arg, schema)?;
@@ -293,7 +336,7 @@ fn to_substrait_unary_scalar_fn(
293336
arguments: vec![FunctionArgument {
294337
arg_type: Some(ArgType::Value(substrait_expr)),
295338
}],
296-
output_type: None,
339+
output_type: Some(output_type.clone()),
297340
options: vec![],
298341
..Default::default()
299342
})),
@@ -306,6 +349,7 @@ pub fn make_binary_op_scalar_func(
306349
lhs: &Expression,
307350
rhs: &Expression,
308351
op: Operator,
352+
output_type: &Type,
309353
) -> Expression {
310354
let function_anchor = producer.register_function(operator_to_name(op).to_string());
311355
#[expect(deprecated)]
@@ -320,7 +364,7 @@ pub fn make_binary_op_scalar_func(
320364
arg_type: Some(ArgType::Value(rhs.clone())),
321365
},
322366
],
323-
output_type: None,
367+
output_type: Some(output_type.clone()),
324368
args: vec![],
325369
options: vec![],
326370
})),
@@ -338,57 +382,21 @@ pub fn from_between(
338382
low,
339383
high,
340384
} = between;
341-
if *negated {
342-
// `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr)
343-
let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?;
344-
let substrait_low = producer.handle_expr(low.as_ref(), schema)?;
345-
let substrait_high = producer.handle_expr(high.as_ref(), schema)?;
346-
347-
let l_expr = make_binary_op_scalar_func(
348-
producer,
349-
&substrait_expr,
350-
&substrait_low,
351-
Operator::Lt,
352-
);
353-
let r_expr = make_binary_op_scalar_func(
354-
producer,
355-
&substrait_high,
356-
&substrait_expr,
357-
Operator::Lt,
358-
);
359385

360-
Ok(make_binary_op_scalar_func(
361-
producer,
362-
&l_expr,
363-
&r_expr,
364-
Operator::Or,
365-
))
386+
let expr = if *negated {
387+
// `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr)
388+
Expr::or(
389+
Expr::lt(*expr.clone(), *low.clone()),
390+
Expr::lt(*high.clone(), *expr.clone()),
391+
)
366392
} else {
367393
// `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high)
368-
let substrait_expr = producer.handle_expr(expr.as_ref(), schema)?;
369-
let substrait_low = producer.handle_expr(low.as_ref(), schema)?;
370-
let substrait_high = producer.handle_expr(high.as_ref(), schema)?;
371-
372-
let l_expr = make_binary_op_scalar_func(
373-
producer,
374-
&substrait_low,
375-
&substrait_expr,
376-
Operator::LtEq,
377-
);
378-
let r_expr = make_binary_op_scalar_func(
379-
producer,
380-
&substrait_expr,
381-
&substrait_high,
382-
Operator::LtEq,
383-
);
384-
385-
Ok(make_binary_op_scalar_func(
386-
producer,
387-
&l_expr,
388-
&r_expr,
389-
Operator::And,
390-
))
391-
}
394+
Expr::and(
395+
Expr::lt_eq(*low.clone(), *expr.clone()),
396+
Expr::lt_eq(*expr.clone(), *high.clone()),
397+
)
398+
};
399+
producer.handle_expr(&expr, schema)
392400
}
393401

394402
pub fn operator_to_name(op: Operator) -> &'static str {
@@ -438,3 +446,37 @@ pub fn operator_to_name(op: Operator) -> &'static str {
438446
Operator::Colon => "colon",
439447
}
440448
}
449+
450+
#[cfg(test)]
451+
mod tests {
452+
use crate::logical_plan::producer::{
453+
DefaultSubstraitProducer, SubstraitProducer, to_substrait_type,
454+
};
455+
use datafusion::arrow::datatypes::DataType;
456+
use datafusion::common::{DFSchema, DFSchemaRef};
457+
use datafusion::execution::SessionStateBuilder;
458+
use datafusion::prelude::lit;
459+
use substrait::proto::Expression;
460+
use substrait::proto::expression::{RexType, ScalarFunction};
461+
462+
#[tokio::test]
463+
async fn binary_expr_output_type() -> datafusion::common::Result<()> {
464+
let state = SessionStateBuilder::default().build();
465+
let empty_schema = DFSchemaRef::new(DFSchema::empty());
466+
let mut producer = DefaultSubstraitProducer::new(&state);
467+
468+
let expr = lit(1i64) + lit(2i64);
469+
let substrait_expr = producer.handle_expr(&expr, &empty_schema)?;
470+
if let Expression {
471+
rex_type: Some(RexType::ScalarFunction(ScalarFunction { output_type, .. })),
472+
} = substrait_expr
473+
{
474+
let expected_type =
475+
to_substrait_type(&mut producer, &DataType::Int64, false)?;
476+
assert_eq!(output_type, Some(expected_type));
477+
Ok(())
478+
} else {
479+
panic!("Substrait ScalarFunction expected")
480+
}
481+
}
482+
}

0 commit comments

Comments
 (0)