@@ -55,6 +55,24 @@ import org.apache.comet.udf.CometUDF
5555 * - Per-partition: [[activeKernel ]] for kernel mutable state (`Rand`'s `XORShiftRandom`,
5656 * `MonotonicallyIncreasingID`'s counter) that advances across batches in one partition and
5757 * resets across partitions.
58+ *
59+ * Concurrency: [[evaluate ]] takes `this.synchronized` for the cache lookup + kernel allocation +
60+ * `process` call. A single Spark task can have multiple concurrent JNI callers into this
61+ * dispatcher because DataFusion operators like `HashJoinExec` pipeline build/probe via
62+ * `OnceAsync` (`tokio::spawn`) regardless of `target_partitions=1`, so different Tokio worker
63+ * threads poll sub-streams within one task and each calls back into Java. The generated kernel
64+ * keeps per-batch state (`col0`, `rowIdx`) in instance fields, so concurrent `process` calls on a
65+ * shared kernel would race; the lock serializes them.
66+ *
67+ * Performance: Spark's `BufferedRowIterator` is single-threaded per task by construction, so
68+ * Spark has no intra-task UDF parallelism to begin with. The lock gives up the intra-task
69+ * pipelining DataFusion would otherwise allow, but probe-side work (the bulk of UDF eval) is
70+ * serial in either model. Per-task throughput matches Spark's; cross-task parallelism is
71+ * unchanged.
72+ *
73+ * TODO(udf-codegen-pool): if intra-task UDF parallelism shows up as a bottleneck (e.g. large
74+ * build sides with heavy UDFs), replace the single `activeKernel` with a per-key pool of
75+ * instances and externalize per-partition stateful expression counters into the dispatcher.
5876 */
5977class CometScalaUDFCodegen extends CometUDF {
6078
@@ -65,13 +83,15 @@ class CometScalaUDFCodegen extends CometUDF {
6583 * across concurrent tasks running the same query. Compile work itself stays deduped JVM-wide
6684 * via `CodeGenerator.compile`'s internal source cache, so identical Janino source shares
6785 * bytecode across tasks; only the `boundExpr` Java object is per-task.
86+ *
87+ * Guarded by `this.synchronized` in [[evaluate ]]; see the class-level Concurrency note.
6888 */
6989 private val kernelCache
7090 : mutable.Map [CometScalaUDFCodegen .CacheKey , CometScalaUDFCodegen .CacheEntry ] =
7191 mutable.HashMap .empty
7292
73- // Plain `var`s: this instance is per-task, Spark drives one partition per task, so
74- // [[ensureKernel]] is never concurrent .
93+ // Active kernel state. Guarded by `this.synchronized` in [[evaluate]]; see the class-level
94+ // Concurrency note .
7595 private var activeKernel : CometBatchKernel = _
7696 private var activeKey : CometScalaUDFCodegen .CacheKey = _
7797 private var activePartition : Int = - 1
@@ -105,26 +125,31 @@ class CometScalaUDFCodegen extends CometUDF {
105125 val specsSeq = specs.toIndexedSeq
106126
107127 val key = CometScalaUDFCodegen .CacheKey (ByteBuffer .wrap(bytes), specsSeq)
108- val entry = lookupOrCompile(key, bytes, specsSeq)
109-
110- val partitionId = CometScalaUDFCodegen .currentPartitionIndex()
111- val kernel = ensureKernel(entry.compiled, key, partitionId)
112-
113- val out = CometBatchKernelCodegen .allocateOutput(
114- entry.outputField,
115- n,
116- estimatedOutputBytes(entry.outputType, dataCols))
117- try {
118- kernel.process(dataCols, out, n)
119- out.setValueCount(n)
120- out
121- } catch {
122- case t : Throwable =>
123- try out.close()
124- catch {
125- case _ : Throwable => ()
126- }
127- throw t
128+
129+ // Cache lookup, kernel allocation, and `process` run under one lock: the generated kernel
130+ // keeps per-batch state (`col0`, `rowIdx`) in instance fields, so concurrent callers would
131+ // race. See the class-level Concurrency note.
132+ this .synchronized {
133+ val entry = lookupOrCompile(key, bytes, specsSeq)
134+ val partitionId = CometScalaUDFCodegen .currentPartitionIndex()
135+ val kernel = ensureKernel(entry.compiled, key, partitionId)
136+
137+ val out = CometBatchKernelCodegen .allocateOutput(
138+ entry.outputField,
139+ n,
140+ estimatedOutputBytes(entry.outputType, dataCols))
141+ try {
142+ kernel.process(dataCols, out, n)
143+ out.setValueCount(n)
144+ out
145+ } catch {
146+ case t : Throwable =>
147+ try out.close()
148+ catch {
149+ case _ : Throwable => ()
150+ }
151+ throw t
152+ }
128153 }
129154 }
130155
0 commit comments