Skip to content

Commit f1ece6c

Browse files
committed
feat: add CometUDF framework and array_exists lambda support
Adds a new JVM UDF bridge framework that allows Spark expressions to be evaluated on the JVM side via Arrow C Data Interface, while keeping the native execution pipeline intact. Includes array_exists as the first lambda-based expression using this framework.
1 parent 3fc301a commit f1ece6c

16 files changed

Lines changed: 998 additions & 123 deletions

File tree

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.udf;
21+
22+
import java.util.LinkedHashMap;
23+
import java.util.Map;
24+
25+
import org.apache.arrow.c.ArrowArray;
26+
import org.apache.arrow.c.ArrowSchema;
27+
import org.apache.arrow.c.Data;
28+
import org.apache.arrow.memory.BufferAllocator;
29+
import org.apache.arrow.vector.FieldVector;
30+
import org.apache.arrow.vector.ValueVector;
31+
32+
/**
33+
* JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method
34+
* pattern used by CometScalarSubquery so the native side can dispatch via
35+
* call_static_method_unchecked.
36+
*/
37+
public class CometUdfBridge {
38+
39+
// Per-thread, bounded LRU of UDF instances keyed by class name. Comet
40+
// native execution threads (Tokio/DataFusion worker pool) are reused
41+
// across tasks within an executor, so the effective lifetime of cached
42+
// entries is the worker thread (i.e. the executor JVM). This is fine for
43+
// stateless UDFs like ArrayExistsUDF; future stateful UDFs would need
44+
// explicit per-task isolation.
45+
private static final int CACHE_CAPACITY = 64;
46+
47+
private static final ThreadLocal<LinkedHashMap<String, CometUDF>> INSTANCES =
48+
ThreadLocal.withInitial(
49+
() ->
50+
new LinkedHashMap<String, CometUDF>(CACHE_CAPACITY, 0.75f, true) {
51+
@Override
52+
protected boolean removeEldestEntry(Map.Entry<String, CometUDF> eldest) {
53+
return size() > CACHE_CAPACITY;
54+
}
55+
});
56+
57+
/**
58+
* Called from native via JNI.
59+
*
60+
* @param udfClassName fully-qualified class name implementing CometUDF
61+
* @param inputArrayPtrs addresses of pre-allocated FFI_ArrowArray structs (one per input)
62+
* @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input)
63+
* @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result
64+
* @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result
65+
*/
66+
public static void evaluate(
67+
String udfClassName,
68+
long[] inputArrayPtrs,
69+
long[] inputSchemaPtrs,
70+
long outArrayPtr,
71+
long outSchemaPtr) {
72+
LinkedHashMap<String, CometUDF> cache = INSTANCES.get();
73+
CometUDF udf = cache.get(udfClassName);
74+
if (udf == null) {
75+
try {
76+
// Resolve via the executor's context classloader so user-supplied UDF jars
77+
// (added via spark.jars / --jars) are visible.
78+
ClassLoader cl = Thread.currentThread().getContextClassLoader();
79+
if (cl == null) {
80+
cl = CometUdfBridge.class.getClassLoader();
81+
}
82+
udf =
83+
(CometUDF) Class.forName(udfClassName, true, cl).getDeclaredConstructor().newInstance();
84+
} catch (ReflectiveOperationException e) {
85+
throw new RuntimeException("Failed to instantiate CometUDF: " + udfClassName, e);
86+
}
87+
cache.put(udfClassName, udf);
88+
}
89+
90+
BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator();
91+
92+
ValueVector[] inputs = new ValueVector[inputArrayPtrs.length];
93+
ValueVector result = null;
94+
try {
95+
for (int i = 0; i < inputArrayPtrs.length; i++) {
96+
ArrowArray inArr = ArrowArray.wrap(inputArrayPtrs[i]);
97+
ArrowSchema inSch = ArrowSchema.wrap(inputSchemaPtrs[i]);
98+
inputs[i] = Data.importVector(allocator, inArr, inSch, null);
99+
}
100+
101+
result = udf.evaluate(inputs);
102+
if (!(result instanceof FieldVector)) {
103+
throw new RuntimeException(
104+
"CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName());
105+
}
106+
// Result length must match the longest input. Scalar (length-1) inputs
107+
// are allowed to be shorter, but a vector input bounds the output.
108+
int expectedLen = 0;
109+
for (ValueVector v : inputs) {
110+
expectedLen = Math.max(expectedLen, v.getValueCount());
111+
}
112+
if (result.getValueCount() != expectedLen) {
113+
throw new RuntimeException(
114+
"CometUDF.evaluate() returned "
115+
+ result.getValueCount()
116+
+ " rows, expected "
117+
+ expectedLen);
118+
}
119+
ArrowArray outArr = ArrowArray.wrap(outArrayPtr);
120+
ArrowSchema outSch = ArrowSchema.wrap(outSchemaPtr);
121+
Data.exportVector(allocator, (FieldVector) result, null, outArr, outSch);
122+
} finally {
123+
for (ValueVector v : inputs) {
124+
if (v != null) {
125+
try {
126+
v.close();
127+
} catch (RuntimeException ignored) {
128+
// do not mask the original throwable
129+
}
130+
}
131+
}
132+
if (result != null) {
133+
try {
134+
result.close();
135+
} catch (RuntimeException ignored) {
136+
// do not mask the original throwable
137+
}
138+
}
139+
}
140+
}
141+
}
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.udf
21+
22+
import java.nio.charset.StandardCharsets
23+
24+
import org.apache.arrow.vector._
25+
import org.apache.arrow.vector.complex.ListVector
26+
import org.apache.spark.sql.catalyst.expressions.{ArrayExists, LambdaFunction, NamedLambdaVariable}
27+
import org.apache.spark.sql.types._
28+
import org.apache.spark.unsafe.types.UTF8String
29+
30+
import org.apache.comet.CometArrowAllocator
31+
32+
/**
33+
* JVM UDF implementing Spark's `exists(array, x -> predicate(x))` higher-order function.
34+
*
35+
* Inputs:
36+
* - inputs(0): ListVector (the array column)
37+
* - inputs(1): VarCharVector length-1 scalar (registry key for the lambda expression)
38+
*
39+
* Output: BitVector (nullable boolean), same length as the input array vector.
40+
*
41+
* Implements Spark's three-valued logic:
42+
* - true if any element satisfies the predicate
43+
* - null if no element satisfies but the predicate returned null for at least one element
44+
* - false if all elements produce false
45+
*/
46+
class ArrayExistsUDF extends CometUDF {
47+
48+
override def evaluate(inputs: Array[ValueVector]): ValueVector = {
49+
require(inputs.length == 2, s"ArrayExistsUDF expects 2 inputs, got ${inputs.length}")
50+
val listVec = inputs(0).asInstanceOf[ListVector]
51+
val keyVec = inputs(1).asInstanceOf[VarCharVector]
52+
require(
53+
keyVec.getValueCount >= 1 && !keyVec.isNull(0),
54+
"ArrayExistsUDF requires a non-null scalar registry key")
55+
56+
val registryKey = new String(keyVec.get(0), StandardCharsets.UTF_8)
57+
val arrayExistsExpr = CometLambdaRegistry.get(registryKey).asInstanceOf[ArrayExists]
58+
59+
val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = arrayExistsExpr.function
60+
val body = arrayExistsExpr.functionForEval
61+
val followThreeValuedLogic = arrayExistsExpr.followThreeValuedLogic
62+
val elementType = elementVar.dataType
63+
64+
val dataVec = listVec.getDataVector
65+
val n = listVec.getValueCount
66+
val out = new BitVector("exists_result", CometArrowAllocator)
67+
out.allocateNew(n)
68+
69+
var i = 0
70+
while (i < n) {
71+
if (listVec.isNull(i)) {
72+
out.setNull(i)
73+
} else {
74+
val startIdx = listVec.getElementStartIndex(i)
75+
val endIdx = listVec.getElementEndIndex(i)
76+
var exists = false
77+
var foundNull = false
78+
var j = startIdx
79+
while (j < endIdx && !exists) {
80+
if (dataVec.isNull(j)) {
81+
elementVar.value.set(null)
82+
val ret = body.eval(null)
83+
if (ret == null) foundNull = true
84+
else if (ret.asInstanceOf[Boolean]) exists = true
85+
} else {
86+
val elem = getSparkValue(dataVec, j, elementType)
87+
elementVar.value.set(elem)
88+
val ret = body.eval(null)
89+
if (ret == null) foundNull = true
90+
else if (ret.asInstanceOf[Boolean]) exists = true
91+
}
92+
j += 1
93+
}
94+
if (exists) {
95+
out.set(i, 1)
96+
} else if (followThreeValuedLogic && foundNull) {
97+
out.setNull(i)
98+
} else {
99+
out.set(i, 0)
100+
}
101+
}
102+
i += 1
103+
}
104+
out.setValueCount(n)
105+
out
106+
}
107+
108+
private def getSparkValue(vec: ValueVector, index: Int, sparkType: DataType): Any = {
109+
sparkType match {
110+
case BooleanType =>
111+
vec.asInstanceOf[BitVector].get(index) == 1
112+
case ByteType =>
113+
vec.asInstanceOf[TinyIntVector].get(index).toByte
114+
case ShortType =>
115+
vec.asInstanceOf[SmallIntVector].get(index).toShort
116+
case IntegerType =>
117+
vec.asInstanceOf[IntVector].get(index)
118+
case LongType =>
119+
vec.asInstanceOf[BigIntVector].get(index)
120+
case FloatType =>
121+
vec.asInstanceOf[Float4Vector].get(index)
122+
case DoubleType =>
123+
vec.asInstanceOf[Float8Vector].get(index)
124+
case StringType =>
125+
val bytes = vec.asInstanceOf[VarCharVector].get(index)
126+
UTF8String.fromBytes(bytes)
127+
case BinaryType =>
128+
vec.asInstanceOf[VarBinaryVector].get(index)
129+
case _: DecimalType =>
130+
val dt = sparkType.asInstanceOf[DecimalType]
131+
val decimal = vec.asInstanceOf[DecimalVector].getObject(index)
132+
Decimal(decimal, dt.precision, dt.scale)
133+
case DateType =>
134+
vec.asInstanceOf[DateDayVector].get(index)
135+
case TimestampType =>
136+
vec.asInstanceOf[TimeStampMicroTZVector].get(index)
137+
case _ =>
138+
throw new UnsupportedOperationException(
139+
s"ArrayExistsUDF does not yet support element type: $sparkType")
140+
}
141+
}
142+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.udf
21+
22+
import java.util.UUID
23+
import java.util.concurrent.ConcurrentHashMap
24+
25+
import org.apache.spark.sql.catalyst.expressions.Expression
26+
27+
/**
28+
* Thread-safe registry bridging plan-time Spark expressions to execution-time UDF lookup. At plan
29+
* time the serde layer registers a lambda expression under a unique key; at execution time the
30+
* UDF retrieves it by that key (passed as a scalar argument).
31+
*/
32+
object CometLambdaRegistry {
33+
34+
private val registry = new ConcurrentHashMap[String, Expression]()
35+
36+
def register(expression: Expression): String = {
37+
val key = UUID.randomUUID().toString
38+
registry.put(key, expression)
39+
key
40+
}
41+
42+
def get(key: String): Expression = {
43+
val expr = registry.get(key)
44+
if (expr == null) {
45+
throw new IllegalStateException(
46+
s"Lambda expression not found in registry for key: $key. " +
47+
"This indicates a lifecycle issue between plan creation and execution.")
48+
}
49+
expr
50+
}
51+
52+
def remove(key: String): Unit = {
53+
registry.remove(key)
54+
}
55+
56+
// Visible for testing
57+
def size(): Int = registry.size()
58+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.udf
21+
22+
import org.apache.arrow.vector.ValueVector
23+
24+
/**
25+
* Scalar UDF invoked from native execution via JNI. Receives Arrow vectors as input and returns
26+
* an Arrow vector.
27+
*
28+
* - Vector arguments arrive at the row count of the current batch.
29+
* - Scalar (literal-folded) arguments arrive as length-1 vectors and must be read at index 0.
30+
* - The returned vector's length must match the longest input.
31+
*
32+
* Implementations must have a public no-arg constructor and should be stateless: instances are
33+
* cached per executor thread for the lifetime of the JVM.
34+
*/
35+
trait CometUDF {
36+
def evaluate(inputs: Array[ValueVector]): ValueVector
37+
}

0 commit comments

Comments
 (0)