Skip to content

Commit f3387fe

Browse files
authored
feat(experimental): ScalaUDF and Java UDF support via Janino codegen (#4267)
1 parent ce1b9d4 commit f3387fe

33 files changed

Lines changed: 5848 additions & 82 deletions

.github/workflows/pr_build_linux.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ jobs:
302302
org.apache.comet.CometFuzzAggregateSuite
303303
org.apache.comet.CometFuzzIcebergSuite
304304
org.apache.comet.CometFuzzMathSuite
305+
org.apache.comet.CometCodegenFuzzSuite
305306
org.apache.comet.DataGeneratorSuite
306307
- name: "shuffle"
307308
value: |
@@ -380,6 +381,9 @@ jobs:
380381
org.apache.comet.expressions.conditional.CometIfSuite
381382
org.apache.comet.expressions.conditional.CometCoalesceSuite
382383
org.apache.comet.expressions.conditional.CometCaseWhenSuite
384+
org.apache.comet.CometCodegenSuite
385+
org.apache.comet.CometCodegenSourceSuite
386+
org.apache.comet.CometCodegenHOFSuite
383387
- name: "sql"
384388
value: |
385389
org.apache.spark.sql.CometToPrettyStringSuite

.github/workflows/pr_build_macos.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ jobs:
155155
org.apache.comet.CometFuzzAggregateSuite
156156
org.apache.comet.CometFuzzIcebergSuite
157157
org.apache.comet.CometFuzzMathSuite
158+
org.apache.comet.CometCodegenFuzzSuite
158159
org.apache.comet.DataGeneratorSuite
159160
- name: "shuffle"
160161
value: |
@@ -232,6 +233,9 @@ jobs:
232233
org.apache.comet.expressions.conditional.CometIfSuite
233234
org.apache.comet.expressions.conditional.CometCoalesceSuite
234235
org.apache.comet.expressions.conditional.CometCaseWhenSuite
236+
org.apache.comet.CometCodegenSuite
237+
org.apache.comet.CometCodegenSourceSuite
238+
org.apache.comet.CometCodegenHOFSuite
235239
- name: "sql"
236240
value: |
237241
org.apache.spark.sql.CometToPrettyStringSuite

docs/source/user-guide/latest/iceberg.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,24 @@ The following scenarios will fall back to Spark's native Iceberg reader:
146146
- Dynamic Partition Pruning under Adaptive Query Execution (non-AQE DPP is supported);
147147
see [#3510](https://github.com/apache/datafusion-comet/issues/3510)
148148

149+
### Iceberg UDFs
150+
151+
Iceberg ships several `ScalaUDF`s that surface in user queries and maintenance actions:
152+
153+
- `IcebergSpark.registerBucketUDF` and `registerTruncateUDF` register `bucket(N, col)` and
154+
`truncate(W, col)` for use in `SELECT` / `JOIN` / `WHERE` predicates that align with hidden
155+
partitioning.
156+
- `RewriteDataFiles` with `sort-strategy=zorder` builds a tree of per-type ordered-bytes UDFs
157+
(`INT_ORDERED_BYTES`, `LONG_ORDERED_BYTES`, ..., `INTERLEAVE_BYTES`) over the sort key columns
158+
during compaction.
159+
160+
By default these UDFs cause the enclosing operator to fall back to Spark, which forces a
161+
columnar-to-row roundtrip and demotes the surrounding shuffle from `CometExchange` to
162+
`CometColumnarExchange`. Enabling the experimental
163+
[Scala UDF and Java UDF Support](scala_java_udfs.md) feature
164+
(`spark.comet.exec.scalaUDF.codegen.enabled=true`) routes these UDFs through native execution so
165+
the project, exchange, and sort operators around them stay on the Comet path end-to-end.
166+
149167
### Task input metrics
150168

151169
The native Iceberg reader populates Spark's task-level `inputMetrics.bytesRead` (visible in the Spark UI Stages tab) using the `bytes_read` counter from iceberg-rust's `ScanMetrics`. This counter includes bytes read from both data files and delete files.

docs/source/user-guide/latest/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ to read more.
4343
Supported Data Types <datatypes>
4444
Supported Operators <operators>
4545
Supported Expressions <expressions>
46+
ScalaUDF and Java UDF Support <scala_java_udfs>
4647
Configuration Settings <configs>
4748
Compatibility Guide <compatibility/index>
4849
Understanding Comet Plans <understanding-comet-plans>
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
<!---
2+
Licensed to the Apache Software Foundation (ASF) under one
3+
or more contributor license agreements. See the NOTICE file
4+
distributed with this work for additional information
5+
regarding copyright ownership. The ASF licenses this file
6+
to you under the Apache License, Version 2.0 (the
7+
"License"); you may not use this file except in compliance
8+
with the License. You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing,
13+
software distributed under the License is distributed on an
14+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
KIND, either express or implied. See the License for the
16+
specific language governing permissions and limitations
17+
under the License.
18+
-->
19+
20+
# Scala UDF and Java UDF Support
21+
22+
Comet executes Spark's Scala and Java [scalar user-defined functions (UDFs)](https://spark.apache.org/docs/latest/sql-ref-functions-udf-scalar.html) on the native Comet path. The presence of a UDF does not force the enclosing operator off the native path; surrounding native operators stay native.
23+
24+
This page covers Spark's `ScalaUDF` (Scala `udf(...)`, `spark.udf.register(...)` over Scala or Java functional interfaces, and SQL `CREATE FUNCTION ... AS 'com.example.MyUDF'`). Other UDF kinds (Python / Pandas, Hive, aggregate) are out of scope and continue to fall back to Spark.
25+
26+
This feature is experimental and disabled by default.
27+
28+
## Configuration
29+
30+
| Key | Default | Description |
31+
| ------------------------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------ |
32+
| `spark.comet.exec.scalaUDF.codegen.enabled` | `false` | When `true`, eligible `ScalaUDF`s run on the Comet path. When `false`, the enclosing operator falls back to Spark. |
33+
34+
## Supported
35+
36+
- User functions registered via `udf(...)`, `spark.udf.register(...)` (Scala or Java functional interfaces), or SQL `CREATE FUNCTION ... AS 'com.example.MyUDF'`.
37+
- Scalar input/output types: `Boolean`, `Byte`, `Short`, `Int`, `Long`, `Float`, `Double`, `Decimal`, `String`, `Binary`, `Date`, `Timestamp`, `TimestampNTZ`.
38+
- Complex input/output types with arbitrary nesting: `ArrayType`, `StructType`, `MapType`.
39+
- Composition with other Catalyst expressions inside the argument tree (e.g. `myUdf(upper(s))` runs as one native unit).
40+
- Higher-order functions (`transform`, `filter`, `exists`, `aggregate`, `zip_with`, `map_filter`, `map_zip_with`, etc.) inside the argument tree.
41+
42+
## Not supported
43+
44+
- Aggregate UDFs (`ScalaAggregator`, `TypedImperativeAggregate`, the legacy `UserDefinedAggregateFunction`).
45+
- Table UDFs and generators.
46+
- Python `@udf` and Pandas `@pandas_udf`.
47+
- Hive `GenericUDF` and `SimpleUDF`.
48+
- `CalendarIntervalType`, `NullType`, and `UserDefinedType` arguments and return types. UDT-typed columns fall back to Spark; for native execution, store and read the underlying representation directly (e.g. write MLlib `Vector` outputs as `Struct<type: Byte, size: Int, indices: Array<Int>, values: Array<Double>>` rather than `VectorUDT`).
49+
- Trees whose total nested-field count (output plus all input columns the UDF tree references) exceeds `spark.sql.codegen.maxFields` (default 100). Comet refuses these at plan time and the operator falls back to Spark.
50+
51+
When a UDF is rejected, the reason surfaces through Comet's standard fallback diagnostics; the query still runs on Spark.
52+
53+
## Behavior
54+
55+
- Non-deterministic expressions referenced from the argument tree (`rand`, `uuid`, `monotonically_increasing_id`) produce per-partition sequences consistent with Spark.
56+
- `TaskContext.get()` inside the user function returns the driving Spark task's context.
57+
- The user function must be closure-serializable; the same function that works with Spark's executor execution works here.
58+
59+
## Known limitations
60+
61+
- Each query containing a ScalaUDF pays a one-time codegen cost on its first batch and reuses the compiled kernel for subsequent batches, matching Spark's whole-stage codegen behavior. Bytecode is deduped JVM-wide via the same `CodeGenerator` cache, so structurally identical queries across a session share the compiled class.

native/core/src/execution/planner.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,9 @@ impl PhysicalPlanner {
211211
self
212212
}
213213

214-
/// Attach the Spark `TaskContext` global reference captured at `createPlan` time. Cloned
215-
/// into every `JvmScalarUdfExpr` the planner builds so the JNI bridge can install it as
216-
/// the thread-local on the Tokio worker driving the UDF.
214+
/// Attach a propagated Spark `TaskContext` global reference. Called by the JNI `executePlan`
215+
/// entry with whatever was captured at `createPlan` time. The planner clones this `Option`
216+
/// into every `JvmScalarUdfExpr` it builds.
217217
pub fn with_task_context(
218218
mut self,
219219
task_context: Option<Arc<Global<JObject<'static>>>>,
@@ -745,6 +745,13 @@ impl PhysicalPlanner {
745745
to_arrow_datatype(udf.return_type.as_ref().ok_or_else(|| {
746746
GeneralError("JvmScalarUdf missing return_type".to_string())
747747
})?);
748+
// Invariant: task_context is propagated for every JvmScalarUdfExpr built during
749+
// normal execution. The TEST_EXEC_CONTEXT_ID path is the only context in which
750+
// task_context may legitimately be None (unit tests, direct native driver runs).
751+
debug_assert!(
752+
self.task_context.is_some() || self.exec_context_id == TEST_EXEC_CONTEXT_ID,
753+
"task_context must be set for non-test execution"
754+
);
748755
Ok(Arc::new(JvmScalarUdfExpr::new(
749756
udf.class_name.clone(),
750757
args,

native/spark-expr/src/jvm_udf/mod.rs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ impl JvmScalarUdfExpr {
5959
return_nullable: bool,
6060
task_context: Option<Arc<Global<JObject<'static>>>>,
6161
) -> Self {
62+
debug_assert!(
63+
!class_name.is_empty(),
64+
"JvmScalarUdfExpr requires a non-empty class name"
65+
);
6266
Self {
6367
class_name,
6468
args,
@@ -120,10 +124,10 @@ impl PhysicalExpr for JvmScalarUdfExpr {
120124
}
121125

122126
fn evaluate(&self, batch: &RecordBatch) -> DFResult<ColumnarValue> {
123-
// Step 1: evaluate child expressions to get Arrow arrays. Scalar children
124-
// (e.g. literal patterns) are sent as length-1 vectors rather than expanded
125-
// to batch-row count, so the JVM bridge does not pay an O(rows) copy for
126-
// values that never vary across the batch.
127+
// Scalar children (e.g. literal patterns) are sent as length-1 vectors rather than
128+
// expanded to batch-row count, so the JVM bridge does not pay an O(rows) copy for
129+
// values that never vary across the batch. The JVM side gets `numRows` directly via
130+
// the bridge so it doesn't need the scalar to carry batch length.
127131
let arrays: Vec<ArrayRef> = self
128132
.args
129133
.iter()
@@ -133,7 +137,6 @@ impl PhysicalExpr for JvmScalarUdfExpr {
133137
})
134138
.collect::<DFResult<_>>()?;
135139

136-
// Step 2: allocate FFI structs on the Rust heap and collect their raw pointers.
137140
// The JVM writes into the out_array/out_schema slots and reads from the in_ slots.
138141
let in_ffi_arrays: Vec<Box<FFI_ArrowArray>> = arrays
139142
.iter()
@@ -157,7 +160,13 @@ impl PhysicalExpr for JvmScalarUdfExpr {
157160
.map(|b| b.as_ref() as *const FFI_ArrowSchema as i64)
158161
.collect();
159162

160-
// Allocate output FFI slots.
163+
debug_assert!(!self.class_name.is_empty(), "class_name must not be empty");
164+
debug_assert_eq!(
165+
in_arr_ptrs.len(),
166+
in_sch_ptrs.len(),
167+
"input array and schema pointer counts must match"
168+
);
169+
161170
let mut out_array = Box::new(FFI_ArrowArray::empty());
162171
let mut out_schema = Box::new(FFI_ArrowSchema::empty());
163172
let out_arr_ptr = out_array.as_mut() as *mut FFI_ArrowArray as i64;
@@ -166,7 +175,6 @@ impl PhysicalExpr for JvmScalarUdfExpr {
166175
let class_name = self.class_name.clone();
167176
let n_args = arrays.len();
168177

169-
// Step 3: attach a JNI env for this thread and call the static bridge method.
170178
JVMClasses::with_env(|env| {
171179
let bridge = JVMClasses::get().comet_udf_bridge.as_ref().ok_or_else(|| {
172180
CometError::from(ExecutionError::GeneralError(
@@ -176,12 +184,10 @@ impl PhysicalExpr for JvmScalarUdfExpr {
176184
))
177185
})?;
178186

179-
// Build the JVM String for the class name.
180187
let jclass_name = env
181188
.new_string(&class_name)
182189
.map_err(|e| CometError::JNI { source: e })?;
183190

184-
// Build the long[] arrays for input pointers.
185191
let in_arr_java = env
186192
.new_long_array(n_args)
187193
.map_err(|e| CometError::JNI { source: e })?;
@@ -196,9 +202,10 @@ impl PhysicalExpr for JvmScalarUdfExpr {
196202
.set_region(env, 0, &in_sch_ptrs)
197203
.map_err(|e| CometError::JNI { source: e })?;
198204

199-
// Pass a null jobject when no TaskContext was propagated so the bridge's null-guard
200-
// leaves the worker thread's current TaskContext.get() in place. The borrow must
201-
// outlive `call_static_method_unchecked`.
205+
// Resolve the TaskContext reference once before building the arg array so the
206+
// borrow lives until `call_static_method_unchecked` returns. When no TaskContext
207+
// was propagated, pass a null object so the bridge's null-guard leaves the thread-
208+
// local alone.
202209
let null_task_context = JObject::null();
203210
let task_context_ref: &JObject = match &self.task_context {
204211
Some(gref) => gref.as_obj(),
@@ -229,7 +236,6 @@ impl PhysicalExpr for JvmScalarUdfExpr {
229236
Ok(())
230237
})?;
231238

232-
// Step 4: import the result from the FFI slots filled by the JVM.
233239
// SAFETY: `*out_array` moves the FFI_ArrowArray out of the Box (the heap
234240
// allocation is freed by the move), and `from_ffi` wraps it in an Arc that
235241
// keeps the JVM-installed release callback alive until the resulting
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.codegen;
21+
22+
import org.apache.arrow.vector.FieldVector;
23+
import org.apache.arrow.vector.ValueVector;
24+
25+
/**
26+
* Abstract base extended by the Janino-compiled batch kernel emitted by {@code
27+
* CometBatchKernelCodegen}. The generated subclass extends {@code CometInternalRow} (so Spark's
28+
* {@code BoundReference.genCode} can call {@code this.getUTF8String(ord)} directly) and carries
29+
* typed input fields baked at codegen time, one per input column. Expression evaluation plus Arrow
30+
* read/write fuse into one method per expression tree.
31+
*/
32+
public abstract class CometBatchKernel extends CometInternalRow {
33+
34+
protected final Object[] references;
35+
36+
protected CometBatchKernel(Object[] references) {
37+
this.references = references;
38+
}
39+
40+
/**
41+
* Run partition-dependent initialization. The generated subclass overrides this to execute
42+
* statements collected via {@code CodegenContext.addPartitionInitializationStatement}, e.g.
43+
* reseeding {@code Rand}'s {@code XORShiftRandom} from {@code seed + partitionIndex}.
44+
* Deterministic expressions leave this as a no-op.
45+
*
46+
* <p>The caller invokes this before the first {@code process} call of each partition. The
47+
* generated subclass is not thread-safe across concurrent {@code process} calls. The dispatcher
48+
* allocates one per partition and serializes calls.
49+
*/
50+
public void init(int partitionIndex) {}
51+
52+
/**
53+
* Process one batch.
54+
*
55+
* @param inputs Arrow input vectors. Length and concrete classes match the schema the kernel was
56+
* compiled against.
57+
* @param output Arrow output vector. Caller allocates to the expression's {@code dataType}.
58+
* @param numRows number of rows in this batch
59+
*/
60+
public abstract void process(ValueVector[] inputs, FieldVector output, int numRows);
61+
}

spark/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,17 @@ object CometConf extends ShimCometConf {
362362
.booleanConf
363363
.createWithDefault(false)
364364

365+
val COMET_SCALA_UDF_CODEGEN_ENABLED: ConfigEntry[Boolean] =
366+
conf("spark.comet.exec.scalaUDF.codegen.enabled")
367+
.category(CATEGORY_EXEC)
368+
.doc("Experimental. Whether to route Spark `ScalaUDF` expressions through Comet's " +
369+
"Arrow-direct codegen dispatcher. When enabled, a supported ScalaUDF is compiled into " +
370+
"a per-batch kernel that reads and writes Arrow vectors directly from native " +
371+
"execution. When disabled, plans containing a ScalaUDF fall back to Spark for the " +
372+
"enclosing operator.")
373+
.booleanConf
374+
.createWithDefault(false)
375+
365376
val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] =
366377
conf("spark.comet.native.shuffle.partitioning.hash.enabled")
367378
.category(CATEGORY_SHUFFLE)

0 commit comments

Comments
 (0)