@@ -80,19 +80,15 @@ public class CometUdfBridge {
8080 * @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input)
8181 * @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result
8282 * @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result
83- * @param numRows number of rows in the current batch. Mirrors DataFusion's {@code
84- * ScalarFunctionArgs.number_rows} and gives UDFs an explicit batch-size signal for cases
85- * where no input arg is a batch-length array (e.g. a zero-arg non-deterministic ScalaUDF).
86- * UDFs that already read size from their input vectors can ignore it.
87- * @param taskContext Spark {@link TaskContext} captured on the driving Spark task thread and
88- * passed through from native. May be {@code null} when the bridge is invoked outside a Spark
89- * task (unit tests, direct native driver runs). When non-null and the current thread has no
90- * {@code TaskContext} of its own, the bridge installs it as the thread-local for the duration
91- * of the UDF call so the UDF body (including partition-sensitive built-ins like {@code Rand}
92- * / {@code Uuid} / {@code MonotonicallyIncreasingID} that read the partition index via {@code
93- * TaskContext.get().partitionId()}) sees the real context rather than null. The thread-local
94- * is cleared in a {@code finally} so Tokio workers don't leak a stale TaskContext across
95- * invocations. The task attempt ID drawn from this context also keys the UDF-instance cache,
83+ * @param numRows row count of the current batch. Mirrors DataFusion's {@code
84+ * ScalarFunctionArgs.number_rows}; the only batch-size signal a zero-input UDF (e.g. a
85+ * zero-arg non-deterministic ScalaUDF) ever sees.
86+ * @param taskContext propagated Spark {@link TaskContext} from the driving Spark task thread, or
87+ * {@code null} outside a Spark task. Treated as ground truth for the call: installed as the
88+ * thread-local on entry, with the prior value (if any) saved and restored in {@code finally}.
89+ * Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code
90+ * MonotonicallyIncreasingID}) work from Tokio workers and avoids reusing a stale TaskContext
91+ * left on a worker by a previous task. Its task attempt ID also keys the UDF-instance cache,
9692 * so a UDF holding per-task state in fields sees a consistent instance for every call within
9793 * the task regardless of which Tokio worker is polling.
9894 */
@@ -113,10 +109,12 @@ public static void evaluate(
113109 assert outArrayPtr != 0L : "outArrayPtr must be a valid FFI pointer" ;
114110 assert outSchemaPtr != 0L : "outSchemaPtr must be a valid FFI pointer" ;
115111
116- boolean installedTaskContext = false ;
117- if (taskContext != null && TaskContext .get () == null ) {
112+ // Save-and-restore rather than only-install-if-null: the propagated `taskContext` is the
113+ // ground truth for this call. Any value already on the thread is either (a) the same object
114+ // on a Spark task thread, or (b) stale from a prior task on a reused Tokio worker.
115+ TaskContext prior = TaskContext .get ();
116+ if (taskContext != null ) {
118117 CometTaskContextShim .set (taskContext );
119- installedTaskContext = true ;
120118 assert TaskContext .get () == taskContext
121119 : "TaskContext install did not take effect on this thread" ;
122120 }
@@ -130,8 +128,12 @@ public static void evaluate(
130128 numRows ,
131129 taskContext );
132130 } finally {
133- if (installedTaskContext ) {
134- CometTaskContextShim .unset ();
131+ if (taskContext != null ) {
132+ if (prior != null ) {
133+ CometTaskContextShim .set (prior );
134+ } else {
135+ CometTaskContextShim .unset ();
136+ }
135137 }
136138 }
137139 }
0 commit comments