Skip to content

Commit 06932fa

Browse files
committed
feat: implement array_exists with lambda support via JVM UDF bridge
Adds support for Spark's `exists(array, x -> predicate(x))` higher-order function using the CometUDF framework from #4170. This is the first lambda-based expression accelerated by Comet. Experimental: scope is intentionally narrow (single-argument lambdas referencing only the array element, primitive + string element types).
1 parent ce01339 commit 06932fa

5 files changed

Lines changed: 325 additions & 2 deletions

File tree

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.udf
21+
22+
import java.nio.charset.StandardCharsets
23+
24+
import org.apache.arrow.vector._
25+
import org.apache.arrow.vector.complex.ListVector
26+
import org.apache.spark.sql.catalyst.expressions.{ArrayExists, LambdaFunction, NamedLambdaVariable}
27+
import org.apache.spark.sql.types._
28+
import org.apache.spark.unsafe.types.UTF8String
29+
30+
import org.apache.comet.CometArrowAllocator
31+
32+
/**
33+
* JVM UDF implementing Spark's `exists(array, x -> predicate(x))` higher-order function.
34+
*
35+
* Inputs:
36+
* - inputs(0): ListVector (the array column)
37+
* - inputs(1): VarCharVector length-1 scalar (registry key for the lambda expression)
38+
*
39+
* Output: BitVector (nullable boolean), same length as the input array vector.
40+
*
41+
* Implements Spark's three-valued logic:
42+
* - true if any element satisfies the predicate
43+
* - null if no element satisfies but the predicate returned null for at least one element
44+
* - false if all elements produce false
45+
*/
46+
class ArrayExistsUDF extends CometUDF {
47+
48+
override def evaluate(inputs: Array[ValueVector]): ValueVector = {
49+
require(inputs.length == 2, s"ArrayExistsUDF expects 2 inputs, got ${inputs.length}")
50+
val listVec = inputs(0).asInstanceOf[ListVector]
51+
val keyVec = inputs(1).asInstanceOf[VarCharVector]
52+
require(
53+
keyVec.getValueCount >= 1 && !keyVec.isNull(0),
54+
"ArrayExistsUDF requires a non-null scalar registry key")
55+
56+
val registryKey = new String(keyVec.get(0), StandardCharsets.UTF_8)
57+
val arrayExistsExpr = CometLambdaRegistry.get(registryKey).asInstanceOf[ArrayExists]
58+
59+
val LambdaFunction(body, Seq(elementVar: NamedLambdaVariable), _) = arrayExistsExpr.function
60+
val followThreeValuedLogic = arrayExistsExpr.followThreeValuedLogic
61+
val elementType = elementVar.dataType
62+
63+
val dataVec = listVec.getDataVector
64+
val n = listVec.getValueCount
65+
val out = new BitVector("exists_result", CometArrowAllocator)
66+
out.allocateNew(n)
67+
68+
var i = 0
69+
while (i < n) {
70+
if (listVec.isNull(i)) {
71+
out.setNull(i)
72+
} else {
73+
val startIdx = listVec.getElementStartIndex(i)
74+
val endIdx = listVec.getElementEndIndex(i)
75+
var exists = false
76+
var foundNull = false
77+
var j = startIdx
78+
while (j < endIdx && !exists) {
79+
if (dataVec.isNull(j)) {
80+
elementVar.value.set(null)
81+
val ret = body.eval(null)
82+
if (ret == null) foundNull = true
83+
else if (ret.asInstanceOf[Boolean]) exists = true
84+
} else {
85+
val elem = getSparkValue(dataVec, j, elementType)
86+
elementVar.value.set(elem)
87+
val ret = body.eval(null)
88+
if (ret == null) foundNull = true
89+
else if (ret.asInstanceOf[Boolean]) exists = true
90+
}
91+
j += 1
92+
}
93+
if (exists) {
94+
out.set(i, 1)
95+
} else if (followThreeValuedLogic && foundNull) {
96+
out.setNull(i)
97+
} else {
98+
out.set(i, 0)
99+
}
100+
}
101+
i += 1
102+
}
103+
out.setValueCount(n)
104+
out
105+
}
106+
107+
private def getSparkValue(vec: ValueVector, index: Int, sparkType: DataType): Any = {
108+
sparkType match {
109+
case BooleanType =>
110+
vec.asInstanceOf[BitVector].get(index) == 1
111+
case ByteType =>
112+
vec.asInstanceOf[TinyIntVector].get(index).toByte
113+
case ShortType =>
114+
vec.asInstanceOf[SmallIntVector].get(index).toShort
115+
case IntegerType =>
116+
vec.asInstanceOf[IntVector].get(index)
117+
case LongType =>
118+
vec.asInstanceOf[BigIntVector].get(index)
119+
case FloatType =>
120+
vec.asInstanceOf[Float4Vector].get(index)
121+
case DoubleType =>
122+
vec.asInstanceOf[Float8Vector].get(index)
123+
case StringType =>
124+
val bytes = vec.asInstanceOf[VarCharVector].get(index)
125+
UTF8String.fromBytes(bytes)
126+
case BinaryType =>
127+
vec.asInstanceOf[VarBinaryVector].get(index)
128+
case _: DecimalType =>
129+
val dt = sparkType.asInstanceOf[DecimalType]
130+
val decimal = vec.asInstanceOf[DecimalVector].getObject(index)
131+
Decimal(decimal, dt.precision, dt.scale)
132+
case DateType =>
133+
vec.asInstanceOf[DateDayVector].get(index)
134+
case TimestampType =>
135+
vec.asInstanceOf[TimeStampMicroTZVector].get(index)
136+
case _ =>
137+
throw new UnsupportedOperationException(
138+
s"ArrayExistsUDF does not yet support element type: $sparkType")
139+
}
140+
}
141+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.udf
21+
22+
import java.util.UUID
23+
import java.util.concurrent.ConcurrentHashMap
24+
25+
import org.apache.spark.sql.catalyst.expressions.Expression
26+
27+
/**
28+
* Thread-safe registry bridging plan-time Spark expressions to execution-time UDF lookup. At plan
29+
* time the serde layer registers a lambda expression under a unique key; at execution time the
30+
* UDF retrieves it by that key (passed as a scalar argument).
31+
*/
32+
object CometLambdaRegistry {
33+
34+
private val registry = new ConcurrentHashMap[String, Expression]()
35+
36+
def register(expression: Expression): String = {
37+
val key = UUID.randomUUID().toString
38+
registry.put(key, expression)
39+
key
40+
}
41+
42+
def get(key: String): Expression = {
43+
val expr = registry.get(key)
44+
if (expr == null) {
45+
throw new IllegalStateException(
46+
s"Lambda expression not found in registry for key: $key. " +
47+
"This indicates a lifecycle issue between plan creation and execution.")
48+
}
49+
expr
50+
}
51+
52+
def remove(key: String): Unit = {
53+
registry.remove(key)
54+
}
55+
56+
// Visible for testing
57+
def size(): Int = registry.size()
58+
}

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
6969
classOf[Flatten] -> CometFlatten,
7070
classOf[GetArrayItem] -> CometGetArrayItem,
7171
classOf[Size] -> CometSize,
72-
classOf[ArraysZip] -> CometArraysZip)
72+
classOf[ArraysZip] -> CometArraysZip,
73+
classOf[ArrayExists] -> CometArrayExists)
7374

7475
private val conditionalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] =
7576
Map(classOf[CaseWhen] -> CometCaseWhen, classOf[If] -> CometIf)

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ package org.apache.comet.serde
2222
import scala.annotation.tailrec
2323
import scala.jdk.CollectionConverters._
2424

25-
import org.apache.spark.sql.catalyst.expressions.{And, ArrayAppend, ArrayContains, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraysOverlap, ArraysZip, ArrayUnion, Attribute, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, SortArray}
25+
import org.apache.spark.sql.catalyst.expressions.{And, ArrayAppend, ArrayContains, ArrayExcept, ArrayExists, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraysOverlap, ArraysZip, ArrayUnion, Attribute, AttributeReference, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, LambdaFunction, Literal, NamedLambdaVariable, Reverse, Size, SortArray}
2626
import org.apache.spark.sql.catalyst.util.GenericArrayData
2727
import org.apache.spark.sql.internal.SQLConf
2828
import org.apache.spark.sql.types._
@@ -31,6 +31,7 @@ import org.apache.comet.CometConf
3131
import org.apache.comet.CometSparkSessionExtensions.withInfo
3232
import org.apache.comet.serde.QueryPlanSerde._
3333
import org.apache.comet.shims.{CometExprShim, CometTypeShim}
34+
import org.apache.comet.udf.CometLambdaRegistry
3435

3536
object CometArrayRemove
3637
extends CometExpressionSerde[ArrayRemove]
@@ -812,3 +813,84 @@ trait ArraysBase {
812813
}
813814
}
814815
}
816+
817+
object CometArrayExists extends CometExpressionSerde[ArrayExists] {
818+
819+
private val supportedElementTypes: Set[Class[_]] = Set(
820+
classOf[BooleanType],
821+
classOf[ByteType],
822+
classOf[ShortType],
823+
classOf[IntegerType],
824+
classOf[LongType],
825+
classOf[FloatType],
826+
classOf[DoubleType],
827+
classOf[DecimalType],
828+
classOf[DateType],
829+
classOf[TimestampType],
830+
classOf[StringType])
831+
832+
private def isElementTypeSupported(dt: DataType): Boolean = dt match {
833+
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType |
834+
_: DecimalType | DateType | TimestampType | StringType =>
835+
true
836+
case _ => false
837+
}
838+
839+
override def getSupportLevel(expr: ArrayExists): SupportLevel = {
840+
val ArrayType(elementType, _) = expr.argument.dataType
841+
if (!isElementTypeSupported(elementType)) {
842+
return Unsupported(Some(s"Unsupported array element type: $elementType"))
843+
}
844+
// Only support lambdas that reference the lambda variable alone (no captured columns)
845+
expr.function match {
846+
case LambdaFunction(body, Seq(_: NamedLambdaVariable), _) =>
847+
val capturedRefs = body.collect { case a: AttributeReference => a }
848+
if (capturedRefs.nonEmpty) {
849+
Unsupported(Some("Lambda references columns outside the array element"))
850+
} else {
851+
Compatible()
852+
}
853+
case _ =>
854+
Unsupported(Some("Only single-argument lambda functions are supported"))
855+
}
856+
}
857+
858+
override def convert(
859+
expr: ArrayExists,
860+
inputs: Seq[Attribute],
861+
binding: Boolean): Option[ExprOuterClass.Expr] = {
862+
val arrayProto = exprToProtoInternal(expr.argument, inputs, binding)
863+
if (arrayProto.isEmpty) {
864+
withInfo(expr, "Failed to serialize array argument")
865+
return None
866+
}
867+
868+
val registryKey = CometLambdaRegistry.register(expr)
869+
val keyLiteral = Literal(registryKey)
870+
val keyProto = exprToProtoInternal(keyLiteral, inputs, binding)
871+
if (keyProto.isEmpty) {
872+
CometLambdaRegistry.remove(registryKey)
873+
withInfo(expr, "Failed to serialize registry key literal")
874+
return None
875+
}
876+
877+
val returnType = serializeDataType(BooleanType).getOrElse {
878+
CometLambdaRegistry.remove(registryKey)
879+
return None
880+
}
881+
882+
val udfBuilder = ExprOuterClass.JvmScalarUdf
883+
.newBuilder()
884+
.setClassName("org.apache.comet.udf.ArrayExistsUDF")
885+
.addArgs(arrayProto.get)
886+
.addArgs(keyProto.get)
887+
.setReturnType(returnType)
888+
.setReturnNullable(expr.nullable)
889+
890+
Some(
891+
ExprOuterClass.Expr
892+
.newBuilder()
893+
.setJvmScalarUdf(udfBuilder.build())
894+
.build())
895+
}
896+
}

spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,4 +1085,45 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
10851085
}
10861086
}
10871087
}
1088+
1089+
test("array_exists - integer predicate") {
1090+
withTable("t") {
1091+
sql("CREATE TABLE t (arr ARRAY<INT>) USING parquet")
1092+
sql("INSERT INTO t VALUES (array(1, 2, 3)), (array(4, 5, 6)), (array(-1, -2)), (NULL)")
1093+
checkSparkAnswerAndOperator(sql("SELECT exists(arr, x -> x > 2) FROM t"))
1094+
}
1095+
}
1096+
1097+
test("array_exists - string predicate") {
1098+
withTable("t") {
1099+
sql("CREATE TABLE t (arr ARRAY<STRING>) USING parquet")
1100+
sql(
1101+
"INSERT INTO t VALUES (array('hello', 'world')), (array('foo')), (array(NULL, 'bar')), (NULL)")
1102+
checkSparkAnswerAndOperator(sql("SELECT exists(arr, x -> x = 'world') FROM t"))
1103+
}
1104+
}
1105+
1106+
test("array_exists - null elements with three-valued logic") {
1107+
withTable("t") {
1108+
sql("CREATE TABLE t (arr ARRAY<INT>) USING parquet")
1109+
sql("INSERT INTO t VALUES (array(1, NULL, 3)), (array(NULL, NULL)), (array(4, 5))")
1110+
checkSparkAnswerAndOperator(sql("SELECT exists(arr, x -> x > 10) FROM t"))
1111+
}
1112+
}
1113+
1114+
test("array_exists - all elements match") {
1115+
withTable("t") {
1116+
sql("CREATE TABLE t (arr ARRAY<INT>) USING parquet")
1117+
sql("INSERT INTO t VALUES (array(10, 20, 30)), (array(1, 2, 3))")
1118+
checkSparkAnswerAndOperator(sql("SELECT exists(arr, x -> x > 0) FROM t"))
1119+
}
1120+
}
1121+
1122+
test("array_exists - empty array") {
1123+
withTable("t") {
1124+
sql("CREATE TABLE t (arr ARRAY<INT>) USING parquet")
1125+
sql("INSERT INTO t VALUES (array()), (array(1))")
1126+
checkSparkAnswerAndOperator(sql("SELECT exists(arr, x -> x > 0) FROM t"))
1127+
}
1128+
}
10881129
}

0 commit comments

Comments
 (0)