1818use crate :: logical_plan:: producer:: {
1919 SubstraitProducer , to_substrait_literal_expr, to_substrait_type,
2020} ;
21+ use datafusion:: arrow:: datatypes:: DataType ;
2122use datafusion:: common:: datatype:: FieldExt ;
2223use datafusion:: common:: {
2324 DFSchemaRef , ScalarValue , internal_datafusion_err, not_impl_err, substrait_err,
@@ -34,7 +35,15 @@ pub fn from_scalar_function(
3435 fun : & expr:: ScalarFunction ,
3536 schema : & DFSchemaRef ,
3637) -> datafusion:: common:: Result < Expression > {
37- from_function ( producer, fun. name ( ) , & fun. args , schema)
38+ let ( _, output_field) = Expr :: ScalarFunction ( fun. clone ( ) ) . to_field ( schema) ?;
39+ from_function (
40+ producer,
41+ fun. name ( ) ,
42+ & fun. args ,
43+ output_field. data_type ( ) ,
44+ output_field. is_nullable ( ) ,
45+ schema,
46+ )
3847}
3948
4049pub fn from_higher_order_function (
@@ -104,12 +113,20 @@ pub fn from_higher_order_function(
104113 . collect :: < datafusion:: common:: Result < _ > > ( ) ?;
105114
106115 let function_anchor = producer. register_function ( fun. name ( ) . to_string ( ) ) ;
116+
117+ let ( _, output_field) = Expr :: HigherOrderFunction ( fun. clone ( ) ) . to_field ( schema) ?;
118+ let output_type = to_substrait_type (
119+ producer,
120+ output_field. data_type ( ) ,
121+ output_field. is_nullable ( ) ,
122+ ) ?;
123+
107124 #[ expect( deprecated) ]
108125 Ok ( Expression {
109126 rex_type : Some ( RexType :: ScalarFunction ( ScalarFunction {
110127 function_reference : function_anchor,
111128 arguments,
112- output_type : None ,
129+ output_type : Some ( output_type ) ,
113130 options : vec ! [ ] ,
114131 args : vec ! [ ] ,
115132 } ) ) ,
@@ -120,6 +137,8 @@ fn from_function(
120137 producer : & mut impl SubstraitProducer ,
121138 name : & str ,
122139 args : & [ Expr ] ,
140+ output_type : & DataType ,
141+ output_nullability : bool ,
123142 schema : & DFSchemaRef ,
124143) -> datafusion:: common:: Result < Expression > {
125144 let mut arguments: Vec < FunctionArgument > = vec ! [ ] ;
@@ -130,14 +149,15 @@ fn from_function(
130149 }
131150
132151 let arguments = custom_argument_handler ( name, arguments) ;
152+ let output_type = to_substrait_type ( producer, output_type, output_nullability) ?;
133153
134154 let function_anchor = producer. register_function ( name. to_string ( ) ) ;
135155 #[ expect( deprecated) ]
136156 Ok ( Expression {
137157 rex_type : Some ( RexType :: ScalarFunction ( ScalarFunction {
138158 function_reference : function_anchor,
139159 arguments,
140- output_type : None ,
160+ output_type : Some ( output_type ) ,
141161 options : vec ! [ ] ,
142162 args : vec ! [ ] ,
143163 } ) ) ,
0 commit comments