Skip to content

Commit a29be46

Browse files
committed
feat: add array_scale scalar function
Adds `array_scale(array, scalar)` returning a new array with each element multiplied by a scalar. Aliased as `list_scale`. Part of the per-function split sequence on tracking issue #21536, following the pattern of the already-merged PRs in this series. Semantics: - NULL row in array -> NULL row out - NULL element at position i in array -> NULL element at i out (per-element propagation) - NULL scalar -> NULL row out (whole-row, because the scalar applies uniformly to every element; the entire operation is undefined) - Empty array -> empty array First argument is List/LargeList/FixedSizeList of any numeric type. Second argument is a numeric scalar. Both coerce to Float64. List-like inputs follow the same widening rules as the binary-op siblings: LargeList wins, FixedSizeList coerces to List.
1 parent 4a41173 commit a29be46

4 files changed

Lines changed: 450 additions & 0 deletions

File tree

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`ScalarUDFImpl`] definitions for array_scale function.
19+
20+
use crate::utils::make_scalar_function;
21+
use arrow::array::{
22+
Array, ArrayRef, Float64Array, GenericListArray, OffsetBufferBuilder,
23+
OffsetSizeTrait,
24+
};
25+
use arrow::buffer::NullBuffer;
26+
use arrow::datatypes::{
27+
DataType,
28+
DataType::{FixedSizeList, LargeList, List, Null},
29+
Field,
30+
};
31+
use datafusion_common::cast::{as_float64_array, as_generic_list_array};
32+
use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
33+
use datafusion_common::{Result, internal_err, plan_err, utils::take_function_args};
34+
use datafusion_expr::{
35+
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
36+
Volatility,
37+
};
38+
use datafusion_macros::user_doc;
39+
use std::sync::Arc;
40+
41+
make_udf_expr_and_func!(
42+
ArrayScale,
43+
array_scale,
44+
array scalar,
45+
"scales each element of a numeric array by a scalar.",
46+
array_scale_udf
47+
);
48+
49+
#[user_doc(
50+
doc_section(label = "Array Functions"),
51+
description = "Returns a new array with each element of the input array multiplied by a scalar value, computed as `array[i] * scalar`. Returns NULL if the input row is NULL or the scalar is NULL. If a NULL element appears in the input array at position `i`, the result element at position `i` is NULL. Returns an empty array for an empty input array.",
52+
syntax_example = "array_scale(array, scalar)",
53+
sql_example = r#"```sql
54+
> select array_scale([1.0, 2.0, 3.0], 2.0);
55+
+----------------------------------+
56+
| array_scale(List([1.0,2.0,3.0]),Float64(2.0)) |
57+
+----------------------------------+
58+
| [2.0, 4.0, 6.0] |
59+
+----------------------------------+
60+
```"#,
61+
argument(
62+
name = "array",
63+
description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
64+
),
65+
argument(
66+
name = "scalar",
67+
description = "Numeric scalar to multiply each element by. Can be a constant or column expression."
68+
)
69+
)]
70+
#[derive(Debug, PartialEq, Eq, Hash)]
71+
pub struct ArrayScale {
72+
signature: Signature,
73+
aliases: Vec<String>,
74+
}
75+
76+
impl Default for ArrayScale {
77+
fn default() -> Self {
78+
Self::new()
79+
}
80+
}
81+
82+
impl ArrayScale {
83+
pub fn new() -> Self {
84+
Self {
85+
signature: Signature::user_defined(Volatility::Immutable),
86+
aliases: vec!["list_scale".to_string()],
87+
}
88+
}
89+
}
90+
91+
impl ScalarUDFImpl for ArrayScale {
92+
fn name(&self) -> &str {
93+
"array_scale"
94+
}
95+
96+
fn signature(&self) -> &Signature {
97+
&self.signature
98+
}
99+
100+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
101+
// After `coerce_types`, `arg_types[0]` is one of List(Float64) or LargeList(Float64).
102+
Ok(arg_types[0].clone())
103+
}
104+
105+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
106+
let [array_type, scalar_type] = take_function_args(self.name(), arg_types)?;
107+
let coercion = Some(&ListCoercion::FixedSizedListToList);
108+
109+
if !matches!(
110+
array_type,
111+
Null | List(_) | LargeList(_) | FixedSizeList(..)
112+
) {
113+
return plan_err!(
114+
"{} first argument must be a list type, got {array_type}",
115+
self.name()
116+
);
117+
}
118+
119+
if !scalar_type.is_numeric() && !matches!(scalar_type, Null) {
120+
return plan_err!(
121+
"{} second argument must be numeric, got {scalar_type}",
122+
self.name()
123+
);
124+
}
125+
126+
let coerced_array = if matches!(array_type, Null) {
127+
List(Arc::new(Field::new_list_field(DataType::Float64, true)))
128+
} else {
129+
coerced_type_with_base_type_only(array_type, &DataType::Float64, coercion)
130+
};
131+
132+
Ok(vec![coerced_array, DataType::Float64])
133+
}
134+
135+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
136+
make_scalar_function(array_scale_inner)(&args.args)
137+
}
138+
139+
fn aliases(&self) -> &[String] {
140+
&self.aliases
141+
}
142+
143+
fn documentation(&self) -> Option<&Documentation> {
144+
self.doc()
145+
}
146+
}
147+
148+
fn array_scale_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
149+
let [array, scalar] = take_function_args("array_scale", args)?;
150+
match array.data_type() {
151+
List(_) => general_array_scale::<i32>(array, scalar),
152+
LargeList(_) => general_array_scale::<i64>(array, scalar),
153+
arg_type => internal_err!(
154+
"array_scale received unexpected type after coercion: {arg_type}"
155+
),
156+
}
157+
}
158+
159+
fn general_array_scale<O: OffsetSizeTrait>(
160+
array: &ArrayRef,
161+
scalar: &ArrayRef,
162+
) -> Result<ArrayRef> {
163+
let list_array = as_generic_list_array::<O>(array)?;
164+
let scalar_array = as_float64_array(scalar)?;
165+
166+
let values = as_float64_array(list_array.values())?;
167+
let offsets = list_array.value_offsets();
168+
169+
// A row is null whenever either input row is null. The scalar applies
170+
// uniformly across the array, so a null scalar makes the whole row
171+
// undefined; union the two row-level null buffers in a single pass
172+
// rather than tracking row nulls inside the value loop.
173+
let row_nulls = NullBuffer::union(list_array.nulls(), scalar_array.nulls());
174+
175+
let mut value_builder = Float64Array::builder(values.len());
176+
let mut new_offsets = OffsetBufferBuilder::<O>::new(list_array.len());
177+
178+
for row in 0..list_array.len() {
179+
if list_array.is_null(row) || scalar_array.is_null(row) {
180+
new_offsets.push_length(0);
181+
continue;
182+
}
183+
184+
let start = offsets[row].as_usize();
185+
let end = offsets[row + 1].as_usize();
186+
let len = end - start;
187+
let scalar_val = scalar_array.value(row);
188+
189+
let slice = values.slice(start, len);
190+
191+
// Per-element NULL propagation for NULL elements inside the array.
192+
for i in 0..len {
193+
if slice.is_null(i) {
194+
value_builder.append_null();
195+
} else {
196+
value_builder.append_value(slice.value(i) * scalar_val);
197+
}
198+
}
199+
200+
new_offsets.push_length(len);
201+
}
202+
203+
let values_array = Arc::new(value_builder.finish());
204+
205+
// Preserve the inner field from the input array (including any user
206+
// metadata). After `coerce_types` the inner type is Float64, but the
207+
// input may still carry field-level annotations worth keeping.
208+
let field = match list_array.data_type() {
209+
List(f) | LargeList(f) => Arc::clone(f),
210+
other => {
211+
return internal_err!("array_scale unexpected list type: {other}");
212+
}
213+
};
214+
215+
Ok(Arc::new(GenericListArray::<O>::try_new(
216+
field,
217+
new_offsets.finish(),
218+
values_array,
219+
row_nulls,
220+
)?))
221+
}

datafusion/functions-nested/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pub mod array_compact;
4747
pub mod array_filter;
4848
pub mod array_has;
4949
pub mod array_normalize;
50+
pub mod array_scale;
5051
pub mod array_transform;
5152
pub mod arrays_zip;
5253
pub mod cardinality;
@@ -96,6 +97,7 @@ pub mod expr_fn {
9697
pub use super::array_has::array_has_all;
9798
pub use super::array_has::array_has_any;
9899
pub use super::array_normalize::array_normalize;
100+
pub use super::array_scale::array_scale;
99101
pub use super::array_transform::array_transform;
100102
pub use super::arrays_zip::arrays_zip;
101103
pub use super::cardinality::cardinality;
@@ -171,6 +173,7 @@ pub fn all_default_nested_functions() -> Vec<Arc<ScalarUDF>> {
171173
empty::array_empty_udf(),
172174
length::array_length_udf(),
173175
array_normalize::array_normalize_udf(),
176+
array_scale::array_scale_udf(),
174177
cosine_distance::cosine_distance_udf(),
175178
inner_product::inner_product_udf(),
176179
distance::array_distance_udf(),

0 commit comments

Comments
 (0)