Skip to content

Commit d967143

Browse files
committed
fix(udf): scope the dispatcher's compile cache per task to isolate boundExpr mutable state
1 parent 0f6f68c commit d967143

5 files changed

Lines changed: 280 additions & 156 deletions

File tree

common/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.arrow.vector._
2323
import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
2424
import org.apache.arrow.vector.types.pojo.Field
2525
import org.apache.spark.internal.Logging
26-
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, Unevaluable}
26+
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, HigherOrderFunction, LambdaFunction, Literal, NamedLambdaVariable, Unevaluable}
2727
import org.apache.spark.sql.catalyst.expressions.codegen._
2828
import org.apache.spark.sql.internal.SQLConf
2929
import org.apache.spark.sql.types._
@@ -139,39 +139,38 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
139139
}
140140
// Reject expressions that can't be safely compiled or cached:
141141
// - AggregateFunction / Generator: non-scalar bridge shape.
142-
// - CodegenFallback: opts out of `doGenCode`, which our compile path assumes works.
143-
// Passing one in would emit interpreted-eval glue that our kernel can't splice cleanly.
144-
// - Unevaluable: unresolved plan markers. Shouldn't reach a serde, but cheap to guard.
145-
// `isCodegenInertUnevaluable` lets the shim exclude version-specific leaves that are
146-
// `Unevaluable` but never touched by codegen (e.g. Spark 4.0's `ResolvedCollation`, which
147-
// lives in `Collate.collation` as a type marker; `Collate.genCode` delegates to its child).
142+
// - CodegenFallback (other than HOF / lambda nodes admitted below): opts out of
143+
// `doGenCode`. The kernel cannot splice the interpreted-eval glue cleanly.
144+
// - Unevaluable: unresolved plan markers. `isCodegenInertUnevaluable` lets the shim allow
145+
// version-specific leaves that are `Unevaluable` but never touched by codegen (e.g.
146+
// Spark 4.0's `ResolvedCollation` in `Collate.collation` as a type marker;
147+
// `Collate.genCode` delegates to its child).
148148
//
149-
// TODO(hof-lambdas): the `CodegenFallback` rule rejects `NamedLambdaVariable`, which flags
150-
// every higher-order function (`ArrayTransform`, `ArrayAggregate`, `ArrayExists`,
151-
// `ArrayFilter`, `ZipWith`, `MapFilter`, etc.) as unsupported. The variable is
152-
// `CodegenFallback` only in isolation; the surrounding HOF binds its `value` field inline
153-
// as part of its own
154-
// `doGenCode`, and the resulting Java compiles fine. Loosening this would unlock
155-
// element-iteration over `Array<Struct>` / `Array<Map>` which today have no fuzz path
156-
// (`array_max` doesn't apply to non-comparable elements, generators are blocked above). Plan:
157-
// allow `NamedLambdaVariable` / `LambdaFunction` in the rejection scan; verify the kernel
158-
// splices the HOF's emitted loop without ctx.references collisions on the lambda holder.
149+
// HOFs are `CodegenFallback` but admitted. `CodegenFallback.doGenCode` emits one
150+
// `((Expression) references[N]).eval(row)` call site; the kernel dispatches to the HOF's
151+
// interpreted `eval`, which mutates `NamedLambdaVariable.value` per element and reads the
152+
// input array through the kernel's typed Arrow getters. Correctness depends on per-task
153+
// `boundExpr` isolation in `CometScalaUDFCodegen.kernelCache`: concurrent partitions get
154+
// their own deserialized expression tree, so they cannot race on the lambda variable's
155+
// `AtomicReference`. See `CometCodegenHOFSuite`.
159156
//
160157
// Nondeterministic / stateful expressions are accepted: per-partition kernel allocation
161-
// (`CometScalaUDFCodegen.ensureKernel`) plus a single `init(partitionIndex)` call at
158+
// in `CometScalaUDFCodegen.ensureKernel` plus a single `init(partitionIndex)` call at
162159
// partition entry give `Rand` / `MonotonicallyIncreasingID` / etc. correct state across
163160
// batches and a clean reset across partitions.
164161
//
165162
// `ExecSubqueryExpression` (`ScalarSubquery`, `InSubqueryExec`) is accepted via a chain:
166163
// the surrounding Comet operator's inherited `SparkPlan.waitForSubqueries` populates the
167-
// subquery's mutable `result` field before evaluation; the closure serializer captures that
168-
// populated value into the arg-0 bytes; the dispatcher keys its compile cache on those
169-
// exact bytes, so distinct subquery results produce distinct cache entries with no
170-
// cross-query staleness. Refactors to the cache-key derivation, the transport, or any
171-
// Comet operator that bypasses `waitForSubqueries` would break this; preserve it.
164+
// subquery's mutable `result` field before evaluation; the closure serializer captures
165+
// that populated value into the arg-0 bytes; the dispatcher keys its compile cache on
166+
// those exact bytes, so distinct subquery results produce distinct cache entries with no
167+
// cross-query staleness. Comet operators that bypass `waitForSubqueries` would break this.
172168
boundExpr.find {
173169
case _: org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction => true
174170
case _: org.apache.spark.sql.catalyst.expressions.Generator => true
171+
case _: HigherOrderFunction => false
172+
case _: LambdaFunction => false
173+
case _: NamedLambdaVariable => false
175174
case _: CodegenFallback => true
176175
case u: Unevaluable if isCodegenInertUnevaluable(u) => false
177176
case _: Unevaluable => true

common/src/main/scala/org/apache/comet/udf/codegen/CometScalaUDFCodegen.scala

Lines changed: 94 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
package org.apache.comet.udf.codegen
2121

2222
import java.nio.ByteBuffer
23-
import java.util.{Collections, LinkedHashMap}
23+
import java.util.Collections
2424
import java.util.concurrent.atomic.AtomicLong
2525

26+
import scala.collection.mutable
27+
2628
import org.apache.arrow.vector._
2729
import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
2830
import org.apache.arrow.vector.types.pojo.Field
@@ -37,32 +39,39 @@ import org.apache.comet.udf.CometUDF
3739

3840
/**
3941
* Arrow-direct codegen dispatcher. For each (bound `Expression`, input Arrow schema) pair,
40-
* compiles a specialized [[CometBatchKernel]] on first encounter and caches it; subsequent
41-
* batches with the same shape reuse the compile.
42+
* compiles a specialized [[CometBatchKernel]] on first encounter and caches it.
4243
*
43-
* Arg 0 is a `VarBinaryVector` scalar carrying the closure-serialized bound Expression bytes.
44+
* Arg 0 is a `VarBinaryVector` scalar carrying the closure-serialized bound `Expression` bytes.
4445
* Args 1..N are the data columns the `BoundReference`s read, in ordinal order. The bytes
4546
* self-describe the expression so the path works in cluster mode without executor-side state.
4647
*
47-
* Three caches at different scopes: the JVM-wide compile cache (`kernelCache` on the companion);
48-
* the per-task UDF-instance cache in `CometUdfBridge.INSTANCES`; and per-partition kernel state
49-
* on this instance (`activeKernel`, `activeKey`, `activePartition`) managed by [[ensureKernel]].
50-
* Each layer covers a distinct lifetime: JVM (compiled bytecode, immutable), task (UDF instance,
51-
* isolated from worker reuse), partition (kernel mutable state for `Rand` /
52-
* `MonotonicallyIncreasingID` / etc.).
48+
* Three lifetime scopes:
49+
* - JVM-wide bytecode dedup: `CodeGenerator.compile`'s source-keyed Guava cache. Stateless.
50+
* - Per-task: this instance, lifetime managed by `CometUdfBridge.INSTANCES` keyed on
51+
* `taskAttemptId` and dropped via `TaskCompletionListener`. Holds [[kernelCache]], so the
52+
* deserialized `boundExpr` (which carries mutable state like `NamedLambdaVariable.value` for
53+
* HOFs) is not shared across concurrent tasks. Mirrors Spark's per-task closure-deserialize
54+
* model.
55+
* - Per-partition: [[activeKernel]] for kernel mutable state (`Rand`'s `XORShiftRandom`,
56+
* `MonotonicallyIncreasingID`'s counter) that advances across batches in one partition and
57+
* resets across partitions.
5358
*/
5459
class CometScalaUDFCodegen extends CometUDF {
5560

5661
/**
57-
* Per-partition kernel instance cache. The compile cache stores the compiled `GeneratedClass`;
58-
* the kernel '''instance''' holds per-row mutable state (`Rand`'s `XORShiftRandom`,
59-
* `MonotonicallyIncreasingID`'s counter, etc.) that must advance across batches in one
60-
* partition and reset across partitions. Allocating per partition gets that right.
61-
*
62-
* Plain `var`s are safe: this dispatcher is per-task (`CometUdfBridge.INSTANCES` keys by
63-
* `taskAttemptId`) and Spark drives one partition per task, so [[ensureKernel]] never sees
64-
* concurrent access. A different partition or expression triggers a fresh allocation.
62+
* Per-task `(serialized-bytes, specs) -> compiled kernel + bound expression`. Per-task scope is
63+
* load-bearing for HOF correctness: `ArrayTransform.eval` and other HOFs mutate
64+
* `NamedLambdaVariable.value`'s `AtomicReference` per element, and a JVM-wide cache would race
65+
* across concurrent tasks running the same query. Compile work itself stays deduped JVM-wide
66+
* via `CodeGenerator.compile`'s internal source cache, so identical Janino source shares
67+
* bytecode across tasks; only the `boundExpr` Java object is per-task.
6568
*/
69+
private val kernelCache
70+
: mutable.Map[CometScalaUDFCodegen.CacheKey, CometScalaUDFCodegen.CacheEntry] =
71+
mutable.HashMap.empty
72+
73+
// Plain `var`s: this instance is per-task, Spark drives one partition per task, so
74+
// [[ensureKernel]] is never concurrent.
6675
private var activeKernel: CometBatchKernel = _
6776
private var activeKey: CometScalaUDFCodegen.CacheKey = _
6877
private var activePartition: Int = -1
@@ -96,7 +105,7 @@ class CometScalaUDFCodegen extends CometUDF {
96105
val specsSeq = specs.toIndexedSeq
97106

98107
val key = CometScalaUDFCodegen.CacheKey(ByteBuffer.wrap(bytes), specsSeq)
99-
val entry = CometScalaUDFCodegen.lookupOrCompile(key, bytes, specsSeq)
108+
val entry = lookupOrCompile(key, bytes, specsSeq)
100109

101110
val partitionId = CometScalaUDFCodegen.currentPartitionIndex()
102111
val kernel = ensureKernel(entry.compiled, key, partitionId)
@@ -132,6 +141,48 @@ class CometScalaUDFCodegen extends CometUDF {
132141
activeKernel
133142
}
134143

144+
private def lookupOrCompile(
145+
key: CometScalaUDFCodegen.CacheKey,
146+
bytes: Array[Byte],
147+
specs: IndexedSeq[ArrowColumnSpec]): CometScalaUDFCodegen.CacheEntry = {
148+
val existing = kernelCache.get(key)
149+
if (existing.isDefined) {
150+
CometScalaUDFCodegen.cacheHitCount.incrementAndGet()
151+
existing.get
152+
} else {
153+
val loader = Option(Thread.currentThread().getContextClassLoader)
154+
.getOrElse(classOf[Expression].getClassLoader)
155+
val rawExpr = SparkEnv.get.closureSerializer
156+
.newInstance()
157+
.deserialize[Expression](ByteBuffer.wrap(bytes), loader)
158+
val boundExpr = rewriteBoundReferences(rawExpr, specs)
159+
val compiled = CometBatchKernelCodegen.compile(boundExpr, specs)
160+
val outputField =
161+
Utils.toArrowField("codegen_result", boundExpr.dataType, nullable = true, "UTC")
162+
val entry = CometScalaUDFCodegen.CacheEntry(compiled, boundExpr.dataType, outputField)
163+
kernelCache.put(key, entry)
164+
CometScalaUDFCodegen.compileCount.incrementAndGet()
165+
CometScalaUDFCodegen.recordCompiledSignature(specs, boundExpr.dataType)
166+
entry
167+
}
168+
}
169+
170+
/**
171+
* Walk the bound expression tree and rewrite any `BoundReference(ord, dt, nullable=true)` to
172+
* `nullable=false` when the corresponding input column in `specs` is non-nullable for this
173+
* batch. Only tightens; never relaxes.
174+
*/
175+
private def rewriteBoundReferences(
176+
expr: Expression,
177+
specs: IndexedSeq[ArrowColumnSpec]): Expression = {
178+
expr.transform {
179+
case BoundReference(ord, dt, true)
180+
if ord >= 0 && ord < specs.length && !specs(ord).nullable =>
181+
BoundReference(ord, dt, nullable = false)
182+
case other => other
183+
}
184+
}
185+
135186
/**
136187
* Did any row in this batch set the null bit? Carried per column on the cache key, so batches
137188
* with different nullability map to different kernels (no correctness risk). The
@@ -213,103 +264,46 @@ class CometScalaUDFCodegen extends CometUDF {
213264

214265
object CometScalaUDFCodegen {
215266

216-
private val CacheCapacity: Int = 128
217-
private val kernelCache: java.util.Map[CacheKey, CacheEntry] =
218-
Collections.synchronizedMap(
219-
new LinkedHashMap[CacheKey, CacheEntry](CacheCapacity, 0.75f, true) {
220-
override def removeEldestEntry(
221-
eldest: java.util.Map.Entry[CacheKey, CacheEntry]): Boolean =
222-
size() > CacheCapacity
223-
})
224-
// Observability counters. Incremented under the `kernelCache.synchronized` block in
225-
// `lookupOrCompile` so counter increments and cache mutations cannot interleave. Read via
226-
// [[stats]]; reset via [[resetStats]] for tests.
267+
// JVM-wide counters aggregated across all per-task instances. Compile work itself is
268+
// deduplicated JVM-wide via `CodeGenerator.compile`'s source cache; these numbers track this
269+
// dispatcher's per-task cache activity.
227270
private val compileCount = new AtomicLong(0)
228271
private val cacheHitCount = new AtomicLong(0)
229272

230-
/** Returns a snapshot of cache counters and current size. Cheap; safe to call anytime. */
273+
// JVM-wide append-only set of distinct compiled-kernel signatures. Lets tests assert
274+
// specialization shape (which vector-class / dataType combinations the dispatcher emitted)
275+
// and that a composed subtree fuses into one kernel. Append-only because each per-task cache
276+
// is dropped on task completion, leaving no other place to observe the set across runs.
277+
private val compiledSignatures =
278+
Collections.synchronizedSet(
279+
new java.util.HashSet[(IndexedSeq[Class[_ <: ValueVector]], DataType)]())
280+
281+
/** Snapshot of JVM-wide counters and the distinct-signature count. Cheap. */
231282
def stats(): DispatcherStats =
232-
DispatcherStats(compileCount.get(), cacheHitCount.get(), kernelCache.size())
283+
DispatcherStats(compileCount.get(), cacheHitCount.get(), compiledSignatures.size())
233284

234-
/** Reset counters to zero. Leaves the compile cache intact. Intended for tests. */
285+
/** Reset counters. Leaves the signature set intact. Tests only. */
235286
def resetStats(): Unit = {
236287
compileCount.set(0)
237288
cacheHitCount.set(0)
238289
}
239290

240291
/**
241-
* Test-facing snapshot of compiled kernel signatures: `(input Arrow vector classes in ordinal
242-
* order, output Spark DataType)` per cache entry. Lets tests assert specialization shape, not
243-
* just result correctness. Drops `ArrowColumnSpec.nullable` so a single assertion matches both
244-
* `nullable=true` and `nullable=false` variants of the same expression.
292+
* Distinct compiled-kernel signatures: `(input Arrow vector classes in ordinal order, output
293+
* Spark DataType)`. Drops `ArrowColumnSpec.nullable` so a single assertion matches both
294+
* nullability variants of the same expression.
245295
*/
246296
def snapshotCompiledSignatures(): Set[(IndexedSeq[Class[_ <: ValueVector]], DataType)] = {
247-
kernelCache.synchronized {
248-
import scala.jdk.CollectionConverters._
249-
kernelCache
250-
.entrySet()
251-
.asScala
252-
.iterator
253-
.map { e =>
254-
(e.getKey.specs.map(_.vectorClass), e.getValue.outputType)
255-
}
256-
.toSet
297+
import scala.jdk.CollectionConverters._
298+
compiledSignatures.synchronized {
299+
compiledSignatures.iterator().asScala.toSet
257300
}
258301
}
259302

260-
private def lookupOrCompile(
261-
key: CacheKey,
262-
bytes: Array[Byte],
263-
specs: IndexedSeq[ArrowColumnSpec]): CacheEntry = {
264-
kernelCache.synchronized {
265-
val existing = kernelCache.get(key)
266-
if (existing != null) {
267-
cacheHitCount.incrementAndGet()
268-
existing
269-
} else {
270-
// Use a classloader that can see Spark classes. The Comet native runtime calls us on a
271-
// Tokio worker thread where the context classloader may not be set to Spark's task
272-
// loader, so fall back to the loader that loaded `Expression` itself if needed.
273-
val loader = Option(Thread.currentThread().getContextClassLoader)
274-
.getOrElse(classOf[Expression].getClassLoader)
275-
val rawExpr = SparkEnv.get.closureSerializer
276-
.newInstance()
277-
.deserialize[Expression](ByteBuffer.wrap(bytes), loader)
278-
// Tighten BoundReference.nullable based on the observed batch. The plan-time value is
279-
// conservative (the column may be null somewhere in the query's execution), but for
280-
// this specific batch we know. Rewriting lets Spark's `BoundReference.genCode` skip the
281-
// `isNull` branch at source level rather than leaving it to JIT constant-folding.
282-
// Correctness is preserved by the cache key: a later batch with nulls on this column has
283-
// a different `specs`, so it hits a different kernel compiled with nullable=true.
284-
val boundExpr = rewriteBoundReferences(rawExpr, specs)
285-
val compiled = CometBatchKernelCodegen.compile(boundExpr, specs)
286-
val outputField =
287-
Utils.toArrowField("codegen_result", boundExpr.dataType, nullable = true, "UTC")
288-
val entry = CacheEntry(compiled, boundExpr.dataType, outputField)
289-
kernelCache.put(key, entry)
290-
compileCount.incrementAndGet()
291-
entry
292-
}
293-
}
294-
}
295-
296-
/**
297-
* Walk the bound expression tree and rewrite any `BoundReference(ord, dt, nullable=true)` to
298-
* `nullable=false` when the corresponding input column in `specs` is non-nullable for this
299-
* batch. Only tightens; never relaxes. Expressions outside the `BoundReference` leaves are
300-
* unchanged.
301-
*/
302-
private def rewriteBoundReferences(
303-
expr: Expression,
304-
specs: IndexedSeq[ArrowColumnSpec]): Expression = {
305-
expr.transform {
306-
case BoundReference(ord, dt, true)
307-
if ord >= 0 && ord < specs.length && !specs(ord).nullable =>
308-
BoundReference(ord, dt, nullable = false)
309-
// Fall through unchanged: non-BoundReference nodes and BoundReferences that are already
310-
// non-nullable or point at a nullable column in this batch.
311-
case other => other
312-
}
303+
private[codegen] def recordCompiledSignature(
304+
specs: IndexedSeq[ArrowColumnSpec],
305+
outputType: DataType): Unit = {
306+
compiledSignatures.add((specs.map(_.vectorClass), outputType))
313307
}
314308

315309
/**

0 commit comments

Comments
 (0)