Skip to content

Commit 8119b1e

Browse files
authored
feat: add JVM UDF framework for native execution (#4232)
1 parent f45229a commit 8119b1e

12 files changed

Lines changed: 567 additions & 5 deletions

File tree

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.udf;
21+
22+
import java.util.concurrent.ConcurrentHashMap;
23+
24+
import org.apache.arrow.c.ArrowArray;
25+
import org.apache.arrow.c.ArrowSchema;
26+
import org.apache.arrow.c.Data;
27+
import org.apache.arrow.memory.BufferAllocator;
28+
import org.apache.arrow.vector.FieldVector;
29+
import org.apache.arrow.vector.ValueVector;
30+
31+
/**
32+
* JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method
33+
* pattern used by CometScalarSubquery so the native side can dispatch via
34+
* call_static_method_unchecked.
35+
*/
36+
public class CometUdfBridge {
37+
38+
// Process-wide cache of UDF instances keyed by class name. CometUDF
39+
// implementations are required to be stateless (see CometUDF), so a
40+
// single shared instance per class is safe across native worker threads.
41+
private static final ConcurrentHashMap<String, CometUDF> INSTANCES = new ConcurrentHashMap<>();
42+
43+
/**
44+
* Called from native via JNI.
45+
*
46+
* @param udfClassName fully-qualified class name implementing CometUDF
47+
* @param inputArrayPtrs addresses of pre-allocated FFI_ArrowArray structs (one per input)
48+
* @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input)
49+
* @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result
50+
* @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result
51+
*/
52+
public static void evaluate(
53+
String udfClassName,
54+
long[] inputArrayPtrs,
55+
long[] inputSchemaPtrs,
56+
long outArrayPtr,
57+
long outSchemaPtr) {
58+
CometUDF udf =
59+
INSTANCES.computeIfAbsent(
60+
udfClassName,
61+
name -> {
62+
try {
63+
// Resolve via the executor's context classloader so user-supplied UDF jars
64+
// (added via spark.jars / --jars) are visible.
65+
ClassLoader cl = Thread.currentThread().getContextClassLoader();
66+
if (cl == null) {
67+
cl = CometUdfBridge.class.getClassLoader();
68+
}
69+
return (CometUDF)
70+
Class.forName(name, true, cl).getDeclaredConstructor().newInstance();
71+
} catch (ReflectiveOperationException e) {
72+
throw new RuntimeException("Failed to instantiate CometUDF: " + name, e);
73+
}
74+
});
75+
76+
BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator();
77+
78+
ValueVector[] inputs = new ValueVector[inputArrayPtrs.length];
79+
ValueVector result = null;
80+
try {
81+
for (int i = 0; i < inputArrayPtrs.length; i++) {
82+
ArrowArray inArr = ArrowArray.wrap(inputArrayPtrs[i]);
83+
ArrowSchema inSch = ArrowSchema.wrap(inputSchemaPtrs[i]);
84+
inputs[i] = Data.importVector(allocator, inArr, inSch, null);
85+
}
86+
87+
result = udf.evaluate(inputs);
88+
if (!(result instanceof FieldVector)) {
89+
throw new RuntimeException(
90+
"CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName());
91+
}
92+
// Result length must match the longest input. Scalar (length-1) inputs
93+
// are allowed to be shorter, but a vector input bounds the output.
94+
int expectedLen = 0;
95+
for (ValueVector v : inputs) {
96+
expectedLen = Math.max(expectedLen, v.getValueCount());
97+
}
98+
if (result.getValueCount() != expectedLen) {
99+
throw new RuntimeException(
100+
"CometUDF.evaluate() returned "
101+
+ result.getValueCount()
102+
+ " rows, expected "
103+
+ expectedLen);
104+
}
105+
ArrowArray outArr = ArrowArray.wrap(outArrayPtr);
106+
ArrowSchema outSch = ArrowSchema.wrap(outSchemaPtr);
107+
Data.exportVector(allocator, (FieldVector) result, null, outArr, outSch);
108+
} finally {
109+
for (ValueVector v : inputs) {
110+
if (v != null) {
111+
try {
112+
v.close();
113+
} catch (RuntimeException ignored) {
114+
// do not mask the original throwable
115+
}
116+
}
117+
}
118+
if (result != null) {
119+
try {
120+
result.close();
121+
} catch (RuntimeException ignored) {
122+
// do not mask the original throwable
123+
}
124+
}
125+
}
126+
}
127+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.udf
21+
22+
import java.util.UUID
23+
import java.util.concurrent.ConcurrentHashMap
24+
25+
import org.apache.spark.sql.catalyst.expressions.Expression
26+
27+
/**
28+
* Thread-safe registry bridging plan-time Spark expressions to execution-time UDF lookup. At plan
29+
* time the serde layer registers a lambda expression under a unique key; at execution time the
30+
* UDF retrieves it by that key (passed as a scalar argument).
31+
*/
32+
object CometLambdaRegistry {
33+
34+
private val registry = new ConcurrentHashMap[String, Expression]()
35+
36+
def register(expression: Expression): String = {
37+
val key = UUID.randomUUID().toString
38+
registry.put(key, expression)
39+
key
40+
}
41+
42+
def get(key: String): Expression = {
43+
val expr = registry.get(key)
44+
if (expr == null) {
45+
throw new IllegalStateException(
46+
s"Lambda expression not found in registry for key: $key. " +
47+
"This indicates a lifecycle issue between plan creation and execution.")
48+
}
49+
expr
50+
}
51+
52+
def remove(key: String): Unit = {
53+
registry.remove(key)
54+
}
55+
56+
// Visible for testing
57+
def size(): Int = registry.size()
58+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.udf
21+
22+
import org.apache.arrow.vector.ValueVector
23+
24+
/**
25+
* Scalar UDF invoked from native execution via JNI. Receives Arrow vectors as input and returns
26+
* an Arrow vector.
27+
*
28+
* - Vector arguments arrive at the row count of the current batch.
29+
* - Scalar (literal-folded) arguments arrive as length-1 vectors and must be read at index 0.
30+
* - The returned vector's length must match the longest input.
31+
*
32+
* Implementations must have a public no-arg constructor and must be stateless: a single instance
33+
* per class is cached and shared across native worker threads for the lifetime of the JVM.
34+
*/
35+
trait CometUDF {
36+
def evaluate(inputs: Array[ValueVector]): ValueVector
37+
}

native/Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

native/core/src/execution/planner.rs

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ use datafusion_comet_proto::{
122122
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
123123
};
124124
use datafusion_comet_spark_expr::{
125-
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct,
126-
DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract,
127-
NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance,
128-
WideDecimalBinaryExpr, WideDecimalOp,
125+
jvm_udf::JvmScalarUdfExpr, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation,
126+
Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields,
127+
GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal,
128+
ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp,
129129
};
130130
use itertools::Itertools;
131131
use jni::objects::{Global, JObject};
@@ -720,6 +720,23 @@ impl PhysicalPlanner {
720720
expr.names.clone(),
721721
)))
722722
}
723+
ExprStruct::JvmScalarUdf(udf) => {
724+
let args = udf
725+
.args
726+
.iter()
727+
.map(|e| self.create_expr(e, Arc::clone(&input_schema)))
728+
.collect::<Result<Vec<_>, _>>()?;
729+
let return_type =
730+
to_arrow_datatype(udf.return_type.as_ref().ok_or_else(|| {
731+
GeneralError("JvmScalarUdf missing return_type".to_string())
732+
})?);
733+
Ok(Arc::new(JvmScalarUdfExpr::new(
734+
udf.class_name.clone(),
735+
args,
736+
return_type,
737+
udf.return_nullable,
738+
)))
739+
}
723740
expr => Err(GeneralError(format!("Not implemented: {expr:?}"))),
724741
}
725742
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use jni::{
19+
errors::Result as JniResult,
20+
objects::{JClass, JStaticMethodID},
21+
signature::{Primitive, ReturnType},
22+
strings::JNIString,
23+
Env,
24+
};
25+
26+
/// JNI handle for the JVM `org.apache.comet.udf.CometUdfBridge` class.
27+
/// Mirrors the static-method pattern in `comet_exec.rs` (`CometScalarSubquery`).
28+
#[allow(dead_code)] // class field is held to keep JStaticMethodID alive
29+
pub struct CometUdfBridge<'a> {
30+
pub class: JClass<'a>,
31+
pub method_evaluate: JStaticMethodID,
32+
pub method_evaluate_ret: ReturnType,
33+
}
34+
35+
impl<'a> CometUdfBridge<'a> {
36+
pub const JVM_CLASS: &'static str = "org/apache/comet/udf/CometUdfBridge";
37+
38+
pub fn new(env: &mut Env<'a>) -> JniResult<CometUdfBridge<'a>> {
39+
let class = env.find_class(JNIString::new(Self::JVM_CLASS))?;
40+
Ok(CometUdfBridge {
41+
method_evaluate: env.get_static_method_id(
42+
JNIString::new(Self::JVM_CLASS),
43+
jni::jni_str!("evaluate"),
44+
jni::jni_sig!("(Ljava/lang/String;[J[JJJ)V"),
45+
)?,
46+
method_evaluate_ret: ReturnType::Primitive(Primitive::Void),
47+
class,
48+
})
49+
}
50+
}

native/jni-bridge/src/lib.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,13 @@ pub use comet_exec::*;
192192
mod batch_iterator;
193193
mod comet_metric_node;
194194
mod comet_task_memory_manager;
195+
mod comet_udf_bridge;
195196
mod shuffle_block_iterator;
196197

197198
use batch_iterator::CometBatchIterator;
198199
pub use comet_metric_node::*;
199200
pub use comet_task_memory_manager::*;
201+
use comet_udf_bridge::CometUdfBridge;
200202
use shuffle_block_iterator::CometShuffleBlockIterator;
201203

202204
/// The JVM classes that are used in the JNI calls.
@@ -228,6 +230,9 @@ pub struct JVMClasses<'a> {
228230
/// The CometTaskMemoryManager used for interacting with JVM side to
229231
/// acquire & release native memory.
230232
pub comet_task_memory_manager: CometTaskMemoryManager<'a>,
233+
/// The CometUdfBridge class used to dispatch JVM scalar UDFs.
234+
/// `None` if the class is not on the classpath.
235+
pub comet_udf_bridge: Option<CometUdfBridge<'a>>,
231236
}
232237

233238
unsafe impl Send for JVMClasses<'_> {}
@@ -298,6 +303,13 @@ impl JVMClasses<'_> {
298303
comet_batch_iterator: CometBatchIterator::new(env).unwrap(),
299304
comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(),
300305
comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(),
306+
comet_udf_bridge: {
307+
let bridge = CometUdfBridge::new(env).ok();
308+
if env.exception_check() {
309+
env.exception_clear();
310+
}
311+
bridge
312+
},
301313
}
302314
});
303315
}

native/proto/src/proto/expr.proto

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ message Expr {
9090
ToCsv to_csv = 67;
9191
HoursTransform hours_transform = 68;
9292
ArraysZip arrays_zip = 69;
93+
JvmScalarUdf jvm_scalar_udf = 70;
9394
}
9495

9596
// Optional QueryContext for error reporting (contains SQL text and position)
@@ -514,3 +515,18 @@ message ArraysZip {
514515
repeated Expr values = 1;
515516
repeated string names = 2;
516517
}
518+
519+
// Scalar UDF dispatched to the JVM via JNI. Native side exports input arrays
520+
// through Arrow C Data Interface, calls CometUdfBridge.evaluate, and imports
521+
// the result.
522+
message JvmScalarUdf {
523+
// Fully-qualified Java/Scala class name implementing
524+
// org.apache.comet.udf.CometUDF (must have a public no-arg constructor).
525+
string class_name = 1;
526+
// Argument expressions, evaluated by the native side before invocation.
527+
repeated Expr args = 2;
528+
// Expected return type. Used to import the result FFI_ArrowArray.
529+
DataType return_type = 3;
530+
// Whether the result column may contain nulls.
531+
bool return_nullable = 4;
532+
}

native/spark-expr/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ regex = { workspace = true }
3636
# preserve_order: needed for get_json_object to match Spark's JSON key ordering
3737
serde_json = { version = "1.0", features = ["preserve_order"] }
3838
datafusion-comet-common = { workspace = true }
39+
datafusion-comet-jni-bridge = { workspace = true }
40+
jni = "0.22.4"
3941
futures = { workspace = true }
4042
twox-hash = "2.1.2"
4143
rand = { workspace = true }

0 commit comments

Comments
 (0)