Skip to content

Commit b161169

Browse files
committed
synchronize per-task UDF evaluation
1 parent 23df354 commit b161169

1 file changed

Lines changed: 47 additions & 22 deletions

File tree

common/src/main/scala/org/apache/comet/udf/codegen/CometScalaUDFCodegen.scala

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,24 @@ import org.apache.comet.udf.CometUDF
5555
* - Per-partition: [[activeKernel]] for kernel mutable state (`Rand`'s `XORShiftRandom`,
5656
* `MonotonicallyIncreasingID`'s counter) that advances across batches in one partition and
5757
* resets across partitions.
58+
*
59+
* Concurrency: [[evaluate]] takes `this.synchronized` for the cache lookup + kernel allocation +
60+
* `process` call. A single Spark task can have multiple concurrent JNI callers into this
61+
* dispatcher because DataFusion operators like `HashJoinExec` pipeline build/probe via
62+
* `OnceAsync` (`tokio::spawn`) regardless of `target_partitions=1`, so different Tokio worker
63+
* threads poll sub-streams within one task and each calls back into Java. The generated kernel
64+
* keeps per-batch state (`col0`, `rowIdx`) in instance fields, so concurrent `process` calls on a
65+
* shared kernel would race; the lock serializes them.
66+
*
67+
* Performance: Spark's `BufferedRowIterator` is single-threaded per task by construction, so
68+
* Spark has no intra-task UDF parallelism to begin with. The lock gives up the intra-task
69+
* pipelining DataFusion would otherwise allow, but probe-side work (the bulk of UDF eval) is
70+
* serial in either model. Per-task throughput matches Spark's; cross-task parallelism is
71+
* unchanged.
72+
*
73+
* TODO(udf-codegen-pool): if intra-task UDF parallelism shows up as a bottleneck (e.g. large
74+
* build sides with heavy UDFs), replace the single `activeKernel` with a per-key pool of
75+
* instances and externalize per-partition stateful expression counters into the dispatcher.
5876
*/
5977
class CometScalaUDFCodegen extends CometUDF {
6078

@@ -65,13 +83,15 @@ class CometScalaUDFCodegen extends CometUDF {
6583
* across concurrent tasks running the same query. Compile work itself stays deduped JVM-wide
6684
* via `CodeGenerator.compile`'s internal source cache, so identical Janino source shares
6785
* bytecode across tasks; only the `boundExpr` Java object is per-task.
86+
*
87+
* Guarded by `this.synchronized` in [[evaluate]]; see the class-level Concurrency note.
6888
*/
6989
private val kernelCache
7090
: mutable.Map[CometScalaUDFCodegen.CacheKey, CometScalaUDFCodegen.CacheEntry] =
7191
mutable.HashMap.empty
7292

73-
// Plain `var`s: this instance is per-task, Spark drives one partition per task, so
74-
// [[ensureKernel]] is never concurrent.
93+
// Active kernel state. Guarded by `this.synchronized` in [[evaluate]]; see the class-level
94+
// Concurrency note.
7595
private var activeKernel: CometBatchKernel = _
7696
private var activeKey: CometScalaUDFCodegen.CacheKey = _
7797
private var activePartition: Int = -1
@@ -105,26 +125,31 @@ class CometScalaUDFCodegen extends CometUDF {
105125
val specsSeq = specs.toIndexedSeq
106126

107127
val key = CometScalaUDFCodegen.CacheKey(ByteBuffer.wrap(bytes), specsSeq)
108-
val entry = lookupOrCompile(key, bytes, specsSeq)
109-
110-
val partitionId = CometScalaUDFCodegen.currentPartitionIndex()
111-
val kernel = ensureKernel(entry.compiled, key, partitionId)
112-
113-
val out = CometBatchKernelCodegen.allocateOutput(
114-
entry.outputField,
115-
n,
116-
estimatedOutputBytes(entry.outputType, dataCols))
117-
try {
118-
kernel.process(dataCols, out, n)
119-
out.setValueCount(n)
120-
out
121-
} catch {
122-
case t: Throwable =>
123-
try out.close()
124-
catch {
125-
case _: Throwable => ()
126-
}
127-
throw t
128+
129+
// Cache lookup, kernel allocation, and `process` run under one lock: the generated kernel
130+
// keeps per-batch state (`col0`, `rowIdx`) in instance fields, so concurrent callers would
131+
// race. See the class-level Concurrency note.
132+
this.synchronized {
133+
val entry = lookupOrCompile(key, bytes, specsSeq)
134+
val partitionId = CometScalaUDFCodegen.currentPartitionIndex()
135+
val kernel = ensureKernel(entry.compiled, key, partitionId)
136+
137+
val out = CometBatchKernelCodegen.allocateOutput(
138+
entry.outputField,
139+
n,
140+
estimatedOutputBytes(entry.outputType, dataCols))
141+
try {
142+
kernel.process(dataCols, out, n)
143+
out.setValueCount(n)
144+
out
145+
} catch {
146+
case t: Throwable =>
147+
try out.close()
148+
catch {
149+
case _: Throwable => ()
150+
}
151+
throw t
152+
}
128153
}
129154
}
130155

0 commit comments

Comments
 (0)