@@ -41,13 +41,16 @@ pub struct JvmScalarUdfExpr {
4141 args : Vec < Arc < dyn PhysicalExpr > > ,
4242 return_type : DataType ,
4343 return_nullable : bool ,
44- /// Captured at `createPlan` time and threaded here by the planner. Passed through the
45- /// JNI bridge so `CometUdfBridge.evaluate` can install it as the Tokio worker's
46- /// thread-local `TaskContext`. Without this, partition-sensitive built-ins inside a UDF
47- /// tree (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, user code reading
48- /// `TaskContext.get()`) see `null` and seed / branch incorrectly. `None` when no driving
49- /// Spark task is available; the bridge then leaves whatever `TaskContext.get()` already
50- /// returns in place.
44+ /// Spark `TaskContext` captured on the driving Spark task thread, stashed in the
45+ /// [`ExecutionContext`] at `createPlan` time, and threaded here by the planner. Passed
46+ /// through the JNI bridge so [`CometUdfBridge.evaluate`] can install it as the
47+ /// thread-local `TaskContext` on the Tokio worker that drives the UDF call. Without this,
48+ /// partition-sensitive built-ins inside a user UDF tree (`Rand`, `Uuid`,
49+ /// `MonotonicallyIncreasingID`, custom UDF code that reads
50+ /// `TaskContext.get().partitionId()`) see a null `TaskContext` and seed / branch
51+ /// incorrectly. `None` means the surrounding driver had no `TaskContext` to propagate
52+ /// (unit tests, direct native driver runs); the bridge then leaves whatever
53+ /// `TaskContext.get()` already returns in place.
5154 task_context : Option < Arc < Global < JObject < ' static > > > > ,
5255}
5356
@@ -120,10 +123,10 @@ impl PhysicalExpr for JvmScalarUdfExpr {
120123 }
121124
122125 fn evaluate ( & self , batch : & RecordBatch ) -> DFResult < ColumnarValue > {
123- // Step 1: evaluate child expressions to get Arrow arrays. Scalar children
124- // (e.g. literal patterns) are sent as length-1 vectors rather than expanded
125- // to batch-row count, so the JVM bridge does not pay an O(rows) copy for
126- // values that never vary across the batch.
126+ // Scalar children (e.g. literal patterns) are sent as length-1 vectors rather than
127+ // expanded to batch-row count, so the JVM bridge does not pay an O(rows) copy for
128+ // values that never vary across the batch. The JVM side gets `numRows` directly via
129+ // the bridge so it doesn't need the scalar to carry batch length .
127130 let arrays: Vec < ArrayRef > = self
128131 . args
129132 . iter ( )
@@ -133,7 +136,6 @@ impl PhysicalExpr for JvmScalarUdfExpr {
133136 } )
134137 . collect :: < DFResult < _ > > ( ) ?;
135138
136- // Step 2: allocate FFI structs on the Rust heap and collect their raw pointers.
137139 // The JVM writes into the out_array/out_schema slots and reads from the in_ slots.
138140 let in_ffi_arrays: Vec < Box < FFI_ArrowArray > > = arrays
139141 . iter ( )
@@ -157,7 +159,6 @@ impl PhysicalExpr for JvmScalarUdfExpr {
157159 . map ( |b| b. as_ref ( ) as * const FFI_ArrowSchema as i64 )
158160 . collect ( ) ;
159161
160- // Allocate output FFI slots.
161162 let mut out_array = Box :: new ( FFI_ArrowArray :: empty ( ) ) ;
162163 let mut out_schema = Box :: new ( FFI_ArrowSchema :: empty ( ) ) ;
163164 let out_arr_ptr = out_array. as_mut ( ) as * mut FFI_ArrowArray as i64 ;
@@ -166,22 +167,20 @@ impl PhysicalExpr for JvmScalarUdfExpr {
166167 let class_name = self . class_name . clone ( ) ;
167168 let n_args = arrays. len ( ) ;
168169
169- // Step 3: attach a JNI env for this thread and call the static bridge method.
170170 JVMClasses :: with_env ( |env| {
171171 let bridge = JVMClasses :: get ( ) . comet_udf_bridge . as_ref ( ) . ok_or_else ( || {
172172 CometError :: from ( ExecutionError :: GeneralError (
173173 "JVM UDF bridge unavailable: org.apache.comet.udf.CometUdfBridge \
174- class was not found on the JVM classpath."
174+ class was not found on the JVM classpath. Set \
175+ spark.comet.exec.regexp.engine=rust to disable this path."
175176 . to_string ( ) ,
176177 ) )
177178 } ) ?;
178179
179- // Build the JVM String for the class name.
180180 let jclass_name = env
181181 . new_string ( & class_name)
182182 . map_err ( |e| CometError :: JNI { source : e } ) ?;
183183
184- // Build the long[] arrays for input pointers.
185184 let in_arr_java = env
186185 . new_long_array ( n_args)
187186 . map_err ( |e| CometError :: JNI { source : e } ) ?;
@@ -196,9 +195,10 @@ impl PhysicalExpr for JvmScalarUdfExpr {
196195 . set_region ( env, 0 , & in_sch_ptrs)
197196 . map_err ( |e| CometError :: JNI { source : e } ) ?;
198197
199- // Pass a null jobject when no TaskContext was propagated so the bridge's null-guard
200- // leaves the worker thread's current TaskContext.get() in place. The borrow must
201- // outlive `call_static_method_unchecked`.
198+ // Resolve the TaskContext reference once before building the arg array so the
199+ // borrow lives until `call_static_method_unchecked` returns. When no TaskContext
200+ // was propagated, pass a null object so the bridge's null-guard leaves the thread-
201+ // local alone.
202202 let null_task_context = JObject :: null ( ) ;
203203 let task_context_ref: & JObject = match & self . task_context {
204204 Some ( gref) => gref. as_obj ( ) ,
@@ -229,15 +229,26 @@ impl PhysicalExpr for JvmScalarUdfExpr {
229229 Ok ( ( ) )
230230 } ) ?;
231231
232- // Step 4: import the result from the FFI slots filled by the JVM.
233232 // SAFETY: `*out_array` moves the FFI_ArrowArray out of the Box (the heap
234233 // allocation is freed by the move), and `from_ffi` wraps it in an Arc that
235234 // keeps the JVM-installed release callback alive until the resulting
236235 // ArrayData drops. `out_schema` is borrowed; its release callback runs
237236 // exactly once when the Box drops at end of scope.
238237 let result_data = unsafe { from_ffi ( * out_array, & out_schema) }
239238 . map_err ( |e| CometError :: Arrow { source : e } ) ?;
240- Ok ( ColumnarValue :: Array ( make_array ( result_data) ) )
239+ let result_array = make_array ( result_data) ;
240+
241+ // The JVM may produce arrays with different field names (e.g. Arrow Java's
242+ // ListVector uses "$data$" for child fields) than what DataFusion expects
243+ // (e.g. "item"). Cast to the declared return_type to normalize schema.
244+ let result_array = if result_array. data_type ( ) != & self . return_type {
245+ arrow:: compute:: cast ( & result_array, & self . return_type )
246+ . map_err ( |e| CometError :: Arrow { source : e } ) ?
247+ } else {
248+ result_array
249+ } ;
250+
251+ Ok ( ColumnarValue :: Array ( result_array) )
241252 }
242253
243254 fn children ( & self ) -> Vec < & Arc < dyn PhysicalExpr > > {
0 commit comments