Skip to content

Commit 3e79a38

Browse files
committed
Refine array replace scalar path tests
1 parent 1468803 commit 3e79a38

2 files changed

Lines changed: 55 additions & 123 deletions

File tree

datafusion/functions-nested/src/replace.rs

Lines changed: 1 addition & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ fn replace_with_scalar_needle(
607607
to_array: &ArrayRef,
608608
arr_n: &[i64],
609609
) -> Result<ArrayRef> {
610-
if scalar_from.is_null() || scalar_from.data_type().is_nested() {
610+
if scalar_from.data_type().is_nested() {
611611
let from_array = scalar_from.to_array_of_size(list_array.len())?;
612612
return array_replace_internal(list_array, &from_array, to_array, arr_n);
613613
}
@@ -635,119 +635,3 @@ fn array_replace_internal(
635635
array_type => exec_err!("array_replace does not support type '{array_type}'."),
636636
}
637637
}
638-
639-
#[cfg(test)]
640-
mod tests {
641-
use super::{ArrayReplace, ArrayReplaceAll, ArrayReplaceN};
642-
use arrow::array::{ArrayRef, AsArray, ListArray};
643-
use arrow::datatypes::{DataType, Field, Int32Type};
644-
use datafusion_common::ScalarValue;
645-
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
646-
use std::sync::Arc;
647-
648-
fn int_list(values: Vec<Vec<i32>>) -> ArrayRef {
649-
Arc::new(ListArray::from_iter_primitive::<Int32Type, _, _>(
650-
values
651-
.into_iter()
652-
.map(|row| Some(row.into_iter().map(Some))),
653-
))
654-
}
655-
656-
fn invoke_replace(
657-
udf: &dyn ScalarUDFImpl,
658-
args: Vec<ColumnarValue>,
659-
return_type: DataType,
660-
) -> ColumnarValue {
661-
let arg_fields = args
662-
.iter()
663-
.enumerate()
664-
.map(|(i, arg)| {
665-
Arc::new(Field::new(format!("arg{i}"), arg.data_type(), true))
666-
})
667-
.collect::<Vec<_>>();
668-
let number_rows = args
669-
.iter()
670-
.find_map(|arg| match arg {
671-
ColumnarValue::Array(array) => Some(array.len()),
672-
ColumnarValue::Scalar(_) => None,
673-
})
674-
.unwrap_or(1);
675-
let return_field = Arc::new(Field::new("result", return_type, true));
676-
677-
udf.invoke_with_args(ScalarFunctionArgs {
678-
args,
679-
arg_fields,
680-
number_rows,
681-
return_field,
682-
config_options: Arc::new(Default::default()),
683-
})
684-
.unwrap()
685-
}
686-
687-
#[test]
688-
fn array_replace_uses_scalar_arguments() {
689-
let input = int_list(vec![vec![1, 2, 2, 3], vec![2, 4, 2]]);
690-
let expected = int_list(vec![vec![1, 9, 2, 3], vec![9, 4, 2]]);
691-
let return_type = input.data_type().clone();
692-
693-
let result = invoke_replace(
694-
&ArrayReplace::new(),
695-
vec![
696-
ColumnarValue::Array(input),
697-
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
698-
ColumnarValue::Scalar(ScalarValue::Int32(Some(9))),
699-
],
700-
return_type,
701-
);
702-
703-
let ColumnarValue::Array(result) = result else {
704-
panic!("expected array result");
705-
};
706-
assert_eq!(result.as_list::<i32>(), expected.as_list::<i32>());
707-
}
708-
709-
#[test]
710-
fn array_replace_n_uses_scalar_arguments() {
711-
let input = int_list(vec![vec![1, 2, 2, 3], vec![2, 4, 2]]);
712-
let expected = int_list(vec![vec![1, 9, 9, 3], vec![9, 4, 9]]);
713-
let return_type = input.data_type().clone();
714-
715-
let result = invoke_replace(
716-
&ArrayReplaceN::new(),
717-
vec![
718-
ColumnarValue::Array(input),
719-
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
720-
ColumnarValue::Scalar(ScalarValue::Int32(Some(9))),
721-
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
722-
],
723-
return_type,
724-
);
725-
726-
let ColumnarValue::Array(result) = result else {
727-
panic!("expected array result");
728-
};
729-
assert_eq!(result.as_list::<i32>(), expected.as_list::<i32>());
730-
}
731-
732-
#[test]
733-
fn array_replace_all_uses_scalar_arguments() {
734-
let input = int_list(vec![vec![1, 2, 2, 3], vec![2, 4, 2]]);
735-
let expected = int_list(vec![vec![1, 9, 9, 3], vec![9, 4, 9]]);
736-
let return_type = input.data_type().clone();
737-
738-
let result = invoke_replace(
739-
&ArrayReplaceAll::new(),
740-
vec![
741-
ColumnarValue::Array(input),
742-
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
743-
ColumnarValue::Scalar(ScalarValue::Int32(Some(9))),
744-
],
745-
return_type,
746-
);
747-
748-
let ColumnarValue::Array(result) = result else {
749-
panic!("expected array result");
750-
};
751-
assert_eq!(result.as_list::<i32>(), expected.as_list::<i32>());
752-
}
753-
}

datafusion/sqllogictest/test_files/array/array_replace.slt

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,33 @@ from large_nested_arrays_with_repeating_elements;
212212
[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [19, 20, 21], [28, 29, 30], [19, 20, 21], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[11, 12, 13], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]]
213213
[[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [37, 38, 39], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [[11, 12, 13], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]]
214214

215+
# array_replace scalar arguments over multiple input rows
216+
query ???
217+
select
218+
array_replace(column1, 2, 9),
219+
array_replace_n(column1, 2, 9, 2),
220+
array_replace_all(column1, 2, 9)
221+
from (
222+
values
223+
(make_array(1, 2, 2, 3)),
224+
(make_array(2, 4, 2))
225+
) as t(column1);
226+
----
227+
[1, 9, 2, 3] [1, 9, 9, 3] [1, 9, 9, 3]
228+
[9, 4, 2] [9, 4, 9] [9, 4, 9]
229+
230+
# array_replace_n scalar max exceeding matches over multiple input rows
231+
query ?
232+
select array_replace_n(column1, 2, 9, 10)
233+
from (
234+
values
235+
(make_array(1, 2, 2, 3)),
236+
(make_array(2, 4, 2))
237+
) as t(column1);
238+
----
239+
[1, 9, 9, 3]
240+
[9, 4, 9]
241+
215242
## array_replace_n (aliases: `list_replace_n`)
216243

217244
# array_replace_n scalar function #1
@@ -226,22 +253,35 @@ select
226253
----
227254
[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] [1, 4, 4] [0, 4, 0, 5]
228255

229-
query ????
256+
query ??????
230257
select
231258
array_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'LargeList(Int64)'), 2, 3, 2),
232259
array_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'LargeList(Int64)'), 4, 0, 2),
233260
array_replace_n(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 4, 0, 3),
234-
array_replace_n(arrow_cast(make_array(1, 4, 4), 'LargeList(Int64)'), 4, 0, 0);
261+
array_replace_n(arrow_cast(make_array(1, 4, 4), 'LargeList(Int64)'), 4, 0, 0),
262+
array_replace_n(arrow_cast(make_array(1, 4, 4), 'LargeList(Int64)'), 4, 0, -1),
263+
array_replace_n(arrow_cast(make_array(1, 4, 1, 5), 'LargeList(Int64)'), 1, 0, 10);
235264
----
236-
[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4]
265+
[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] [1, 4, 4] [0, 4, 0, 5]
237266

238-
query ???
267+
query ??????
239268
select
240269
array_replace_n(arrow_cast(make_array(1, 2, 3, 4), 'FixedSizeList(4, Int64)'), 2, 3, 2),
241270
array_replace_n(arrow_cast(make_array(1, 4, 4, 5, 4, 6, 7), 'FixedSizeList(7, Int64)'), 4, 0, 2),
242-
array_replace_n(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 4, 0, 3);
271+
array_replace_n(arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'), 4, 0, 3),
272+
array_replace_n(arrow_cast(make_array(1, 4, 4), 'FixedSizeList(3, Int64)'), 4, 0, 0),
273+
array_replace_n(arrow_cast(make_array(1, 4, 4), 'FixedSizeList(3, Int64)'), 4, 0, -1),
274+
array_replace_n(arrow_cast(make_array(1, 4, 1, 5), 'FixedSizeList(4, Int64)'), 1, 0, 10);
243275
----
244-
[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3]
276+
[1, 3, 3, 4] [1, 0, 0, 5, 4, 6, 7] [1, 2, 3] [1, 4, 4] [1, 4, 4] [0, 4, 0, 5]
277+
278+
# array_replace_n scalar max exceeding matches for empty arrays
279+
query ??
280+
select
281+
array_replace_n(arrow_cast(make_array(), 'List(Int64)'), 2, 9, 10),
282+
array_replace_n(arrow_cast(make_array(), 'LargeList(Int64)'), 2, 9, 10);
283+
----
284+
[] []
245285

246286
# array_replace_n scalar function #2 (element is list)
247287
query ??
@@ -657,6 +697,14 @@ select column1, column2, column3, column4, array_replace_n(column1, column2, col
657697
NULL 3 2 1 NULL
658698
[3, 1, 3] 3 NULL 1 [NULL, 1, 3]
659699

700+
query ???
701+
select
702+
array_replace(make_array(3, NULL, NULL), NULL, 5),
703+
array_replace_n(make_array(3, NULL, NULL), NULL, 5, 10),
704+
array_replace_all(make_array(3, NULL, NULL), NULL, 5);
705+
----
706+
[3, 5, NULL] [3, 5, 5] [3, 5, 5]
707+
660708

661709

662710
statement ok

0 commit comments

Comments
 (0)