Skip to content

Commit a4bf90a

Browse files
authored
Fix: array contains null handling (#3372)
1 parent 7b61b30 commit a4bf90a

3 files changed

Lines changed: 90 additions & 2 deletions

File tree

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ package org.apache.comet.serde
2222
import scala.annotation.tailrec
2323

2424
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size}
25+
import org.apache.spark.sql.catalyst.util.GenericArrayData
2526
import org.apache.spark.sql.internal.SQLConf
2627
import org.apache.spark.sql.types._
2728

@@ -134,7 +135,34 @@ object CometArrayContains extends CometExpressionSerde[ArrayContains] {
134135

135136
val arrayContainsScalarExpr =
136137
scalarFunctionExprToProto("array_has", arrayExprProto, keyExprProto)
137-
optExprWithInfo(arrayContainsScalarExpr, expr, expr.children: _*)
138+
139+
// Handle NULL array input - return NULL if array is NULL (matching Spark's behavior)
140+
val isNotNullExpr = createUnaryExpr(
141+
expr,
142+
expr.children.head,
143+
inputs,
144+
binding,
145+
(builder, unaryExpr) => builder.setIsNotNull(unaryExpr))
146+
147+
val nullLiteralProto = exprToProto(Literal(null, BooleanType), Seq.empty)
148+
149+
if (arrayContainsScalarExpr.isDefined && isNotNullExpr.isDefined &&
150+
nullLiteralProto.isDefined) {
151+
val caseWhenExpr = ExprOuterClass.CaseWhen
152+
.newBuilder()
153+
.addWhen(isNotNullExpr.get)
154+
.addThen(arrayContainsScalarExpr.get)
155+
.setElseExpr(nullLiteralProto.get)
156+
.build()
157+
Some(
158+
ExprOuterClass.Expr
159+
.newBuilder()
160+
.setCaseWhen(caseWhenExpr)
161+
.build())
162+
} else {
163+
withInfo(expr, expr.children: _*)
164+
None
165+
}
138166
}
139167
}
140168

@@ -395,6 +423,15 @@ object CometCreateArray extends CometExpressionSerde[CreateArray] {
395423
inputs: Seq[Attribute],
396424
binding: Boolean): Option[ExprOuterClass.Expr] = {
397425
val children = expr.children
426+
427+
// Handle empty array: return literal directly to avoid DataFusion coerce_types bug
428+
// when make_array is called with 0 arguments (issue #3338)
429+
if (children.isEmpty) {
430+
val emptyArrayLiteral =
431+
Literal.create(new GenericArrayData(Array.empty[Any]), expr.dataType)
432+
return exprToProtoInternal(emptyArrayLiteral, inputs, binding)
433+
}
434+
398435
val childExprs = children.map(exprToProtoInternal(_, inputs, binding))
399436

400437
if (childExprs.forall(_.isDefined)) {

spark/src/test/resources/sql-tests/expressions/array/array_contains.sql

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,24 @@ query spark_answer_only
3535
SELECT array_contains(array(1, 2, 3), val) FROM test_array_contains
3636

3737
-- literal + literal
38-
query ignore(https://github.com/apache/datafusion-comet/issues/3345)
38+
-- Note: array_contains(array(), 1) still has a bug (issue #3346) so we use spark_answer_only
39+
-- The NULL array case (cast(NULL as array<int>)) was fixed in issue #3345
40+
query spark_answer_only
3941
SELECT array_contains(array(1, 2, 3), 2), array_contains(array(1, 2, 3), 4), array_contains(array(), 1), array_contains(cast(NULL as array<int>), 1)
42+
43+
-- Additional NULL array tests (issue #3345 fix verification)
44+
-- NULL array with integer value
45+
query
46+
SELECT array_contains(cast(NULL as array<int>), 1)
47+
48+
-- NULL array with string value
49+
query
50+
SELECT array_contains(cast(NULL as array<string>), 'test')
51+
52+
-- NULL array with NULL value
53+
query
54+
SELECT array_contains(cast(NULL as array<int>), cast(NULL as int))
55+
56+
-- NULL array with column value
57+
query
58+
SELECT array_contains(cast(NULL as array<int>), val) FROM test_array_contains

spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,38 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
325325
}
326326
}
327327

328+
test("array_contains - NULL array returns NULL") {
329+
// Test that array_contains returns NULL when the array argument is NULL
330+
// This matches Spark's SQL three-valued logic behavior
331+
withTempDir { dir =>
332+
withTempView("t1") {
333+
val path = new Path(dir.toURI.toString, "test.parquet")
334+
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 100)
335+
spark.read.parquet(path.toString).createOrReplaceTempView("t1")
336+
337+
// Test NULL array with non-null value
338+
checkSparkAnswerAndOperator(
339+
sql("SELECT array_contains(cast(null as array<int>), 1) FROM t1"))
340+
checkSparkAnswerAndOperator(
341+
sql("SELECT array_contains(cast(null as array<string>), 'test') FROM t1"))
342+
checkSparkAnswerAndOperator(
343+
sql("SELECT array_contains(cast(null as array<double>), 1.5) FROM t1"))
344+
345+
// Test NULL array with NULL value
346+
checkSparkAnswerAndOperator(
347+
sql("SELECT array_contains(cast(null as array<int>), cast(null as int)) FROM t1"))
348+
349+
// Test NULL array with column value
350+
checkSparkAnswerAndOperator(
351+
sql("SELECT array_contains(cast(null as array<int>), _2) FROM t1"))
352+
353+
// Test non-null array with values (to ensure fix doesn't break normal operation)
354+
checkSparkAnswerAndOperator(sql("SELECT array_contains(array(1, 2, 3), 2) FROM t1"))
355+
checkSparkAnswerAndOperator(sql("SELECT array_contains(array(1, 2, 3), 5) FROM t1"))
356+
}
357+
}
358+
}
359+
328360
test("array_contains - test all types (convert from Parquet)") {
329361
withTempDir { dir =>
330362
val path = new Path(dir.toURI.toString, "test.parquet")

0 commit comments

Comments
 (0)