Skip to content

Commit 8986da2

Browse files
committed
fix: address review feedback for array_position
- Compute combined null buffer upfront via NullBuffer::union and use Vec<i64> with Int64Array::new() instead of Vec<Option<i64>>, avoiding per-row null tracking overhead in all typed paths - Use TypeSignature::Any(2) instead of variadic_any for precise arity - Replace .unwrap() on downcast with .ok_or_else() for safer error handling - Add nested array test cases to exercise position_fallback code path
1 parent 0f2e41e commit 8986da2

2 files changed

Lines changed: 50 additions & 40 deletions

File tree

native/spark-expr/src/array_funcs/array_position.rs

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
use arrow::array::{
1919
Array, ArrayRef, AsArray, BooleanArray, GenericListArray, Int64Array, OffsetSizeTrait,
2020
};
21+
use arrow::buffer::{NullBuffer, ScalarBuffer};
2122
use arrow::datatypes::{
2223
ArrowPrimitiveType, DataType, Date32Type, Decimal128Type, Float32Type, Float64Type, Int16Type,
2324
Int32Type, Int64Type, Int8Type, TimestampMicrosecondType,
2425
};
2526
use datafusion::common::{exec_err, DataFusionError, Result as DataFusionResult, ScalarValue};
2627
use datafusion::logical_expr::{
27-
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
28+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
2829
};
2930
use num::Float;
3031
use std::any::Any;
@@ -77,7 +78,7 @@ fn generic_array_position<O: OffsetSizeTrait>(
7778
let list_array = array
7879
.as_any()
7980
.downcast_ref::<GenericListArray<O>>()
80-
.unwrap();
81+
.ok_or_else(|| DataFusionError::Internal("expected list array".into()))?;
8182

8283
let values = list_array.values();
8384
let offsets = list_array.offsets();
@@ -107,6 +108,16 @@ fn generic_array_position<O: OffsetSizeTrait>(
107108
}
108109
}
109110

111+
/// Compute the combined null buffer from list array and element nulls.
112+
fn combined_nulls(list_array_nulls: Option<&NullBuffer>, element_nulls: Option<&NullBuffer>) -> Option<NullBuffer> {
113+
match (list_array_nulls, element_nulls) {
114+
(Some(a), Some(b)) => NullBuffer::union(Some(a), Some(b)),
115+
(Some(a), None) => Some(a.clone()),
116+
(None, Some(b)) => Some(b.clone()),
117+
(None, None) => None,
118+
}
119+
}
120+
110121
/// Fast path for primitive types: downcast once, iterate using offsets into the flat buffer.
111122
fn position_primitive<O: OffsetSizeTrait, T: ArrowPrimitiveType>(
112123
list_array: &GenericListArray<O>,
@@ -120,27 +131,25 @@ where
120131
let values_typed = values.as_primitive::<T>();
121132
let element_typed = element.as_primitive::<T>();
122133
let num_rows = list_array.len();
123-
let mut result = Vec::with_capacity(num_rows);
134+
let nulls = combined_nulls(list_array.nulls(), element.nulls());
135+
let mut result = vec![0i64; num_rows];
124136

125137
for (row_index, w) in offsets.windows(2).enumerate() {
126-
if list_array.is_null(row_index) || element.is_null(row_index) {
127-
result.push(None);
138+
if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) {
128139
continue;
129140
}
130141
let start = w[0].as_usize();
131142
let end = w[1].as_usize();
132143
let search_val = element_typed.value(row_index);
133-
let mut pos: i64 = 0;
134144
for i in start..end {
135145
if !values_typed.is_null(i) && values_typed.value(i) == search_val {
136-
pos = (i - start + 1) as i64;
146+
result[row_index] = (i - start + 1) as i64;
137147
break;
138148
}
139149
}
140-
result.push(Some(pos));
141150
}
142151

143-
Ok(Arc::new(Int64Array::from(result)))
152+
Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls)))
144153
}
145154

146155
/// Float path: same as primitive but treats NaN == NaN (Spark's ordering.equiv() semantics).
@@ -156,31 +165,29 @@ where
156165
let values_typed = values.as_primitive::<T>();
157166
let element_typed = element.as_primitive::<T>();
158167
let num_rows = list_array.len();
159-
let mut result = Vec::with_capacity(num_rows);
168+
let nulls = combined_nulls(list_array.nulls(), element.nulls());
169+
let mut result = vec![0i64; num_rows];
160170

161171
for (row_index, w) in offsets.windows(2).enumerate() {
162-
if list_array.is_null(row_index) || element.is_null(row_index) {
163-
result.push(None);
172+
if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) {
164173
continue;
165174
}
166175
let start = w[0].as_usize();
167176
let end = w[1].as_usize();
168177
let search_val = element_typed.value(row_index);
169178
let search_is_nan = search_val.is_nan();
170-
let mut pos: i64 = 0;
171179
for i in start..end {
172180
if !values_typed.is_null(i) {
173181
let v = values_typed.value(i);
174182
if (search_is_nan && v.is_nan()) || v == search_val {
175-
pos = (i - start + 1) as i64;
183+
result[row_index] = (i - start + 1) as i64;
176184
break;
177185
}
178186
}
179187
}
180-
result.push(Some(pos));
181188
}
182189

183-
Ok(Arc::new(Int64Array::from(result)))
190+
Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls)))
184191
}
185192

186193
/// Boolean path.
@@ -190,30 +197,30 @@ fn position_boolean<O: OffsetSizeTrait>(
190197
values: &ArrayRef,
191198
element: &ArrayRef,
192199
) -> Result<ArrayRef, DataFusionError> {
193-
let values_typed = values.as_any().downcast_ref::<BooleanArray>().unwrap();
194-
let element_typed = element.as_any().downcast_ref::<BooleanArray>().unwrap();
200+
let values_typed = values.as_any().downcast_ref::<BooleanArray>()
201+
.ok_or_else(|| DataFusionError::Internal("expected boolean array".into()))?;
202+
let element_typed = element.as_any().downcast_ref::<BooleanArray>()
203+
.ok_or_else(|| DataFusionError::Internal("expected boolean array".into()))?;
195204
let num_rows = list_array.len();
196-
let mut result = Vec::with_capacity(num_rows);
205+
let nulls = combined_nulls(list_array.nulls(), element.nulls());
206+
let mut result = vec![0i64; num_rows];
197207

198208
for (row_index, w) in offsets.windows(2).enumerate() {
199-
if list_array.is_null(row_index) || element.is_null(row_index) {
200-
result.push(None);
209+
if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) {
201210
continue;
202211
}
203212
let start = w[0].as_usize();
204213
let end = w[1].as_usize();
205214
let search_val = element_typed.value(row_index);
206-
let mut pos: i64 = 0;
207215
for i in start..end {
208216
if !values_typed.is_null(i) && values_typed.value(i) == search_val {
209-
pos = (i - start + 1) as i64;
217+
result[row_index] = (i - start + 1) as i64;
210218
break;
211219
}
212220
}
213-
result.push(Some(pos));
214221
}
215222

216-
Ok(Arc::new(Int64Array::from(result)))
223+
Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls)))
217224
}
218225

219226
/// String path: downcast once, iterate using offsets into the flat string buffer.
@@ -226,27 +233,25 @@ fn position_string<O: OffsetSizeTrait, S: OffsetSizeTrait>(
226233
let values_typed = values.as_string::<S>();
227234
let element_typed = element.as_string::<S>();
228235
let num_rows = list_array.len();
229-
let mut result = Vec::with_capacity(num_rows);
236+
let nulls = combined_nulls(list_array.nulls(), element.nulls());
237+
let mut result = vec![0i64; num_rows];
230238

231239
for (row_index, w) in offsets.windows(2).enumerate() {
232-
if list_array.is_null(row_index) || element.is_null(row_index) {
233-
result.push(None);
240+
if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) {
234241
continue;
235242
}
236243
let start = w[0].as_usize();
237244
let end = w[1].as_usize();
238245
let search_val = element_typed.value(row_index);
239-
let mut pos: i64 = 0;
240246
for i in start..end {
241247
if !values_typed.is_null(i) && values_typed.value(i) == search_val {
242-
pos = (i - start + 1) as i64;
248+
result[row_index] = (i - start + 1) as i64;
243249
break;
244250
}
245251
}
246-
result.push(Some(pos));
247252
}
248253

249-
Ok(Arc::new(Int64Array::from(result)))
254+
Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls)))
250255
}
251256

252257
/// Fallback for complex types (nested arrays, structs, etc.) using ScalarValue comparison.
@@ -257,30 +262,28 @@ fn position_fallback<O: OffsetSizeTrait>(
257262
element: &ArrayRef,
258263
) -> Result<ArrayRef, DataFusionError> {
259264
let num_rows = list_array.len();
260-
let mut result = Vec::with_capacity(num_rows);
265+
let nulls = combined_nulls(list_array.nulls(), element.nulls());
266+
let mut result = vec![0i64; num_rows];
261267

262268
for (row_index, w) in offsets.windows(2).enumerate() {
263-
if list_array.is_null(row_index) || element.is_null(row_index) {
264-
result.push(None);
269+
if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) {
265270
continue;
266271
}
267272
let start = w[0].as_usize();
268273
let end = w[1].as_usize();
269274
let search_scalar = ScalarValue::try_from_array(element, row_index)?;
270-
let mut pos: i64 = 0;
271275
for i in start..end {
272276
if !values.is_null(i) {
273277
let item_scalar = ScalarValue::try_from_array(values, i)?;
274278
if search_scalar == item_scalar {
275-
pos = (i - start + 1) as i64;
279+
result[row_index] = (i - start + 1) as i64;
276280
break;
277281
}
278282
}
279283
}
280-
result.push(Some(pos));
281284
}
282285

283-
Ok(Arc::new(Int64Array::from(result)))
286+
Ok(Arc::new(Int64Array::new(ScalarBuffer::from(result), nulls)))
284287
}
285288

286289
#[derive(Debug, Hash, Eq, PartialEq)]
@@ -297,7 +300,7 @@ impl Default for SparkArrayPositionFunc {
297300
impl SparkArrayPositionFunc {
298301
pub fn new() -> Self {
299302
Self {
300-
signature: Signature::variadic_any(Volatility::Immutable),
303+
signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
301304
}
302305
}
303306
}

spark/src/test/resources/sql-tests/expressions/array/array_position.sql

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,13 @@ INSERT INTO test_ap_date VALUES
218218
query
219219
SELECT array_position(arr, val) FROM test_ap_date
220220

221+
-- nested array (exercises position_fallback code path)
222+
query spark_answer_only
223+
SELECT array_position(array(array(1, 2), array(3, 4)), array(1, 2))
224+
225+
query spark_answer_only
226+
SELECT array_position(array(array(1, 2), array(3, 4)), array(5, 6))
227+
221228
-- timestamp arrays
222229
statement
223230
CREATE TABLE test_ap_ts(arr array<timestamp>, val timestamp) USING parquet

0 commit comments

Comments
 (0)