Skip to content

Commit 9c76e87

Browse files
authored
feat: support stateful CometUDFs (apache#4345)
1 parent 43d5fb9 commit 9c76e87

2 files changed

Lines changed: 97 additions & 16 deletions

File tree

common/src/main/java/org/apache/comet/udf/CometUdfBridge.java

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,48 @@
2929
import org.apache.arrow.vector.ValueVector;
3030
import org.apache.spark.TaskContext;
3131
import org.apache.spark.comet.CometTaskContextShim;
32+
import org.apache.spark.util.TaskCompletionListener;
3233

3334
/**
3435
* JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method
3536
* pattern used by CometScalarSubquery so the native side can dispatch via
3637
* call_static_method_unchecked.
38+
*
39+
* <p>Cache invariants:
40+
*
41+
* <ol>
42+
* <li>For each live Spark task attempt there is at most one {@link CometUDF} instance per class
43+
* name.
44+
* <li>A {@link CometUDF} instance is visible only within the Spark task attempt that instantiated
45+
* it. Two task attempts observing the same class name receive distinct instances.
46+
* <li>At any instant at most one thread is inside {@code evaluate()} for a given {@code
47+
* taskAttemptId}. This follows from Spark executing one native future per partition and Tokio
48+
* polling one future per worker at a time.
49+
* <li>All instances for a task are dropped by the {@link TaskCompletionListener} registered on
50+
* the first cache miss for that task. No cache entry outlives its task.
51+
* <li>When {@code taskContext} is {@code null} (unit tests, direct native driver) the fallback
52+
* key {@code -1L} is used; that bucket is never evicted because no task-completion event will
53+
* fire.
54+
* </ol>
55+
*
56+
* <p>Keying by {@code taskAttemptId} rather than by thread keeps the cache correct under Tokio
57+
* work-stealing: on the scan-free execution path the same Spark task can be polled by different
58+
* Tokio workers across batches, so a thread-local cache would lose per-task state on migration. The
59+
* task attempt ID is stable for the life of the task regardless of which worker is polling.
3760
*/
3861
public class CometUdfBridge {
3962

40-
// Process-wide cache of UDF instances keyed by class name. CometUDF
41-
// implementations are required to be stateless (see CometUDF), so a
42-
// single shared instance per class is safe across native worker threads.
43-
private static final ConcurrentHashMap<String, CometUDF> INSTANCES = new ConcurrentHashMap<>();
63+
/**
64+
* Task-scoped cache of {@link CometUDF} instances. Outer map keys are Spark task attempt IDs (or
65+
* {@code -1L} when no {@link TaskContext} is available). Inner maps hold one instance per UDF
66+
* class name for the task's lifetime. Entries are removed by the {@link TaskCompletionListener}
67+
* registered on the first cache miss per task.
68+
*/
69+
private static final ConcurrentHashMap<Long, ConcurrentHashMap<String, CometUDF>> INSTANCES =
70+
new ConcurrentHashMap<>();
71+
72+
/** Sentinel key for calls that carry no {@link TaskContext} (unit tests, direct driver). */
73+
private static final long NO_TASK_ID = -1L;
4474

4575
/**
4676
* Called from native via JNI.
@@ -58,7 +88,9 @@ public class CometUdfBridge {
5888
* thread-local on entry, with the prior value (if any) saved and restored in {@code finally}.
5989
* Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code
6090
* MonotonicallyIncreasingID}) work from Tokio workers and avoids reusing a stale TaskContext
61-
* left on a worker by a previous task.
91+
* left on a worker by a previous task. Its task attempt ID also keys the UDF-instance cache,
92+
* so a UDF holding per-task state in fields sees a consistent instance for every call within
93+
* the task regardless of which Tokio worker is polling.
6294
*/
6395
public static void evaluate(
6496
String udfClassName,
@@ -68,16 +100,33 @@ public static void evaluate(
68100
long outSchemaPtr,
69101
int numRows,
70102
TaskContext taskContext) {
71-
// Save-and-restore rather than only-install-if-null: the propagated context is the ground
72-
// truth for this call. Any value already on the thread is either (a) the same object on a
73-
// Spark task thread, or (b) stale from a prior task on a reused Tokio worker.
103+
assert udfClassName != null && !udfClassName.isEmpty() : "udfClassName must be non-empty";
104+
assert inputArrayPtrs != null && inputSchemaPtrs != null
105+
: "input pointer arrays must be non-null";
106+
assert inputArrayPtrs.length == inputSchemaPtrs.length
107+
: "input array pointer count must equal schema pointer count";
108+
assert numRows >= 0 : "numRows must be non-negative";
109+
assert outArrayPtr != 0L : "outArrayPtr must be a valid FFI pointer";
110+
assert outSchemaPtr != 0L : "outSchemaPtr must be a valid FFI pointer";
111+
112+
// Save-and-restore rather than only-install-if-null: the propagated `taskContext` is the
113+
// ground truth for this call. Any value already on the thread is either (a) the same object
114+
// on a Spark task thread, or (b) stale from a prior task on a reused Tokio worker.
74115
TaskContext prior = TaskContext.get();
75116
if (taskContext != null) {
76117
CometTaskContextShim.set(taskContext);
118+
assert TaskContext.get() == taskContext
119+
: "TaskContext install did not take effect on this thread";
77120
}
78121
try {
79122
evaluateInternal(
80-
udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr, outSchemaPtr, numRows);
123+
udfClassName,
124+
inputArrayPtrs,
125+
inputSchemaPtrs,
126+
outArrayPtr,
127+
outSchemaPtr,
128+
numRows,
129+
taskContext);
81130
} finally {
82131
if (taskContext != null) {
83132
if (prior != null) {
@@ -95,9 +144,34 @@ private static void evaluateInternal(
95144
long[] inputSchemaPtrs,
96145
long outArrayPtr,
97146
long outSchemaPtr,
98-
int numRows) {
99-
CometUDF udf =
147+
int numRows,
148+
TaskContext taskContext) {
149+
long taskAttemptId = (taskContext != null) ? taskContext.taskAttemptId() : NO_TASK_ID;
150+
151+
ConcurrentHashMap<String, CometUDF> perTask =
100152
INSTANCES.computeIfAbsent(
153+
taskAttemptId,
154+
id -> {
155+
ConcurrentHashMap<String, CometUDF> fresh = new ConcurrentHashMap<>();
156+
if (taskContext != null) {
157+
// computeIfAbsent runs this lambda at most once per key, so the listener is
158+
// registered exactly once per task attempt.
159+
taskContext.addTaskCompletionListener(
160+
(TaskCompletionListener)
161+
ctx -> {
162+
ConcurrentHashMap<String, CometUDF> removed = INSTANCES.remove(id);
163+
assert removed != null
164+
: "task-completion listener fired but cache already removed "
165+
+ "entry for task "
166+
+ id;
167+
});
168+
}
169+
return fresh;
170+
});
171+
assert perTask != null : "per-task cache must be non-null after computeIfAbsent";
172+
173+
CometUDF udf =
174+
perTask.computeIfAbsent(
101175
udfClassName,
102176
name -> {
103177
try {
@@ -113,6 +187,7 @@ private static void evaluateInternal(
113187
throw new RuntimeException("Failed to instantiate CometUDF: " + name, e);
114188
}
115189
});
190+
assert udf != null : "reflective instantiation returned null for " + udfClassName;
116191

117192
BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator();
118193

common/src/main/scala/org/apache/comet/udf/CometUDF.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,18 @@ import org.apache.arrow.vector.ValueVector
3030
* - The returned vector's length must match `numRows`.
3131
*
3232
* `numRows` mirrors DataFusion's `ScalarFunctionArgs.number_rows` and is the batch row count.
33-
* UDFs that always have at least one batch-length input can derive length from the inputs and
34-
* ignore `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF)
35-
* need `numRows` to know how many rows to produce.
33+
* UDFs that always have at least one batch-length input can read length from it and ignore
34+
* `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF through the
35+
* codegen dispatcher) need `numRows` to know how many rows to produce.
3636
*
37-
* Implementations must have a public no-arg constructor and must be stateless: a single instance
38-
* per class is cached and shared across native worker threads for the lifetime of the JVM.
37+
* Implementations must have a public no-arg constructor. A fresh instance is created per Spark
38+
* task attempt per class and reused for every call within that task. Instances may hold per-task
39+
* state in fields (counters, compiled patterns, scratch buffers); instances are dropped at task
40+
* completion. Do not hold state that must persist across tasks.
41+
*
42+
* At most one thread calls `evaluate` on a given instance at a time: Spark runs one native future
43+
* per partition and Tokio polls one future per worker, so the per-task instance is never touched
44+
* concurrently even if the task's future migrates between Tokio workers across batches.
3945
*/
4046
trait CometUDF {
4147
def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector

0 commit comments

Comments
 (0)