Skip to content

Commit 7aba470

Browse files
committed
feat: add array_subtract scalar function
1 parent d3983d3 commit 7aba470

5 files changed

Lines changed: 493 additions & 4 deletions

File tree

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`ScalarUDFImpl`] definitions for array_subtract function.
19+
20+
use crate::utils::{
21+
array_math_binary_op, coerce_array_math_arg_types, make_scalar_function,
22+
};
23+
use arrow::array::ArrayRef;
24+
use arrow::datatypes::{
25+
DataType,
26+
DataType::{LargeList, List},
27+
};
28+
use datafusion_common::{Result, exec_err, utils::take_function_args};
29+
use datafusion_expr::{
30+
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
31+
Volatility,
32+
};
33+
use datafusion_macros::user_doc;
34+
35+
make_udf_expr_and_func!(
36+
ArraySubtract,
37+
array_subtract,
38+
array1 array2,
39+
"returns the element-wise difference of two numeric arrays.",
40+
array_subtract_udf
41+
);
42+
43+
#[user_doc(
44+
doc_section(label = "Array Functions"),
45+
description = "Returns the element-wise difference of two numeric arrays of equal length, computed as `array1[i] - array2[i]` per position. NULL is propagated per element: if either input element at position `i` is NULL, the corresponding output element is NULL (positions are preserved). Returns NULL if either entire input array is NULL. Errors if the per-row lengths differ. Returns an empty array if both inputs are empty.",
46+
syntax_example = "array_subtract(array1, array2)",
47+
sql_example = r#"```sql
48+
> select array_subtract([10.0, 20.0, 30.0], [1.0, 2.0, 3.0]);
49+
+--------------------------------------------------------------+
50+
| array_subtract(List([10.0,20.0,30.0]),List([1.0,2.0,3.0])) |
51+
+--------------------------------------------------------------+
52+
| [9.0, 18.0, 27.0] |
53+
+--------------------------------------------------------------+
54+
```"#,
55+
argument(
56+
name = "array1",
57+
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
58+
),
59+
argument(
60+
name = "array2",
61+
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
62+
)
63+
)]
64+
#[derive(Debug, PartialEq, Eq, Hash)]
65+
pub struct ArraySubtract {
66+
signature: Signature,
67+
aliases: Vec<String>,
68+
}
69+
70+
impl Default for ArraySubtract {
71+
fn default() -> Self {
72+
Self::new()
73+
}
74+
}
75+
76+
impl ArraySubtract {
77+
pub fn new() -> Self {
78+
Self {
79+
signature: Signature::user_defined(Volatility::Immutable),
80+
aliases: vec!["list_subtract".to_string()],
81+
}
82+
}
83+
}
84+
85+
impl ScalarUDFImpl for ArraySubtract {
86+
fn name(&self) -> &str {
87+
"array_subtract"
88+
}
89+
90+
fn signature(&self) -> &Signature {
91+
&self.signature
92+
}
93+
94+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
95+
Ok(arg_types[0].clone())
96+
}
97+
98+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
99+
let [_, _] = take_function_args(self.name(), arg_types)?;
100+
coerce_array_math_arg_types(self.name(), arg_types)
101+
}
102+
103+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
104+
make_scalar_function(array_subtract_inner)(&args.args)
105+
}
106+
107+
fn aliases(&self) -> &[String] {
108+
&self.aliases
109+
}
110+
111+
fn documentation(&self) -> Option<&Documentation> {
112+
self.doc()
113+
}
114+
}
115+
116+
fn array_subtract_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
117+
let [array1, array2] = take_function_args("array_subtract", args)?;
118+
let sub = |a: f64, b: f64| a - b;
119+
match (array1.data_type(), array2.data_type()) {
120+
(List(_), List(_)) => {
121+
array_math_binary_op::<i32, _>("array_subtract", array1, array2, sub)
122+
}
123+
(LargeList(_), LargeList(_)) => {
124+
array_math_binary_op::<i64, _>("array_subtract", array1, array2, sub)
125+
}
126+
(arg_type1, arg_type2) => exec_err!(
127+
"array_subtract received unexpected types after coercion: {arg_type1} and {arg_type2}"
128+
),
129+
}
130+
}

datafusion/functions-nested/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pub mod array_filter;
4747
pub mod array_has;
4848
pub mod array_normalize;
4949
pub mod array_scale;
50+
pub mod array_subtract;
5051
pub mod array_transform;
5152
pub mod arrays_zip;
5253
pub mod cardinality;
@@ -99,6 +100,7 @@ pub mod expr_fn {
99100
pub use super::array_has::array_has_any;
100101
pub use super::array_normalize::array_normalize;
101102
pub use super::array_scale::array_scale;
103+
pub use super::array_subtract::array_subtract;
102104
pub use super::array_transform::array_transform;
103105
pub use super::arrays_zip::arrays_zip;
104106
pub use super::cardinality::cardinality;
@@ -176,6 +178,7 @@ pub fn all_default_nested_functions() -> Vec<Arc<ScalarUDF>> {
176178
array_normalize::array_normalize_udf(),
177179
array_add::array_add_udf(),
178180
array_scale::array_scale_udf(),
181+
array_subtract::array_subtract_udf(),
179182
cosine_distance::cosine_distance_udf(),
180183
inner_product::inner_product_udf(),
181184
distance::array_distance_udf(),

datafusion/functions-nested/src/utils.rs

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ use std::sync::Arc;
2222
use arrow::datatypes::{DataType, Field, Fields};
2323

2424
use arrow::array::{
25-
Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar,
25+
Array, ArrayRef, BooleanArray, Float64Array, GenericListArray, NullBufferBuilder,
26+
OffsetBufferBuilder, OffsetSizeTrait, Scalar,
2627
};
27-
use arrow::buffer::OffsetBuffer;
28+
use arrow::buffer::{NullBuffer, OffsetBuffer};
2829
use datafusion_common::cast::{
29-
as_fixed_size_list_array, as_large_list_array, as_large_list_view_array,
30-
as_list_array, as_list_view_array,
30+
as_fixed_size_list_array, as_float64_array, as_generic_list_array,
31+
as_large_list_array, as_large_list_view_array, as_list_array, as_list_view_array,
3132
};
3233
use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err};
3334

@@ -327,6 +328,90 @@ pub(crate) fn coerce_array_math_arg_types(
327328
Ok(coerced)
328329
}
329330

331+
/// Element-wise binary operation kernel for two `Float64` lists of equal per-row
332+
/// length. The caller is responsible for type-dispatching on `O` (`i32` for
333+
/// `List`, `i64` for `LargeList`).
334+
///
335+
/// Semantics:
336+
/// - whole-row NULL on either side → NULL output row, length 0
337+
/// - per-element NULL on either side → NULL at that output position
338+
/// - per-row length mismatch → exec error tagged with `op_name`
339+
///
340+
/// `op_name` flows into the error message; `op` is the per-element scalar op
341+
/// (e.g. `|a, b| a + b` for `array_add`, `|a, b| a - b` for `array_subtract`).
342+
pub(crate) fn array_math_binary_op<O, F>(
343+
op_name: &str,
344+
lhs: &ArrayRef,
345+
rhs: &ArrayRef,
346+
op: F,
347+
) -> Result<ArrayRef>
348+
where
349+
O: OffsetSizeTrait,
350+
F: Fn(f64, f64) -> f64,
351+
{
352+
let lhs = as_generic_list_array::<O>(lhs)?;
353+
let rhs = as_generic_list_array::<O>(rhs)?;
354+
355+
let lhs_values = as_float64_array(lhs.values())?;
356+
let rhs_values = as_float64_array(rhs.values())?;
357+
let lhs_offsets = lhs.value_offsets();
358+
let rhs_offsets = rhs.value_offsets();
359+
360+
let row_nulls = NullBuffer::union(lhs.nulls(), rhs.nulls());
361+
362+
let mut out_values: Vec<f64> = Vec::with_capacity(lhs_values.len());
363+
let mut out_inner_nulls = NullBufferBuilder::new(lhs_values.len());
364+
let mut out_offsets = OffsetBufferBuilder::<O>::new(lhs.len());
365+
366+
for row in 0..lhs.len() {
367+
if row_nulls.as_ref().is_some_and(|nb| nb.is_null(row)) {
368+
out_offsets.push_length(0);
369+
continue;
370+
}
371+
372+
let start1 = lhs_offsets[row].as_usize();
373+
let len1 = lhs.value_length(row).as_usize();
374+
let start2 = rhs_offsets[row].as_usize();
375+
let len2 = rhs.value_length(row).as_usize();
376+
377+
if len1 != len2 {
378+
return exec_err!(
379+
"{op_name} requires both list inputs to have the same length per row, got {len1} and {len2} at row {row}"
380+
);
381+
}
382+
383+
let l_slice = lhs_values.slice(start1, len1);
384+
let r_slice = rhs_values.slice(start2, len2);
385+
386+
let l_vals = l_slice.values();
387+
let r_vals = r_slice.values();
388+
389+
for i in 0..len1 {
390+
out_values.push(op(l_vals[i], r_vals[i]));
391+
}
392+
393+
match NullBuffer::union(l_slice.nulls(), r_slice.nulls()) {
394+
Some(nb) => out_inner_nulls.append_buffer(&nb),
395+
None => out_inner_nulls.append_n_non_nulls(len1),
396+
}
397+
398+
out_offsets.push_length(len1);
399+
}
400+
401+
let values_array = Arc::new(Float64Array::new(
402+
out_values.into(),
403+
out_inner_nulls.finish(),
404+
));
405+
let field = Arc::new(Field::new_list_field(DataType::Float64, true));
406+
407+
Ok(Arc::new(GenericListArray::<O>::try_new(
408+
field,
409+
out_offsets.finish(),
410+
values_array,
411+
row_nulls,
412+
)?))
413+
}
414+
330415
#[cfg(test)]
331416
mod tests {
332417
use super::*;

0 commit comments

Comments
 (0)