Skip to content

Commit 9f8aa07

Browse files
committed
fix after merging in upstream/main.
1 parent 6fcd81c commit 9f8aa07

5 files changed

Lines changed: 117 additions & 82 deletions

File tree

common/src/main/java/org/apache/comet/udf/CometUdfBridge.java

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
package org.apache.comet.udf;
2121

22-
import java.util.concurrent.ConcurrentHashMap;
22+
import java.util.LinkedHashMap;
23+
import java.util.Map;
2324

2425
import org.apache.arrow.c.ArrowArray;
2526
import org.apache.arrow.c.ArrowSchema;
@@ -37,10 +38,23 @@
3738
*/
3839
public class CometUdfBridge {
3940

40-
// Process-wide cache of UDF instances keyed by class name. CometUDF
41-
// implementations are required to be stateless (see CometUDF), so a
42-
// single shared instance per class is safe across native worker threads.
43-
private static final ConcurrentHashMap<String, CometUDF> INSTANCES = new ConcurrentHashMap<>();
41+
// Per-thread, bounded LRU of UDF instances keyed by class name. Comet
42+
// native execution threads (Tokio/DataFusion worker pool) are reused
43+
// across tasks within an executor, so the effective lifetime of cached
44+
// entries is the worker thread (i.e. the executor JVM). Fine for
45+
// stateless UDFs; future stateful UDFs would need explicit per-task
46+
// isolation.
47+
private static final int CACHE_CAPACITY = 64;
48+
49+
private static final ThreadLocal<LinkedHashMap<String, CometUDF>> INSTANCES =
50+
ThreadLocal.withInitial(
51+
() ->
52+
new LinkedHashMap<String, CometUDF>(CACHE_CAPACITY, 0.75f, true) {
53+
@Override
54+
protected boolean removeEldestEntry(Map.Entry<String, CometUDF> eldest) {
55+
return size() > CACHE_CAPACITY;
56+
}
57+
});
4458

4559
/**
4660
* Called from native via JNI.
@@ -50,15 +64,19 @@ public class CometUdfBridge {
5064
* @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input)
5165
* @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result
5266
* @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result
53-
* @param numRows row count of the current batch. Mirrors DataFusion's {@code
54-
* ScalarFunctionArgs.number_rows}; the only batch-size signal a zero-input UDF (e.g. a
55-
* zero-arg non-deterministic ScalaUDF) ever sees.
56-
* @param taskContext propagated Spark {@link TaskContext} from the driving Spark task thread, or
57-
* {@code null} outside a Spark task. Treated as ground truth for the call: installed as the
58-
* thread-local on entry, with the prior value (if any) saved and restored in {@code finally}.
59-
* Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code
60-
* MonotonicallyIncreasingID}) work from Tokio workers and avoids reusing a stale TaskContext
61-
* left on a worker by a previous task.
67+
* @param numRows number of rows in the current batch. Mirrors DataFusion's {@code
68+
* ScalarFunctionArgs.number_rows} and gives UDFs an explicit batch-size signal for cases
69+
* where no input arg is a batch-length array (e.g. a zero-arg non-deterministic ScalaUDF).
70+
* UDFs that already read size from their input vectors can ignore it.
71+
* @param taskContext Spark {@link TaskContext} captured on the driving Spark task thread and
72+
* passed through from native. May be {@code null} when the bridge is invoked outside a Spark
73+
* task (unit tests, direct native driver runs). When non-null and the current thread has no
74+
* {@code TaskContext} of its own, the bridge installs it as the thread-local for the duration
75+
* of the UDF call so the UDF body (including partition-sensitive built-ins like {@code Rand}
76+
* / {@code Uuid} / {@code MonotonicallyIncreasingID} that read the partition index via {@code
77+
* TaskContext.get().partitionId()}) sees the real context rather than null. The thread-local
78+
* is cleared in a {@code finally} so Tokio workers don't leak a stale TaskContext across
79+
* invocations.
6280
*/
6381
public static void evaluate(
6482
String udfClassName,
@@ -68,23 +86,17 @@ public static void evaluate(
6886
long outSchemaPtr,
6987
int numRows,
7088
TaskContext taskContext) {
71-
// Save-and-restore rather than only-install-if-null: the propagated context is the ground
72-
// truth for this call. Any value already on the thread is either (a) the same object on a
73-
// Spark task thread, or (b) stale from a prior task on a reused Tokio worker.
74-
TaskContext prior = TaskContext.get();
75-
if (taskContext != null) {
89+
boolean installedTaskContext = false;
90+
if (taskContext != null && TaskContext.get() == null) {
7691
CometTaskContextShim.set(taskContext);
92+
installedTaskContext = true;
7793
}
7894
try {
7995
evaluateInternal(
8096
udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr, outSchemaPtr, numRows);
8197
} finally {
82-
if (taskContext != null) {
83-
if (prior != null) {
84-
CometTaskContextShim.set(prior);
85-
} else {
86-
CometTaskContextShim.unset();
87-
}
98+
if (installedTaskContext) {
99+
CometTaskContextShim.unset();
88100
}
89101
}
90102
}
@@ -96,23 +108,23 @@ private static void evaluateInternal(
96108
long outArrayPtr,
97109
long outSchemaPtr,
98110
int numRows) {
99-
CometUDF udf =
100-
INSTANCES.computeIfAbsent(
101-
udfClassName,
102-
name -> {
103-
try {
104-
// Resolve via the executor's context classloader so user-supplied UDF jars
105-
// (added via spark.jars / --jars) are visible.
106-
ClassLoader cl = Thread.currentThread().getContextClassLoader();
107-
if (cl == null) {
108-
cl = CometUdfBridge.class.getClassLoader();
109-
}
110-
return (CometUDF)
111-
Class.forName(name, true, cl).getDeclaredConstructor().newInstance();
112-
} catch (ReflectiveOperationException e) {
113-
throw new RuntimeException("Failed to instantiate CometUDF: " + name, e);
114-
}
115-
});
111+
LinkedHashMap<String, CometUDF> cache = INSTANCES.get();
112+
CometUDF udf = cache.get(udfClassName);
113+
if (udf == null) {
114+
try {
115+
// Resolve via the executor's context classloader so user-supplied UDF jars
116+
// (added via spark.jars / --jars) are visible.
117+
ClassLoader cl = Thread.currentThread().getContextClassLoader();
118+
if (cl == null) {
119+
cl = CometUdfBridge.class.getClassLoader();
120+
}
121+
udf =
122+
(CometUDF) Class.forName(udfClassName, true, cl).getDeclaredConstructor().newInstance();
123+
} catch (ReflectiveOperationException e) {
124+
throw new RuntimeException("Failed to instantiate CometUDF: " + udfClassName, e);
125+
}
126+
cache.put(udfClassName, udf);
127+
}
116128

117129
BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator();
118130

native/core/src/execution/jni_api.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
462462
};
463463

464464
// Capture the driving Spark task's TaskContext as a JNI global reference when
465-
// non-null. The `Arc<Global<JObject>>` releases its global ref on drop, so cleanup
466-
// is automatic when the ExecutionContext drops.
465+
// non-null. The `Arc<Global<JObject>>` releases its global ref on drop, so
466+
// cleanup is automatic when the ExecutionContext drops.
467467
let task_context = if !task_context_obj.is_null() {
468468
Some(Arc::new(jni_new_global_ref!(env, task_context_obj)?))
469469
} else {

native/core/src/execution/planner.rs

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,11 @@ pub struct PhysicalPlanner {
183183
partition: i32,
184184
session_ctx: Arc<SessionContext>,
185185
query_context_registry: Arc<datafusion_comet_spark_expr::QueryContextMap>,
186-
/// Captured at `createPlan` time on `ExecutionContext`; see that struct for the
187-
/// propagation rationale. `None` when no driving Spark task is available.
186+
/// Spark `TaskContext` captured on the driving Spark task thread and stashed on the
187+
/// [`ExecutionContext`] at `createPlan` time. Threaded into every [`JvmScalarUdfExpr`] the
188+
/// planner builds so the JNI bridge can install it as the thread-local `TaskContext` on
189+
/// the Tokio worker that drives the UDF. `None` when no driving Spark task is available
190+
/// (unit tests, direct native driver runs).
188191
task_context: Option<Arc<Global<JObject<'static>>>>,
189192
}
190193

@@ -205,20 +208,27 @@ impl PhysicalPlanner {
205208
}
206209
}
207210

208-
pub fn with_exec_id(mut self, exec_context_id: i64) -> Self {
209-
self.exec_context_id = exec_context_id;
210-
self
211+
pub fn with_exec_id(self, exec_context_id: i64) -> Self {
212+
Self {
213+
exec_context_id,
214+
partition: self.partition,
215+
session_ctx: Arc::clone(&self.session_ctx),
216+
query_context_registry: Arc::clone(&self.query_context_registry),
217+
task_context: self.task_context,
218+
}
211219
}
212220

213-
/// Attach the Spark `TaskContext` global reference captured at `createPlan` time. Cloned
214-
/// into every `JvmScalarUdfExpr` the planner builds so the JNI bridge can install it as
215-
/// the thread-local on the Tokio worker driving the UDF.
216-
pub fn with_task_context(
217-
mut self,
218-
task_context: Option<Arc<Global<JObject<'static>>>>,
219-
) -> Self {
220-
self.task_context = task_context;
221-
self
221+
/// Attach a propagated Spark `TaskContext` global reference. Called by the JNI `executePlan`
222+
/// entry with whatever was captured at `createPlan` time. The planner clones this `Option`
223+
/// into every `JvmScalarUdfExpr` it builds.
224+
pub fn with_task_context(self, task_context: Option<Arc<Global<JObject<'static>>>>) -> Self {
225+
Self {
226+
exec_context_id: self.exec_context_id,
227+
partition: self.partition,
228+
session_ctx: self.session_ctx,
229+
query_context_registry: self.query_context_registry,
230+
task_context,
231+
}
222232
}
223233

224234
/// Return session context of this planner.

native/spark-expr/src/jvm_udf/mod.rs

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,16 @@ pub struct JvmScalarUdfExpr {
4141
args: Vec<Arc<dyn PhysicalExpr>>,
4242
return_type: DataType,
4343
return_nullable: bool,
44-
/// Captured at `createPlan` time and threaded here by the planner. Passed through the
45-
/// JNI bridge so `CometUdfBridge.evaluate` can install it as the Tokio worker's
46-
/// thread-local `TaskContext`. Without this, partition-sensitive built-ins inside a UDF
47-
/// tree (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, user code reading
48-
/// `TaskContext.get()`) see `null` and seed / branch incorrectly. `None` when no driving
49-
/// Spark task is available; the bridge then leaves whatever `TaskContext.get()` already
50-
/// returns in place.
44+
/// Spark `TaskContext` captured on the driving Spark task thread, stashed in the
45+
/// [`ExecutionContext`] at `createPlan` time, and threaded here by the planner. Passed
46+
/// through the JNI bridge so [`CometUdfBridge.evaluate`] can install it as the
47+
/// thread-local `TaskContext` on the Tokio worker that drives the UDF call. Without this,
48+
/// partition-sensitive built-ins inside a user UDF tree (`Rand`, `Uuid`,
49+
/// `MonotonicallyIncreasingID`, custom UDF code that reads
50+
/// `TaskContext.get().partitionId()`) see a null `TaskContext` and seed / branch
51+
/// incorrectly. `None` means the surrounding driver had no `TaskContext` to propagate
52+
/// (unit tests, direct native driver runs); the bridge then leaves whatever
53+
/// `TaskContext.get()` already returns in place.
5154
task_context: Option<Arc<Global<JObject<'static>>>>,
5255
}
5356

@@ -120,10 +123,10 @@ impl PhysicalExpr for JvmScalarUdfExpr {
120123
}
121124

122125
fn evaluate(&self, batch: &RecordBatch) -> DFResult<ColumnarValue> {
123-
// Step 1: evaluate child expressions to get Arrow arrays. Scalar children
124-
// (e.g. literal patterns) are sent as length-1 vectors rather than expanded
125-
// to batch-row count, so the JVM bridge does not pay an O(rows) copy for
126-
// values that never vary across the batch.
126+
// Scalar children (e.g. literal patterns) are sent as length-1 vectors rather than
127+
// expanded to batch-row count, so the JVM bridge does not pay an O(rows) copy for
128+
// values that never vary across the batch. The JVM side gets `numRows` directly via
129+
// the bridge so it doesn't need the scalar to carry batch length.
127130
let arrays: Vec<ArrayRef> = self
128131
.args
129132
.iter()
@@ -133,7 +136,6 @@ impl PhysicalExpr for JvmScalarUdfExpr {
133136
})
134137
.collect::<DFResult<_>>()?;
135138

136-
// Step 2: allocate FFI structs on the Rust heap and collect their raw pointers.
137139
// The JVM writes into the out_array/out_schema slots and reads from the in_ slots.
138140
let in_ffi_arrays: Vec<Box<FFI_ArrowArray>> = arrays
139141
.iter()
@@ -157,7 +159,6 @@ impl PhysicalExpr for JvmScalarUdfExpr {
157159
.map(|b| b.as_ref() as *const FFI_ArrowSchema as i64)
158160
.collect();
159161

160-
// Allocate output FFI slots.
161162
let mut out_array = Box::new(FFI_ArrowArray::empty());
162163
let mut out_schema = Box::new(FFI_ArrowSchema::empty());
163164
let out_arr_ptr = out_array.as_mut() as *mut FFI_ArrowArray as i64;
@@ -166,22 +167,20 @@ impl PhysicalExpr for JvmScalarUdfExpr {
166167
let class_name = self.class_name.clone();
167168
let n_args = arrays.len();
168169

169-
// Step 3: attach a JNI env for this thread and call the static bridge method.
170170
JVMClasses::with_env(|env| {
171171
let bridge = JVMClasses::get().comet_udf_bridge.as_ref().ok_or_else(|| {
172172
CometError::from(ExecutionError::GeneralError(
173173
"JVM UDF bridge unavailable: org.apache.comet.udf.CometUdfBridge \
174-
class was not found on the JVM classpath."
174+
class was not found on the JVM classpath. Set \
175+
spark.comet.exec.regexp.engine=rust to disable this path."
175176
.to_string(),
176177
))
177178
})?;
178179

179-
// Build the JVM String for the class name.
180180
let jclass_name = env
181181
.new_string(&class_name)
182182
.map_err(|e| CometError::JNI { source: e })?;
183183

184-
// Build the long[] arrays for input pointers.
185184
let in_arr_java = env
186185
.new_long_array(n_args)
187186
.map_err(|e| CometError::JNI { source: e })?;
@@ -196,9 +195,10 @@ impl PhysicalExpr for JvmScalarUdfExpr {
196195
.set_region(env, 0, &in_sch_ptrs)
197196
.map_err(|e| CometError::JNI { source: e })?;
198197

199-
// Pass a null jobject when no TaskContext was propagated so the bridge's null-guard
200-
// leaves the worker thread's current TaskContext.get() in place. The borrow must
201-
// outlive `call_static_method_unchecked`.
198+
// Resolve the TaskContext reference once before building the arg array so the
199+
// borrow lives until `call_static_method_unchecked` returns. When no TaskContext
200+
// was propagated, pass a null object so the bridge's null-guard leaves the thread-
201+
// local alone.
202202
let null_task_context = JObject::null();
203203
let task_context_ref: &JObject = match &self.task_context {
204204
Some(gref) => gref.as_obj(),
@@ -229,15 +229,26 @@ impl PhysicalExpr for JvmScalarUdfExpr {
229229
Ok(())
230230
})?;
231231

232-
// Step 4: import the result from the FFI slots filled by the JVM.
233232
// SAFETY: `*out_array` moves the FFI_ArrowArray out of the Box (the heap
234233
// allocation is freed by the move), and `from_ffi` wraps it in an Arc that
235234
// keeps the JVM-installed release callback alive until the resulting
236235
// ArrayData drops. `out_schema` is borrowed; its release callback runs
237236
// exactly once when the Box drops at end of scope.
238237
let result_data = unsafe { from_ffi(*out_array, &out_schema) }
239238
.map_err(|e| CometError::Arrow { source: e })?;
240-
Ok(ColumnarValue::Array(make_array(result_data)))
239+
let result_array = make_array(result_data);
240+
241+
// The JVM may produce arrays with different field names (e.g. Arrow Java's
242+
// ListVector uses "$data$" for child fields) than what DataFusion expects
243+
// (e.g. "item"). Cast to the declared return_type to normalize schema.
244+
let result_array = if result_array.data_type() != &self.return_type {
245+
arrow::compute::cast(&result_array, &self.return_type)
246+
.map_err(|e| CometError::Arrow { source: e })?
247+
} else {
248+
result_array
249+
};
250+
251+
Ok(ColumnarValue::Array(result_array))
241252
}
242253

243254
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {

spark/src/main/scala/org/apache/comet/CometExecIterator.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,10 @@ class CometExecIterator(
128128
taskAttemptId,
129129
taskCPUs,
130130
keyUnwrapper,
131-
// Propagated to Tokio workers running JVM UDFs so they see this Spark task's
132-
// TaskContext. See CometUdfBridge.evaluate.
131+
// Capture the Spark task thread's TaskContext at `createPlan` time. Stashed native-side
132+
// in the ExecutionContext and passed through the JVM UDF bridge so that Tokio workers
133+
// running JVM UDFs see the real `TaskContext` via their thread-local. See
134+
// `CometUdfBridge.evaluate` and `CometTaskContextShim` for the receive side.
133135
TaskContext.get())
134136
}
135137

0 commit comments

Comments
 (0)