Skip to content

Commit bfdc0f6

Browse files
authored
fix: Correct GetArrayItem null handling for dynamic indices and re-enable native execution (#3709)
* Correct GetArrayItem null handling for dynamic indices * query default mode
1 parent f27c4c3 commit bfdc0f6

6 files changed

Lines changed: 69 additions & 41 deletions

File tree

native/spark-expr/src/array_funcs/list_extract.rs

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -277,33 +277,38 @@ fn list_extract<O: OffsetSizeTrait>(
277277

278278
let mut mutable = MutableArrayData::new(vec![&data, &default_data], true, index_array.len());
279279

280-
for (row, (offset_window, index)) in offsets.windows(2).zip(index_array.values()).enumerate() {
280+
for (row, (offset_window, index)) in offsets.windows(2).zip(index_array.iter()).enumerate() {
281281
let start = offset_window[0].as_usize();
282282
let len = offset_window[1].as_usize() - start;
283283

284-
if let Some(i) = adjust_index(*index, len)? {
285-
mutable.extend(0, start + i, start + i + 1);
286-
} else if list_array.is_null(row) {
284+
if list_array.is_null(row) {
287285
mutable.extend_nulls(1);
288-
} else if fail_on_error {
289-
// Throw appropriate error based on whether this is element_at (one_based=true)
290-
// or GetArrayItem (one_based=false)
291-
let error = if one_based {
292-
// element_at function
293-
SparkError::InvalidElementAtIndex {
294-
index_value: *index,
295-
array_size: len as i32,
296-
}
286+
} else if let Some(index) = index {
287+
if let Some(i) = adjust_index(index, len)? {
288+
mutable.extend(0, start + i, start + i + 1);
289+
} else if fail_on_error {
290+
// Throw appropriate error based on whether this is element_at (one_based=true)
291+
// or GetArrayItem (one_based=false)
292+
let error = if one_based {
293+
// element_at function
294+
SparkError::InvalidElementAtIndex {
295+
index_value: index,
296+
array_size: len as i32,
297+
}
298+
} else {
299+
// GetArrayItem (arr[index])
300+
SparkError::InvalidArrayIndex {
301+
index_value: index,
302+
array_size: len as i32,
303+
}
304+
};
305+
return Err(error_wrapper(error));
297306
} else {
298-
// GetArrayItem (arr[index])
299-
SparkError::InvalidArrayIndex {
300-
index_value: *index,
301-
array_size: len as i32,
302-
}
303-
};
304-
return Err(error_wrapper(error));
307+
mutable.extend(1, 0, 1);
308+
}
305309
} else {
306-
mutable.extend(1, 0, 1);
310+
// index is NULL → result is NULL
311+
mutable.extend_nulls(1);
307312
}
308313
}
309314

@@ -382,4 +387,40 @@ mod test {
382387
);
383388
Ok(())
384389
}
390+
391+
#[test]
392+
fn test_list_extract_null_index() -> Result<()> {
393+
// GetArrayItem returns incorrect results with dynamic (column) index containing nulls
394+
let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
395+
Some(vec![Some(10), Some(20), Some(30)]),
396+
Some(vec![Some(10), Some(20), Some(30)]),
397+
Some(vec![Some(10), Some(20), Some(30)]),
398+
Some(vec![Some(1)]),
399+
None,
400+
Some(vec![Some(10), Some(20)]),
401+
]);
402+
let indices = Int32Array::from(vec![Some(0), Some(1), Some(2), Some(0), Some(0), None]);
403+
404+
let null_default = ScalarValue::Int32(None);
405+
let error_wrapper = |error: SparkError| DataFusionError::from(error);
406+
407+
let ColumnarValue::Array(result) = list_extract(
408+
&list,
409+
&indices,
410+
&null_default,
411+
false,
412+
false,
413+
|idx, len| zero_based_index(idx, len, &error_wrapper),
414+
&error_wrapper,
415+
)?
416+
else {
417+
unreachable!()
418+
};
419+
420+
assert_eq!(
421+
&result.to_data(),
422+
&Int32Array::from(vec![Some(10), Some(20), Some(30), Some(1), None, None]).to_data()
423+
);
424+
Ok(())
425+
}
385426
}

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -486,13 +486,6 @@ object CometCreateArray extends CometExpressionSerde[CreateArray] {
486486

487487
object CometGetArrayItem extends CometExpressionSerde[GetArrayItem] {
488488

489-
override def getSupportLevel(expr: GetArrayItem): SupportLevel =
490-
Incompatible(
491-
Some(
492-
"Known correctness issues with index handling" +
493-
" (https://github.com/apache/datafusion-comet/issues/3330," +
494-
" https://github.com/apache/datafusion-comet/issues/3332)"))
495-
496489
override def convert(
497490
expr: GetArrayItem,
498491
inputs: Seq[Attribute],

spark/src/test/resources/sql-tests/expressions/array/get_array_item.sql

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ CREATE TABLE test_get_array_item(arr array<int>, idx int) USING parquet
2323
statement
2424
INSERT INTO test_get_array_item VALUES (array(10, 20, 30), 0), (array(10, 20, 30), 1), (array(10, 20, 30), 2), (array(1), 0), (NULL, 0), (array(10, 20), NULL)
2525

26-
query spark_answer_only
26+
query
2727
SELECT arr[0], arr[1], arr[2] FROM test_get_array_item
2828

29-
query ignore(https://github.com/apache/datafusion-comet/issues/3332)
29+
query
3030
SELECT arr[idx] FROM test_get_array_item
3131

3232
-- literal arguments
33-
query spark_answer_only
33+
query
3434
SELECT array(10, 20, 30)[0], array(10, 20, 30)[2], array()[0]

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.scalatest.Tag
2828

2929
import org.apache.hadoop.fs.Path
3030
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
31-
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Ceil, Floor, FromUnixTime, GetArrayItem, Literal, StructsToJson, Tan, TruncDate, TruncTimestamp}
31+
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Ceil, Floor, FromUnixTime, Literal, StructsToJson, Tan, TruncDate, TruncTimestamp}
3232
import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps
3333
import org.apache.spark.sql.comet.CometProjectExec
3434
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
@@ -2587,8 +2587,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
25872587
withSQLConf(
25882588
SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString(),
25892589
// Prevent the optimizer from collapsing an extract value of a create array
2590-
SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> SimplifyExtractValueOps.ruleName,
2591-
CometConf.getExprAllowIncompatConfigKey(classOf[GetArrayItem]) -> "true") {
2590+
SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> SimplifyExtractValueOps.ruleName) {
25922591
val df = spark.read.parquet(path.toString)
25932592

25942593
val stringArray = df.select(array(col("_8"), col("_8"), lit(null)).alias("arr"))

spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
3131
import org.apache.comet.CometConf
3232

3333
class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper {
34-
import org.apache.spark.sql.catalyst.expressions.GetArrayItem
3534

3635
override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
3736
pos: Position): Unit = {
@@ -42,8 +41,7 @@ class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper
4241
SQLConf.USE_V1_SOURCE_LIST.key -> "parquet",
4342
CometConf.COMET_ENABLED.key -> "true",
4443
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "false",
45-
CometConf.COMET_NATIVE_SCAN_IMPL.key -> scan,
46-
CometConf.getExprAllowIncompatConfigKey(classOf[GetArrayItem]) -> "true") {
44+
CometConf.COMET_NATIVE_SCAN_IMPL.key -> scan) {
4745
testFun
4846
}
4947
})

spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ import org.apache.parquet.example.data.simple.SimpleGroup
3535
import org.apache.parquet.schema.MessageTypeParser
3636
import org.apache.spark.SparkException
3737
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
38-
import org.apache.spark.sql.catalyst.expressions.GetArrayItem
3938
import org.apache.spark.sql.catalyst.util.DateTimeUtils
4039
import org.apache.spark.sql.comet.{CometNativeScanExec, CometScanExec}
4140
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -1506,9 +1505,7 @@ class ParquetReadV1Suite extends ParquetReadSuite with AdaptiveSparkPlanHelper {
15061505
withParquetTable(path.toUri.toString, "complex_types") {
15071506
Seq(CometConf.SCAN_NATIVE_DATAFUSION, CometConf.SCAN_NATIVE_ICEBERG_COMPAT).foreach(
15081507
scanMode => {
1509-
withSQLConf(
1510-
CometConf.COMET_NATIVE_SCAN_IMPL.key -> scanMode,
1511-
CometConf.getExprAllowIncompatConfigKey(classOf[GetArrayItem]) -> "true") {
1508+
withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> scanMode) {
15121509
checkSparkAnswerAndOperator(sql("select * from complex_types"))
15131510
// First level
15141511
checkSparkAnswerAndOperator(sql(

0 commit comments

Comments
 (0)