Skip to content

Commit 6d0fbdc

Browse files
committed
Set Substrait output type for functions
1 parent 22786f2 commit 6d0fbdc

1 file changed

Lines changed: 12 additions & 3 deletions

File tree

datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use crate::logical_plan::producer::{
1919
SubstraitProducer, to_substrait_literal_expr, to_substrait_type,
2020
};
21+
use datafusion::arrow::datatypes::Field;
2122
use datafusion::common::{DFSchemaRef, ScalarValue, not_impl_err};
2223
use datafusion::logical_expr::{
2324
Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, expr,
@@ -31,21 +32,24 @@ pub fn from_scalar_function(
3132
fun: &expr::ScalarFunction,
3233
schema: &DFSchemaRef,
3334
) -> datafusion::common::Result<Expression> {
34-
from_function(producer, fun.name(), &fun.args, schema)
35+
let (_, output_field) = Expr::ScalarFunction(fun.clone()).to_field(schema)?;
36+
from_function(producer, fun.name(), &fun.args, &output_field, schema)
3537
}
3638

3739
pub fn from_higher_order_function(
3840
producer: &mut impl SubstraitProducer,
3941
fun: &expr::HigherOrderFunction,
4042
schema: &DFSchemaRef,
4143
) -> datafusion::common::Result<Expression> {
42-
from_function(producer, fun.name(), &fun.args, schema)
44+
let (_, output_field) = Expr::HigherOrderFunction(fun.clone()).to_field(schema)?;
45+
from_function(producer, fun.name(), &fun.args, &output_field, schema)
4346
}
4447

4548
fn from_function(
4649
producer: &mut impl SubstraitProducer,
4750
name: &str,
4851
args: &[Expr],
52+
output_field: &Field,
4953
schema: &DFSchemaRef,
5054
) -> datafusion::common::Result<Expression> {
5155
let mut arguments: Vec<FunctionArgument> = vec![];
@@ -56,14 +60,19 @@ fn from_function(
5660
}
5761

5862
let arguments = custom_argument_handler(name, arguments);
63+
let output_type = to_substrait_type(
64+
producer,
65+
output_field.data_type(),
66+
output_field.is_nullable(),
67+
)?;
5968

6069
let function_anchor = producer.register_function(name.to_string());
6170
#[expect(deprecated)]
6271
Ok(Expression {
6372
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
6473
function_reference: function_anchor,
6574
arguments,
66-
output_type: None,
75+
output_type: Some(output_type),
6776
options: vec![],
6877
args: vec![],
6978
})),

0 commit comments

Comments
 (0)