|
17 | 17 |
|
18 | 18 | //! [`ScalarUDFImpl`] definitions for array_add function. |
19 | 19 |
|
20 | | -use crate::utils::make_scalar_function; |
| 20 | +use crate::utils::{coerce_array_math_arg_types, make_scalar_function}; |
21 | 21 | use arrow::array::{ |
22 | 22 | Array, ArrayRef, Float64Array, GenericListArray, NullBufferBuilder, |
23 | 23 | OffsetBufferBuilder, OffsetSizeTrait, |
24 | 24 | }; |
25 | 25 | use arrow::buffer::NullBuffer; |
26 | 26 | use arrow::datatypes::{ |
27 | 27 | DataType, |
28 | | - DataType::{FixedSizeList, LargeList, List, Null}, |
| 28 | + DataType::{LargeList, List}, |
29 | 29 | Field, |
30 | 30 | }; |
31 | 31 | use datafusion_common::cast::{as_float64_array, as_generic_list_array}; |
32 | | -use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; |
33 | | -use datafusion_common::{ |
34 | | - Result, exec_err, not_impl_err, plan_err, utils::take_function_args, |
35 | | -}; |
| 32 | +use datafusion_common::{Result, exec_err, utils::take_function_args}; |
36 | 33 | use datafusion_expr::{ |
37 | 34 | ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, |
38 | 35 | Volatility, |
@@ -106,67 +103,7 @@ impl ScalarUDFImpl for ArrayAdd { |
106 | 103 |
|
107 | 104 | fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { |
108 | 105 | let [_, _] = take_function_args(self.name(), arg_types)?; |
109 | | - let coercion = Some(&ListCoercion::FixedSizedListToList); |
110 | | - |
111 | | - for arg_type in arg_types { |
112 | | - if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { |
113 | | - return plan_err!("{} does not support type {arg_type}", self.name()); |
114 | | - } |
115 | | - // Only flat lists of numeric leaves are supported. Peel exactly |
116 | | - // one layer so that `List<List<_>>` is rejected as non-numeric |
117 | | - // rather than passing through and failing opaquely in the kernel. |
118 | | - let element_type = match arg_type { |
119 | | - List(field) | LargeList(field) | FixedSizeList(field, _) => { |
120 | | - field.data_type() |
121 | | - } |
122 | | - other => other, |
123 | | - }; |
124 | | - if matches!( |
125 | | - element_type, |
126 | | - DataType::Decimal128(_, _) | DataType::Decimal256(_, _) |
127 | | - ) { |
128 | | - return not_impl_err!( |
129 | | - "{} does not yet support decimal element types ({element_type}); \ |
130 | | - cast to DOUBLE explicitly to opt into lossy float arithmetic", |
131 | | - self.name() |
132 | | - ); |
133 | | - } |
134 | | - if !matches!(element_type, Null) && !element_type.is_numeric() { |
135 | | - return plan_err!( |
136 | | - "{} requires numeric array elements, got list of {element_type}", |
137 | | - self.name() |
138 | | - ); |
139 | | - } |
140 | | - } |
141 | | - |
142 | | - // If either side is `LargeList`, widen both to `LargeList` so the runtime |
143 | | - // dispatch sees a homogeneous pair. |
144 | | - let any_large_list = arg_types.iter().any(|t| matches!(t, LargeList(_))); |
145 | | - |
146 | | - let coerced = arg_types |
147 | | - .iter() |
148 | | - .map(|arg_type| { |
149 | | - if matches!(arg_type, Null) { |
150 | | - let field = Arc::new(Field::new_list_field(DataType::Float64, true)); |
151 | | - return if any_large_list { |
152 | | - LargeList(field) |
153 | | - } else { |
154 | | - List(field) |
155 | | - }; |
156 | | - } |
157 | | - let coerced = coerced_type_with_base_type_only( |
158 | | - arg_type, |
159 | | - &DataType::Float64, |
160 | | - coercion, |
161 | | - ); |
162 | | - match coerced { |
163 | | - List(field) if any_large_list => LargeList(field), |
164 | | - other => other, |
165 | | - } |
166 | | - }) |
167 | | - .collect(); |
168 | | - |
169 | | - Ok(coerced) |
| 106 | + coerce_array_math_arg_types(self.name(), arg_types) |
170 | 107 | } |
171 | 108 |
|
172 | 109 | fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { |
|
0 commit comments