Skip to content

Commit b48e2a6

Browse files
committed
Set Substrait output type for binary expressions
1 parent a8c01af commit b48e2a6

1 file changed

Lines changed: 56 additions & 5 deletions

File tree

datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,16 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::logical_plan::producer::{SubstraitProducer, to_substrait_literal_expr};
18+
use crate::logical_plan::producer::{
19+
SubstraitProducer, to_substrait_literal_expr, to_substrait_type,
20+
};
1921
use datafusion::common::{DFSchemaRef, ScalarValue, not_impl_err};
20-
use datafusion::logical_expr::{Between, BinaryExpr, Expr, Like, Operator, expr};
22+
use datafusion::logical_expr::{
23+
Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, expr,
24+
};
2125
use substrait::proto::expression::{RexType, ScalarFunction};
2226
use substrait::proto::function_argument::ArgType;
23-
use substrait::proto::{Expression, FunctionArgument};
27+
use substrait::proto::{Expression, FunctionArgument, Type};
2428

2529
pub fn from_scalar_function(
2630
producer: &mut impl SubstraitProducer,
@@ -114,7 +118,19 @@ pub fn from_binary_expr(
114118
let BinaryExpr { left, op, right } = expr;
115119
let l = producer.handle_expr(left, schema)?;
116120
let r = producer.handle_expr(right, schema)?;
117-
Ok(make_binary_op_scalar_func(producer, &l, &r, *op))
121+
let (_, output_field) = Expr::BinaryExpr(expr.clone()).to_field(schema)?;
122+
let output_type = to_substrait_type(
123+
producer,
124+
output_field.data_type(),
125+
output_field.is_nullable(),
126+
)?;
127+
Ok(make_binary_op_scalar_func(
128+
producer,
129+
&l,
130+
&r,
131+
*op,
132+
&output_type,
133+
))
118134
}
119135

120136
pub fn from_like(
@@ -232,6 +248,7 @@ pub fn make_binary_op_scalar_func(
232248
lhs: &Expression,
233249
rhs: &Expression,
234250
op: Operator,
251+
output_type: &Type,
235252
) -> Expression {
236253
let function_anchor = producer.register_function(operator_to_name(op).to_string());
237254
#[expect(deprecated)]
@@ -246,7 +263,7 @@ pub fn make_binary_op_scalar_func(
246263
arg_type: Some(ArgType::Value(rhs.clone())),
247264
},
248265
],
249-
output_type: None,
266+
output_type: Some(output_type.clone()),
250267
args: vec![],
251268
options: vec![],
252269
})),
@@ -328,3 +345,37 @@ pub fn operator_to_name(op: Operator) -> &'static str {
328345
Operator::Colon => "colon",
329346
}
330347
}
348+
349+
#[cfg(test)]
350+
mod tests {
351+
use crate::logical_plan::producer::{
352+
DefaultSubstraitProducer, SubstraitProducer, to_substrait_type,
353+
};
354+
use datafusion::arrow::datatypes::DataType;
355+
use datafusion::common::{DFSchema, DFSchemaRef};
356+
use datafusion::execution::SessionStateBuilder;
357+
use datafusion::prelude::lit;
358+
use substrait::proto::Expression;
359+
use substrait::proto::expression::{RexType, ScalarFunction};
360+
361+
#[tokio::test]
362+
async fn binary_expr_output_type() -> datafusion::common::Result<()> {
363+
let state = SessionStateBuilder::default().build();
364+
let empty_schema = DFSchemaRef::new(DFSchema::empty());
365+
let mut producer = DefaultSubstraitProducer::new(&state);
366+
367+
let expr = lit(1i64) + lit(2i64);
368+
let substrait_expr = producer.handle_expr(&expr, &empty_schema)?;
369+
if let Expression {
370+
rex_type: Some(RexType::ScalarFunction(ScalarFunction { output_type, .. })),
371+
} = substrait_expr
372+
{
373+
let expected_type =
374+
to_substrait_type(&mut producer, &DataType::Int64, false)?;
375+
assert_eq!(output_type, Some(expected_type));
376+
Ok(())
377+
} else {
378+
panic!("Substrait ScalarFunction expected")
379+
}
380+
}
381+
}

0 commit comments

Comments
 (0)