Skip to content

Commit 0fc58c6

Browse files
Return List of Nulls for Null input
1 parent 2d6254e commit 0fc58c6

File tree

2 files changed

+57
-31
lines changed
  • datafusion

2 files changed

+57
-31
lines changed

datafusion/spark/src/function/array/slice.rs

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{Array, ArrayRef, Int64Builder};
18+
use arrow::array::{Array, ArrayData, ArrayRef, Int64Builder, ListArray};
1919
use arrow::datatypes::{DataType, Field, FieldRef};
2020
use datafusion_common::cast::{as_int64_array, as_list_array};
2121
use datafusion_common::utils::ListCoercion;
22-
use datafusion_common::{
23-
Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
24-
};
22+
use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
2523
use datafusion_expr::{
2624
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs,
2725
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
@@ -85,21 +83,26 @@ impl ScalarUDFImpl for SparkSlice {
8583
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
8684
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
8785

88-
Ok(Arc::new(Field::new(
89-
"slice",
90-
args.arg_fields[0].data_type().clone(),
91-
nullable,
92-
)))
86+
let data_type = match args.arg_fields[0].data_type() {
87+
DataType::Null => {
88+
DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)))
89+
}
90+
dt => dt.clone(),
91+
};
92+
93+
Ok(Arc::new(Field::new("slice", data_type, nullable)))
9394
}
9495

9596
fn invoke_with_args(
9697
&self,
9798
mut func_args: ScalarFunctionArgs,
9899
) -> Result<ColumnarValue> {
99-
if func_args.args[0].data_type() == DataType::Null
100-
&& let Some(result) = check_null_types(&func_args.args[0])
101-
{
102-
return Ok(result);
100+
if func_args.args[0].data_type() == DataType::Null {
101+
let len = match &func_args.args[0] {
102+
ColumnarValue::Array(a) => a.len(),
103+
ColumnarValue::Scalar(_) => func_args.number_rows,
104+
};
105+
return Ok(ColumnarValue::Array(list_null_array(len)));
103106
}
104107

105108
let array_len = func_args
@@ -136,14 +139,9 @@ impl ScalarUDFImpl for SparkSlice {
136139
}
137140
}
138141

139-
fn check_null_types(cv: &ColumnarValue) -> Option<ColumnarValue> {
140-
match cv {
141-
ColumnarValue::Scalar(ScalarValue::Null) => {
142-
Some(ColumnarValue::create_null_array(1))
143-
}
144-
ColumnarValue::Array(_) => Some(cv.clone()),
145-
_ => None,
146-
}
142+
fn list_null_array(len: usize) -> ArrayRef {
143+
let list_type = DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)));
144+
Arc::new(ListArray::from(ArrayData::new_null(&list_type, len)))
147145
}
148146

149147
fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> {
@@ -193,9 +191,30 @@ fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> {
193191
mod tests {
194192
use super::*;
195193
use arrow::array::NullArray;
196-
use arrow::datatypes::DataType::List;
197194
use arrow::datatypes::Field;
198195
use datafusion_common::ScalarValue;
196+
use datafusion_common::cast::as_list_array;
197+
use datafusion_expr::ReturnFieldArgs;
198+
199+
#[test]
200+
fn test_spark_slice_function_when_input_is_null() {
201+
let slice = SparkSlice::new();
202+
let arg_fields: Vec<Arc<Field>> = vec![
203+
Arc::new(Field::new("a", DataType::Null, true)),
204+
Arc::new(Field::new("s", DataType::Int64, true)),
205+
Arc::new(Field::new("l", DataType::Int64, true)),
206+
];
207+
let out = slice
208+
.return_field_from_args(ReturnFieldArgs {
209+
arg_fields: &arg_fields,
210+
scalar_arguments: &[],
211+
})
212+
.unwrap();
213+
assert_eq!(
214+
out.data_type(),
215+
&DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)))
216+
);
217+
}
199218

200219
#[test]
201220
fn test_spark_slice_function_when_input_array_is_null() {
@@ -207,21 +226,23 @@ mod tests {
207226

208227
let args = ScalarFunctionArgs {
209228
args: input_args,
210-
arg_fields: vec![Arc::new(Field::new(
211-
"item",
212-
List(FieldRef::new(Field::new("f", DataType::Int64, true))),
213-
false,
214-
))],
229+
arg_fields: vec![Arc::new(Field::new("item", DataType::Null, true))],
215230
number_rows: 1,
216231
return_field: Arc::new(Field::new(
217-
"item",
218-
List(FieldRef::new(Field::new_list_field(DataType::Int64, true))),
219-
false,
232+
"slice",
233+
DataType::List(Arc::new(Field::new_list_field(DataType::Null, true))),
234+
true,
220235
)),
221236
config_options: Arc::new(Default::default()),
222237
};
223238
let slice = SparkSlice::new();
224239
let result = slice.invoke_with_args(args).unwrap();
225-
assert_eq!(*result.to_array(1).unwrap(), *Arc::new(NullArray::new(1)));
240+
let arr = result.to_array(1).unwrap();
241+
let list = as_list_array(&arr).unwrap();
242+
assert_eq!(
243+
arr.data_type(),
244+
&DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)))
245+
);
246+
assert!(list.is_null(0));
226247
}
227248
}

datafusion/sqllogictest/test_files/spark/array/slice.slt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,8 @@ query ?
132132
SELECT slice(slice(NULL, 1, 2), 1, 2)
133133
----
134134
NULL
135+
136+
query ?
137+
SELECT slice(slice(make_array(NULL), 1, 2), 1, 2)
138+
----
139+
[NULL]

0 commit comments

Comments
 (0)