Skip to content

Commit 58d433c

Browse files
committed
input scalar output scalar
1 parent 58d42be commit 58d433c

1 file changed

Lines changed: 16 additions & 3 deletions

File tree

native/spark-expr/src/datetime_funcs/extract_date_part.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
// under the License.
1717

1818
use crate::utils::array_with_timezone;
19+
use arrow::array::{Array, Int32Array};
1920
use arrow::compute::{date_part, DatePart};
2021
use arrow::datatypes::{DataType, TimeUnit::Microsecond};
21-
use datafusion::common::internal_datafusion_err;
22+
use datafusion::common::{internal_datafusion_err, ScalarValue};
2223
use datafusion::logical_expr::{
2324
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
2425
};
@@ -90,7 +91,8 @@ macro_rules! extract_date_part {
9091
// When Spark's ConstantFolding is disabled, literal-only expressions like
9192
// hour can reach the native engine as scalar inputs.
9293
// Instead of failing and requiring JVM folding, we evaluate the scalar
93-
// natively by broadcasting it to a single-element array.
94+
// natively by broadcasting it to a single-element array and then
95+
// converting the result back to a scalar.
9496
let array = scalar.clone().to_array_of_size(1)?;
9597
let array = array_with_timezone(
9698
array,
@@ -101,7 +103,18 @@ macro_rules! extract_date_part {
101103
)),
102104
)?;
103105
let result = date_part(&array, DatePart::$date_part_variant)?;
104-
Ok(ColumnarValue::Array(result))
106+
let result_arr = result
107+
.as_any()
108+
.downcast_ref::<Int32Array>()
109+
.expect(concat!($fn_name, " should return Int32Array"));
110+
111+
let scalar_result = if result_arr.is_null(0) {
112+
ScalarValue::Int32(None)
113+
} else {
114+
ScalarValue::Int32(Some(result_arr.value(0)))
115+
};
116+
117+
Ok(ColumnarValue::Scalar(scalar_result))
105118
}
106119
}
107120
}

0 commit comments

Comments
 (0)