1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use crate :: logical_plan:: producer:: { SubstraitProducer , to_substrait_literal_expr} ;
18+ use crate :: logical_plan:: producer:: {
19+ SubstraitProducer , to_substrait_literal_expr, to_substrait_type,
20+ } ;
1921use datafusion:: common:: { DFSchemaRef , ScalarValue , not_impl_err} ;
20- use datafusion:: logical_expr:: { Between , BinaryExpr , Expr , Like , Operator , expr} ;
22+ use datafusion:: logical_expr:: {
23+ Between , BinaryExpr , Expr , ExprSchemable , Like , Operator , expr,
24+ } ;
2125use substrait:: proto:: expression:: { RexType , ScalarFunction } ;
2226use substrait:: proto:: function_argument:: ArgType ;
23- use substrait:: proto:: { Expression , FunctionArgument } ;
27+ use substrait:: proto:: { Expression , FunctionArgument , Type } ;
2428
2529pub fn from_scalar_function (
2630 producer : & mut impl SubstraitProducer ,
@@ -114,7 +118,19 @@ pub fn from_binary_expr(
114118 let BinaryExpr { left, op, right } = expr;
115119 let l = producer. handle_expr ( left, schema) ?;
116120 let r = producer. handle_expr ( right, schema) ?;
117- Ok ( make_binary_op_scalar_func ( producer, & l, & r, * op) )
121+ let ( _, output_field) = Expr :: BinaryExpr ( expr. clone ( ) ) . to_field ( schema) ?;
122+ let output_type = to_substrait_type (
123+ producer,
124+ output_field. data_type ( ) ,
125+ output_field. is_nullable ( ) ,
126+ ) ?;
127+ Ok ( make_binary_op_scalar_func (
128+ producer,
129+ & l,
130+ & r,
131+ * op,
132+ & output_type,
133+ ) )
118134}
119135
120136pub fn from_like (
@@ -232,6 +248,7 @@ pub fn make_binary_op_scalar_func(
232248 lhs : & Expression ,
233249 rhs : & Expression ,
234250 op : Operator ,
251+ output_type : & Type ,
235252) -> Expression {
236253 let function_anchor = producer. register_function ( operator_to_name ( op) . to_string ( ) ) ;
237254 #[ expect( deprecated) ]
@@ -246,7 +263,7 @@ pub fn make_binary_op_scalar_func(
246263 arg_type: Some ( ArgType :: Value ( rhs. clone( ) ) ) ,
247264 } ,
248265 ] ,
249- output_type : None ,
266+ output_type : Some ( output_type . clone ( ) ) ,
250267 args : vec ! [ ] ,
251268 options : vec ! [ ] ,
252269 } ) ) ,
@@ -328,3 +345,37 @@ pub fn operator_to_name(op: Operator) -> &'static str {
328345 Operator :: Colon => "colon" ,
329346 }
330347}
348+
349+ #[ cfg( test) ]
350+ mod tests {
351+ use crate :: logical_plan:: producer:: {
352+ DefaultSubstraitProducer , SubstraitProducer , to_substrait_type,
353+ } ;
354+ use datafusion:: arrow:: datatypes:: DataType ;
355+ use datafusion:: common:: { DFSchema , DFSchemaRef } ;
356+ use datafusion:: execution:: SessionStateBuilder ;
357+ use datafusion:: prelude:: lit;
358+ use substrait:: proto:: Expression ;
359+ use substrait:: proto:: expression:: { RexType , ScalarFunction } ;
360+
361+ #[ tokio:: test]
362+ async fn binary_expr_output_type ( ) -> datafusion:: common:: Result < ( ) > {
363+ let state = SessionStateBuilder :: default ( ) . build ( ) ;
364+ let empty_schema = DFSchemaRef :: new ( DFSchema :: empty ( ) ) ;
365+ let mut producer = DefaultSubstraitProducer :: new ( & state) ;
366+
367+ let expr = lit ( 1i64 ) + lit ( 2i64 ) ;
368+ let substrait_expr = producer. handle_expr ( & expr, & empty_schema) ?;
369+ if let Expression {
370+ rex_type : Some ( RexType :: ScalarFunction ( ScalarFunction { output_type, .. } ) ) ,
371+ } = substrait_expr
372+ {
373+ let expected_type =
374+ to_substrait_type ( & mut producer, & DataType :: Int64 , false ) ?;
375+ assert_eq ! ( output_type, Some ( expected_type) ) ;
376+ Ok ( ( ) )
377+ } else {
378+ panic ! ( "Substrait ScalarFunction expected" )
379+ }
380+ }
381+ }
0 commit comments