Skip to content

Commit 1746bcc

Browse files
committed
feat: Arrow-direct codegen dispatcher for Spark expressions and ScalaUDFs
1 parent 47ec2a6 commit 1746bcc

50 files changed

Lines changed: 5708 additions & 152 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/pr_benchmark_check.yml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,5 @@ jobs:
8484
${{ runner.os }}-benchmark-maven-
8585
8686
- name: Check Scala compilation and linting
87-
# Pin to spark-4.0 (Scala 2.13.16) because the default profile is now
88-
# spark-4.1 / Scala 2.13.17, and semanticdb-scalac_2.13.17 is not yet
89-
# published, which breaks `-Psemanticdb`. See pr_build_linux.yml for
90-
# the same exclusion in the main lint matrix.
9187
run: |
92-
./mvnw -B compile test-compile scalafix:scalafix -Dscalafix.mode=CHECK -Psemanticdb -Pspark-4.0 -DskipTests
88+
./mvnw -B compile test-compile scalafix:scalafix -Dscalafix.mode=CHECK -Psemanticdb -DskipTests
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.udf;
21+
22+
import org.apache.arrow.vector.FieldVector;
23+
import org.apache.arrow.vector.ValueVector;
24+
25+
/**
26+
* Abstract base extended by the Janino-compiled batch kernel emitted by {@code
27+
* CometBatchKernelCodegen}. The generated subclass extends {@code CometInternalRow} (so Spark's
28+
* {@code BoundReference.genCode} can call {@code this.getUTF8String(ord)} directly) and carries
29+
* typed input fields baked at codegen time, one per input column. Expression evaluation plus Arrow
30+
* read/write fuse into one method per expression tree.
31+
*
32+
* <p>Input scope: any {@code ValueVector[]}; the generated subclass casts each slot to the concrete
33+
* Arrow type the compile-time schema specified. Output is a generic {@code FieldVector}; the
34+
* generated subclass casts to the concrete type matching the bound expression's {@code dataType}.
35+
* Widen input support by adding vector classes to the getter switch in {@code
36+
* CometBatchKernelCodegen.typedInputAccessors}; widen output support by adding cases in {@code
37+
* CometBatchKernelCodegen.allocateOutput} and {@code outputWriter}.
38+
*/
39+
public abstract class CometBatchKernel extends CometInternalRow {
40+
41+
protected final Object[] references;
42+
43+
protected CometBatchKernel(Object[] references) {
44+
this.references = references;
45+
}
46+
47+
/**
48+
* Process one batch.
49+
*
50+
* @param inputs Arrow input vectors; length and concrete classes must match the schema the kernel
51+
* was compiled against
52+
* @param output Arrow output vector; caller allocates to the expression's {@code dataType}
53+
* @param numRows number of rows in this batch
54+
*/
55+
public abstract void process(ValueVector[] inputs, FieldVector output, int numRows);
56+
57+
/**
58+
* Run partition-dependent initialization. The generated subclass overrides this to execute
59+
* statements collected via {@code CodegenContext.addPartitionInitializationStatement}, for
60+
* example reseeding {@code Rand}'s {@code XORShiftRandom} from {@code seed + partitionIndex}.
61+
* Deterministic expressions leave this as a no-op.
62+
*
63+
* <p>The caller must invoke this before the first {@code process} call of each partition. The
64+
* generated subclass is not thread-safe across concurrent {@code process} calls, so kernels are
65+
* allocated per dispatcher invocation and init is run once on the fresh instance.
66+
*/
67+
public void init(int partitionIndex) {}
68+
}

common/src/main/java/org/apache/comet/udf/CometUdfBridge.java

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
package org.apache.comet.udf;
2121

22-
import java.util.concurrent.ConcurrentHashMap;
22+
import java.util.LinkedHashMap;
23+
import java.util.Map;
2324

2425
import org.apache.arrow.c.ArrowArray;
2526
import org.apache.arrow.c.ArrowSchema;
@@ -35,10 +36,23 @@
3536
*/
3637
public class CometUdfBridge {
3738

38-
// Process-wide cache of UDF instances keyed by class name. CometUDF
39-
// implementations are required to be stateless (see CometUDF), so a
40-
// single shared instance per class is safe across native worker threads.
41-
private static final ConcurrentHashMap<String, CometUDF> INSTANCES = new ConcurrentHashMap<>();
39+
// Per-thread, bounded LRU of UDF instances keyed by class name. Comet
40+
// native execution threads (Tokio/DataFusion worker pool) are reused
41+
// across tasks within an executor, so the effective lifetime of cached
42+
// entries is the worker thread (i.e. the executor JVM). This is fine for
43+
// stateless UDFs like RegExpLikeUDF; future stateful UDFs would need
44+
// explicit per-task isolation.
45+
private static final int CACHE_CAPACITY = 64;
46+
47+
private static final ThreadLocal<LinkedHashMap<String, CometUDF>> INSTANCES =
48+
ThreadLocal.withInitial(
49+
() ->
50+
new LinkedHashMap<String, CometUDF>(CACHE_CAPACITY, 0.75f, true) {
51+
@Override
52+
protected boolean removeEldestEntry(Map.Entry<String, CometUDF> eldest) {
53+
return size() > CACHE_CAPACITY;
54+
}
55+
});
4256

4357
/**
4458
* Called from native via JNI.
@@ -48,30 +62,35 @@ public class CometUdfBridge {
4862
* @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input)
4963
* @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result
5064
* @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result
65+
* @param numRows number of rows in the current batch. Mirrors DataFusion's
66+
* {@code ScalarFunctionArgs.number_rows} and gives UDFs an explicit batch-size signal for
67+
* cases where no input arg is a batch-length array (e.g. a zero-arg non-deterministic
68+
* ScalaUDF). UDFs that already read size from their input vectors can ignore it.
5169
*/
5270
public static void evaluate(
5371
String udfClassName,
5472
long[] inputArrayPtrs,
5573
long[] inputSchemaPtrs,
5674
long outArrayPtr,
57-
long outSchemaPtr) {
58-
CometUDF udf =
59-
INSTANCES.computeIfAbsent(
60-
udfClassName,
61-
name -> {
62-
try {
63-
// Resolve via the executor's context classloader so user-supplied UDF jars
64-
// (added via spark.jars / --jars) are visible.
65-
ClassLoader cl = Thread.currentThread().getContextClassLoader();
66-
if (cl == null) {
67-
cl = CometUdfBridge.class.getClassLoader();
68-
}
69-
return (CometUDF)
70-
Class.forName(name, true, cl).getDeclaredConstructor().newInstance();
71-
} catch (ReflectiveOperationException e) {
72-
throw new RuntimeException("Failed to instantiate CometUDF: " + name, e);
73-
}
74-
});
75+
long outSchemaPtr,
76+
int numRows) {
77+
LinkedHashMap<String, CometUDF> cache = INSTANCES.get();
78+
CometUDF udf = cache.get(udfClassName);
79+
if (udf == null) {
80+
try {
81+
// Resolve via the executor's context classloader so user-supplied UDF jars
82+
// (added via spark.jars / --jars) are visible.
83+
ClassLoader cl = Thread.currentThread().getContextClassLoader();
84+
if (cl == null) {
85+
cl = CometUdfBridge.class.getClassLoader();
86+
}
87+
udf =
88+
(CometUDF) Class.forName(udfClassName, true, cl).getDeclaredConstructor().newInstance();
89+
} catch (ReflectiveOperationException e) {
90+
throw new RuntimeException("Failed to instantiate CometUDF: " + udfClassName, e);
91+
}
92+
cache.put(udfClassName, udf);
93+
}
7594

7695
BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator();
7796

@@ -84,23 +103,17 @@ public static void evaluate(
84103
inputs[i] = Data.importVector(allocator, inArr, inSch, null);
85104
}
86105

87-
result = udf.evaluate(inputs);
106+
result = udf.evaluate(inputs, numRows);
88107
if (!(result instanceof FieldVector)) {
89108
throw new RuntimeException(
90109
"CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName());
91110
}
92-
// Result length must match the longest input. Scalar (length-1) inputs
93-
// are allowed to be shorter, but a vector input bounds the output.
94-
int expectedLen = 0;
95-
for (ValueVector v : inputs) {
96-
expectedLen = Math.max(expectedLen, v.getValueCount());
97-
}
98-
if (result.getValueCount() != expectedLen) {
111+
if (result.getValueCount() != numRows) {
99112
throw new RuntimeException(
100113
"CometUDF.evaluate() returned "
101114
+ result.getValueCount()
102115
+ " rows, expected "
103-
+ expectedLen);
116+
+ numRows);
104117
}
105118
ArrowArray outArr = ArrowArray.wrap(outArrayPtr);
106119
ArrowSchema outSch = ArrowSchema.wrap(outSchemaPtr);

common/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,46 @@ object CometConf extends ShimCometConf {
380380
.booleanConf
381381
.createWithDefault(false)
382382

383+
val REGEXP_ENGINE_RUST = "rust"
384+
val REGEXP_ENGINE_JAVA = "java"
385+
386+
val COMET_REGEXP_ENGINE: ConfigEntry[String] =
387+
conf("spark.comet.exec.regexp.engine")
388+
.category(CATEGORY_EXEC)
389+
.doc(
390+
"Experimental. Selects the engine used to evaluate supported regular-expression " +
391+
s"expressions. `$REGEXP_ENGINE_RUST` uses the native DataFusion regexp engine. " +
392+
s"`$REGEXP_ENGINE_JAVA` routes through a JVM-side UDF (java.util.regex.Pattern) for " +
393+
"Spark-compatible semantics, at the cost of JNI roundtrips per batch. Expressions " +
394+
"routed when set to java: rlike, regexp_extract, regexp_extract_all, regexp_replace, " +
395+
"regexp_instr, and split.")
396+
.stringConf
397+
.transform(_.toLowerCase(Locale.ROOT))
398+
.checkValues(Set(REGEXP_ENGINE_RUST, REGEXP_ENGINE_JAVA))
399+
.createWithDefault(REGEXP_ENGINE_JAVA)
400+
401+
val CODEGEN_DISPATCH_AUTO = "auto"
402+
val CODEGEN_DISPATCH_DISABLED = "disabled"
403+
val CODEGEN_DISPATCH_FORCE = "force"
404+
405+
val COMET_CODEGEN_DISPATCH_MODE: ConfigEntry[String] =
406+
conf("spark.comet.exec.codegenDispatch.mode")
407+
.category(CATEGORY_EXEC)
408+
.doc("Controls whether Comet routes eligible scalar expressions through the Arrow-direct " +
409+
"codegen dispatcher (`CometCodegenDispatchUDF`) rather than through a native " +
410+
s"DataFusion implementation or a hand-coded JVM UDF. `$CODEGEN_DISPATCH_AUTO` lets " +
411+
"each expression's serde decide its preferred path based on measured evidence " +
412+
"(e.g. for regex, codegen is preferred when " +
413+
s"spark.comet.exec.regexp.engine=$REGEXP_ENGINE_JAVA). " +
414+
s"`$CODEGEN_DISPATCH_DISABLED` never uses codegen dispatch. `$CODEGEN_DISPATCH_FORCE` " +
415+
"inverts the chain: every serde tries codegen first and falls through to its next " +
416+
"preferred path only when `canHandle` rejects the expression. Useful for debugging " +
417+
"and benchmarking.")
418+
.stringConf
419+
.transform(_.toLowerCase(Locale.ROOT))
420+
.checkValues(Set(CODEGEN_DISPATCH_AUTO, CODEGEN_DISPATCH_DISABLED, CODEGEN_DISPATCH_FORCE))
421+
.createWithDefault(CODEGEN_DISPATCH_AUTO)
422+
383423
val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] =
384424
conf("spark.comet.native.shuffle.partitioning.hash.enabled")
385425
.category(CATEGORY_SHUFFLE)

0 commit comments

Comments
 (0)