Skip to content

Commit 34f797b

Browse files
committed
Set Substrait output type for binary expressions
1 parent 66be207 commit 34f797b

1 file changed

Lines changed: 56 additions & 5 deletions

File tree

datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,19 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::logical_plan::producer::{SubstraitProducer, to_substrait_literal_expr};
18+
use crate::logical_plan::producer::{
19+
SubstraitProducer, to_substrait_literal_expr, to_substrait_type,
20+
};
1921
use datafusion::common::datatype::FieldExt;
2022
use datafusion::common::{
2123
DFSchemaRef, ScalarValue, internal_datafusion_err, not_impl_err, substrait_err,
2224
};
23-
use datafusion::logical_expr::{Between, BinaryExpr, Expr, Like, Operator, expr};
25+
use datafusion::logical_expr::{
26+
Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, expr,
27+
};
2428
use substrait::proto::expression::{RexType, ScalarFunction};
2529
use substrait::proto::function_argument::ArgType;
26-
use substrait::proto::{Expression, FunctionArgument};
30+
use substrait::proto::{Expression, FunctionArgument, Type};
2731

2832
pub fn from_scalar_function(
2933
producer: &mut impl SubstraitProducer,
@@ -188,7 +192,19 @@ pub fn from_binary_expr(
188192
let BinaryExpr { left, op, right } = expr;
189193
let l = producer.handle_expr(left, schema)?;
190194
let r = producer.handle_expr(right, schema)?;
191-
Ok(make_binary_op_scalar_func(producer, &l, &r, *op))
195+
let (_, output_field) = Expr::BinaryExpr(expr.clone()).to_field(schema)?;
196+
let output_type = to_substrait_type(
197+
producer,
198+
output_field.data_type(),
199+
output_field.is_nullable(),
200+
)?;
201+
Ok(make_binary_op_scalar_func(
202+
producer,
203+
&l,
204+
&r,
205+
*op,
206+
&output_type,
207+
))
192208
}
193209

194210
pub fn from_like(
@@ -306,6 +322,7 @@ pub fn make_binary_op_scalar_func(
306322
lhs: &Expression,
307323
rhs: &Expression,
308324
op: Operator,
325+
output_type: &Type,
309326
) -> Expression {
310327
let function_anchor = producer.register_function(operator_to_name(op).to_string());
311328
#[expect(deprecated)]
@@ -320,7 +337,7 @@ pub fn make_binary_op_scalar_func(
320337
arg_type: Some(ArgType::Value(rhs.clone())),
321338
},
322339
],
323-
output_type: None,
340+
output_type: Some(output_type.clone()),
324341
args: vec![],
325342
options: vec![],
326343
})),
@@ -402,3 +419,37 @@ pub fn operator_to_name(op: Operator) -> &'static str {
402419
Operator::Colon => "colon",
403420
}
404421
}
422+
423+
#[cfg(test)]
424+
mod tests {
425+
use crate::logical_plan::producer::{
426+
DefaultSubstraitProducer, SubstraitProducer, to_substrait_type,
427+
};
428+
use datafusion::arrow::datatypes::DataType;
429+
use datafusion::common::{DFSchema, DFSchemaRef};
430+
use datafusion::execution::SessionStateBuilder;
431+
use datafusion::prelude::lit;
432+
use substrait::proto::Expression;
433+
use substrait::proto::expression::{RexType, ScalarFunction};
434+
435+
#[tokio::test]
436+
async fn binary_expr_output_type() -> datafusion::common::Result<()> {
437+
let state = SessionStateBuilder::default().build();
438+
let empty_schema = DFSchemaRef::new(DFSchema::empty());
439+
let mut producer = DefaultSubstraitProducer::new(&state);
440+
441+
let expr = lit(1i64) + lit(2i64);
442+
let substrait_expr = producer.handle_expr(&expr, &empty_schema)?;
443+
if let Expression {
444+
rex_type: Some(RexType::ScalarFunction(ScalarFunction { output_type, .. })),
445+
} = substrait_expr
446+
{
447+
let expected_type =
448+
to_substrait_type(&mut producer, &DataType::Int64, false)?;
449+
assert_eq!(output_type, Some(expected_type));
450+
Ok(())
451+
} else {
452+
panic!("Substrait ScalarFunction expected")
453+
}
454+
}
455+
}

0 commit comments

Comments
 (0)