Skip to content

Commit 53f2bb3

Browse files
committed
Set Substrait output type for functions
1 parent 6d7d41f commit 53f2bb3

1 file changed

Lines changed: 23 additions & 3 deletions

File tree

datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use crate::logical_plan::producer::{
1919
SubstraitProducer, to_substrait_literal_expr, to_substrait_type,
2020
};
21+
use datafusion::arrow::datatypes::DataType;
2122
use datafusion::common::datatype::FieldExt;
2223
use datafusion::common::{
2324
DFSchemaRef, ScalarValue, internal_datafusion_err, not_impl_err, substrait_err,
@@ -34,7 +35,15 @@ pub fn from_scalar_function(
3435
fun: &expr::ScalarFunction,
3536
schema: &DFSchemaRef,
3637
) -> datafusion::common::Result<Expression> {
37-
from_function(producer, fun.name(), &fun.args, schema)
38+
let (_, output_field) = Expr::ScalarFunction(fun.clone()).to_field(schema)?;
39+
from_function(
40+
producer,
41+
fun.name(),
42+
&fun.args,
43+
output_field.data_type(),
44+
output_field.is_nullable(),
45+
schema,
46+
)
3847
}
3948

4049
pub fn from_higher_order_function(
@@ -104,12 +113,20 @@ pub fn from_higher_order_function(
104113
.collect::<datafusion::common::Result<_>>()?;
105114

106115
let function_anchor = producer.register_function(fun.name().to_string());
116+
117+
let (_, output_field) = Expr::HigherOrderFunction(fun.clone()).to_field(schema)?;
118+
let output_type = to_substrait_type(
119+
producer,
120+
output_field.data_type(),
121+
output_field.is_nullable(),
122+
)?;
123+
107124
#[expect(deprecated)]
108125
Ok(Expression {
109126
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
110127
function_reference: function_anchor,
111128
arguments,
112-
output_type: None,
129+
output_type: Some(output_type),
113130
options: vec![],
114131
args: vec![],
115132
})),
@@ -120,6 +137,8 @@ fn from_function(
120137
producer: &mut impl SubstraitProducer,
121138
name: &str,
122139
args: &[Expr],
140+
output_type: &DataType,
141+
output_nullability: bool,
123142
schema: &DFSchemaRef,
124143
) -> datafusion::common::Result<Expression> {
125144
let mut arguments: Vec<FunctionArgument> = vec![];
@@ -130,14 +149,15 @@ fn from_function(
130149
}
131150

132151
let arguments = custom_argument_handler(name, arguments);
152+
let output_type = to_substrait_type(producer, output_type, output_nullability)?;
133153

134154
let function_anchor = producer.register_function(name.to_string());
135155
#[expect(deprecated)]
136156
Ok(Expression {
137157
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
138158
function_reference: function_anchor,
139159
arguments,
140-
output_type: None,
160+
output_type: Some(output_type),
141161
options: vec![],
142162
args: vec![],
143163
})),

0 commit comments

Comments
 (0)