Skip to content

Commit 24318d3

Browse files
committed
fix: handle NaN equality in array_position to match Spark semantics
Treat NaN == NaN in float/double comparisons, matching Spark's ordering.equiv() behavior. This makes array_position Compatible rather than Incompatible.
1 parent fad07b9 commit 24318d3

4 files changed

Lines changed: 27 additions & 22 deletions

File tree

native/spark-expr/src/array_funcs/array_position.rs

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,27 @@ macro_rules! find_position_primitive {
8686
}};
8787
}
8888

89+
/// Float-aware comparison that treats NaN == NaN (matching Spark's ordering.equiv() semantics).
90+
macro_rules! find_position_float {
91+
($list_items:expr, $element:expr, $row_index:expr, $arrow_type:ty) => {{
92+
let items = $list_items.as_primitive::<$arrow_type>();
93+
let search = $element.as_primitive::<$arrow_type>();
94+
let search_val = search.value($row_index);
95+
let search_is_nan = search_val.is_nan();
96+
let mut pos: i64 = 0;
97+
for i in 0..items.len() {
98+
if !items.is_null(i) {
99+
let item_val = items.value(i);
100+
if (search_is_nan && item_val.is_nan()) || item_val == search_val {
101+
pos = (i + 1) as i64;
102+
break;
103+
}
104+
}
105+
}
106+
pos
107+
}};
108+
}
109+
89110
fn find_position_in_row(
90111
list_items: &ArrayRef,
91112
element: &ArrayRef,
@@ -109,12 +130,8 @@ fn find_position_in_row(
109130
DataType::Int16 => find_position_primitive!(list_items, element, row_index, Int16Type),
110131
DataType::Int32 => find_position_primitive!(list_items, element, row_index, Int32Type),
111132
DataType::Int64 => find_position_primitive!(list_items, element, row_index, Int64Type),
112-
DataType::Float32 => {
113-
find_position_primitive!(list_items, element, row_index, Float32Type)
114-
}
115-
DataType::Float64 => {
116-
find_position_primitive!(list_items, element, row_index, Float64Type)
117-
}
133+
DataType::Float32 => find_position_float!(list_items, element, row_index, Float32Type),
134+
DataType::Float64 => find_position_float!(list_items, element, row_index, Float64Type),
118135
DataType::Decimal128(_, _) => {
119136
find_position_primitive!(list_items, element, row_index, Decimal128Type)
120137
}

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -664,10 +664,7 @@ object CometSize extends CometExpressionSerde[Size] {
664664

665665
object CometArrayPosition extends CometExpressionSerde[ArrayPosition] with ArraysBase {
666666

667-
override def getSupportLevel(expr: ArrayPosition): SupportLevel =
668-
Incompatible(Some(
669-
"element comparison uses IEEE 754 equality where NaN != NaN, " +
670-
"but Spark treats NaN as equal to NaN"))
667+
override def getSupportLevel(expr: ArrayPosition): SupportLevel = Compatible()
671668

672669
override def convert(
673670
expr: ArrayPosition,

spark/src/test/resources/sql-tests/expressions/array/array_position.sql

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
-- specific language governing permissions and limitations
1616
-- under the License.
1717

18-
-- Config: spark.comet.expression.ArrayPosition.allowIncompatible=true
1918
-- ConfigMatrix: parquet.enable.dictionary=false,true
2019

2120
statement
@@ -188,7 +187,7 @@ INSERT INTO test_ap_nan VALUES
188187
(array(cast(1.0 as float), cast('NaN' as float), cast(2.0 as float)), cast('NaN' as float)),
189188
(array(cast(1.0 as float), cast(2.0 as float)), cast('NaN' as float))
190189

191-
query ignore(NaN equality: IEEE 754 says NaN != NaN but Spark treats NaN == NaN)
190+
query
192191
SELECT array_position(arr, val) FROM test_ap_nan
193192

194193
-- decimal arrays

spark/src/test/scala/org/apache/spark/sql/benchmark/CometArrayExpressionBenchmark.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,10 @@ object CometArrayExpressionBenchmark extends CometBenchmarkBase {
5252
| cast((value + 5) % 100 as int) as search_val
5353
|FROM $tbl""".stripMargin))
5454

55-
val extraConfigs =
56-
Map("spark.comet.expression.ArrayPosition.allowIncompatible" -> "true")
57-
5855
runExpressionBenchmark(
5956
"array_position - int array",
6057
values,
61-
"SELECT array_position(int_arr, search_val) FROM parquetV1Table",
62-
extraConfigs)
58+
"SELECT array_position(int_arr, search_val) FROM parquetV1Table")
6359
}
6460
}
6561

@@ -84,14 +80,10 @@ object CometArrayExpressionBenchmark extends CometBenchmarkBase {
8480
| cast((value + 5) % 100 as string) as search_val
8581
|FROM $tbl""".stripMargin))
8682

87-
val extraConfigs =
88-
Map("spark.comet.expression.ArrayPosition.allowIncompatible" -> "true")
89-
9083
runExpressionBenchmark(
9184
"array_position - string array",
9285
values,
93-
"SELECT array_position(str_arr, search_val) FROM parquetV1Table",
94-
extraConfigs)
86+
"SELECT array_position(str_arr, search_val) FROM parquetV1Table")
9587
}
9688
}
9789
}

0 commit comments

Comments
 (0)