1919
2020package org .apache .comet .udf ;
2121
22- import java .util .LinkedHashMap ;
23- import java .util .Map ;
22+ import java .util .concurrent .ConcurrentHashMap ;
2423
2524import org .apache .arrow .c .ArrowArray ;
2625import org .apache .arrow .c .ArrowSchema ;
3837 */
3938public class CometUdfBridge {
4039
41- // Per-thread, bounded LRU of UDF instances keyed by class name. Comet
42- // native execution threads (Tokio/DataFusion worker pool) are reused
43- // across tasks within an executor, so the effective lifetime of cached
44- // entries is the worker thread (i.e. the executor JVM). Fine for
45- // stateless UDFs; future stateful UDFs would need explicit per-task
46- // isolation.
47- private static final int CACHE_CAPACITY = 64 ;
48-
49- private static final ThreadLocal <LinkedHashMap <String , CometUDF >> INSTANCES =
50- ThreadLocal .withInitial (
51- () ->
52- new LinkedHashMap <String , CometUDF >(CACHE_CAPACITY , 0.75f , true ) {
53- @ Override
54- protected boolean removeEldestEntry (Map .Entry <String , CometUDF > eldest ) {
55- return size () > CACHE_CAPACITY ;
56- }
57- });
40+ // Process-wide cache of UDF instances keyed by class name. CometUDF
41+ // implementations are required to be stateless (see CometUDF), so a
42+ // single shared instance per class is safe across native worker threads.
43+ private static final ConcurrentHashMap <String , CometUDF > INSTANCES = new ConcurrentHashMap <>();
5844
5945 /**
6046 * Called from native via JNI.
@@ -64,19 +50,15 @@ protected boolean removeEldestEntry(Map.Entry<String, CometUDF> eldest) {
6450 * @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input)
6551 * @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result
6652 * @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result
67- * @param numRows number of rows in the current batch. Mirrors DataFusion's {@code
68- * ScalarFunctionArgs.number_rows} and gives UDFs an explicit batch-size signal for cases
69- * where no input arg is a batch-length array (e.g. a zero-arg non-deterministic ScalaUDF).
70- * UDFs that already read size from their input vectors can ignore it.
71- * @param taskContext Spark {@link TaskContext} captured on the driving Spark task thread and
72- * passed through from native. May be {@code null} when the bridge is invoked outside a Spark
73- * task (unit tests, direct native driver runs). When non-null and the current thread has no
74- * {@code TaskContext} of its own, the bridge installs it as the thread-local for the duration
75- * of the UDF call so the UDF body (including partition-sensitive built-ins like {@code Rand}
76- * / {@code Uuid} / {@code MonotonicallyIncreasingID} that read the partition index via {@code
77- * TaskContext.get().partitionId()}) sees the real context rather than null. The thread-local
78- * is cleared in a {@code finally} so Tokio workers don't leak a stale TaskContext across
79- * invocations.
53+ * @param numRows row count of the current batch. Mirrors DataFusion's {@code
54+ * ScalarFunctionArgs.number_rows}; the only batch-size signal a zero-input UDF (e.g. a
55+ * zero-arg non-deterministic ScalaUDF) ever sees.
56+ * @param taskContext propagated Spark {@link TaskContext} from the driving Spark task thread, or
57+ * {@code null} outside a Spark task. Treated as ground truth for the call: installed as the
58+ * thread-local on entry, with the prior value (if any) saved and restored in {@code finally}.
59+ * Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code
60+ * MonotonicallyIncreasingID}) work from Tokio workers and avoids reusing a stale TaskContext
61+ * left on a worker by a previous task.
8062 */
8163 public static void evaluate (
8264 String udfClassName ,
@@ -86,17 +68,23 @@ public static void evaluate(
8668 long outSchemaPtr ,
8769 int numRows ,
8870 TaskContext taskContext ) {
89- boolean installedTaskContext = false ;
90- if (taskContext != null && TaskContext .get () == null ) {
71+ // Save-and-restore rather than only-install-if-null: the propagated context is the ground
72+ // truth for this call. Any value already on the thread is either (a) the same object on a
73+ // Spark task thread, or (b) stale from a prior task on a reused Tokio worker.
74+ TaskContext prior = TaskContext .get ();
75+ if (taskContext != null ) {
9176 CometTaskContextShim .set (taskContext );
92- installedTaskContext = true ;
9377 }
9478 try {
9579 evaluateInternal (
9680 udfClassName , inputArrayPtrs , inputSchemaPtrs , outArrayPtr , outSchemaPtr , numRows );
9781 } finally {
98- if (installedTaskContext ) {
99- CometTaskContextShim .unset ();
82+ if (taskContext != null ) {
83+ if (prior != null ) {
84+ CometTaskContextShim .set (prior );
85+ } else {
86+ CometTaskContextShim .unset ();
87+ }
10088 }
10189 }
10290 }
@@ -108,23 +96,23 @@ private static void evaluateInternal(
10896 long outArrayPtr ,
10997 long outSchemaPtr ,
11098 int numRows ) {
111- LinkedHashMap < String , CometUDF > cache = INSTANCES . get ();
112- CometUDF udf = cache . get ( udfClassName );
113- if ( udf == null ) {
114- try {
115- // Resolve via the executor's context classloader so user-supplied UDF jars
116- // (added via spark.jars / --jars) are visible.
117- ClassLoader cl = Thread . currentThread (). getContextClassLoader ();
118- if ( cl == null ) {
119- cl = CometUdfBridge . class . getClassLoader ();
120- }
121- udf =
122- ( CometUDF ) Class . forName ( udfClassName , true , cl ). getDeclaredConstructor (). newInstance ();
123- } catch ( ReflectiveOperationException e ) {
124- throw new RuntimeException ( "Failed to instantiate CometUDF: " + udfClassName , e );
125- }
126- cache . put ( udfClassName , udf );
127- }
99+ CometUDF udf =
100+ INSTANCES . computeIfAbsent (
101+ udfClassName ,
102+ name -> {
103+ try {
104+ // Resolve via the executor's context classloader so user-supplied UDF jars
105+ // (added via spark.jars / --jars) are visible.
106+ ClassLoader cl = Thread . currentThread (). getContextClassLoader ();
107+ if ( cl == null ) {
108+ cl = CometUdfBridge . class . getClassLoader ();
109+ }
110+ return ( CometUDF )
111+ Class . forName ( name , true , cl ). getDeclaredConstructor (). newInstance ();
112+ } catch ( ReflectiveOperationException e ) {
113+ throw new RuntimeException ( "Failed to instantiate CometUDF: " + name , e );
114+ }
115+ });
128116
129117 BufferAllocator allocator = org .apache .comet .package$ .MODULE$ .CometArrowAllocator ();
130118
0 commit comments