forked from apache/datafusion-comet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCometUdfBridge.java
More file actions
127 lines (118 loc) · 4.85 KB
/
CometUdfBridge.java
File metadata and controls
127 lines (118 loc) · 4.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.comet.udf;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.arrow.c.ArrowArray;
import org.apache.arrow.c.ArrowSchema;
import org.apache.arrow.c.Data;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.ValueVector;
/**
* JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method
* pattern used by CometScalarSubquery so the native side can dispatch via
* call_static_method_unchecked.
*/
public class CometUdfBridge {
// Process-wide cache of UDF instances keyed by class name. CometUDF
// implementations are required to be stateless (see CometUDF), so a
// single shared instance per class is safe across native worker threads.
private static final ConcurrentHashMap<String, CometUDF> INSTANCES = new ConcurrentHashMap<>();
/**
* Called from native via JNI.
*
* @param udfClassName fully-qualified class name implementing CometUDF
* @param inputArrayPtrs addresses of pre-allocated FFI_ArrowArray structs (one per input)
* @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input)
* @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result
* @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result
*/
public static void evaluate(
String udfClassName,
long[] inputArrayPtrs,
long[] inputSchemaPtrs,
long outArrayPtr,
long outSchemaPtr) {
CometUDF udf =
INSTANCES.computeIfAbsent(
udfClassName,
name -> {
try {
// Resolve via the executor's context classloader so user-supplied UDF jars
// (added via spark.jars / --jars) are visible.
ClassLoader cl = Thread.currentThread().getContextClassLoader();
if (cl == null) {
cl = CometUdfBridge.class.getClassLoader();
}
return (CometUDF)
Class.forName(name, true, cl).getDeclaredConstructor().newInstance();
} catch (ReflectiveOperationException e) {
throw new RuntimeException("Failed to instantiate CometUDF: " + name, e);
}
});
BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator();
ValueVector[] inputs = new ValueVector[inputArrayPtrs.length];
ValueVector result = null;
try {
for (int i = 0; i < inputArrayPtrs.length; i++) {
ArrowArray inArr = ArrowArray.wrap(inputArrayPtrs[i]);
ArrowSchema inSch = ArrowSchema.wrap(inputSchemaPtrs[i]);
inputs[i] = Data.importVector(allocator, inArr, inSch, null);
}
result = udf.evaluate(inputs);
if (!(result instanceof FieldVector)) {
throw new RuntimeException(
"CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName());
}
// Result length must match the longest input. Scalar (length-1) inputs
// are allowed to be shorter, but a vector input bounds the output.
int expectedLen = 0;
for (ValueVector v : inputs) {
expectedLen = Math.max(expectedLen, v.getValueCount());
}
if (result.getValueCount() != expectedLen) {
throw new RuntimeException(
"CometUDF.evaluate() returned "
+ result.getValueCount()
+ " rows, expected "
+ expectedLen);
}
ArrowArray outArr = ArrowArray.wrap(outArrayPtr);
ArrowSchema outSch = ArrowSchema.wrap(outSchemaPtr);
Data.exportVector(allocator, (FieldVector) result, null, outArr, outSch);
} finally {
for (ValueVector v : inputs) {
if (v != null) {
try {
v.close();
} catch (RuntimeException ignored) {
// do not mask the original throwable
}
}
}
if (result != null) {
try {
result.close();
} catch (RuntimeException ignored) {
// do not mask the original throwable
}
}
}
}
}