Skip to content

Commit 6cf8f6a

Browse files
committed
feat: optimize array_position with typed array comparison and address review feedback
- Use typed array downcasting instead of ScalarValue for element comparison, improving performance from 0.4X to 0.7-0.8X of Spark - Add getSupportLevel override marking as Incompatible (NaN equality) - Add NaN edge case tests for float/double arrays - Add CometArrayExpressionBenchmark microbenchmark - Make spark_array_position function private - Update docs to mark array_position as supported
1 parent e4620e9 commit 6cf8f6a

6 files changed

Lines changed: 247 additions & 26 deletions

File tree

docs/spark_expressions_support.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
- [x] array_join
9494
- [x] array_max
9595
- [ ] array_min
96-
- [ ] array_position
96+
- [x] array_position
9797
- [x] array_remove
9898
- [x] array_repeat
9999
- [x] array_union

native/spark-expr/src/array_funcs/array_position.rs

Lines changed: 109 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{Array, ArrayRef, GenericListArray, Int64Array, OffsetSizeTrait};
19-
use arrow::datatypes::DataType;
18+
use arrow::array::{
19+
Array, ArrayRef, AsArray, BooleanArray, GenericListArray, Int64Array, OffsetSizeTrait,
20+
};
21+
use arrow::datatypes::{
22+
DataType, Date32Type, Decimal128Type, Float32Type, Float64Type, Int16Type, Int32Type,
23+
Int64Type, Int8Type, TimestampMicrosecondType,
24+
};
2025
use datafusion::common::{exec_err, DataFusionError, Result as DataFusionResult, ScalarValue};
2126
use datafusion::logical_expr::{
2227
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
@@ -26,7 +31,7 @@ use std::sync::Arc;
2631

2732
/// Spark array_position() function that returns the 1-based position of an element in an array.
2833
/// Returns 0 if the element is not found (Spark behavior differs from DataFusion which returns null).
29-
pub fn spark_array_position(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
34+
fn spark_array_position(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
3035
if args.len() != 2 {
3136
return exec_err!("array_position function takes exactly two arguments");
3237
}
@@ -63,6 +68,105 @@ fn array_position_inner(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError>
6368
}
6469
}
6570

71+
/// Find the 1-based position of `search_val` in a typed primitive array.
72+
/// Returns 0 if not found.
73+
macro_rules! find_position_primitive {
74+
($list_items:expr, $element:expr, $row_index:expr, $arrow_type:ty) => {{
75+
let items = $list_items.as_primitive::<$arrow_type>();
76+
let search = $element.as_primitive::<$arrow_type>();
77+
let search_val = search.value($row_index);
78+
let mut pos: i64 = 0;
79+
for i in 0..items.len() {
80+
if !items.is_null(i) && items.value(i) == search_val {
81+
pos = (i + 1) as i64;
82+
break;
83+
}
84+
}
85+
pos
86+
}};
87+
}
88+
89+
fn find_position_in_row(
90+
list_items: &ArrayRef,
91+
element: &ArrayRef,
92+
row_index: usize,
93+
) -> Result<i64, DataFusionError> {
94+
let pos = match list_items.data_type() {
95+
DataType::Boolean => {
96+
let items = list_items.as_any().downcast_ref::<BooleanArray>().unwrap();
97+
let search = element.as_any().downcast_ref::<BooleanArray>().unwrap();
98+
let search_val = search.value(row_index);
99+
let mut pos: i64 = 0;
100+
for i in 0..items.len() {
101+
if !items.is_null(i) && items.value(i) == search_val {
102+
pos = (i + 1) as i64;
103+
break;
104+
}
105+
}
106+
pos
107+
}
108+
DataType::Int8 => find_position_primitive!(list_items, element, row_index, Int8Type),
109+
DataType::Int16 => find_position_primitive!(list_items, element, row_index, Int16Type),
110+
DataType::Int32 => find_position_primitive!(list_items, element, row_index, Int32Type),
111+
DataType::Int64 => find_position_primitive!(list_items, element, row_index, Int64Type),
112+
DataType::Float32 => {
113+
find_position_primitive!(list_items, element, row_index, Float32Type)
114+
}
115+
DataType::Float64 => {
116+
find_position_primitive!(list_items, element, row_index, Float64Type)
117+
}
118+
DataType::Decimal128(_, _) => {
119+
find_position_primitive!(list_items, element, row_index, Decimal128Type)
120+
}
121+
DataType::Date32 => {
122+
find_position_primitive!(list_items, element, row_index, Date32Type)
123+
}
124+
DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, _) => {
125+
find_position_primitive!(list_items, element, row_index, TimestampMicrosecondType)
126+
}
127+
DataType::Utf8 => {
128+
let items = list_items.as_string::<i32>();
129+
let search = element.as_string::<i32>();
130+
let search_val = search.value(row_index);
131+
let mut pos: i64 = 0;
132+
for i in 0..items.len() {
133+
if !items.is_null(i) && items.value(i) == search_val {
134+
pos = (i + 1) as i64;
135+
break;
136+
}
137+
}
138+
pos
139+
}
140+
DataType::LargeUtf8 => {
141+
let items = list_items.as_string::<i64>();
142+
let search = element.as_string::<i64>();
143+
let search_val = search.value(row_index);
144+
let mut pos: i64 = 0;
145+
for i in 0..items.len() {
146+
if !items.is_null(i) && items.value(i) == search_val {
147+
pos = (i + 1) as i64;
148+
break;
149+
}
150+
}
151+
pos
152+
}
153+
// Fallback to ScalarValue for complex types (nested arrays, etc.)
154+
_ => {
155+
let element_scalar = ScalarValue::try_from_array(element, row_index)?;
156+
let mut pos: i64 = 0;
157+
for i in 0..list_items.len() {
158+
let item_scalar = ScalarValue::try_from_array(list_items, i)?;
159+
if !item_scalar.is_null() && element_scalar == item_scalar {
160+
pos = (i + 1) as i64;
161+
break;
162+
}
163+
}
164+
pos
165+
}
166+
};
167+
Ok(pos)
168+
}
169+
66170
fn generic_array_position<O: OffsetSizeTrait>(
67171
array: &ArrayRef,
68172
element: &ArrayRef,
@@ -75,30 +179,11 @@ fn generic_array_position<O: OffsetSizeTrait>(
75179
let mut data = Vec::with_capacity(list_array.len());
76180

77181
for row_index in 0..list_array.len() {
78-
if list_array.is_null(row_index) {
79-
// Null array returns null position (same as Spark)
80-
data.push(None);
81-
} else if element.is_null(row_index) {
82-
// Searching for null element returns null in Spark
182+
if list_array.is_null(row_index) || element.is_null(row_index) {
83183
data.push(None);
84184
} else {
85185
let list_array_row = list_array.value(row_index);
86-
87-
// Get the search element as a scalar
88-
let element_scalar = ScalarValue::try_from_array(element, row_index)?;
89-
90-
// Compare element to each item in the list
91-
let mut position: i64 = 0;
92-
for i in 0..list_array_row.len() {
93-
let list_item_scalar = ScalarValue::try_from_array(&list_array_row, i)?;
94-
95-
// null != anything in Spark array_position
96-
if !list_item_scalar.is_null() && element_scalar == list_item_scalar {
97-
position = (i + 1) as i64; // 1-indexed
98-
break;
99-
}
100-
}
101-
186+
let position = find_position_in_row(&list_array_row, element, row_index)?;
102187
data.push(Some(position));
103188
}
104189
}

native/spark-expr/src/array_funcs/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ mod list_extract;
2222
mod size;
2323

2424
pub use array_insert::ArrayInsert;
25-
pub use array_position::{spark_array_position, SparkArrayPositionFunc};
25+
pub use array_position::SparkArrayPositionFunc;
2626
pub use get_array_struct_fields::GetArrayStructFields;
2727
pub use list_extract::ListExtract;
2828
pub use size::{spark_size, SparkSizeFunc};

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,8 @@ object CometSize extends CometExpressionSerde[Size] {
664664

665665
object CometArrayPosition extends CometExpressionSerde[ArrayPosition] with ArraysBase {
666666

667+
override def getSupportLevel(expr: ArrayPosition): SupportLevel = Incompatible(None)
668+
667669
override def convert(
668670
expr: ArrayPosition,
669671
inputs: Seq[Attribute],

spark/src/test/resources/sql-tests/expressions/array/array_position.sql

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
-- specific language governing permissions and limitations
1616
-- under the License.
1717

18+
-- Config: spark.comet.expression.ArrayPosition.allowIncompatible=true
1819
-- ConfigMatrix: parquet.enable.dictionary=false,true
1920

2021
statement
@@ -163,6 +164,33 @@ INSERT INTO test_ap_double VALUES
163164
query
164165
SELECT array_position(arr, val) FROM test_ap_double
165166

167+
-- NaN handling for float arrays (Spark treats NaN == NaN)
168+
query spark_answer_only
169+
SELECT array_position(array(cast('NaN' as float), cast(1.0 as float)), cast('NaN' as float))
170+
171+
query spark_answer_only
172+
SELECT array_position(array(cast(1.0 as float), cast('NaN' as float), cast(2.0 as float)), cast('NaN' as float))
173+
174+
-- NaN handling for double arrays (Spark treats NaN == NaN)
175+
query spark_answer_only
176+
SELECT array_position(array(cast('NaN' as double), 1.0), cast('NaN' as double))
177+
178+
query spark_answer_only
179+
SELECT array_position(array(1.0, cast('NaN' as double), 2.0), cast('NaN' as double))
180+
181+
-- NaN handling with column data
182+
statement
183+
CREATE TABLE test_ap_nan(arr array<float>, val float) USING parquet
184+
185+
statement
186+
INSERT INTO test_ap_nan VALUES
187+
(array(cast('NaN' as float), cast(1.0 as float)), cast('NaN' as float)),
188+
(array(cast(1.0 as float), cast('NaN' as float), cast(2.0 as float)), cast('NaN' as float)),
189+
(array(cast(1.0 as float), cast(2.0 as float)), cast('NaN' as float))
190+
191+
query ignore(NaN equality: IEEE 754 says NaN != NaN but Spark treats NaN == NaN)
192+
SELECT array_position(arr, val) FROM test_ap_nan
193+
166194
-- decimal arrays
167195
statement
168196
CREATE TABLE test_ap_decimal(arr array<decimal(10,2)>, val decimal(10,2)) USING parquet
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql.benchmark
21+
22+
// spotless:off
23+
/**
24+
* Benchmark to measure performance of Comet array expressions. To run this benchmark:
25+
* {{{
26+
* SPARK_GENERATE_BENCHMARK_FILES=1 make benchmark-org.apache.spark.sql.benchmark.CometArrayExpressionBenchmark
27+
* }}}
28+
* Results will be written to "spark/benchmarks/CometArrayExpressionBenchmark-**results.txt".
29+
*/
30+
// spotless:on
31+
object CometArrayExpressionBenchmark extends CometBenchmarkBase {
32+
33+
def arrayPositionBenchmark(values: Int): Unit = {
34+
withTempPath { dir =>
35+
withTempTable("parquetV1Table") {
36+
// Create a table with int arrays of size 10 and a search value
37+
prepareTable(
38+
dir,
39+
spark.sql(s"""SELECT
40+
| array(
41+
| cast(value % 100 as int),
42+
| cast((value + 1) % 100 as int),
43+
| cast((value + 2) % 100 as int),
44+
| cast((value + 3) % 100 as int),
45+
| cast((value + 4) % 100 as int),
46+
| cast((value + 5) % 100 as int),
47+
| cast((value + 6) % 100 as int),
48+
| cast((value + 7) % 100 as int),
49+
| cast((value + 8) % 100 as int),
50+
| cast((value + 9) % 100 as int)
51+
| ) as int_arr,
52+
| cast((value + 5) % 100 as int) as search_val
53+
|FROM $tbl""".stripMargin))
54+
55+
val extraConfigs =
56+
Map("spark.comet.expression.ArrayPosition.allowIncompatible" -> "true")
57+
58+
runExpressionBenchmark(
59+
"array_position - int array",
60+
values,
61+
"SELECT array_position(int_arr, search_val) FROM parquetV1Table",
62+
extraConfigs)
63+
}
64+
}
65+
66+
withTempPath { dir =>
67+
withTempTable("parquetV1Table") {
68+
// Create a table with string arrays of size 10 and a search value
69+
prepareTable(
70+
dir,
71+
spark.sql(s"""SELECT
72+
| array(
73+
| cast(value % 100 as string),
74+
| cast((value + 1) % 100 as string),
75+
| cast((value + 2) % 100 as string),
76+
| cast((value + 3) % 100 as string),
77+
| cast((value + 4) % 100 as string),
78+
| cast((value + 5) % 100 as string),
79+
| cast((value + 6) % 100 as string),
80+
| cast((value + 7) % 100 as string),
81+
| cast((value + 8) % 100 as string),
82+
| cast((value + 9) % 100 as string)
83+
| ) as str_arr,
84+
| cast((value + 5) % 100 as string) as search_val
85+
|FROM $tbl""".stripMargin))
86+
87+
val extraConfigs =
88+
Map("spark.comet.expression.ArrayPosition.allowIncompatible" -> "true")
89+
90+
runExpressionBenchmark(
91+
"array_position - string array",
92+
values,
93+
"SELECT array_position(str_arr, search_val) FROM parquetV1Table",
94+
extraConfigs)
95+
}
96+
}
97+
}
98+
99+
override def runCometBenchmark(mainArgs: Array[String]): Unit = {
100+
val values = 1024 * 1024
101+
102+
runBenchmarkWithTable("ArrayPosition", values) { v =>
103+
arrayPositionBenchmark(v)
104+
}
105+
}
106+
}

0 commit comments

Comments
 (0)