Skip to content

Commit c84216c

Browse files
committed
Support fixed size list array
1 parent f9b422e commit c84216c

1 file changed

Lines changed: 113 additions & 5 deletions

File tree

datafusion/functions-nested/src/transform.rs

Lines changed: 113 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717

1818
//! [`ScalarUDFImpl`] definition for array_transform function.
1919
20-
use arrow::array::{Array, ArrayRef, GenericListArray, GenericListViewArray};
20+
use arrow::array::{
21+
Array, ArrayRef, FixedSizeListArray, GenericListArray, GenericListViewArray,
22+
};
2123
use arrow::datatypes::{DataType, Field, FieldRef};
2224
use datafusion_common::cast::{
23-
as_large_list_array, as_large_list_view_array, as_list_array, as_list_view_array,
25+
as_fixed_size_list_array, as_large_list_array, as_large_list_view_array,
26+
as_list_array, as_list_view_array,
2427
};
2528
use datafusion_common::{exec_datafusion_err, exec_err, plan_err, Result};
2629
use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf;
@@ -120,6 +123,37 @@ macro_rules! invoke_by_list_type {
120123
}
121124
};
122125
}
126+
macro_rules! invoke_by_fixed_size_type {
127+
($fn_name:ident, $downcast_fn:ident, $return_type:ty) => {
128+
fn $fn_name(
129+
&self,
130+
replacement_array: ArrayRef,
131+
mut args: ScalarFunctionArgs,
132+
return_field: FieldRef,
133+
) -> Result<ColumnarValue> {
134+
let array = $downcast_fn(&replacement_array)?;
135+
let size = array.value_length();
136+
let nulls = array.nulls().cloned();
137+
138+
let values = array.values();
139+
140+
args.args[self.argument_index] = ColumnarValue::Array(Arc::clone(values));
141+
142+
let results = self.function.invoke_with_args(args)?;
143+
144+
let ColumnarValue::Array(result_array) = results else {
145+
return Ok(results);
146+
};
147+
148+
Ok(ColumnarValue::Array(Arc::new(<$return_type>::try_new(
149+
return_field,
150+
size,
151+
result_array,
152+
nulls,
153+
)?) as ArrayRef))
154+
}
155+
};
156+
}
123157
macro_rules! invoke_by_list_view_type {
124158
($fn_name:ident, $downcast_fn:ident, $return_type:ty) => {
125159
fn $fn_name(
@@ -171,6 +205,11 @@ impl ArrayTransform {
171205
as_large_list_view_array,
172206
GenericListViewArray<i64>
173207
);
208+
invoke_by_fixed_size_type!(
209+
invoke_fixed_size_list,
210+
as_fixed_size_list_array,
211+
FixedSizeListArray
212+
);
174213
}
175214

176215
impl ScalarUDFImpl for ArrayTransform {
@@ -228,7 +267,8 @@ impl ScalarUDFImpl for ArrayTransform {
228267
DataType::List(field)
229268
| DataType::LargeList(field)
230269
| DataType::ListView(field)
231-
| DataType::LargeListView(field) => Ok(Arc::clone(field)),
270+
| DataType::LargeListView(field)
271+
| DataType::FixedSizeList(field, _) => Ok(Arc::clone(field)),
232272
arg_type => plan_err!("{} does not support type {arg_type}", self.name()),
233273
}?;
234274

@@ -264,6 +304,11 @@ impl ScalarUDFImpl for ArrayTransform {
264304
DataType::LargeListView(inner_return),
265305
true,
266306
))),
307+
DataType::FixedSizeList(_, size) => Ok(Arc::new(Field::new(
308+
name,
309+
DataType::FixedSizeList(inner_return, *size),
310+
true,
311+
))),
267312
_ => unreachable!(),
268313
}
269314
}
@@ -324,7 +369,8 @@ impl ScalarUDFImpl for ArrayTransform {
324369
DataType::List(field)
325370
| DataType::LargeList(field)
326371
| DataType::ListView(field)
327-
| DataType::LargeListView(field) => Arc::clone(field),
372+
| DataType::LargeListView(field)
373+
| DataType::FixedSizeList(field, _) => Arc::clone(field),
328374
_ => {
329375
return exec_err!(
330376
"Unexpected return field for array_transform. Expected list data type."
@@ -357,6 +403,9 @@ impl ScalarUDFImpl for ArrayTransform {
357403
DataType::LargeListView(_) => {
358404
self.invoke_large_list_view(replacement_array, args, return_field)
359405
}
406+
DataType::FixedSizeList(_, _) => {
407+
self.invoke_fixed_size_list(replacement_array, args, return_field)
408+
}
360409
arg_type => {
361410
exec_err!("array_transform does not support type {arg_type}")
362411
}
@@ -378,7 +427,8 @@ impl ScalarUDFImpl for ArrayTransform {
378427
mod tests {
379428
use super::array_transform_udf;
380429
use arrow::array::{
381-
create_array, Array, ArrayRef, GenericListArray, GenericListViewArray, Int32Array,
430+
create_array, Array, ArrayRef, FixedSizeListArray, GenericListArray,
431+
GenericListViewArray, Int32Array,
382432
};
383433
use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
384434
use arrow::datatypes::{DataType, Field};
@@ -546,6 +596,64 @@ mod tests {
546596
i64
547597
);
548598

599+
#[test]
600+
fn test_array_transform_fixed_size_list_array_test() -> Result<(), DataFusionError> {
601+
let udf = array_transform_udf(abs(), 0);
602+
603+
let field = Arc::new(Field::new_list_field(DataType::Int32, true));
604+
let values = Int32Array::from(vec![
605+
Some(0),
606+
Some(-1),
607+
Some(-2),
608+
None,
609+
Some(4),
610+
Some(-5),
611+
Some(-6),
612+
Some(7),
613+
None,
614+
]);
615+
let nulls = NullBuffer::from(vec![true, true, false]);
616+
let data = FixedSizeListArray::new(field, 3, Arc::new(values), Some(nulls));
617+
618+
let data = Arc::new(data) as ArrayRef;
619+
let input_field = Arc::new(Field::new(
620+
"a",
621+
DataType::FixedSizeList(Field::new("b", DataType::Int32, true).into(), 3),
622+
true,
623+
));
624+
let return_field = udf.return_field_from_args(ReturnFieldArgs {
625+
arg_fields: &[Arc::clone(&input_field)],
626+
scalar_arguments: &[None],
627+
})?;
628+
629+
let args = ScalarFunctionArgs {
630+
args: vec![ColumnarValue::Array(data)],
631+
arg_fields: vec![input_field],
632+
number_rows: 3,
633+
return_field,
634+
config_options: Arc::new(Default::default()),
635+
};
636+
637+
let ColumnarValue::Array(result) = udf.invoke_with_args(args)? else {
638+
return exec_err!("Invalid return type");
639+
};
640+
let list_array = result
641+
.as_any()
642+
.downcast_ref::<FixedSizeListArray>()
643+
.unwrap();
644+
645+
let expected = create_array!(Int32, [Some(0), Some(1), Some(2)]) as ArrayRef;
646+
assert_eq!(&list_array.value(0), &expected);
647+
648+
// assert!(list_array.is_null(1));
649+
let expected = create_array!(Int32, [None, Some(4), Some(5)]) as ArrayRef;
650+
assert_eq!(&list_array.value(1), &expected);
651+
652+
assert!(list_array.is_null(2));
653+
654+
Ok(())
655+
}
656+
549657
#[test]
550658
fn test_array_transform_test_argument_index() -> Result<(), DataFusionError> {
551659
let udf = array_transform_udf(round(), 1);

0 commit comments

Comments
 (0)