Skip to content

Commit 17b2714

Browse files
committed
switch to taskid-keyed state for CometUDFs.
1 parent 9f8aa07 commit 17b2714

5 files changed

Lines changed: 148 additions & 40 deletions

File tree

common/src/main/java/org/apache/comet/udf/CometUdfBridge.java

Lines changed: 102 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919

2020
package org.apache.comet.udf;
2121

22-
import java.util.LinkedHashMap;
23-
import java.util.Map;
22+
import java.util.concurrent.ConcurrentHashMap;
2423

2524
import org.apache.arrow.c.ArrowArray;
2625
import org.apache.arrow.c.ArrowSchema;
@@ -30,31 +29,48 @@
3029
import org.apache.arrow.vector.ValueVector;
3130
import org.apache.spark.TaskContext;
3231
import org.apache.spark.comet.CometTaskContextShim;
32+
import org.apache.spark.util.TaskCompletionListener;
3333

3434
/**
3535
* JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method
3636
* pattern used by CometScalarSubquery so the native side can dispatch via
3737
* call_static_method_unchecked.
38+
*
39+
* <p>Cache invariants:
40+
*
41+
* <ol>
42+
* <li>For each live Spark task attempt there is at most one {@link CometUDF} instance per class
43+
* name.
44+
* <li>A {@link CometUDF} instance is visible only within the Spark task attempt that instantiated
45+
* it. Two task attempts observing the same class name receive distinct instances.
46+
* <li>At any instant at most one thread is inside {@code evaluate()} for a given {@code
47+
* taskAttemptId}. This follows from Spark executing one native future per partition and Tokio
48+
* polling one future per worker at a time.
49+
* <li>All instances for a task are dropped by the {@link TaskCompletionListener} registered on
50+
* the first cache miss for that task. No cache entry outlives its task.
51+
* <li>When {@code taskContext} is {@code null} (unit tests, direct native driver) the fallback
52+
* key {@code -1L} is used; that bucket is never evicted because no task-completion event will
53+
* fire.
54+
* </ol>
55+
*
56+
* <p>Keying by {@code taskAttemptId} rather than by thread keeps the cache correct under Tokio
57+
* work-stealing: on the scan-free execution path the same Spark task can be polled by different
58+
* Tokio workers across batches, so a thread-local cache would lose per-task state on migration. The
59+
* task attempt ID is stable for the life of the task regardless of which worker is polling.
3860
*/
3961
public class CometUdfBridge {
4062

41-
// Per-thread, bounded LRU of UDF instances keyed by class name. Comet
42-
// native execution threads (Tokio/DataFusion worker pool) are reused
43-
// across tasks within an executor, so the effective lifetime of cached
44-
// entries is the worker thread (i.e. the executor JVM). Fine for
45-
// stateless UDFs; future stateful UDFs would need explicit per-task
46-
// isolation.
47-
private static final int CACHE_CAPACITY = 64;
63+
/**
64+
* Task-scoped cache of {@link CometUDF} instances. Outer map keys are Spark task attempt IDs (or
65+
* {@code -1L} when no {@link TaskContext} is available). Inner maps hold one instance per UDF
66+
* class name for the task's lifetime. Entries are removed by the {@link TaskCompletionListener}
67+
* registered on the first cache miss per task.
68+
*/
69+
private static final ConcurrentHashMap<Long, ConcurrentHashMap<String, CometUDF>> INSTANCES =
70+
new ConcurrentHashMap<>();
4871

49-
private static final ThreadLocal<LinkedHashMap<String, CometUDF>> INSTANCES =
50-
ThreadLocal.withInitial(
51-
() ->
52-
new LinkedHashMap<String, CometUDF>(CACHE_CAPACITY, 0.75f, true) {
53-
@Override
54-
protected boolean removeEldestEntry(Map.Entry<String, CometUDF> eldest) {
55-
return size() > CACHE_CAPACITY;
56-
}
57-
});
72+
/** Sentinel key for calls that carry no {@link TaskContext} (unit tests, direct driver). */
73+
private static final long NO_TASK_ID = -1L;
5874

5975
/**
6076
* Called from native via JNI.
@@ -76,7 +92,9 @@ protected boolean removeEldestEntry(Map.Entry<String, CometUDF> eldest) {
7692
* / {@code Uuid} / {@code MonotonicallyIncreasingID} that read the partition index via {@code
7793
* TaskContext.get().partitionId()}) sees the real context rather than null. The thread-local
7894
* is cleared in a {@code finally} so Tokio workers don't leak a stale TaskContext across
79-
* invocations.
95+
* invocations. The task attempt ID drawn from this context also keys the UDF-instance cache,
96+
* so a UDF holding per-task state in fields sees a consistent instance for every call within
97+
* the task regardless of which Tokio worker is polling.
8098
*/
8199
public static void evaluate(
82100
String udfClassName,
@@ -86,14 +104,31 @@ public static void evaluate(
86104
long outSchemaPtr,
87105
int numRows,
88106
TaskContext taskContext) {
107+
assert udfClassName != null && !udfClassName.isEmpty() : "udfClassName must be non-empty";
108+
assert inputArrayPtrs != null && inputSchemaPtrs != null
109+
: "input pointer arrays must be non-null";
110+
assert inputArrayPtrs.length == inputSchemaPtrs.length
111+
: "input array pointer count must equal schema pointer count";
112+
assert numRows >= 0 : "numRows must be non-negative";
113+
assert outArrayPtr != 0L : "outArrayPtr must be a valid FFI pointer";
114+
assert outSchemaPtr != 0L : "outSchemaPtr must be a valid FFI pointer";
115+
89116
boolean installedTaskContext = false;
90117
if (taskContext != null && TaskContext.get() == null) {
91118
CometTaskContextShim.set(taskContext);
92119
installedTaskContext = true;
120+
assert TaskContext.get() == taskContext
121+
: "TaskContext install did not take effect on this thread";
93122
}
94123
try {
95124
evaluateInternal(
96-
udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr, outSchemaPtr, numRows);
125+
udfClassName,
126+
inputArrayPtrs,
127+
inputSchemaPtrs,
128+
outArrayPtr,
129+
outSchemaPtr,
130+
numRows,
131+
taskContext);
97132
} finally {
98133
if (installedTaskContext) {
99134
CometTaskContextShim.unset();
@@ -107,24 +142,50 @@ private static void evaluateInternal(
107142
long[] inputSchemaPtrs,
108143
long outArrayPtr,
109144
long outSchemaPtr,
110-
int numRows) {
111-
LinkedHashMap<String, CometUDF> cache = INSTANCES.get();
112-
CometUDF udf = cache.get(udfClassName);
113-
if (udf == null) {
114-
try {
115-
// Resolve via the executor's context classloader so user-supplied UDF jars
116-
// (added via spark.jars / --jars) are visible.
117-
ClassLoader cl = Thread.currentThread().getContextClassLoader();
118-
if (cl == null) {
119-
cl = CometUdfBridge.class.getClassLoader();
120-
}
121-
udf =
122-
(CometUDF) Class.forName(udfClassName, true, cl).getDeclaredConstructor().newInstance();
123-
} catch (ReflectiveOperationException e) {
124-
throw new RuntimeException("Failed to instantiate CometUDF: " + udfClassName, e);
125-
}
126-
cache.put(udfClassName, udf);
127-
}
145+
int numRows,
146+
TaskContext taskContext) {
147+
long taskAttemptId = (taskContext != null) ? taskContext.taskAttemptId() : NO_TASK_ID;
148+
149+
ConcurrentHashMap<String, CometUDF> perTask =
150+
INSTANCES.computeIfAbsent(
151+
taskAttemptId,
152+
id -> {
153+
ConcurrentHashMap<String, CometUDF> fresh = new ConcurrentHashMap<>();
154+
if (taskContext != null) {
155+
// computeIfAbsent runs this lambda at most once per key, so the listener is
156+
// registered exactly once per task attempt.
157+
taskContext.addTaskCompletionListener(
158+
(TaskCompletionListener)
159+
ctx -> {
160+
ConcurrentHashMap<String, CometUDF> removed = INSTANCES.remove(id);
161+
assert removed != null
162+
: "task-completion listener fired but cache already removed "
163+
+ "entry for task "
164+
+ id;
165+
});
166+
}
167+
return fresh;
168+
});
169+
assert perTask != null : "per-task cache must be non-null after computeIfAbsent";
170+
171+
CometUDF udf =
172+
perTask.computeIfAbsent(
173+
udfClassName,
174+
name -> {
175+
try {
176+
// Resolve via the executor's context classloader so user-supplied UDF jars
177+
// (added via spark.jars / --jars) are visible.
178+
ClassLoader cl = Thread.currentThread().getContextClassLoader();
179+
if (cl == null) {
180+
cl = CometUdfBridge.class.getClassLoader();
181+
}
182+
return (CometUDF)
183+
Class.forName(name, true, cl).getDeclaredConstructor().newInstance();
184+
} catch (ReflectiveOperationException e) {
185+
throw new RuntimeException("Failed to instantiate CometUDF: " + name, e);
186+
}
187+
});
188+
assert udf != null : "reflective instantiation returned null for " + udfClassName;
128189

129190
BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator();
130191

@@ -138,6 +199,9 @@ private static void evaluateInternal(
138199
}
139200

140201
result = udf.evaluate(inputs, numRows);
202+
assert result instanceof FieldVector
203+
: "CometUDF implementations must return FieldVector; got "
204+
+ (result == null ? "null" : result.getClass().getName());
141205
if (!(result instanceof FieldVector)) {
142206
throw new RuntimeException(
143207
"CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName());

common/src/main/scala/org/apache/comet/udf/CometUDF.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,14 @@ import org.apache.arrow.vector.ValueVector
3434
* `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF through the
3535
* codegen dispatcher) need `numRows` to know how many rows to produce.
3636
*
37-
* Implementations must have a public no-arg constructor and should be stateless: instances are
38-
* cached per executor thread for the lifetime of the JVM.
37+
* Implementations must have a public no-arg constructor. A fresh instance is created per Spark
38+
* task attempt per class and reused for every call within that task. Instances may hold per-task
39+
* state in fields (counters, compiled patterns, scratch buffers); instances are dropped at task
40+
* completion. Do not hold state that must persist across tasks.
41+
*
42+
* At most one thread calls `evaluate` on a given instance at a time: Spark runs one native future
43+
* per partition and Tokio polls one future per worker, so the per-task instance is never touched
44+
* concurrently even if the task's future migrates between Tokio workers across batches.
3945
*/
4046
trait CometUDF {
4147
def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector

native/core/src/execution/planner.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,13 @@ impl PhysicalPlanner {
751751
to_arrow_datatype(udf.return_type.as_ref().ok_or_else(|| {
752752
GeneralError("JvmScalarUdf missing return_type".to_string())
753753
})?);
754+
// Invariant: task_context is propagated for every JvmScalarUdfExpr built during
755+
// normal execution. The TEST_EXEC_CONTEXT_ID path is the only context in which
756+
// task_context may legitimately be None (unit tests, direct native driver runs).
757+
debug_assert!(
758+
self.task_context.is_some() || self.exec_context_id == TEST_EXEC_CONTEXT_ID,
759+
"task_context must be set for non-test execution"
760+
);
754761
Ok(Arc::new(JvmScalarUdfExpr::new(
755762
udf.class_name.clone(),
756763
args,

native/spark-expr/src/jvm_udf/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ impl JvmScalarUdfExpr {
6262
return_nullable: bool,
6363
task_context: Option<Arc<Global<JObject<'static>>>>,
6464
) -> Self {
65+
debug_assert!(
66+
!class_name.is_empty(),
67+
"JvmScalarUdfExpr requires a non-empty class name"
68+
);
6569
Self {
6670
class_name,
6771
args,
@@ -159,6 +163,13 @@ impl PhysicalExpr for JvmScalarUdfExpr {
159163
.map(|b| b.as_ref() as *const FFI_ArrowSchema as i64)
160164
.collect();
161165

166+
debug_assert!(!self.class_name.is_empty(), "class_name must not be empty");
167+
debug_assert_eq!(
168+
in_arr_ptrs.len(),
169+
in_sch_ptrs.len(),
170+
"input array and schema pointer counts must match"
171+
);
172+
162173
let mut out_array = Box::new(FFI_ArrowArray::empty());
163174
let mut out_schema = Box::new(FFI_ArrowSchema::empty());
164175
let out_arr_ptr = out_array.as_mut() as *mut FFI_ArrowArray as i64;

spark/src/test/scala/org/apache/comet/CometCodegenDispatchSmokeSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,26 @@ class CometCodegenDispatchSmokeSuite extends CometTestBase with AdaptiveSparkPla
439439
}
440440
}
441441

442+
test("per-task cache isolates UDF state across sequential task runs in one session") {
443+
// Regression guard for the cache-scoping invariant on CometUdfBridge: instances live for
444+
// exactly one Spark task and are dropped on task completion, so a stateful kernel sees a
445+
// fresh instance per task. Running the same `monotonically_increasing_id()`-carrying query
446+
// twice in one session must produce identical results each run. Under a cache that outlived
447+
// a task and got reused by the next one, the counter would continue from the previous run's
448+
// final value and the second run's IDs would diverge. Under a cache that was keyed by Tokio
449+
// worker thread rather than task attempt ID, worker reuse across tasks would cause the same
450+
// leak whenever the second task happened to be polled by the same worker.
451+
val rows = (0 until 2048).map(i => s"row_$i")
452+
withSubjects(rows: _*) {
453+
val q = "SELECT s, monotonically_increasing_id() AS mid FROM t"
454+
val first = sql(q).collect().map(r => (r.getString(0), r.getLong(1))).toSeq
455+
val second = sql(q).collect().map(r => (r.getString(0), r.getLong(1))).toSeq
456+
assert(
457+
first == second,
458+
s"per-task cache leaked state across runs: first=${first.take(5)} second=${second.take(5)}")
459+
}
460+
}
461+
442462
/**
443463
* Scalar ScalaUDF smoke tests. These prove that user-registered UDFs route through the codegen
444464
* dispatcher rather than forcing a whole-plan Spark fallback. Spark's `ScalaUDF.doGenCode`

0 commit comments

Comments
 (0)