Skip to content

Commit 7d862d6

Browse files
feat: adds array_add function (#22459)
## Which issue does this PR close? - Part of [#21536](#21536) (array_add — first PR in the vector math series). ## Rationale for this change ## What changes are included in this PR? ## Are these changes tested? Yes, via SLT only (`array_add.slt`). Coverage: - Happy paths: basic, negative components, single-element, empty, multi-row. - NULL propagation: whole-row NULL on each side / both sides; element-level NULL on each side / both sides at same and different positions. - Type / variant: integer literals, mixed int+float, `LargeList`×`LargeList`, mixed `List`+`LargeList`, `FixedSizeList` → `List` coercion, `Float32` leaf, `Int64` leaf. - Decimal handling: `Decimal128` / `Decimal256` rejected at planning; explicit `cast to DOUBLE` opt-in works. - Error paths: per-row length mismatch (exec), unsupported non-list input (plan), non-numeric leaf (plan), boolean leaf (plan), nested list (plan), wrong arg count. - Aliases: `list_add` single-row + multi-row. - Composition: `array_add(array_add(...), ...)` chained — single-row, with element NULLs propagating across both layers, and multi-row with row-level NULL. ## Are there any user-facing changes? Yes — two new functions: - `array_add(array1, array2) → List<Float64>` / `LargeList<Float64>` - `list_add(...)` alias Both exposed via `expr_fn` and registered in `all_default_nested_functions()`. Documented inline via `#[user_doc]` (description, syntax, SQL example, argument descriptions). No breaking API changes.
1 parent 8a9653c commit 7d862d6

5 files changed

Lines changed: 529 additions & 2 deletions

File tree

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`ScalarUDFImpl`] definitions for array_add function.
19+
20+
use crate::utils::{coerce_array_math_arg_types, make_scalar_function};
21+
use arrow::array::{
22+
Array, ArrayRef, Float64Array, GenericListArray, NullBufferBuilder,
23+
OffsetBufferBuilder, OffsetSizeTrait,
24+
};
25+
use arrow::buffer::NullBuffer;
26+
use arrow::datatypes::{
27+
DataType,
28+
DataType::{LargeList, List},
29+
Field,
30+
};
31+
use datafusion_common::cast::{as_float64_array, as_generic_list_array};
32+
use datafusion_common::{Result, exec_err, utils::take_function_args};
33+
use datafusion_expr::{
34+
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
35+
Volatility,
36+
};
37+
use datafusion_macros::user_doc;
38+
use std::sync::Arc;
39+
40+
make_udf_expr_and_func!(
41+
ArrayAdd,
42+
array_add,
43+
array1 array2,
44+
"returns the element-wise sum of two numeric arrays.",
45+
array_add_udf
46+
);
47+
48+
#[user_doc(
49+
doc_section(label = "Array Functions"),
50+
description = "Returns the element-wise sum of two numeric arrays of equal length, computed as `array1[i] + array2[i]` per position. NULL is propagated per element: if either input element at position `i` is NULL, the corresponding output element is NULL (positions are preserved). Returns NULL if either entire input array is NULL. Errors if the per-row lengths differ. Returns an empty array if both inputs are empty.",
51+
syntax_example = "array_add(array1, array2)",
52+
sql_example = r#"```sql
53+
> select array_add([1.0, 2.0, 3.0], [10.0, 20.0, 30.0]);
54+
+---------------------------------------------------------+
55+
| array_add(List([1.0,2.0,3.0]),List([10.0,20.0,30.0])) |
56+
+---------------------------------------------------------+
57+
| [11.0, 22.0, 33.0] |
58+
+---------------------------------------------------------+
59+
```"#,
60+
argument(
61+
name = "array1",
62+
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
63+
),
64+
argument(
65+
name = "array2",
66+
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
67+
)
68+
)]
69+
#[derive(Debug, PartialEq, Eq, Hash)]
70+
pub struct ArrayAdd {
71+
signature: Signature,
72+
aliases: Vec<String>,
73+
}
74+
75+
impl Default for ArrayAdd {
76+
fn default() -> Self {
77+
Self::new()
78+
}
79+
}
80+
81+
impl ArrayAdd {
82+
pub fn new() -> Self {
83+
Self {
84+
signature: Signature::user_defined(Volatility::Immutable),
85+
aliases: vec!["list_add".to_string()],
86+
}
87+
}
88+
}
89+
90+
impl ScalarUDFImpl for ArrayAdd {
91+
fn name(&self) -> &str {
92+
"array_add"
93+
}
94+
95+
fn signature(&self) -> &Signature {
96+
&self.signature
97+
}
98+
99+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
100+
// After `coerce_types`, both args share the same List/LargeList<Float64> shape.
101+
Ok(arg_types[0].clone())
102+
}
103+
104+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
105+
let [_, _] = take_function_args(self.name(), arg_types)?;
106+
coerce_array_math_arg_types(self.name(), arg_types)
107+
}
108+
109+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
110+
make_scalar_function(array_add_inner)(&args.args)
111+
}
112+
113+
fn aliases(&self) -> &[String] {
114+
&self.aliases
115+
}
116+
117+
fn documentation(&self) -> Option<&Documentation> {
118+
self.doc()
119+
}
120+
}
121+
122+
fn array_add_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
123+
let [array1, array2] = take_function_args("array_add", args)?;
124+
match (array1.data_type(), array2.data_type()) {
125+
(List(_), List(_)) => general_array_add::<i32>(array1, array2),
126+
(LargeList(_), LargeList(_)) => general_array_add::<i64>(array1, array2),
127+
(arg_type1, arg_type2) => exec_err!(
128+
"array_add received unexpected types after coercion: {arg_type1} and {arg_type2}"
129+
),
130+
}
131+
}
132+
133+
fn general_array_add<O: OffsetSizeTrait>(
134+
lhs: &ArrayRef,
135+
rhs: &ArrayRef,
136+
) -> Result<ArrayRef> {
137+
let lhs = as_generic_list_array::<O>(lhs)?;
138+
let rhs = as_generic_list_array::<O>(rhs)?;
139+
140+
let lhs_values = as_float64_array(lhs.values())?;
141+
let rhs_values = as_float64_array(rhs.values())?;
142+
let lhs_offsets = lhs.value_offsets();
143+
let rhs_offsets = rhs.value_offsets();
144+
145+
// Row-level validity: a row is valid iff both sides are valid at that row.
146+
let row_nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
147+
148+
let mut out_values: Vec<f64> = Vec::with_capacity(lhs_values.len());
149+
let mut out_inner_nulls = NullBufferBuilder::new(lhs_values.len());
150+
let mut out_offsets = OffsetBufferBuilder::<O>::new(lhs.len());
151+
152+
for row in 0..lhs.len() {
153+
// Whole-row NULL on either side -> NULL output row, no elements.
154+
if row_nulls.as_ref().is_some_and(|nb| nb.is_null(row)) {
155+
out_offsets.push_length(0);
156+
continue;
157+
}
158+
159+
let start1 = lhs_offsets[row].as_usize();
160+
let len1 = lhs.value_length(row).as_usize();
161+
let start2 = rhs_offsets[row].as_usize();
162+
let len2 = rhs.value_length(row).as_usize();
163+
164+
if len1 != len2 {
165+
return exec_err!(
166+
"array_add requires both list inputs to have the same length per row, got {len1} and {len2} at row {row}"
167+
);
168+
}
169+
170+
let l_slice = lhs_values.slice(start1, len1);
171+
let r_slice = rhs_values.slice(start2, len2);
172+
173+
let l_vals = l_slice.values();
174+
let r_vals = r_slice.values();
175+
176+
for i in 0..len1 {
177+
out_values.push(l_vals[i] + r_vals[i]);
178+
}
179+
180+
// Per-element validity: position `i` is valid iff both lhs[i] and rhs[i]
181+
// are valid. `NullBuffer::union` returns `None` when both sides are
182+
// entirely valid.
183+
match NullBuffer::union(l_slice.nulls(), r_slice.nulls()) {
184+
Some(nb) => out_inner_nulls.append_buffer(&nb),
185+
None => out_inner_nulls.append_n_non_nulls(len1),
186+
}
187+
188+
out_offsets.push_length(len1);
189+
}
190+
191+
let values_array = Arc::new(Float64Array::new(
192+
out_values.into(),
193+
out_inner_nulls.finish(),
194+
));
195+
let field = Arc::new(Field::new_list_field(DataType::Float64, true));
196+
197+
Ok(Arc::new(GenericListArray::<O>::try_new(
198+
field,
199+
out_offsets.finish(),
200+
values_array,
201+
row_nulls,
202+
)?))
203+
}

datafusion/functions-nested/src/lib.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ pub mod macros;
4040
#[macro_use]
4141
pub mod macros_lambda;
4242

43+
pub mod array_add;
4344
pub mod array_any_match;
44-
pub(crate) mod lambda_utils;
45-
4645
pub mod array_compact;
4746
pub mod array_filter;
4847
pub mod array_has;
@@ -61,6 +60,7 @@ pub mod expr_ext;
6160
pub mod extract;
6261
pub mod flatten;
6362
pub mod inner_product;
63+
pub(crate) mod lambda_utils;
6464
pub mod length;
6565
pub mod make_array;
6666
pub mod map;
@@ -90,6 +90,7 @@ use std::sync::Arc;
9090

9191
/// Fluent-style API for creating `Expr`s
9292
pub mod expr_fn {
93+
pub use super::array_add::array_add;
9394
pub use super::array_any_match::array_any_match;
9495
pub use super::array_compact::array_compact;
9596
pub use super::array_filter::array_filter;
@@ -173,6 +174,7 @@ pub fn all_default_nested_functions() -> Vec<Arc<ScalarUDF>> {
173174
empty::array_empty_udf(),
174175
length::array_length_udf(),
175176
array_normalize::array_normalize_udf(),
177+
array_add::array_add_udf(),
176178
array_scale::array_scale_udf(),
177179
cosine_distance::cosine_distance_udf(),
178180
inner_product::inner_product_udf(),

datafusion/functions-nested/src/utils.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,57 @@ pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> {
276276
}
277277
}
278278

279+
/// Shared `coerce_types` impl for array-math UDFs whose kernels expect
280+
/// `List<Float64>` / `LargeList<Float64>` (e.g. `array_add`, `cosine_distance`,
281+
/// `inner_product`, `array_normalize`).
282+
///
283+
/// Each input must be `Null`, `List`, `LargeList`, or `FixedSizeList`; otherwise
284+
/// returns a plan error naming `name`. `FixedSizeList` is widened to `List`,
285+
/// `Null` is coerced to a list of `Float64`, and if any input is `LargeList`
286+
/// the rest are widened to `LargeList` so the runtime sees a homogeneous pair.
287+
pub(crate) fn coerce_array_math_arg_types(
288+
name: &str,
289+
arg_types: &[DataType],
290+
) -> Result<Vec<DataType>> {
291+
use DataType::{FixedSizeList, LargeList, List, Null};
292+
use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
293+
294+
let coercion = Some(&ListCoercion::FixedSizedListToList);
295+
296+
for arg_type in arg_types {
297+
if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
298+
return plan_err!("{name} does not support type {arg_type}");
299+
}
300+
}
301+
302+
// If any input is `LargeList`, both sides must be widened to `LargeList`
303+
// so the runtime dispatch in `inner_product_inner` sees a homogeneous
304+
// pair. Follows the pattern in `ArrayConcat::coerce_types`.
305+
let any_large_list = arg_types.iter().any(|t| matches!(t, LargeList(_)));
306+
307+
let coerced = arg_types
308+
.iter()
309+
.map(|arg_type| {
310+
if matches!(arg_type, Null) {
311+
let field = Arc::new(Field::new_list_field(DataType::Float64, true));
312+
return if any_large_list {
313+
LargeList(field)
314+
} else {
315+
List(field)
316+
};
317+
}
318+
let coerced =
319+
coerced_type_with_base_type_only(arg_type, &DataType::Float64, coercion);
320+
match coerced {
321+
List(field) if any_large_list => LargeList(field),
322+
other => other,
323+
}
324+
})
325+
.collect();
326+
327+
Ok(coerced)
328+
}
329+
279330
#[cfg(test)]
280331
mod tests {
281332
use super::*;

0 commit comments

Comments
 (0)