Skip to content

Commit 2a158f4

Browse files
committed
strengthen tests for composed expressions
1 parent a82e160 commit 2a158f4

1 file changed

Lines changed: 43 additions & 8 deletions

File tree

spark/src/test/scala/org/apache/comet/CometCodegenDispatchSmokeSuite.scala

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,35 @@ class CometCodegenDispatchSmokeSuite extends CometTestBase with AdaptiveSparkPla
116116
s"expected codegen dispatcher activity, got $after")
117117
}
118118

119+
/**
120+
* Stronger form of [[assertCodegenDidWork]] for composition tests. Asserts that the full
121+
* expression subtree compiled into at most one kernel. The "one JNI crossing per nesting level"
122+
* alternative (the PR description's foil) would produce one `(bytes, specs)` cache entry per
123+
* nested sub-expression, so `compileCount` would be N and the cache would grow by N after the
124+
* first batch. Asserting `compileCount <= 1` and `cacheSize` growth `<= 1` directly falsifies
125+
* that shape.
126+
*
127+
* Uses `<=` rather than `==` because the compile cache is JVM-wide and shared across tests; a
128+
* prior test that already compiled the same `(expression bytes, input schema)` pair will make
129+
* this run a cache hit (`compileCount == 0`). The dispatcher-activity check guards against a
130+
* silent fallback where the query runs through Spark and the first two assertions pass
131+
* vacuously.
132+
*/
133+
private def assertOneKernelForSubtree(f: => Unit): Unit = {
134+
CometCodegenDispatchUDF.resetStats()
135+
val sizeBefore = CometCodegenDispatchUDF.stats().cacheSize
136+
f
137+
val after = CometCodegenDispatchUDF.stats()
138+
assert(
139+
after.compileCount <= 1,
140+
s"expected <= 1 compile for the composed subtree, got $after")
141+
val grew = after.cacheSize - sizeBefore
142+
assert(grew <= 1, s"expected cache to grow by <= 1 entry, grew by $grew; stats=$after")
143+
assert(
144+
after.compileCount + after.cacheHitCount >= 1,
145+
s"expected codegen dispatcher activity, got $after")
146+
}
147+
119148
/**
120149
* Assert that the dispatcher's compile cache contains a kernel compiled for the given input
121150
* Arrow vector classes (in ordinal order) and output Spark `DataType`. This is a specialization
@@ -473,14 +502,18 @@ class CometCodegenDispatchSmokeSuite extends CometTestBase with AdaptiveSparkPla
473502
}
474503

475504
test("codegen: three-deep ScalaUDF composition lvl3(lvl2(lvl1(s)))") {
476-
// Three user UDFs stacked in one tree: String -> String -> String -> Int. Single Janino
477-
// compile, three `ctx.addReferenceObj` calls in the fused method. Verifies the dispatcher
478-
// doesn't flatten or reorder the chain.
505+
// Three user UDFs stacked in one tree: String -> String -> String -> Int. The fused kernel
506+
// carries three `ctx.addReferenceObj` calls. `assertOneKernelForSubtree` asserts that the
507+
// whole chain collapses into a single compile rather than one per nesting level.
508+
// Input rows intentionally exclude nulls: per-batch nullability is a cache-key dimension
509+
// (`nullable()` reads `getNullCount != 0`), so a null-present batch compiles a second kernel
510+
// specialized for `nullable=true`. Null handling through composed UDFs is covered by the
511+
// other composition tests above.
479512
spark.udf.register("lvl1", (s: String) => if (s == null) null else s.toUpperCase)
480513
spark.udf.register("lvl2", (s: String) => if (s == null) null else s.reverse)
481514
spark.udf.register("lvl3", (s: String) => if (s == null) -1 else s.length)
482-
withSubjects("abc", null, "hello world", "x") {
483-
assertCodegenDidWork {
515+
withSubjects("abc", "hello world", "x") {
516+
assertOneKernelForSubtree {
484517
checkSparkAnswerAndOperator(sql("SELECT lvl3(lvl2(lvl1(s))) FROM t"))
485518
}
486519
assertKernelSignaturePresent(Seq(classOf[VarCharVector]), IntegerType)
@@ -490,14 +523,16 @@ class CometCodegenDispatchSmokeSuite extends CometTestBase with AdaptiveSparkPla
490523
test("codegen: multi-column ScalaUDF composition join(upperU(c1), lowerU(c2))") {
491524
// One multi-arg user UDF consuming two other user UDFs, each on a different input column.
492525
// The bound tree has two BoundReferences, and the kernel is specialized on two VarCharVector
493-
// columns. Proves multi-column composition of pure user UDFs works with zero Spark helpers.
526+
// columns. `assertOneKernelForSubtree` asserts that the two-branch composition fuses into a
527+
// single kernel rather than one per branch or one per UDF.
528+
// Input rows intentionally exclude nulls (see note on the three-deep test above).
494529
spark.udf.register("upperU", (s: String) => if (s == null) null else s.toUpperCase)
495530
spark.udf.register("lowerU", (s: String) => if (s == null) null else s.toLowerCase)
496531
spark.udf.register(
497532
"joinU",
498533
(a: String, b: String) => if (a == null || b == null) null else s"$a-$b")
499-
withTwoStringCols(("Abc", "XYZ"), ("Foo", null), (null, "Bar"), ("Hi", "Lo")) {
500-
assertCodegenDidWork {
534+
withTwoStringCols(("Abc", "XYZ"), ("Foo", "bar"), ("baz", "Bar"), ("Hi", "Lo")) {
535+
assertOneKernelForSubtree {
501536
checkSparkAnswerAndOperator(sql("SELECT joinU(upperU(c1), lowerU(c2)) FROM t"))
502537
}
503538
assertKernelSignaturePresent(

0 commit comments

Comments
 (0)