1919
2020package org .apache .comet .udf ;
2121
22- import java .util .LinkedHashMap ;
23- import java .util .Map ;
22+ import java .util .concurrent .ConcurrentHashMap ;
2423
2524import org .apache .arrow .c .ArrowArray ;
2625import org .apache .arrow .c .ArrowSchema ;
3029import org .apache .arrow .vector .ValueVector ;
3130import org .apache .spark .TaskContext ;
3231import org .apache .spark .comet .CometTaskContextShim ;
32+ import org .apache .spark .util .TaskCompletionListener ;
3333
3434/**
3535 * JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method
3636 * pattern used by CometScalarSubquery so the native side can dispatch via
3737 * call_static_method_unchecked.
38+ *
39+ * <p>Cache invariants:
40+ *
41+ * <ol>
42+ * <li>For each live Spark task attempt there is at most one {@link CometUDF} instance per class
43+ * name.
44+ * <li>A {@link CometUDF} instance is visible only within the Spark task attempt that instantiated
45+ * it. Two task attempts observing the same class name receive distinct instances.
46+ * <li>At any instant at most one thread is inside {@code evaluate()} for a given {@code
47+ * taskAttemptId}. This follows from Spark executing one native future per partition and Tokio
48+ * polling one future per worker at a time.
49+ * <li>All instances for a task are dropped by the {@link TaskCompletionListener} registered on
50+ * the first cache miss for that task. No cache entry outlives its task.
51+ * <li>When {@code taskContext} is {@code null} (unit tests, direct native driver) the fallback
52+ * key {@code -1L} is used; that bucket is never evicted because no task-completion event will
53+ * fire.
54+ * </ol>
55+ *
56+ * <p>Keying by {@code taskAttemptId} rather than by thread keeps the cache correct under Tokio
57+ * work-stealing: on the scan-free execution path the same Spark task can be polled by different
58+ * Tokio workers across batches, so a thread-local cache would lose per-task state on migration. The
59+ * task attempt ID is stable for the life of the task regardless of which worker is polling.
3860 */
3961public class CometUdfBridge {
4062
41- // Per-thread, bounded LRU of UDF instances keyed by class name. Comet
42- // native execution threads (Tokio/DataFusion worker pool) are reused
43- // across tasks within an executor, so the effective lifetime of cached
44- // entries is the worker thread (i.e. the executor JVM). Fine for
45- // stateless UDFs; future stateful UDFs would need explicit per-task
46- // isolation.
47- private static final int CACHE_CAPACITY = 64 ;
63+ /**
64+ * Task-scoped cache of {@link CometUDF} instances. Outer map keys are Spark task attempt IDs (or
65+ * {@code -1L} when no {@link TaskContext} is available). Inner maps hold one instance per UDF
66+ * class name for the task's lifetime. Entries are removed by the {@link TaskCompletionListener}
67+ * registered on the first cache miss per task.
68+ */
69+ private static final ConcurrentHashMap <Long , ConcurrentHashMap <String , CometUDF >> INSTANCES =
70+ new ConcurrentHashMap <>();
4871
49- private static final ThreadLocal <LinkedHashMap <String , CometUDF >> INSTANCES =
50- ThreadLocal .withInitial (
51- () ->
52- new LinkedHashMap <String , CometUDF >(CACHE_CAPACITY , 0.75f , true ) {
53- @ Override
54- protected boolean removeEldestEntry (Map .Entry <String , CometUDF > eldest ) {
55- return size () > CACHE_CAPACITY ;
56- }
57- });
72+ /** Sentinel key for calls that carry no {@link TaskContext} (unit tests, direct driver). */
73+ private static final long NO_TASK_ID = -1L ;
5874
5975 /**
6076 * Called from native via JNI.
@@ -76,7 +92,9 @@ protected boolean removeEldestEntry(Map.Entry<String, CometUDF> eldest) {
7692 * / {@code Uuid} / {@code MonotonicallyIncreasingID} that read the partition index via {@code
7793 * TaskContext.get().partitionId()}) sees the real context rather than null. The thread-local
7894 * is cleared in a {@code finally} so Tokio workers don't leak a stale TaskContext across
79- * invocations.
95+ * invocations. The task attempt ID drawn from this context also keys the UDF-instance cache,
96+ * so a UDF holding per-task state in fields sees a consistent instance for every call within
97+ * the task regardless of which Tokio worker is polling.
8098 */
8199 public static void evaluate (
82100 String udfClassName ,
@@ -86,14 +104,31 @@ public static void evaluate(
86104 long outSchemaPtr ,
87105 int numRows ,
88106 TaskContext taskContext ) {
107+ assert udfClassName != null && !udfClassName .isEmpty () : "udfClassName must be non-empty" ;
108+ assert inputArrayPtrs != null && inputSchemaPtrs != null
109+ : "input pointer arrays must be non-null" ;
110+ assert inputArrayPtrs .length == inputSchemaPtrs .length
111+ : "input array pointer count must equal schema pointer count" ;
112+ assert numRows >= 0 : "numRows must be non-negative" ;
113+ assert outArrayPtr != 0L : "outArrayPtr must be a valid FFI pointer" ;
114+ assert outSchemaPtr != 0L : "outSchemaPtr must be a valid FFI pointer" ;
115+
89116 boolean installedTaskContext = false ;
90117 if (taskContext != null && TaskContext .get () == null ) {
91118 CometTaskContextShim .set (taskContext );
92119 installedTaskContext = true ;
120+ assert TaskContext .get () == taskContext
121+ : "TaskContext install did not take effect on this thread" ;
93122 }
94123 try {
95124 evaluateInternal (
96- udfClassName , inputArrayPtrs , inputSchemaPtrs , outArrayPtr , outSchemaPtr , numRows );
125+ udfClassName ,
126+ inputArrayPtrs ,
127+ inputSchemaPtrs ,
128+ outArrayPtr ,
129+ outSchemaPtr ,
130+ numRows ,
131+ taskContext );
97132 } finally {
98133 if (installedTaskContext ) {
99134 CometTaskContextShim .unset ();
@@ -107,24 +142,50 @@ private static void evaluateInternal(
107142 long [] inputSchemaPtrs ,
108143 long outArrayPtr ,
109144 long outSchemaPtr ,
110- int numRows ) {
111- LinkedHashMap <String , CometUDF > cache = INSTANCES .get ();
112- CometUDF udf = cache .get (udfClassName );
113- if (udf == null ) {
114- try {
115- // Resolve via the executor's context classloader so user-supplied UDF jars
116- // (added via spark.jars / --jars) are visible.
117- ClassLoader cl = Thread .currentThread ().getContextClassLoader ();
118- if (cl == null ) {
119- cl = CometUdfBridge .class .getClassLoader ();
120- }
121- udf =
122- (CometUDF ) Class .forName (udfClassName , true , cl ).getDeclaredConstructor ().newInstance ();
123- } catch (ReflectiveOperationException e ) {
124- throw new RuntimeException ("Failed to instantiate CometUDF: " + udfClassName , e );
125- }
126- cache .put (udfClassName , udf );
127- }
145+ int numRows ,
146+ TaskContext taskContext ) {
147+ long taskAttemptId = (taskContext != null ) ? taskContext .taskAttemptId () : NO_TASK_ID ;
148+
149+ ConcurrentHashMap <String , CometUDF > perTask =
150+ INSTANCES .computeIfAbsent (
151+ taskAttemptId ,
152+ id -> {
153+ ConcurrentHashMap <String , CometUDF > fresh = new ConcurrentHashMap <>();
154+ if (taskContext != null ) {
155+ // computeIfAbsent runs this lambda at most once per key, so the listener is
156+ // registered exactly once per task attempt.
157+ taskContext .addTaskCompletionListener (
158+ (TaskCompletionListener )
159+ ctx -> {
160+ ConcurrentHashMap <String , CometUDF > removed = INSTANCES .remove (id );
161+ assert removed != null
162+ : "task-completion listener fired but cache already removed "
163+ + "entry for task "
164+ + id ;
165+ });
166+ }
167+ return fresh ;
168+ });
169+ assert perTask != null : "per-task cache must be non-null after computeIfAbsent" ;
170+
171+ CometUDF udf =
172+ perTask .computeIfAbsent (
173+ udfClassName ,
174+ name -> {
175+ try {
176+ // Resolve via the executor's context classloader so user-supplied UDF jars
177+ // (added via spark.jars / --jars) are visible.
178+ ClassLoader cl = Thread .currentThread ().getContextClassLoader ();
179+ if (cl == null ) {
180+ cl = CometUdfBridge .class .getClassLoader ();
181+ }
182+ return (CometUDF )
183+ Class .forName (name , true , cl ).getDeclaredConstructor ().newInstance ();
184+ } catch (ReflectiveOperationException e ) {
185+ throw new RuntimeException ("Failed to instantiate CometUDF: " + name , e );
186+ }
187+ });
188+ assert udf != null : "reflective instantiation returned null for " + udfClassName ;
128189
129190 BufferAllocator allocator = org .apache .comet .package$ .MODULE$ .CometArrowAllocator ();
130191
@@ -138,6 +199,9 @@ private static void evaluateInternal(
138199 }
139200
140201 result = udf .evaluate (inputs , numRows );
202+ assert result instanceof FieldVector
203+ : "CometUDF implementations must return FieldVector; got "
204+ + (result == null ? "null" : result .getClass ().getName ());
141205 if (!(result instanceof FieldVector )) {
142206 throw new RuntimeException (
143207 "CometUDF.evaluate() must return a FieldVector, got: " + result .getClass ().getName ());
0 commit comments