2929import org .apache .arrow .vector .ValueVector ;
3030import org .apache .spark .TaskContext ;
3131import org .apache .spark .comet .CometTaskContextShim ;
32+ import org .apache .spark .util .TaskCompletionListener ;
3233
3334/**
3435 * JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method
3536 * pattern used by CometScalarSubquery so the native side can dispatch via
3637 * call_static_method_unchecked.
38+ *
39+ * <p>Cache invariants:
40+ *
41+ * <ol>
42+ * <li>For each live Spark task attempt there is at most one {@link CometUDF} instance per class
43+ * name.
44+ * <li>A {@link CometUDF} instance is visible only within the Spark task attempt that instantiated
45+ * it. Two task attempts observing the same class name receive distinct instances.
46+ * <li>At any instant at most one thread is inside {@code evaluate()} for a given {@code
47+ * taskAttemptId}. This follows from Spark executing one native future per partition and Tokio
48+ * polling one future per worker at a time.
49+ * <li>All instances for a task are dropped by the {@link TaskCompletionListener} registered on
50+ * the first cache miss for that task. No cache entry outlives its task.
51+ * <li>When {@code taskContext} is {@code null} (unit tests, direct native driver) the fallback
52+ * key {@code -1L} is used; that bucket is never evicted because no task-completion event will
53+ * fire.
54+ * </ol>
55+ *
56+ * <p>Keying by {@code taskAttemptId} rather than by thread keeps the cache correct under Tokio
57+ * work-stealing: on the scan-free execution path the same Spark task can be polled by different
58+ * Tokio workers across batches, so a thread-local cache would lose per-task state on migration. The
59+ * task attempt ID is stable for the life of the task regardless of which worker is polling.
3760 */
3861public class CometUdfBridge {
3962
40- // Process-wide cache of UDF instances keyed by class name. CometUDF
41- // implementations are required to be stateless (see CometUDF), so a
42- // single shared instance per class is safe across native worker threads.
43- private static final ConcurrentHashMap <String , CometUDF > INSTANCES = new ConcurrentHashMap <>();
63+ /**
64+ * Task-scoped cache of {@link CometUDF} instances. Outer map keys are Spark task attempt IDs (or
65+ * {@code -1L} when no {@link TaskContext} is available). Inner maps hold one instance per UDF
66+ * class name for the task's lifetime. Entries are removed by the {@link TaskCompletionListener}
67+ * registered on the first cache miss per task.
68+ */
69+ private static final ConcurrentHashMap <Long , ConcurrentHashMap <String , CometUDF >> INSTANCES =
70+ new ConcurrentHashMap <>();
71+
72+ /** Sentinel key for calls that carry no {@link TaskContext} (unit tests, direct driver). */
73+ private static final long NO_TASK_ID = -1L ;
4474
4575 /**
4676 * Called from native via JNI.
@@ -58,7 +88,9 @@ public class CometUdfBridge {
5888 * thread-local on entry, with the prior value (if any) saved and restored in {@code finally}.
5989 * Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code
6090 * MonotonicallyIncreasingID}) work from Tokio workers and avoids reusing a stale TaskContext
61- * left on a worker by a previous task.
91+ * left on a worker by a previous task. Its task attempt ID also keys the UDF-instance cache,
92+ * so a UDF holding per-task state in fields sees a consistent instance for every call within
93+ * the task regardless of which Tokio worker is polling.
6294 */
6395 public static void evaluate (
6496 String udfClassName ,
@@ -68,16 +100,33 @@ public static void evaluate(
68100 long outSchemaPtr ,
69101 int numRows ,
70102 TaskContext taskContext ) {
71- // Save-and-restore rather than only-install-if-null: the propagated context is the ground
72- // truth for this call. Any value already on the thread is either (a) the same object on a
73- // Spark task thread, or (b) stale from a prior task on a reused Tokio worker.
103+ assert udfClassName != null && !udfClassName .isEmpty () : "udfClassName must be non-empty" ;
104+ assert inputArrayPtrs != null && inputSchemaPtrs != null
105+ : "input pointer arrays must be non-null" ;
106+ assert inputArrayPtrs .length == inputSchemaPtrs .length
107+ : "input array pointer count must equal schema pointer count" ;
108+ assert numRows >= 0 : "numRows must be non-negative" ;
109+ assert outArrayPtr != 0L : "outArrayPtr must be a valid FFI pointer" ;
110+ assert outSchemaPtr != 0L : "outSchemaPtr must be a valid FFI pointer" ;
111+
112+ // Save-and-restore rather than only-install-if-null: the propagated `taskContext` is the
113+ // ground truth for this call. Any value already on the thread is either (a) the same object
114+ // on a Spark task thread, or (b) stale from a prior task on a reused Tokio worker.
74115 TaskContext prior = TaskContext .get ();
75116 if (taskContext != null ) {
76117 CometTaskContextShim .set (taskContext );
118+ assert TaskContext .get () == taskContext
119+ : "TaskContext install did not take effect on this thread" ;
77120 }
78121 try {
79122 evaluateInternal (
80- udfClassName , inputArrayPtrs , inputSchemaPtrs , outArrayPtr , outSchemaPtr , numRows );
123+ udfClassName ,
124+ inputArrayPtrs ,
125+ inputSchemaPtrs ,
126+ outArrayPtr ,
127+ outSchemaPtr ,
128+ numRows ,
129+ taskContext );
81130 } finally {
82131 if (taskContext != null ) {
83132 if (prior != null ) {
@@ -95,9 +144,34 @@ private static void evaluateInternal(
95144 long [] inputSchemaPtrs ,
96145 long outArrayPtr ,
97146 long outSchemaPtr ,
98- int numRows ) {
99- CometUDF udf =
147+ int numRows ,
148+ TaskContext taskContext ) {
149+ long taskAttemptId = (taskContext != null ) ? taskContext .taskAttemptId () : NO_TASK_ID ;
150+
151+ ConcurrentHashMap <String , CometUDF > perTask =
100152 INSTANCES .computeIfAbsent (
153+ taskAttemptId ,
154+ id -> {
155+ ConcurrentHashMap <String , CometUDF > fresh = new ConcurrentHashMap <>();
156+ if (taskContext != null ) {
157+ // computeIfAbsent runs this lambda at most once per key, so the listener is
158+ // registered exactly once per task attempt.
159+ taskContext .addTaskCompletionListener (
160+ (TaskCompletionListener )
161+ ctx -> {
162+ ConcurrentHashMap <String , CometUDF > removed = INSTANCES .remove (id );
163+ assert removed != null
164+ : "task-completion listener fired but cache already removed "
165+ + "entry for task "
166+ + id ;
167+ });
168+ }
169+ return fresh ;
170+ });
171+ assert perTask != null : "per-task cache must be non-null after computeIfAbsent" ;
172+
173+ CometUDF udf =
174+ perTask .computeIfAbsent (
101175 udfClassName ,
102176 name -> {
103177 try {
@@ -113,6 +187,7 @@ private static void evaluateInternal(
113187 throw new RuntimeException ("Failed to instantiate CometUDF: " + name , e );
114188 }
115189 });
190+ assert udf != null : "reflective instantiation returned null for " + udfClassName ;
116191
117192 BufferAllocator allocator = org .apache .comet .package$ .MODULE$ .CometArrowAllocator ();
118193
0 commit comments