Skip to content

Commit 681185e

Browse files
committed
Adds utility function
1 parent 8c7b90e commit 681185e

3 files changed

Lines changed: 59 additions & 95 deletions

File tree

datafusion/functions-nested/src/array_add.rs

Lines changed: 4 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,19 @@
1717

1818
//! [`ScalarUDFImpl`] definitions for array_add function.
1919
20-
use crate::utils::make_scalar_function;
20+
use crate::utils::{coerce_array_math_arg_types, make_scalar_function};
2121
use arrow::array::{
2222
Array, ArrayRef, Float64Array, GenericListArray, NullBufferBuilder,
2323
OffsetBufferBuilder, OffsetSizeTrait,
2424
};
2525
use arrow::buffer::NullBuffer;
2626
use arrow::datatypes::{
2727
DataType,
28-
DataType::{FixedSizeList, LargeList, List, Null},
28+
DataType::{LargeList, List},
2929
Field,
3030
};
3131
use datafusion_common::cast::{as_float64_array, as_generic_list_array};
32-
use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
33-
use datafusion_common::{
34-
Result, exec_err, not_impl_err, plan_err, utils::take_function_args,
35-
};
32+
use datafusion_common::{Result, exec_err, utils::take_function_args};
3633
use datafusion_expr::{
3734
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
3835
Volatility,
@@ -106,67 +103,7 @@ impl ScalarUDFImpl for ArrayAdd {
106103

107104
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
108105
let [_, _] = take_function_args(self.name(), arg_types)?;
109-
let coercion = Some(&ListCoercion::FixedSizedListToList);
110-
111-
for arg_type in arg_types {
112-
if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
113-
return plan_err!("{} does not support type {arg_type}", self.name());
114-
}
115-
// Only flat lists of numeric leaves are supported. Peel exactly
116-
// one layer so that `List<List<_>>` is rejected as non-numeric
117-
// rather than passing through and failing opaquely in the kernel.
118-
let element_type = match arg_type {
119-
List(field) | LargeList(field) | FixedSizeList(field, _) => {
120-
field.data_type()
121-
}
122-
other => other,
123-
};
124-
if matches!(
125-
element_type,
126-
DataType::Decimal128(_, _) | DataType::Decimal256(_, _)
127-
) {
128-
return not_impl_err!(
129-
"{} does not yet support decimal element types ({element_type}); \
130-
cast to DOUBLE explicitly to opt into lossy float arithmetic",
131-
self.name()
132-
);
133-
}
134-
if !matches!(element_type, Null) && !element_type.is_numeric() {
135-
return plan_err!(
136-
"{} requires numeric array elements, got list of {element_type}",
137-
self.name()
138-
);
139-
}
140-
}
141-
142-
// If either side is `LargeList`, widen both to `LargeList` so the runtime
143-
// dispatch sees a homogeneous pair.
144-
let any_large_list = arg_types.iter().any(|t| matches!(t, LargeList(_)));
145-
146-
let coerced = arg_types
147-
.iter()
148-
.map(|arg_type| {
149-
if matches!(arg_type, Null) {
150-
let field = Arc::new(Field::new_list_field(DataType::Float64, true));
151-
return if any_large_list {
152-
LargeList(field)
153-
} else {
154-
List(field)
155-
};
156-
}
157-
let coerced = coerced_type_with_base_type_only(
158-
arg_type,
159-
&DataType::Float64,
160-
coercion,
161-
);
162-
match coerced {
163-
List(field) if any_large_list => LargeList(field),
164-
other => other,
165-
}
166-
})
167-
.collect();
168-
169-
Ok(coerced)
106+
coerce_array_math_arg_types(self.name(), arg_types)
170107
}
171108

172109
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {

datafusion/functions-nested/src/utils.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,57 @@ pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> {
276276
}
277277
}
278278

279+
/// Shared `coerce_types` impl for array-math UDFs whose kernels expect
280+
/// `List<Float64>` / `LargeList<Float64>` (e.g. `array_add`, `cosine_distance`,
281+
/// `inner_product`, `array_normalize`).
282+
///
283+
/// Each input must be `Null`, `List`, `LargeList`, or `FixedSizeList`; otherwise
284+
/// returns a plan error naming `name`. `FixedSizeList` is widened to `List`,
285+
/// `Null` is coerced to a list of `Float64`, and if any input is `LargeList`
286+
/// the rest are widened to `LargeList` so the runtime sees a homogeneous pair.
287+
pub(crate) fn coerce_array_math_arg_types(
288+
name: &str,
289+
arg_types: &[DataType],
290+
) -> Result<Vec<DataType>> {
291+
use DataType::{FixedSizeList, LargeList, List, Null};
292+
use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only};
293+
294+
let coercion = Some(&ListCoercion::FixedSizedListToList);
295+
296+
for arg_type in arg_types {
297+
if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) {
298+
return plan_err!("{name} does not support type {arg_type}");
299+
}
300+
}
301+
302+
// If any input is `LargeList`, both sides must be widened to `LargeList`
303+
// so the runtime dispatch in `inner_product_inner` sees a homogeneous
304+
// pair. Follows the pattern in `ArrayConcat::coerce_types`.
305+
let any_large_list = arg_types.iter().any(|t| matches!(t, LargeList(_)));
306+
307+
let coerced = arg_types
308+
.iter()
309+
.map(|arg_type| {
310+
if matches!(arg_type, Null) {
311+
let field = Arc::new(Field::new_list_field(DataType::Float64, true));
312+
return if any_large_list {
313+
LargeList(field)
314+
} else {
315+
List(field)
316+
};
317+
}
318+
let coerced =
319+
coerced_type_with_base_type_only(arg_type, &DataType::Float64, coercion);
320+
match coerced {
321+
List(field) if any_large_list => LargeList(field),
322+
other => other,
323+
}
324+
})
325+
.collect();
326+
327+
Ok(coerced)
328+
}
329+
279330
#[cfg(test)]
280331
mod tests {
281332
use super::*;

datafusion/sqllogictest/test_files/array_add.slt

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -148,18 +148,6 @@ select array_add(
148148
query error array_add does not support type
149149
select array_add(1, [1.0, 2.0]);
150150

151-
# Non-numeric leaf type rejected at planning
152-
query error array_add requires numeric array elements
153-
select array_add(['a', 'b'], ['c', 'd']);
154-
155-
# Boolean leaf type rejected at planning
156-
query error array_add requires numeric array elements
157-
select array_add(arrow_cast([true, false], 'List(Boolean)'), arrow_cast([false, true], 'List(Boolean)'));
158-
159-
# Nested list rejected at planning (leaf is itself a list, not numeric)
160-
query error array_add requires numeric array elements
161-
select array_add([[1, 2], [3, 4]], [[5, 6], [7, 8]]);
162-
163151
# Wrong arg count
164152
query error array_add function requires 2 arguments, got 0
165153
select array_add();
@@ -203,26 +191,14 @@ select list_add(a, b) from (values
203191
[11.0, 22.0]
204192
NULL
205193

206-
# Decimal128 leaf rejected at planning (would lose precision via Float64 cast)
207-
query error array_add does not yet support decimal element types
194+
# Decimal element types are coerced to Float64 (lossy) like other array-math UDFs
195+
query ?
208196
select array_add(
209197
arrow_cast([1, 2, 3], 'List(Decimal128(10, 2))'),
210198
arrow_cast([10, 20, 30], 'List(Decimal128(10, 2))')
211199
);
212-
213-
# Decimal128 mixed with Float64 also rejected
214-
query error array_add does not yet support decimal element types
215-
select array_add(
216-
arrow_cast([1, 2, 3], 'List(Decimal128(10, 2))'),
217-
[10.0, 20.0, 30.0]
218-
);
219-
220-
# Decimal256 leaf also rejected
221-
query error array_add does not yet support decimal element types
222-
select array_add(
223-
arrow_cast([1, 2, 3], 'List(Decimal256(20, 2))'),
224-
arrow_cast([10, 20, 30], 'List(Decimal256(20, 2))')
225-
);
200+
----
201+
[11.0, 22.0, 33.0]
226202

227203
# Explicit cast to DOUBLE works as the documented opt-in
228204
query ?

0 commit comments

Comments
 (0)