Skip to content

Commit b7ddad1

Browse files
committed
perf: use flat values buffer and offsets for array_position
Avoid per-row subarray allocation from list_array.value(row_index). Instead, downcast the flat values buffer once and iterate using offset ranges directly. Improves from 0.7-0.8X to 0.9X of Spark.
1 parent 24318d3 commit b7ddad1

1 file changed

Lines changed: 195 additions & 112 deletions

File tree

native/spark-expr/src/array_funcs/array_position.rs

Lines changed: 195 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@ use arrow::array::{
1919
Array, ArrayRef, AsArray, BooleanArray, GenericListArray, Int64Array, OffsetSizeTrait,
2020
};
2121
use arrow::datatypes::{
22-
DataType, Date32Type, Decimal128Type, Float32Type, Float64Type, Int16Type, Int32Type,
23-
Int64Type, Int8Type, TimestampMicrosecondType,
22+
ArrowPrimitiveType, DataType, Date32Type, Decimal128Type, Float32Type, Float64Type, Int16Type,
23+
Int32Type, Int64Type, Int8Type, TimestampMicrosecondType,
2424
};
2525
use datafusion::common::{exec_err, DataFusionError, Result as DataFusionResult, ScalarValue};
2626
use datafusion::logical_expr::{
2727
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
2828
};
29+
use num::Float;
2930
use std::any::Any;
3031
use std::sync::Arc;
3132

@@ -36,7 +37,6 @@ fn spark_array_position(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFus
3637
return exec_err!("array_position function takes exactly two arguments");
3738
}
3839

39-
// Convert all arguments to arrays for consistent processing
4040
let len = args
4141
.iter()
4242
.fold(Option::<usize>::None, |acc, arg| match arg {
@@ -68,144 +68,227 @@ fn array_position_inner(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError>
6868
}
6969
}
7070

71-
/// Find the 1-based position of `search_val` in a typed primitive array.
72-
/// Returns 0 if not found.
73-
macro_rules! find_position_primitive {
74-
($list_items:expr, $element:expr, $row_index:expr, $arrow_type:ty) => {{
75-
let items = $list_items.as_primitive::<$arrow_type>();
76-
let search = $element.as_primitive::<$arrow_type>();
77-
let search_val = search.value($row_index);
71+
/// Searches for an element in a list array using the flat values buffer and offsets directly,
72+
/// avoiding per-row subarray allocation. Dispatches to typed fast paths by element data type.
73+
fn generic_array_position<O: OffsetSizeTrait>(
74+
array: &ArrayRef,
75+
element: &ArrayRef,
76+
) -> Result<ArrayRef, DataFusionError> {
77+
let list_array = array
78+
.as_any()
79+
.downcast_ref::<GenericListArray<O>>()
80+
.unwrap();
81+
82+
let values = list_array.values();
83+
let offsets = list_array.offsets();
84+
let elem_type = values.data_type().clone();
85+
86+
match &elem_type {
87+
DataType::Boolean => {
88+
position_boolean::<O>(list_array, offsets, values, element)
89+
}
90+
DataType::Int8 => position_primitive::<O, Int8Type>(list_array, offsets, values, element),
91+
DataType::Int16 => position_primitive::<O, Int16Type>(list_array, offsets, values, element),
92+
DataType::Int32 => position_primitive::<O, Int32Type>(list_array, offsets, values, element),
93+
DataType::Int64 => position_primitive::<O, Int64Type>(list_array, offsets, values, element),
94+
DataType::Float32 => {
95+
position_float::<O, Float32Type>(list_array, offsets, values, element)
96+
}
97+
DataType::Float64 => {
98+
position_float::<O, Float64Type>(list_array, offsets, values, element)
99+
}
100+
DataType::Decimal128(_, _) => {
101+
position_primitive::<O, Decimal128Type>(list_array, offsets, values, element)
102+
}
103+
DataType::Date32 => {
104+
position_primitive::<O, Date32Type>(list_array, offsets, values, element)
105+
}
106+
DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, _) => {
107+
position_primitive::<O, TimestampMicrosecondType>(
108+
list_array, offsets, values, element,
109+
)
110+
}
111+
DataType::Utf8 => position_string::<O, i32>(list_array, offsets, values, element),
112+
DataType::LargeUtf8 => position_string::<O, i64>(list_array, offsets, values, element),
113+
// Fallback to ScalarValue for complex types (nested arrays, etc.)
114+
_ => position_fallback::<O>(list_array, offsets, values, element),
115+
}
116+
}
117+
118+
/// Fast path for primitive types: downcast once, iterate using offsets into the flat buffer.
119+
fn position_primitive<O: OffsetSizeTrait, T: ArrowPrimitiveType>(
120+
list_array: &GenericListArray<O>,
121+
offsets: &arrow::buffer::OffsetBuffer<O>,
122+
values: &ArrayRef,
123+
element: &ArrayRef,
124+
) -> Result<ArrayRef, DataFusionError>
125+
where
126+
T::Native: PartialEq,
127+
{
128+
let values_typed = values.as_primitive::<T>();
129+
let element_typed = element.as_primitive::<T>();
130+
let num_rows = list_array.len();
131+
let mut result = Vec::with_capacity(num_rows);
132+
133+
for (row_index, w) in offsets.windows(2).enumerate() {
134+
if list_array.is_null(row_index) || element.is_null(row_index) {
135+
result.push(None);
136+
continue;
137+
}
138+
let start = w[0].as_usize();
139+
let end = w[1].as_usize();
140+
let search_val = element_typed.value(row_index);
78141
let mut pos: i64 = 0;
79-
for i in 0..items.len() {
80-
if !items.is_null(i) && items.value(i) == search_val {
81-
pos = (i + 1) as i64;
142+
for i in start..end {
143+
if !values_typed.is_null(i) && values_typed.value(i) == search_val {
144+
pos = (i - start + 1) as i64;
82145
break;
83146
}
84147
}
85-
pos
86-
}};
148+
result.push(Some(pos));
149+
}
150+
151+
Ok(Arc::new(Int64Array::from(result)))
87152
}
88153

89-
/// Float-aware comparison that treats NaN == NaN (matching Spark's ordering.equiv() semantics).
90-
macro_rules! find_position_float {
91-
($list_items:expr, $element:expr, $row_index:expr, $arrow_type:ty) => {{
92-
let items = $list_items.as_primitive::<$arrow_type>();
93-
let search = $element.as_primitive::<$arrow_type>();
94-
let search_val = search.value($row_index);
154+
/// Float path: same as primitive but treats NaN == NaN (Spark's ordering.equiv() semantics).
155+
fn position_float<O: OffsetSizeTrait, T: ArrowPrimitiveType>(
156+
list_array: &GenericListArray<O>,
157+
offsets: &arrow::buffer::OffsetBuffer<O>,
158+
values: &ArrayRef,
159+
element: &ArrayRef,
160+
) -> Result<ArrayRef, DataFusionError>
161+
where
162+
T::Native: PartialEq + num::Float,
163+
{
164+
let values_typed = values.as_primitive::<T>();
165+
let element_typed = element.as_primitive::<T>();
166+
let num_rows = list_array.len();
167+
let mut result = Vec::with_capacity(num_rows);
168+
169+
for (row_index, w) in offsets.windows(2).enumerate() {
170+
if list_array.is_null(row_index) || element.is_null(row_index) {
171+
result.push(None);
172+
continue;
173+
}
174+
let start = w[0].as_usize();
175+
let end = w[1].as_usize();
176+
let search_val = element_typed.value(row_index);
95177
let search_is_nan = search_val.is_nan();
96178
let mut pos: i64 = 0;
97-
for i in 0..items.len() {
98-
if !items.is_null(i) {
99-
let item_val = items.value(i);
100-
if (search_is_nan && item_val.is_nan()) || item_val == search_val {
101-
pos = (i + 1) as i64;
179+
for i in start..end {
180+
if !values_typed.is_null(i) {
181+
let v = values_typed.value(i);
182+
if (search_is_nan && v.is_nan()) || v == search_val {
183+
pos = (i - start + 1) as i64;
102184
break;
103185
}
104186
}
105187
}
106-
pos
107-
}};
188+
result.push(Some(pos));
189+
}
190+
191+
Ok(Arc::new(Int64Array::from(result)))
108192
}
109193

110-
fn find_position_in_row(
111-
list_items: &ArrayRef,
194+
/// Boolean path.
195+
fn position_boolean<O: OffsetSizeTrait>(
196+
list_array: &GenericListArray<O>,
197+
offsets: &arrow::buffer::OffsetBuffer<O>,
198+
values: &ArrayRef,
112199
element: &ArrayRef,
113-
row_index: usize,
114-
) -> Result<i64, DataFusionError> {
115-
let pos = match list_items.data_type() {
116-
DataType::Boolean => {
117-
let items = list_items.as_any().downcast_ref::<BooleanArray>().unwrap();
118-
let search = element.as_any().downcast_ref::<BooleanArray>().unwrap();
119-
let search_val = search.value(row_index);
120-
let mut pos: i64 = 0;
121-
for i in 0..items.len() {
122-
if !items.is_null(i) && items.value(i) == search_val {
123-
pos = (i + 1) as i64;
124-
break;
125-
}
126-
}
127-
pos
128-
}
129-
DataType::Int8 => find_position_primitive!(list_items, element, row_index, Int8Type),
130-
DataType::Int16 => find_position_primitive!(list_items, element, row_index, Int16Type),
131-
DataType::Int32 => find_position_primitive!(list_items, element, row_index, Int32Type),
132-
DataType::Int64 => find_position_primitive!(list_items, element, row_index, Int64Type),
133-
DataType::Float32 => find_position_float!(list_items, element, row_index, Float32Type),
134-
DataType::Float64 => find_position_float!(list_items, element, row_index, Float64Type),
135-
DataType::Decimal128(_, _) => {
136-
find_position_primitive!(list_items, element, row_index, Decimal128Type)
137-
}
138-
DataType::Date32 => {
139-
find_position_primitive!(list_items, element, row_index, Date32Type)
200+
) -> Result<ArrayRef, DataFusionError> {
201+
let values_typed = values.as_any().downcast_ref::<BooleanArray>().unwrap();
202+
let element_typed = element.as_any().downcast_ref::<BooleanArray>().unwrap();
203+
let num_rows = list_array.len();
204+
let mut result = Vec::with_capacity(num_rows);
205+
206+
for (row_index, w) in offsets.windows(2).enumerate() {
207+
if list_array.is_null(row_index) || element.is_null(row_index) {
208+
result.push(None);
209+
continue;
140210
}
141-
DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, _) => {
142-
find_position_primitive!(list_items, element, row_index, TimestampMicrosecondType)
143-
}
144-
DataType::Utf8 => {
145-
let items = list_items.as_string::<i32>();
146-
let search = element.as_string::<i32>();
147-
let search_val = search.value(row_index);
148-
let mut pos: i64 = 0;
149-
for i in 0..items.len() {
150-
if !items.is_null(i) && items.value(i) == search_val {
151-
pos = (i + 1) as i64;
152-
break;
153-
}
154-
}
155-
pos
156-
}
157-
DataType::LargeUtf8 => {
158-
let items = list_items.as_string::<i64>();
159-
let search = element.as_string::<i64>();
160-
let search_val = search.value(row_index);
161-
let mut pos: i64 = 0;
162-
for i in 0..items.len() {
163-
if !items.is_null(i) && items.value(i) == search_val {
164-
pos = (i + 1) as i64;
165-
break;
166-
}
211+
let start = w[0].as_usize();
212+
let end = w[1].as_usize();
213+
let search_val = element_typed.value(row_index);
214+
let mut pos: i64 = 0;
215+
for i in start..end {
216+
if !values_typed.is_null(i) && values_typed.value(i) == search_val {
217+
pos = (i - start + 1) as i64;
218+
break;
167219
}
168-
pos
169220
}
170-
// Fallback to ScalarValue for complex types (nested arrays, etc.)
171-
_ => {
172-
let element_scalar = ScalarValue::try_from_array(element, row_index)?;
173-
let mut pos: i64 = 0;
174-
for i in 0..list_items.len() {
175-
let item_scalar = ScalarValue::try_from_array(list_items, i)?;
176-
if !item_scalar.is_null() && element_scalar == item_scalar {
177-
pos = (i + 1) as i64;
178-
break;
179-
}
221+
result.push(Some(pos));
222+
}
223+
224+
Ok(Arc::new(Int64Array::from(result)))
225+
}
226+
227+
/// String path: downcast once, iterate using offsets into the flat string buffer.
228+
fn position_string<O: OffsetSizeTrait, S: OffsetSizeTrait>(
229+
list_array: &GenericListArray<O>,
230+
offsets: &arrow::buffer::OffsetBuffer<O>,
231+
values: &ArrayRef,
232+
element: &ArrayRef,
233+
) -> Result<ArrayRef, DataFusionError> {
234+
let values_typed = values.as_string::<S>();
235+
let element_typed = element.as_string::<S>();
236+
let num_rows = list_array.len();
237+
let mut result = Vec::with_capacity(num_rows);
238+
239+
for (row_index, w) in offsets.windows(2).enumerate() {
240+
if list_array.is_null(row_index) || element.is_null(row_index) {
241+
result.push(None);
242+
continue;
243+
}
244+
let start = w[0].as_usize();
245+
let end = w[1].as_usize();
246+
let search_val = element_typed.value(row_index);
247+
let mut pos: i64 = 0;
248+
for i in start..end {
249+
if !values_typed.is_null(i) && values_typed.value(i) == search_val {
250+
pos = (i - start + 1) as i64;
251+
break;
180252
}
181-
pos
182253
}
183-
};
184-
Ok(pos)
254+
result.push(Some(pos));
255+
}
256+
257+
Ok(Arc::new(Int64Array::from(result)))
185258
}
186259

187-
fn generic_array_position<O: OffsetSizeTrait>(
188-
array: &ArrayRef,
260+
/// Fallback for complex types (nested arrays, structs, etc.) using ScalarValue comparison.
261+
fn position_fallback<O: OffsetSizeTrait>(
262+
list_array: &GenericListArray<O>,
263+
offsets: &arrow::buffer::OffsetBuffer<O>,
264+
values: &ArrayRef,
189265
element: &ArrayRef,
190266
) -> Result<ArrayRef, DataFusionError> {
191-
let list_array = array
192-
.as_any()
193-
.downcast_ref::<GenericListArray<O>>()
194-
.unwrap();
195-
196-
let mut data = Vec::with_capacity(list_array.len());
267+
let num_rows = list_array.len();
268+
let mut result = Vec::with_capacity(num_rows);
197269

198-
for row_index in 0..list_array.len() {
270+
for (row_index, w) in offsets.windows(2).enumerate() {
199271
if list_array.is_null(row_index) || element.is_null(row_index) {
200-
data.push(None);
201-
} else {
202-
let list_array_row = list_array.value(row_index);
203-
let position = find_position_in_row(&list_array_row, element, row_index)?;
204-
data.push(Some(position));
272+
result.push(None);
273+
continue;
274+
}
275+
let start = w[0].as_usize();
276+
let end = w[1].as_usize();
277+
let search_scalar = ScalarValue::try_from_array(element, row_index)?;
278+
let mut pos: i64 = 0;
279+
for i in start..end {
280+
if !values.is_null(i) {
281+
let item_scalar = ScalarValue::try_from_array(values, i)?;
282+
if search_scalar == item_scalar {
283+
pos = (i - start + 1) as i64;
284+
break;
285+
}
286+
}
205287
}
288+
result.push(Some(pos));
206289
}
207290

208-
Ok(Arc::new(Int64Array::from(data)))
291+
Ok(Arc::new(Int64Array::from(result)))
209292
}
210293

211294
#[derive(Debug, Hash, Eq, PartialEq)]

0 commit comments

Comments
 (0)