Skip to content

Commit a841b13

Browse files
fix: Fix Spark slice function Null type to GenericListArray casting issue
1 parent a257c29 commit a841b13

File tree

2 files changed

+62
-1
lines changed
  • datafusion

2 files changed

+62
-1
lines changed

datafusion/spark/src/function/array/slice.rs

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ use arrow::array::{Array, ArrayRef, Int64Builder};
1919
use arrow::datatypes::{DataType, Field, FieldRef};
2020
use datafusion_common::cast::{as_int64_array, as_list_array};
2121
use datafusion_common::utils::ListCoercion;
22-
use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
22+
use datafusion_common::{
23+
DataFusionError, Result, exec_err, internal_err, utils::take_function_args,
24+
};
2325
use datafusion_expr::{
2426
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs,
2527
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
@@ -94,6 +96,10 @@ impl ScalarUDFImpl for SparkSlice {
9496
&self,
9597
mut func_args: ScalarFunctionArgs,
9698
) -> Result<ColumnarValue> {
99+
if func_args.args[0].data_type() == DataType::Null {
100+
return Ok::<ColumnarValue, DataFusionError>(func_args.args[0].clone());
101+
};
102+
97103
let array_len = func_args
98104
.args
99105
.iter()
@@ -170,3 +176,40 @@ fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> {
170176

171177
Ok((Arc::new(adjusted_start.finish()), Arc::new(end.finish())))
172178
}
179+
180+
#[cfg(test)]
181+
mod tests {
182+
use super::*;
183+
use arrow::array::NullArray;
184+
use arrow::datatypes::DataType::List;
185+
use arrow::datatypes::Field;
186+
use datafusion_common::ScalarValue;
187+
188+
#[test]
189+
fn test_spark_slice_function_when_input_array_is_null() {
190+
let input_args = vec![
191+
ColumnarValue::Array(Arc::new(NullArray::new(1))),
192+
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
193+
ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
194+
];
195+
196+
let args = ScalarFunctionArgs {
197+
args: input_args.to_owned(),
198+
arg_fields: vec![Arc::new(Field::new(
199+
"item",
200+
List(FieldRef::new(Field::new("", DataType::Int64, true))),
201+
false,
202+
))],
203+
number_rows: 0,
204+
return_field: Arc::new(Field::new(
205+
"item",
206+
List(FieldRef::new(Field::new_list_field(DataType::Int64, true))),
207+
false,
208+
)),
209+
config_options: Arc::new(Default::default()),
210+
};
211+
let slice = SparkSlice::new();
212+
let result = slice.invoke_with_args(args).unwrap();
213+
assert!(result.to_array(1).unwrap() == Arc::new(NullArray::new(1)));
214+
}
215+
}

datafusion/sqllogictest/test_files/spark/array/slice.slt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,21 @@ query ?
114114
SELECT slice([1, 2, 3, 4], CAST('2' AS INT), 4);
115115
----
116116
[2, 3, 4]
117+
118+
query ?
119+
SELECT slice(column1, column2, column3)
120+
FROM VALUES
121+
(NULL, 1, 2),
122+
(NULL, 1, -2),
123+
(NULL, -1, 2),
124+
(NULL, 0, 2);
125+
----
126+
NULL
127+
NULL
128+
NULL
129+
NULL
130+
131+
query ?
132+
SELECT slice(slice(NULL, 1, 2), 1, 2)
133+
----
134+
NULL

0 commit comments

Comments
 (0)