Skip to content

Commit bf052a9

Browse files
committed
feat: add array_scale scalar function
Adds `array_scale(array, scalar)` returning a new array with each element multiplied by a scalar. Aliased as `list_scale`. Part of the per-function split sequence on tracking issue #21536, following the pattern of the already-merged PRs in this series. Semantics: - NULL row in array -> NULL row out - NULL element at position i in array -> NULL element at i out (per-element propagation) - NULL scalar -> NULL row out (whole-row, because the scalar applies uniformly to every element; the entire operation is undefined) - Empty array -> empty array First argument is List/LargeList/FixedSizeList of any numeric type. Second argument is a numeric scalar. Both coerce to Float64. List-like inputs follow the same widening rules as the binary-op siblings: LargeList wins, FixedSizeList coerces to List.
1 parent 4a41173 commit bf052a9

4 files changed

Lines changed: 449 additions & 0 deletions

File tree

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`ScalarUDFImpl`] definitions for array_scale function.
19+
20+
use crate::utils::make_scalar_function;
21+
use arrow::array::{
22+
Array, ArrayRef, Float64Array, GenericListArray, OffsetBufferBuilder, OffsetSizeTrait,
23+
};
24+
use arrow::buffer::NullBuffer;
25+
use arrow::datatypes::{
26+
DataType,
27+
DataType::{FixedSizeList, LargeList, List, Null},
28+
Field,
29+
};
30+
use datafusion_common::cast::{as_float64_array, as_generic_list_array};
31+
use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
32+
use datafusion_common::{Result, internal_err, plan_err, utils::take_function_args};
33+
use datafusion_expr::{
34+
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
35+
Volatility,
36+
};
37+
use datafusion_macros::user_doc;
38+
use std::sync::Arc;
39+
40+
make_udf_expr_and_func!(
41+
ArrayScale,
42+
array_scale,
43+
array scalar,
44+
"scales each element of a numeric array by a scalar.",
45+
array_scale_udf
46+
);
47+
48+
#[user_doc(
49+
doc_section(label = "Array Functions"),
50+
description = "Returns a new array with each element of the input array multiplied by a scalar value, computed as `array[i] * scalar`. Returns NULL if the input row is NULL or the scalar is NULL. If a NULL element appears in the input array at position `i`, the result element at position `i` is NULL. Returns an empty array for an empty input array.",
51+
syntax_example = "array_scale(array, scalar)",
52+
sql_example = r#"```sql
53+
> select array_scale([1.0, 2.0, 3.0], 2.0);
54+
+----------------------------------+
55+
| array_scale(List([1.0,2.0,3.0]),Float64(2.0)) |
56+
+----------------------------------+
57+
| [2.0, 4.0, 6.0] |
58+
+----------------------------------+
59+
```"#,
60+
argument(
61+
name = "array",
62+
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
63+
),
64+
argument(
65+
name = "scalar",
66+
description = "Numeric scalar to multiply each element by. Can be a constant or column expression."
67+
)
68+
)]
69+
#[derive(Debug, PartialEq, Eq, Hash)]
70+
pub struct ArrayScale {
71+
signature: Signature,
72+
aliases: Vec<String>,
73+
}
74+
75+
impl Default for ArrayScale {
76+
fn default() -> Self {
77+
Self::new()
78+
}
79+
}
80+
81+
impl ArrayScale {
82+
pub fn new() -> Self {
83+
Self {
84+
signature: Signature::user_defined(Volatility::Immutable),
85+
aliases: vec!["list_scale".to_string()],
86+
}
87+
}
88+
}
89+
90+
impl ScalarUDFImpl for ArrayScale {
91+
fn name(&self) -> &str {
92+
"array_scale"
93+
}
94+
95+
fn signature(&self) -> &Signature {
96+
&self.signature
97+
}
98+
99+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
100+
// After `coerce_types`, `arg_types[0]` is one of List(Float64) or LargeList(Float64).
101+
Ok(arg_types[0].clone())
102+
}
103+
104+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
105+
let [array_type, scalar_type] = take_function_args(self.name(), arg_types)?;
106+
let coercion = Some(&ListCoercion::FixedSizedListToList);
107+
108+
if !matches!(
109+
array_type,
110+
Null | List(_) | LargeList(_) | FixedSizeList(..)
111+
) {
112+
return plan_err!(
113+
"{} first argument must be a list type, got {array_type}",
114+
self.name()
115+
);
116+
}
117+
118+
if !scalar_type.is_numeric() && !matches!(scalar_type, Null) {
119+
return plan_err!(
120+
"{} second argument must be numeric, got {scalar_type}",
121+
self.name()
122+
);
123+
}
124+
125+
let coerced_array = if matches!(array_type, Null) {
126+
List(Arc::new(Field::new_list_field(DataType::Float64, true)))
127+
} else {
128+
coerced_type_with_base_type_only(array_type, &DataType::Float64, coercion)
129+
};
130+
131+
Ok(vec![coerced_array, DataType::Float64])
132+
}
133+
134+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
135+
make_scalar_function(array_scale_inner)(&args.args)
136+
}
137+
138+
fn aliases(&self) -> &[String] {
139+
&self.aliases
140+
}
141+
142+
fn documentation(&self) -> Option<&Documentation> {
143+
self.doc()
144+
}
145+
}
146+
147+
fn array_scale_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
148+
let [array, scalar] = take_function_args("array_scale", args)?;
149+
match array.data_type() {
150+
List(_) => general_array_scale::<i32>(array, scalar),
151+
LargeList(_) => general_array_scale::<i64>(array, scalar),
152+
arg_type => internal_err!(
153+
"array_scale received unexpected type after coercion: {arg_type}"
154+
),
155+
}
156+
}
157+
158+
fn general_array_scale<O: OffsetSizeTrait>(
159+
array: &ArrayRef,
160+
scalar: &ArrayRef,
161+
) -> Result<ArrayRef> {
162+
let list_array = as_generic_list_array::<O>(array)?;
163+
let scalar_array = as_float64_array(scalar)?;
164+
165+
let values = as_float64_array(list_array.values())?;
166+
let offsets = list_array.value_offsets();
167+
168+
// A row is null whenever either input row is null. The scalar applies
169+
// uniformly across the array, so a null scalar makes the whole row
170+
// undefined; union the two row-level null buffers in a single pass
171+
// rather than tracking row nulls inside the value loop.
172+
let row_nulls = NullBuffer::union(list_array.nulls(), scalar_array.nulls());
173+
174+
let mut value_builder = Float64Array::builder(values.len());
175+
let mut new_offsets = OffsetBufferBuilder::<O>::new(list_array.len());
176+
177+
for row in 0..list_array.len() {
178+
if list_array.is_null(row) || scalar_array.is_null(row) {
179+
new_offsets.push_length(0);
180+
continue;
181+
}
182+
183+
let start = offsets[row].as_usize();
184+
let end = offsets[row + 1].as_usize();
185+
let len = end - start;
186+
let scalar_val = scalar_array.value(row);
187+
188+
let slice = values.slice(start, len);
189+
190+
// Per-element NULL propagation for NULL elements inside the array.
191+
for i in 0..len {
192+
if slice.is_null(i) {
193+
value_builder.append_null();
194+
} else {
195+
value_builder.append_value(slice.value(i) * scalar_val);
196+
}
197+
}
198+
199+
new_offsets.push_length(len);
200+
}
201+
202+
let values_array = Arc::new(value_builder.finish());
203+
204+
// Preserve the inner field from the input array (including any user
205+
// metadata). After `coerce_types` the inner type is Float64, but the
206+
// input may still carry field-level annotations worth keeping.
207+
let field = match list_array.data_type() {
208+
List(f) | LargeList(f) => Arc::clone(f),
209+
other => {
210+
return internal_err!("array_scale unexpected list type: {other}");
211+
}
212+
};
213+
214+
Ok(Arc::new(GenericListArray::<O>::try_new(
215+
field,
216+
new_offsets.finish(),
217+
values_array,
218+
row_nulls,
219+
)?))
220+
}

datafusion/functions-nested/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pub mod array_compact;
4747
pub mod array_filter;
4848
pub mod array_has;
4949
pub mod array_normalize;
50+
pub mod array_scale;
5051
pub mod array_transform;
5152
pub mod arrays_zip;
5253
pub mod cardinality;
@@ -96,6 +97,7 @@ pub mod expr_fn {
9697
pub use super::array_has::array_has_all;
9798
pub use super::array_has::array_has_any;
9899
pub use super::array_normalize::array_normalize;
100+
pub use super::array_scale::array_scale;
99101
pub use super::array_transform::array_transform;
100102
pub use super::arrays_zip::arrays_zip;
101103
pub use super::cardinality::cardinality;
@@ -171,6 +173,7 @@ pub fn all_default_nested_functions() -> Vec<Arc<ScalarUDF>> {
171173
empty::array_empty_udf(),
172174
length::array_length_udf(),
173175
array_normalize::array_normalize_udf(),
176+
array_scale::array_scale_udf(),
174177
cosine_distance::cosine_distance_udf(),
175178
inner_product::inner_product_udf(),
176179
distance::array_distance_udf(),

0 commit comments

Comments
 (0)