2020package org .apache .comet .udf .codegen
2121
2222import java .nio .ByteBuffer
23- import java .util .{ Collections , LinkedHashMap }
23+ import java .util .Collections
2424import java .util .concurrent .atomic .AtomicLong
2525
26+ import scala .collection .mutable
27+
2628import org .apache .arrow .vector ._
2729import org .apache .arrow .vector .complex .{ListVector , MapVector , StructVector }
2830import org .apache .arrow .vector .types .pojo .Field
@@ -37,32 +39,39 @@ import org.apache.comet.udf.CometUDF
3739
3840/**
3941 * Arrow-direct codegen dispatcher. For each (bound `Expression`, input Arrow schema) pair,
40- * compiles a specialized [[CometBatchKernel ]] on first encounter and caches it; subsequent
41- * batches with the same shape reuse the compile.
42+ * compiles a specialized [[CometBatchKernel ]] on first encounter and caches it.
4243 *
43- * Arg 0 is a `VarBinaryVector` scalar carrying the closure-serialized bound Expression bytes.
44+ * Arg 0 is a `VarBinaryVector` scalar carrying the closure-serialized bound ` Expression` bytes.
4445 * Args 1..N are the data columns the `BoundReference`s read, in ordinal order. The bytes
4546 * self-describe the expression so the path works in cluster mode without executor-side state.
4647 *
47- * Three caches at different scopes: the JVM-wide compile cache (`kernelCache` on the companion);
48- * the per-task UDF-instance cache in `CometUdfBridge.INSTANCES`; and per-partition kernel state
49- * on this instance (`activeKernel`, `activeKey`, `activePartition`) managed by [[ensureKernel ]].
50- * Each layer covers a distinct lifetime: JVM (compiled bytecode, immutable), task (UDF instance,
51- * isolated from worker reuse), partition (kernel mutable state for `Rand` /
52- * `MonotonicallyIncreasingID` / etc.).
48+ * Three lifetime scopes:
49+ * - JVM-wide bytecode dedup: `CodeGenerator.compile`'s source-keyed Guava cache. Stateless.
50+ * - Per-task: this instance, lifetime managed by `CometUdfBridge.INSTANCES` keyed on
51+ * `taskAttemptId` and dropped via `TaskCompletionListener`. Holds [[kernelCache ]], so the
52+ * deserialized `boundExpr` (which carries mutable state like `NamedLambdaVariable.value` for
53+ * HOFs) is not shared across concurrent tasks. Mirrors Spark's per-task closure-deserialize
54+ * model.
55+ * - Per-partition: [[activeKernel ]] for kernel mutable state (`Rand`'s `XORShiftRandom`,
56+ * `MonotonicallyIncreasingID`'s counter) that advances across batches in one partition and
57+ * resets across partitions.
5358 */
5459class CometScalaUDFCodegen extends CometUDF {
5560
5661 /**
57- * Per-partition kernel instance cache. The compile cache stores the compiled `GeneratedClass`;
58- * the kernel '''instance''' holds per-row mutable state (`Rand`'s `XORShiftRandom`,
59- * `MonotonicallyIncreasingID`'s counter, etc.) that must advance across batches in one
60- * partition and reset across partitions. Allocating per partition gets that right.
61- *
62- * Plain `var`s are safe: this dispatcher is per-task (`CometUdfBridge.INSTANCES` keys by
63- * `taskAttemptId`) and Spark drives one partition per task, so [[ensureKernel ]] never sees
64- * concurrent access. A different partition or expression triggers a fresh allocation.
62+ * Per-task `(serialized-bytes, specs) -> compiled kernel + bound expression`. Per-task scope is
63+ * load-bearing for HOF correctness: `ArrayTransform.eval` and other HOFs mutate
64+ * `NamedLambdaVariable.value`'s `AtomicReference` per element, and a JVM-wide cache would race
65+ * across concurrent tasks running the same query. Compile work itself stays deduped JVM-wide
66+ * via `CodeGenerator.compile`'s internal source cache, so identical Janino source shares
67+ * bytecode across tasks; only the `boundExpr` Java object is per-task.
6568 */
69+ private val kernelCache
70+ : mutable.Map [CometScalaUDFCodegen .CacheKey , CometScalaUDFCodegen .CacheEntry ] =
71+ mutable.HashMap .empty
72+
73+ // Plain `var`s: this instance is per-task, Spark drives one partition per task, so
74+ // [[ensureKernel]] is never concurrent.
6675 private var activeKernel : CometBatchKernel = _
6776 private var activeKey : CometScalaUDFCodegen .CacheKey = _
6877 private var activePartition : Int = - 1
@@ -96,7 +105,7 @@ class CometScalaUDFCodegen extends CometUDF {
96105 val specsSeq = specs.toIndexedSeq
97106
98107 val key = CometScalaUDFCodegen .CacheKey (ByteBuffer .wrap(bytes), specsSeq)
99- val entry = CometScalaUDFCodegen . lookupOrCompile(key, bytes, specsSeq)
108+ val entry = lookupOrCompile(key, bytes, specsSeq)
100109
101110 val partitionId = CometScalaUDFCodegen .currentPartitionIndex()
102111 val kernel = ensureKernel(entry.compiled, key, partitionId)
@@ -132,6 +141,48 @@ class CometScalaUDFCodegen extends CometUDF {
132141 activeKernel
133142 }
134143
144+ private def lookupOrCompile (
145+ key : CometScalaUDFCodegen .CacheKey ,
146+ bytes : Array [Byte ],
147+ specs : IndexedSeq [ArrowColumnSpec ]): CometScalaUDFCodegen .CacheEntry = {
148+ val existing = kernelCache.get(key)
149+ if (existing.isDefined) {
150+ CometScalaUDFCodegen .cacheHitCount.incrementAndGet()
151+ existing.get
152+ } else {
153+ val loader = Option (Thread .currentThread().getContextClassLoader)
154+ .getOrElse(classOf [Expression ].getClassLoader)
155+ val rawExpr = SparkEnv .get.closureSerializer
156+ .newInstance()
157+ .deserialize[Expression ](ByteBuffer .wrap(bytes), loader)
158+ val boundExpr = rewriteBoundReferences(rawExpr, specs)
159+ val compiled = CometBatchKernelCodegen .compile(boundExpr, specs)
160+ val outputField =
161+ Utils .toArrowField(" codegen_result" , boundExpr.dataType, nullable = true , " UTC" )
162+ val entry = CometScalaUDFCodegen .CacheEntry (compiled, boundExpr.dataType, outputField)
163+ kernelCache.put(key, entry)
164+ CometScalaUDFCodegen .compileCount.incrementAndGet()
165+ CometScalaUDFCodegen .recordCompiledSignature(specs, boundExpr.dataType)
166+ entry
167+ }
168+ }
169+
170+ /**
171+ * Walk the bound expression tree and rewrite any `BoundReference(ord, dt, nullable=true)` to
172+ * `nullable=false` when the corresponding input column in `specs` is non-nullable for this
173+ * batch. Only tightens; never relaxes.
174+ */
175+ private def rewriteBoundReferences (
176+ expr : Expression ,
177+ specs : IndexedSeq [ArrowColumnSpec ]): Expression = {
178+ expr.transform {
179+ case BoundReference (ord, dt, true )
180+ if ord >= 0 && ord < specs.length && ! specs(ord).nullable =>
181+ BoundReference (ord, dt, nullable = false )
182+ case other => other
183+ }
184+ }
185+
135186 /**
136187 * Did any row in this batch set the null bit? Carried per column on the cache key, so batches
137188 * with different nullability map to different kernels (no correctness risk). The
@@ -213,103 +264,46 @@ class CometScalaUDFCodegen extends CometUDF {
213264
214265object CometScalaUDFCodegen {
215266
216- private val CacheCapacity : Int = 128
217- private val kernelCache : java.util.Map [CacheKey , CacheEntry ] =
218- Collections .synchronizedMap(
219- new LinkedHashMap [CacheKey , CacheEntry ](CacheCapacity , 0.75f , true ) {
220- override def removeEldestEntry (
221- eldest : java.util.Map .Entry [CacheKey , CacheEntry ]): Boolean =
222- size() > CacheCapacity
223- })
224- // Observability counters. Incremented under the `kernelCache.synchronized` block in
225- // `lookupOrCompile` so counter increments and cache mutations cannot interleave. Read via
226- // [[stats]]; reset via [[resetStats]] for tests.
267+ // JVM-wide counters aggregated across all per-task instances. Compile work itself is
268+ // deduplicated JVM-wide via `CodeGenerator.compile`'s source cache; these numbers track this
269+ // dispatcher's per-task cache activity.
227270 private val compileCount = new AtomicLong (0 )
228271 private val cacheHitCount = new AtomicLong (0 )
229272
230- /** Returns a snapshot of cache counters and current size. Cheap; safe to call anytime. */
273+ // JVM-wide append-only set of distinct compiled-kernel signatures. Lets tests assert
274+ // specialization shape (which vector-class / dataType combinations the dispatcher emitted)
275+ // and that a composed subtree fuses into one kernel. Append-only because each per-task cache
276+ // is dropped on task completion, leaving no other place to observe the set across runs.
277+ private val compiledSignatures =
278+ Collections .synchronizedSet(
279+ new java.util.HashSet [(IndexedSeq [Class [_ <: ValueVector ]], DataType )]())
280+
281+ /** Snapshot of JVM-wide counters and the distinct-signature count. Cheap. */
231282 def stats (): DispatcherStats =
232- DispatcherStats (compileCount.get(), cacheHitCount.get(), kernelCache .size())
283+ DispatcherStats (compileCount.get(), cacheHitCount.get(), compiledSignatures .size())
233284
234- /** Reset counters to zero . Leaves the compile cache intact. Intended for tests . */
285+ /** Reset counters. Leaves the signature set intact. Tests only . */
235286 def resetStats (): Unit = {
236287 compileCount.set(0 )
237288 cacheHitCount.set(0 )
238289 }
239290
240291 /**
241- * Test-facing snapshot of compiled kernel signatures: `(input Arrow vector classes in ordinal
242- * order, output Spark DataType)` per cache entry. Lets tests assert specialization shape, not
243- * just result correctness. Drops `ArrowColumnSpec.nullable` so a single assertion matches both
244- * `nullable=true` and `nullable=false` variants of the same expression.
292+ * Distinct compiled-kernel signatures: `(input Arrow vector classes in ordinal order, output
293+ * Spark DataType)`. Drops `ArrowColumnSpec.nullable` so a single assertion matches both
294+ * nullability variants of the same expression.
245295 */
246296 def snapshotCompiledSignatures (): Set [(IndexedSeq [Class [_ <: ValueVector ]], DataType )] = {
247- kernelCache.synchronized {
248- import scala .jdk .CollectionConverters ._
249- kernelCache
250- .entrySet()
251- .asScala
252- .iterator
253- .map { e =>
254- (e.getKey.specs.map(_.vectorClass), e.getValue.outputType)
255- }
256- .toSet
297+ import scala .jdk .CollectionConverters ._
298+ compiledSignatures.synchronized {
299+ compiledSignatures.iterator().asScala.toSet
257300 }
258301 }
259302
260- private def lookupOrCompile (
261- key : CacheKey ,
262- bytes : Array [Byte ],
263- specs : IndexedSeq [ArrowColumnSpec ]): CacheEntry = {
264- kernelCache.synchronized {
265- val existing = kernelCache.get(key)
266- if (existing != null ) {
267- cacheHitCount.incrementAndGet()
268- existing
269- } else {
270- // Use a classloader that can see Spark classes. The Comet native runtime calls us on a
271- // Tokio worker thread where the context classloader may not be set to Spark's task
272- // loader, so fall back to the loader that loaded `Expression` itself if needed.
273- val loader = Option (Thread .currentThread().getContextClassLoader)
274- .getOrElse(classOf [Expression ].getClassLoader)
275- val rawExpr = SparkEnv .get.closureSerializer
276- .newInstance()
277- .deserialize[Expression ](ByteBuffer .wrap(bytes), loader)
278- // Tighten BoundReference.nullable based on the observed batch. The plan-time value is
279- // conservative (the column may be null somewhere in the query's execution), but for
280- // this specific batch we know. Rewriting lets Spark's `BoundReference.genCode` skip the
281- // `isNull` branch at source level rather than leaving it to JIT constant-folding.
282- // Correctness is preserved by the cache key: a later batch with nulls on this column has
283- // a different `specs`, so it hits a different kernel compiled with nullable=true.
284- val boundExpr = rewriteBoundReferences(rawExpr, specs)
285- val compiled = CometBatchKernelCodegen .compile(boundExpr, specs)
286- val outputField =
287- Utils .toArrowField(" codegen_result" , boundExpr.dataType, nullable = true , " UTC" )
288- val entry = CacheEntry (compiled, boundExpr.dataType, outputField)
289- kernelCache.put(key, entry)
290- compileCount.incrementAndGet()
291- entry
292- }
293- }
294- }
295-
296- /**
297- * Walk the bound expression tree and rewrite any `BoundReference(ord, dt, nullable=true)` to
298- * `nullable=false` when the corresponding input column in `specs` is non-nullable for this
299- * batch. Only tightens; never relaxes. Expressions outside the `BoundReference` leaves are
300- * unchanged.
301- */
302- private def rewriteBoundReferences (
303- expr : Expression ,
304- specs : IndexedSeq [ArrowColumnSpec ]): Expression = {
305- expr.transform {
306- case BoundReference (ord, dt, true )
307- if ord >= 0 && ord < specs.length && ! specs(ord).nullable =>
308- BoundReference (ord, dt, nullable = false )
309- // Fall through unchanged: non-BoundReference nodes and BoundReferences that are already
310- // non-nullable or point at a nullable column in this batch.
311- case other => other
312- }
303+ private [codegen] def recordCompiledSignature (
304+ specs : IndexedSeq [ArrowColumnSpec ],
305+ outputType : DataType ): Unit = {
306+ compiledSignatures.add((specs.map(_.vectorClass), outputType))
313307 }
314308
315309 /**
0 commit comments