Skip to content

Commit d560e09

Browse files
committed
Set Substrait output type for functions
1 parent c3f48fe commit d560e09

1 file changed

Lines changed: 23 additions & 3 deletions

File tree

datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use crate::logical_plan::producer::{
1919
SubstraitProducer, to_substrait_literal_expr, to_substrait_type,
2020
};
21+
use datafusion::arrow::datatypes::DataType;
2122
use datafusion::common::{DFSchemaRef, ScalarValue, not_impl_err};
2223
use datafusion::logical_expr::{
2324
Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, expr,
@@ -31,21 +32,39 @@ pub fn from_scalar_function(
3132
fun: &expr::ScalarFunction,
3233
schema: &DFSchemaRef,
3334
) -> datafusion::common::Result<Expression> {
34-
from_function(producer, fun.name(), &fun.args, schema)
35+
let (_, output_field) = Expr::ScalarFunction(fun.clone()).to_field(schema)?;
36+
from_function(
37+
producer,
38+
fun.name(),
39+
&fun.args,
40+
output_field.data_type(),
41+
output_field.is_nullable(),
42+
schema,
43+
)
3544
}
3645

3746
pub fn from_higher_order_function(
3847
producer: &mut impl SubstraitProducer,
3948
fun: &expr::HigherOrderFunction,
4049
schema: &DFSchemaRef,
4150
) -> datafusion::common::Result<Expression> {
42-
from_function(producer, fun.name(), &fun.args, schema)
51+
let (_, output_field) = Expr::HigherOrderFunction(fun.clone()).to_field(schema)?;
52+
from_function(
53+
producer,
54+
fun.name(),
55+
&fun.args,
56+
output_field.data_type(),
57+
output_field.is_nullable(),
58+
schema,
59+
)
4360
}
4461

4562
fn from_function(
4663
producer: &mut impl SubstraitProducer,
4764
name: &str,
4865
args: &[Expr],
66+
output_type: &DataType,
67+
output_nullability: bool,
4968
schema: &DFSchemaRef,
5069
) -> datafusion::common::Result<Expression> {
5170
let mut arguments: Vec<FunctionArgument> = vec![];
@@ -56,14 +75,15 @@ fn from_function(
5675
}
5776

5877
let arguments = custom_argument_handler(name, arguments);
78+
let output_type = to_substrait_type(producer, output_type, output_nullability)?;
5979

6080
let function_anchor = producer.register_function(name.to_string());
6181
#[expect(deprecated)]
6282
Ok(Expression {
6383
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
6484
function_reference: function_anchor,
6585
arguments,
66-
output_type: None,
86+
output_type: Some(output_type),
6787
options: vec![],
6888
args: vec![],
6989
})),

0 commit comments

Comments
 (0)