1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use crate :: logical_plan:: producer:: { SubstraitProducer , to_substrait_literal_expr} ;
18+ use crate :: logical_plan:: producer:: {
19+ SubstraitProducer , to_substrait_literal_expr, to_substrait_type,
20+ } ;
1921use datafusion:: common:: datatype:: FieldExt ;
2022use datafusion:: common:: {
2123 DFSchemaRef , ScalarValue , internal_datafusion_err, not_impl_err, substrait_err,
2224} ;
23- use datafusion:: logical_expr:: { Between , BinaryExpr , Expr , Like , Operator , expr} ;
25+ use datafusion:: logical_expr:: {
26+ Between , BinaryExpr , Expr , ExprSchemable , Like , Operator , expr,
27+ } ;
2428use substrait:: proto:: expression:: { RexType , ScalarFunction } ;
2529use substrait:: proto:: function_argument:: ArgType ;
26- use substrait:: proto:: { Expression , FunctionArgument } ;
30+ use substrait:: proto:: { Expression , FunctionArgument , Type } ;
2731
2832pub fn from_scalar_function (
2933 producer : & mut impl SubstraitProducer ,
@@ -188,7 +192,19 @@ pub fn from_binary_expr(
188192 let BinaryExpr { left, op, right } = expr;
189193 let l = producer. handle_expr ( left, schema) ?;
190194 let r = producer. handle_expr ( right, schema) ?;
191- Ok ( make_binary_op_scalar_func ( producer, & l, & r, * op) )
195+ let ( _, output_field) = Expr :: BinaryExpr ( expr. clone ( ) ) . to_field ( schema) ?;
196+ let output_type = to_substrait_type (
197+ producer,
198+ output_field. data_type ( ) ,
199+ output_field. is_nullable ( ) ,
200+ ) ?;
201+ Ok ( make_binary_op_scalar_func (
202+ producer,
203+ & l,
204+ & r,
205+ * op,
206+ & output_type,
207+ ) )
192208}
193209
194210pub fn from_like (
@@ -306,6 +322,7 @@ pub fn make_binary_op_scalar_func(
306322 lhs : & Expression ,
307323 rhs : & Expression ,
308324 op : Operator ,
325+ output_type : & Type ,
309326) -> Expression {
310327 let function_anchor = producer. register_function ( operator_to_name ( op) . to_string ( ) ) ;
311328 #[ expect( deprecated) ]
@@ -320,7 +337,7 @@ pub fn make_binary_op_scalar_func(
320337 arg_type: Some ( ArgType :: Value ( rhs. clone( ) ) ) ,
321338 } ,
322339 ] ,
323- output_type : None ,
340+ output_type : Some ( output_type . clone ( ) ) ,
324341 args : vec ! [ ] ,
325342 options : vec ! [ ] ,
326343 } ) ) ,
@@ -402,3 +419,37 @@ pub fn operator_to_name(op: Operator) -> &'static str {
402419 Operator :: Colon => "colon" ,
403420 }
404421}
422+
423+ #[ cfg( test) ]
424+ mod tests {
425+ use crate :: logical_plan:: producer:: {
426+ DefaultSubstraitProducer , SubstraitProducer , to_substrait_type,
427+ } ;
428+ use datafusion:: arrow:: datatypes:: DataType ;
429+ use datafusion:: common:: { DFSchema , DFSchemaRef } ;
430+ use datafusion:: execution:: SessionStateBuilder ;
431+ use datafusion:: prelude:: lit;
432+ use substrait:: proto:: Expression ;
433+ use substrait:: proto:: expression:: { RexType , ScalarFunction } ;
434+
435+ #[ tokio:: test]
436+ async fn binary_expr_output_type ( ) -> datafusion:: common:: Result < ( ) > {
437+ let state = SessionStateBuilder :: default ( ) . build ( ) ;
438+ let empty_schema = DFSchemaRef :: new ( DFSchema :: empty ( ) ) ;
439+ let mut producer = DefaultSubstraitProducer :: new ( & state) ;
440+
441+ let expr = lit ( 1i64 ) + lit ( 2i64 ) ;
442+ let substrait_expr = producer. handle_expr ( & expr, & empty_schema) ?;
443+ if let Expression {
444+ rex_type : Some ( RexType :: ScalarFunction ( ScalarFunction { output_type, .. } ) ) ,
445+ } = substrait_expr
446+ {
447+ let expected_type =
448+ to_substrait_type ( & mut producer, & DataType :: Int64 , false ) ?;
449+ assert_eq ! ( output_type, Some ( expected_type) ) ;
450+ Ok ( ( ) )
451+ } else {
452+ panic ! ( "Substrait ScalarFunction expected" )
453+ }
454+ }
455+ }
0 commit comments